cubecl_opt/
debug.rs

1use std::{fmt::Display, rc::Rc};
2
3use cubecl_ir::{FloatKind, IntKind, UIntKind};
4use petgraph::visit::EdgeRef;
5
6use crate::{
7    ControlFlow,
8    analyses::{const_len::Slices, liveness::Liveness, uniformity::Uniformity},
9    gvn::{BlockSets, Constant, Expression, GlobalValues, Instruction, Local, Value, ValueTable},
10};
11
12use super::Optimizer;
13
14const DEBUG_GVN: bool = false;
15
16/// Debug display for the program state.
17impl Display for Optimizer {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        let slices = self.analysis_cache.try_get::<Slices>().unwrap_or_default();
20
21        f.write_str("Slices:\n")?;
22        for (var_id, slice) in slices.iter() {
23            let end_op = slice.end_op.as_ref().map(|it| format!("{it}"));
24            writeln!(
25                f,
26                "slice{var_id:?}: {{ start: {}, end: {}, end_op: {}, const_len: {:?} }}",
27                slice.start,
28                slice.end,
29                end_op.unwrap_or("None".to_string()),
30                slice.const_len
31            )?;
32        }
33        f.write_str("\n\n")?;
34
35        let global_nums = self
36            .analysis_cache
37            .try_get::<GlobalValues>()
38            .unwrap_or_default();
39        let liveness = self
40            .analysis_cache
41            .try_get::<Liveness>()
42            .unwrap_or_else(|| Rc::new(Liveness::empty(self)));
43        let uniformity = self
44            .analysis_cache
45            .try_get::<Uniformity>()
46            .unwrap_or_default();
47
48        if DEBUG_GVN {
49            writeln!(f, "# Value Table:")?;
50            writeln!(f, "{}", global_nums.borrow().values)?;
51        }
52
53        for node in self.program.node_indices() {
54            let id = node.index();
55            let bb = &self.program[node];
56            let uniform = match uniformity.is_block_uniform(node) {
57                true => "uniform ",
58                false => "",
59            };
60            writeln!(f, "{uniform}bb{id} {{")?;
61            if DEBUG_GVN {
62                let block_sets = &global_nums
63                    .borrow()
64                    .block_sets
65                    .get(&node)
66                    .cloned()
67                    .unwrap_or_default();
68                writeln!(f, "{block_sets}")?;
69            }
70
71            if !bb.block_use.is_empty() {
72                writeln!(f, "    Uses: {:?}", bb.block_use)?;
73            }
74            let live_vars = liveness.at_block(node).iter();
75            let live_vars = live_vars.map(|it| format!("local({})", it));
76            let live_vars = live_vars.collect::<Vec<_>>();
77            writeln!(f, "    Live variables: [{}]\n", live_vars.join(", "))?;
78
79            for phi in bb.phi_nodes.borrow().iter() {
80                write!(f, "    {} = phi ", phi.out)?;
81                for entry in &phi.entries {
82                    write!(f, "[bb{}: ", entry.block.index())?;
83                    write!(f, "{}]", entry.value)?;
84                }
85                let is_uniform = match uniformity.is_var_uniform(phi.out) {
86                    true => " @ uniform",
87                    false => "",
88                };
89                writeln!(f, ";{is_uniform}\n")?;
90            }
91            if !bb.phi_nodes.borrow().is_empty() {
92                writeln!(f)?;
93            }
94
95            for op in bb.ops.borrow_mut().values_mut() {
96                let op_fmt = op.to_string();
97                if op_fmt.is_empty() {
98                    continue;
99                }
100
101                let is_uniform = match op
102                    .out
103                    .map(|out| uniformity.is_var_uniform(out))
104                    .unwrap_or(false)
105                {
106                    true => " @ uniform",
107                    false => "",
108                };
109                writeln!(f, "    {op_fmt};{is_uniform}")?;
110            }
111            match &*bb.control_flow.borrow() {
112                ControlFlow::IfElse {
113                    cond,
114                    then,
115                    or_else,
116                    merge,
117                } => {
118                    writeln!(
119                        f,
120                        "    {cond} ? bb{} : bb{}; merge: {}",
121                        then.index(),
122                        or_else.index(),
123                        merge
124                            .as_ref()
125                            .map(|it| format!("bb{}", it.index()))
126                            .unwrap_or("None".to_string())
127                    )?;
128                }
129                super::ControlFlow::Switch {
130                    value,
131                    default,
132                    branches,
133                    ..
134                } => {
135                    write!(f, "    switch({value}) ")?;
136                    for (val, block) in branches {
137                        write!(f, "[{val}: bb{}] ", block.index())?;
138                    }
139                    writeln!(f, "[default: bb{}];", default.index())?;
140                }
141                super::ControlFlow::Loop {
142                    body,
143                    continue_target,
144                    merge,
145                } => {
146                    writeln!(
147                        f,
148                        "    loop(continue: bb{}, merge: bb{})",
149                        continue_target.index(),
150                        merge.index()
151                    )?;
152                    writeln!(f, "    branch bb{};", body.index())?
153                }
154                super::ControlFlow::LoopBreak {
155                    break_cond,
156                    body,
157                    continue_target,
158                    merge,
159                } => {
160                    writeln!(
161                        f,
162                        "    loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
163                        break_cond,
164                        body.index(),
165                        continue_target.index(),
166                        merge.index()
167                    )?;
168                }
169                super::ControlFlow::Return => writeln!(f, "    return;")?,
170                super::ControlFlow::None => {
171                    let edge = self.program.edges(node).next();
172                    let target = edge.map(|it| it.target().index()).unwrap_or(255);
173                    writeln!(f, "    branch bb{target};")?;
174                }
175            }
176            f.write_str("}\n\n")?;
177        }
178
179        Ok(())
180    }
181}
182
183impl Display for BlockSets {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        let mut exp_gen = self.exp_gen.iter().collect::<Vec<_>>();
186        exp_gen.sort_by_key(|it| it.0);
187        let exp_gen = exp_gen
188            .into_iter()
189            .map(|(val, expr)| format!("{val}: {expr}"))
190            .collect::<Vec<_>>();
191        let mut phi_gen = self.phi_gen.iter().collect::<Vec<_>>();
192        phi_gen.sort_by_key(|it| it.0);
193        let phi_gen = phi_gen
194            .into_iter()
195            .map(|(val, expr)| format!("{val}: {expr}"))
196            .collect::<Vec<_>>();
197        let tmp_gen = self
198            .tmp_gen
199            .iter()
200            .map(|it| format!("{it}"))
201            .collect::<Vec<_>>();
202        let mut leaders = self.leaders.iter().collect::<Vec<_>>();
203        leaders.sort_by_key(|it| it.0);
204        let leaders = leaders
205            .into_iter()
206            .map(|(val, expr)| format!("{val}: {expr}"))
207            .collect::<Vec<_>>();
208        let mut antic_out = self.antic_out.iter().collect::<Vec<_>>();
209        antic_out.sort_by_key(|it| it.0);
210        let antic_out = antic_out
211            .into_iter()
212            .map(|(val, expr)| format!("{val}: {expr}"))
213            .collect::<Vec<_>>();
214        let mut antic_in = self.antic_in.iter().collect::<Vec<_>>();
215        antic_in.sort_by_key(|it| it.0);
216        let antic_in = antic_in
217            .into_iter()
218            .map(|(val, expr)| format!("{val}: {expr}"))
219            .collect::<Vec<_>>();
220
221        writeln!(f, "    exp_gen: [{}]", exp_gen.join(", "))?;
222        writeln!(f, "    phi_gen: [{}]", phi_gen.join(", "))?;
223        writeln!(f, "    tmp_gen: [{}]", tmp_gen.join(", "))?;
224        writeln!(f, "    leaders: [{}]", leaders.join(", "))?;
225        writeln!(f, "    antic_in: [{}]", antic_in.join(", "))?;
226        writeln!(f, "    antic_out: [{}]", antic_out.join(", "))
227    }
228}
229
230impl Display for ValueTable {
231    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232        let mut values = self.value_numbers.iter().collect::<Vec<_>>();
233        values.sort_by_key(|it| it.1);
234        writeln!(f, "values: [")?;
235        for (val, num) in values {
236            writeln!(f, "    {num}: {val},")?;
237        }
238        writeln!(f, "]")?;
239        writeln!(f, "expressions: [")?;
240        let mut expressions = self.expression_numbers.iter().collect::<Vec<_>>();
241        expressions.sort_by_key(|it| it.1);
242        for (expr, val) in expressions {
243            writeln!(f, "    {val}: {expr},")?;
244        }
245        writeln!(f, "]")
246    }
247}
248
249impl Display for Value {
250    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251        match self {
252            Value::Constant(constant) => write!(f, "{constant}"),
253            Value::Local(local) => write!(f, "{local}"),
254            Value::Input(id, _) => write!(f, "input({id})"),
255            Value::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"),
256            Value::ConstArray(id, _, _) => write!(f, "const_array({id})"),
257            Value::Builtin(builtin) => write!(f, "{builtin:?}"),
258            Value::Output(id, _) => write!(f, "output({id})"),
259            Value::Slice(id, _) => write!(f, "slice({id})"),
260        }
261    }
262}
263
264impl Display for Local {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        match self.version {
267            0 => write!(f, "binding({})", self.id),
268            v => write!(f, "local({}).v{v}", self.id),
269        }
270    }
271}
272
273impl Display for Constant {
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        match self {
276            Constant::Int(val, IntKind::I8) => write!(f, "{val}i8"),
277            Constant::Int(val, IntKind::I16) => write!(f, "{val}i16"),
278            Constant::Int(val, IntKind::I32) => write!(f, "{val}i32"),
279            Constant::Int(val, IntKind::I64) => write!(f, "{val}i64"),
280            Constant::Float(val, FloatKind::BF16) => write!(f, "{}bf16", val.0),
281            Constant::Float(val, FloatKind::F16) => write!(f, "{}f16", val.0),
282            Constant::Float(val, FloatKind::Flex32) => write!(f, "{}minf16", val.0),
283            Constant::Float(val, FloatKind::TF32) => write!(f, "{}tf32", val.0),
284            Constant::Float(val, FloatKind::F32) => write!(f, "{}f32", val.0),
285            Constant::Float(val, FloatKind::F64) => write!(f, "{}f64", val.0),
286            Constant::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
287            Constant::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
288            Constant::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
289            Constant::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
290            Constant::Bool(val) => write!(f, "{val}"),
291        }
292    }
293}
294
295impl Display for Expression {
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        match self {
298            Expression::Instruction(instruction) => write!(f, "{instruction}"),
299            Expression::Copy(val, _) => write!(f, "copy({val})"),
300            Expression::Value(value) => write!(f, "{value}"),
301            Expression::Volatile(value) => write!(f, "volatile({value})"),
302            Expression::Phi(entries) => write!(
303                f,
304                "phi({})",
305                entries
306                    .iter()
307                    .map(|(val, b)| format!("{val}: bb{}", b.index()))
308                    .collect::<Vec<_>>()
309                    .join(", ")
310            ),
311        }
312    }
313}
314
315impl Display for Instruction {
316    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317        write!(f, "{:?}: [{:?}]", self.op, self.args)
318    }
319}