cubecl_opt/
debug.rs

1use std::{fmt::Display, rc::Rc};
2
3use cubecl_core::ir::{FloatKind, IntKind, UIntKind};
4use petgraph::visit::EdgeRef;
5
6use crate::{
7    analyses::{const_len::Slices, integer_range::Ranges, liveness::Liveness},
8    gvn::{BlockSets, Constant, Expression, GvnState, Instruction, Local, OpId, Value, ValueTable},
9    ControlFlow,
10};
11
12use super::Optimizer;
13
14const DEBUG_GVN: bool = false;
15
16/// Debug display for the program state.
17impl Display for Optimizer {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        let slices = self.analysis_cache.try_get::<Slices>().unwrap_or_default();
20        let ranges = self.analysis_cache.try_get::<Ranges>().unwrap_or_default();
21
22        f.write_str("Slices:\n")?;
23        for (var_id, slice) in slices.iter() {
24            let end_op = slice.end_op.as_ref().map(|it| format!("{it}"));
25            writeln!(
26                f,
27                "slice{var_id:?}: {{ start: {}, end: {}, end_op: {}, const_len: {:?} }}",
28                slice.start,
29                slice.end,
30                end_op.unwrap_or("None".to_string()),
31                slice.const_len
32            )?;
33        }
34        f.write_str("\n\n")?;
35
36        let global_nums = self
37            .analysis_cache
38            .try_get::<GvnState>()
39            .unwrap_or_default();
40        let liveness = self
41            .analysis_cache
42            .try_get::<Liveness>()
43            .unwrap_or_else(|| Rc::new(Liveness::empty(self)));
44
45        if DEBUG_GVN {
46            writeln!(f, "# Value Table:")?;
47            writeln!(f, "{}", global_nums.values)?;
48        }
49
50        for node in self.program.node_indices() {
51            let id = node.index();
52            let bb = &self.program[node];
53            writeln!(f, "bb{id} {{")?;
54            if DEBUG_GVN {
55                let block_sets = &global_nums
56                    .block_sets
57                    .get(&node)
58                    .cloned()
59                    .unwrap_or_default();
60                writeln!(f, "{block_sets}")?;
61            }
62
63            if !bb.block_use.is_empty() {
64                writeln!(f, "    Uses: {:?}", bb.block_use)?;
65            }
66            let live_vars = liveness.at_block(node).iter();
67            let live_vars = live_vars.map(|it| format!("local({})", it));
68            let live_vars = live_vars.collect::<Vec<_>>();
69            writeln!(f, "    Live variables: [{}]\n", live_vars.join(", "))?;
70
71            for phi in bb.phi_nodes.borrow().iter() {
72                write!(f, "    {} = phi ", phi.out)?;
73                for entry in &phi.entries {
74                    write!(f, "[bb{}: ", entry.block.index())?;
75                    write!(f, "{}]", entry.value)?;
76                }
77                f.write_str(";\n")?;
78            }
79            if !bb.phi_nodes.borrow().is_empty() {
80                writeln!(f)?;
81            }
82
83            for op in bb.ops.borrow_mut().values_mut() {
84                let range = op.out.map(|var| ranges.range_of(self, &var));
85                let range = range.map(|it| format!(" range: {it};")).unwrap_or_default();
86
87                writeln!(f, "    {op};{range}")?;
88            }
89            match &*bb.control_flow.borrow() {
90                ControlFlow::IfElse {
91                    cond,
92                    then,
93                    or_else,
94                    merge,
95                } => {
96                    writeln!(
97                        f,
98                        "    {cond} ? bb{} : bb{}; merge: {}",
99                        then.index(),
100                        or_else.index(),
101                        merge
102                            .as_ref()
103                            .map(|it| format!("bb{}", it.index()))
104                            .unwrap_or("None".to_string())
105                    )?;
106                }
107                super::ControlFlow::Switch {
108                    value,
109                    default,
110                    branches,
111                    ..
112                } => {
113                    write!(f, "    switch({value}) ")?;
114                    for (val, block) in branches {
115                        write!(f, "[{val}: bb{}] ", block.index())?;
116                    }
117                    writeln!(f, "[default: bb{}];", default.index())?;
118                }
119                super::ControlFlow::Loop {
120                    body,
121                    continue_target,
122                    merge,
123                } => {
124                    writeln!(
125                        f,
126                        "    loop(continue: bb{}, merge: bb{})",
127                        continue_target.index(),
128                        merge.index()
129                    )?;
130                    writeln!(f, "    branch bb{};", body.index())?
131                }
132                super::ControlFlow::LoopBreak {
133                    break_cond,
134                    body,
135                    continue_target,
136                    merge,
137                } => {
138                    writeln!(
139                        f,
140                        "    loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
141                        break_cond,
142                        body.index(),
143                        continue_target.index(),
144                        merge.index()
145                    )?;
146                }
147                super::ControlFlow::Return => writeln!(f, "    return;")?,
148                super::ControlFlow::None => {
149                    let edge = self.program.edges(node).next();
150                    let target = edge.map(|it| it.target().index()).unwrap_or(255);
151                    writeln!(f, "    branch bb{target};")?;
152                }
153            }
154            f.write_str("}\n\n")?;
155        }
156
157        Ok(())
158    }
159}
160
161impl Display for BlockSets {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        let mut exp_gen = self.exp_gen.iter().collect::<Vec<_>>();
164        exp_gen.sort_by_key(|it| it.0);
165        let exp_gen = exp_gen
166            .into_iter()
167            .map(|(val, expr)| format!("{val}: {expr}"))
168            .collect::<Vec<_>>();
169        let mut phi_gen = self.phi_gen.iter().collect::<Vec<_>>();
170        phi_gen.sort_by_key(|it| it.0);
171        let phi_gen = phi_gen
172            .into_iter()
173            .map(|(val, expr)| format!("{val}: {expr}"))
174            .collect::<Vec<_>>();
175        let tmp_gen = self
176            .tmp_gen
177            .iter()
178            .map(|it| format!("{it}"))
179            .collect::<Vec<_>>();
180        let mut leaders = self.leaders.iter().collect::<Vec<_>>();
181        leaders.sort_by_key(|it| it.0);
182        let leaders = leaders
183            .into_iter()
184            .map(|(val, expr)| format!("{val}: {expr}"))
185            .collect::<Vec<_>>();
186        let mut antic_out = self.antic_out.iter().collect::<Vec<_>>();
187        antic_out.sort_by_key(|it| it.0);
188        let antic_out = antic_out
189            .into_iter()
190            .map(|(val, expr)| format!("{val}: {expr}"))
191            .collect::<Vec<_>>();
192        let mut antic_in = self.antic_in.iter().collect::<Vec<_>>();
193        antic_in.sort_by_key(|it| it.0);
194        let antic_in = antic_in
195            .into_iter()
196            .map(|(val, expr)| format!("{val}: {expr}"))
197            .collect::<Vec<_>>();
198
199        writeln!(f, "    exp_gen: [{}]", exp_gen.join(", "))?;
200        writeln!(f, "    phi_gen: [{}]", phi_gen.join(", "))?;
201        writeln!(f, "    tmp_gen: [{}]", tmp_gen.join(", "))?;
202        writeln!(f, "    leaders: [{}]", leaders.join(", "))?;
203        writeln!(f, "    antic_in: [{}]", antic_in.join(", "))?;
204        writeln!(f, "    antic_out: [{}]", antic_out.join(", "))
205    }
206}
207
208impl Display for ValueTable {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        let mut values = self.value_numbers.iter().collect::<Vec<_>>();
211        values.sort_by_key(|it| it.1);
212        writeln!(f, "values: [")?;
213        for (val, num) in values {
214            writeln!(f, "    {num}: {val},")?;
215        }
216        writeln!(f, "]")?;
217        writeln!(f, "expressions: [")?;
218        let mut expressions = self.expression_numbers.iter().collect::<Vec<_>>();
219        expressions.sort_by_key(|it| it.1);
220        for (expr, val) in expressions {
221            writeln!(f, "    {val}: {expr},")?;
222        }
223        writeln!(f, "]")
224    }
225}
226
227impl Display for Value {
228    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        match self {
230            Value::Constant(constant) => write!(f, "{constant}"),
231            Value::Local(local) => write!(f, "{local}"),
232            Value::Input(id, _) => write!(f, "input({id})"),
233            Value::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"),
234            Value::ConstArray(id, _, _) => write!(f, "const_array({id})"),
235            Value::Builtin(builtin) => write!(f, "{builtin:?}"),
236            Value::Output(id, _) => write!(f, "output({id})"),
237            Value::Slice(id, _) => write!(f, "slice({id})"),
238        }
239    }
240}
241
242impl Display for Local {
243    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        match self.version {
245            0 => write!(f, "binding({})", self.id),
246            v => write!(f, "local({}).v{v}", self.id),
247        }
248    }
249}
250
251impl Display for Constant {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        match self {
254            Constant::Int(val, IntKind::I8) => write!(f, "{val}i8"),
255            Constant::Int(val, IntKind::I16) => write!(f, "{val}i16"),
256            Constant::Int(val, IntKind::I32) => write!(f, "{val}i32"),
257            Constant::Int(val, IntKind::I64) => write!(f, "{val}i64"),
258            Constant::Float(val, FloatKind::BF16) => write!(f, "{}bf16", val.0),
259            Constant::Float(val, FloatKind::F16) => write!(f, "{}f16", val.0),
260            Constant::Float(val, FloatKind::Flex32) => write!(f, "{}minf16", val.0),
261            Constant::Float(val, FloatKind::TF32) => write!(f, "{}tf32", val.0),
262            Constant::Float(val, FloatKind::F32) => write!(f, "{}f32", val.0),
263            Constant::Float(val, FloatKind::F64) => write!(f, "{}f64", val.0),
264            Constant::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
265            Constant::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
266            Constant::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
267            Constant::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
268            Constant::Bool(val) => write!(f, "{val}"),
269        }
270    }
271}
272
273impl Display for Expression {
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        match self {
276            Expression::Instruction(instruction) => write!(f, "{instruction}"),
277            Expression::Copy(val, _) => write!(f, "copy({val})"),
278            Expression::Value(value) => write!(f, "{value}"),
279            Expression::Volatile(value) => write!(f, "volatile({value})"),
280            Expression::Phi(entries) => write!(
281                f,
282                "phi({})",
283                entries
284                    .iter()
285                    .map(|(val, b)| format!("{val}: bb{}", b.index()))
286                    .collect::<Vec<_>>()
287                    .join(", ")
288            ),
289        }
290    }
291}
292
293impl Display for Instruction {
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        let args = &self.args;
296        match self.op {
297            OpId::Add => write!(f, "{} + {}", args[0], args[1]),
298            OpId::Fma => write!(f, "fma({}, {}, {})", args[0], args[1], args[2]),
299            OpId::Sub => write!(f, "{} - {}", args[0], args[1]),
300            OpId::Mul => write!(f, "{} * {}", args[0], args[1]),
301            OpId::Div => write!(f, "{} / {}", args[0], args[1]),
302            OpId::Abs => write!(f, "{}.abs()", args[0]),
303            OpId::Exp => write!(f, "{}.exp()", args[0]),
304            OpId::Log => write!(f, "{}.log()", args[0]),
305            OpId::Log1p => write!(f, "{}.log1p()", args[0]),
306            OpId::Cos => write!(f, "{}.cos()", args[0]),
307            OpId::Sin => write!(f, "{}.sin()", args[0]),
308            OpId::Tanh => write!(f, "{}.tanh()", args[0]),
309            OpId::Powf => write!(f, "{}.powf()", args[0]),
310            OpId::Sqrt => write!(f, "{}.sqrt()", args[0]),
311            OpId::Round => write!(f, "{}.round()", args[0]),
312            OpId::Floor => write!(f, "{}.floor()", args[0]),
313            OpId::Ceil => write!(f, "{}.ceil()", args[0]),
314            OpId::Erf => write!(f, "{}.erf()", args[0]),
315            OpId::Recip => write!(f, "1.0 / {}", args[0]),
316            OpId::Equal => write!(f, "{} == {}", args[0], args[1]),
317            OpId::NotEqual => write!(f, "{} != {}", args[0], args[1]),
318            OpId::Lower => write!(f, "{} < {}", args[0], args[1]),
319            OpId::Clamp => write!(f, "clamp({}, {}, {})", args[0], args[1], args[2]),
320            OpId::Greater => write!(f, "{} > {}", args[0], args[1]),
321            OpId::LowerEqual => write!(f, "{} <= {}", args[0], args[1]),
322            OpId::GreaterEqual => write!(f, "{} >= {}", args[0], args[1]),
323            OpId::Modulo => write!(f, "{} % {}", args[0], args[1]),
324            OpId::Index => write!(f, "{}[{}]", args[0], args[1]),
325            OpId::InitLine => write!(
326                f,
327                "vec{}({})",
328                args.len(),
329                args.iter()
330                    .map(|it| it.to_string())
331                    .collect::<Vec<_>>()
332                    .join(", ")
333            ),
334            OpId::And => write!(f, "{} && {}", args[0], args[1]),
335            OpId::Or => write!(f, "{} || {}", args[0], args[1]),
336            OpId::Not => write!(f, "!{}", args[0]),
337            OpId::Neg => write!(f, "-{}", args[0]),
338            OpId::Max => write!(f, "max({}, {})", args[0], args[1]),
339            OpId::Min => write!(f, "min({}, {})", args[0], args[1]),
340            OpId::BitwiseAnd => write!(f, "{} & {}", args[0], args[1]),
341            OpId::BitwiseOr => write!(f, "{} | {}", args[0], args[1]),
342            OpId::BitwiseXor => write!(f, "{} ^ {}", args[0], args[1]),
343            OpId::CountOnes => write!(f, "{}.count_ones()", args[0]),
344            OpId::ReverseBits => write!(f, "{}.reverse_bits()", args[0]),
345            OpId::ShiftLeft => write!(f, "{} << {}", args[0], args[1]),
346            OpId::ShiftRight => write!(f, "{} >> {}", args[0], args[1]),
347            OpId::Remainder => write!(f, "{} % {}", args[0], args[1]),
348            OpId::Magnitude => write!(f, "{}.length()", args[0]),
349            OpId::Normalize => write!(f, "{}.normalize()", args[0]),
350            OpId::Dot => write!(f, "dot({}, {})", args[0], args[1]),
351            OpId::Select => write!(f, "select({}, {}, {})", args[0], args[1], args[2]),
352            OpId::Bitcast => write!(f, "bitcast<{}>({})", self.item, args[0]),
353            OpId::Rank => write!(f, "{}.rank()", args[0]),
354            OpId::Length => write!(f, "{}.len()", args[0]),
355            OpId::BufferLength => write!(f, "buffer_len({})", args[0]),
356            OpId::Shape => write!(f, "{}.shape[{}]", args[0], args[1]),
357            OpId::Stride => write!(f, "{}.stride[{}]", args[0], args[1]),
358            OpId::Cast => write!(f, "cast<{}>({})", self.item, args[0]),
359        }
360    }
361}