cubecl_opt/
debug.rs

1use std::{fmt::Display, rc::Rc};
2
3use cubecl_ir::{FloatKind, IntKind, UIntKind};
4use petgraph::visit::EdgeRef;
5
6use crate::{
7    BasicBlock, ControlFlow,
8    analyses::{liveness::Liveness, uniformity::Uniformity},
9    gvn::{BlockSets, Constant, Expression, GlobalValues, Instruction, Local, Value, ValueTable},
10};
11
12use super::Optimizer;
13
14const DEBUG_GVN: bool = false;
15
16/// Debug display for the program state.
17impl Display for Optimizer {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        let global_nums = self
20            .analysis_cache
21            .try_get::<GlobalValues>()
22            .unwrap_or_default();
23        let liveness = self
24            .analysis_cache
25            .try_get::<Liveness>()
26            .unwrap_or_else(|| Rc::new(Liveness::empty(self)));
27        let uniformity = self
28            .analysis_cache
29            .try_get::<Uniformity>()
30            .unwrap_or_default();
31
32        if DEBUG_GVN {
33            writeln!(f, "# Value Table:")?;
34            writeln!(f, "{}", global_nums.borrow().values)?;
35        }
36
37        for node in self.program.node_indices() {
38            let id = node.index();
39            let bb = &self.program[node];
40            let uniform = match uniformity.is_block_uniform(node) {
41                true => "uniform ",
42                false => "",
43            };
44            writeln!(f, "{uniform}bb{id} {{")?;
45            if DEBUG_GVN {
46                let block_sets = &global_nums
47                    .borrow()
48                    .block_sets
49                    .get(&node)
50                    .cloned()
51                    .unwrap_or_default();
52                writeln!(f, "{block_sets}")?;
53            }
54
55            if !bb.block_use.is_empty() {
56                writeln!(f, "    Uses: {:?}", bb.block_use)?;
57            }
58            let live_vars = liveness.at_block(node).iter();
59            let live_vars = live_vars.map(|it| format!("local({it})"));
60            let live_vars = live_vars.collect::<Vec<_>>();
61            writeln!(f, "    Live variables: [{}]\n", live_vars.join(", "))?;
62
63            for phi in bb.phi_nodes.borrow().iter() {
64                write!(f, "    {} = phi ", phi.out)?;
65                for entry in &phi.entries {
66                    write!(f, "[bb{}: ", entry.block.index())?;
67                    write!(f, "{}]", entry.value)?;
68                }
69                let is_uniform = match uniformity.is_var_uniform(phi.out) {
70                    true => " @ uniform",
71                    false => "",
72                };
73                writeln!(f, ";{is_uniform}\n")?;
74            }
75            if !bb.phi_nodes.borrow().is_empty() {
76                writeln!(f)?;
77            }
78
79            for op in bb.ops.borrow_mut().values_mut() {
80                let op_fmt = op.to_string();
81                if op_fmt.is_empty() {
82                    continue;
83                }
84
85                let is_uniform = match op
86                    .out
87                    .map(|out| uniformity.is_var_uniform(out))
88                    .unwrap_or(false)
89                {
90                    true => " @ uniform",
91                    false => "",
92                };
93                writeln!(f, "    {op_fmt};{is_uniform}")?;
94            }
95            match &*bb.control_flow.borrow() {
96                ControlFlow::IfElse {
97                    cond,
98                    then,
99                    or_else,
100                    merge,
101                } => {
102                    writeln!(
103                        f,
104                        "    {cond} ? bb{} : bb{}; merge: {}",
105                        then.index(),
106                        or_else.index(),
107                        merge
108                            .as_ref()
109                            .map(|it| format!("bb{}", it.index()))
110                            .unwrap_or("None".to_string())
111                    )?;
112                }
113                super::ControlFlow::Switch {
114                    value,
115                    default,
116                    branches,
117                    ..
118                } => {
119                    write!(f, "    switch({value}) ")?;
120                    for (val, block) in branches {
121                        write!(f, "[{val}: bb{}] ", block.index())?;
122                    }
123                    writeln!(f, "[default: bb{}];", default.index())?;
124                }
125                super::ControlFlow::Loop {
126                    body,
127                    continue_target,
128                    merge,
129                } => {
130                    writeln!(
131                        f,
132                        "    loop(continue: bb{}, merge: bb{})",
133                        continue_target.index(),
134                        merge.index()
135                    )?;
136                    writeln!(f, "    branch bb{};", body.index())?
137                }
138                super::ControlFlow::LoopBreak {
139                    break_cond,
140                    body,
141                    continue_target,
142                    merge,
143                } => {
144                    writeln!(
145                        f,
146                        "    loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
147                        break_cond,
148                        body.index(),
149                        continue_target.index(),
150                        merge.index()
151                    )?;
152                }
153                super::ControlFlow::Return => writeln!(f, "    return;")?,
154                super::ControlFlow::None => {
155                    let edge = self.program.edges(node).next();
156                    let target = edge.map(|it| it.target().index()).unwrap_or(255);
157                    writeln!(f, "    branch bb{target};")?;
158                }
159            }
160            f.write_str("}\n\n")?;
161        }
162
163        Ok(())
164    }
165}
166
167impl Display for BlockSets {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        let mut exp_gen = self.exp_gen.iter().collect::<Vec<_>>();
170        exp_gen.sort_by_key(|it| it.0);
171        let exp_gen = exp_gen
172            .into_iter()
173            .map(|(val, expr)| format!("{val}: {expr}"))
174            .collect::<Vec<_>>();
175        let mut phi_gen = self.phi_gen.iter().collect::<Vec<_>>();
176        phi_gen.sort_by_key(|it| it.0);
177        let phi_gen = phi_gen
178            .into_iter()
179            .map(|(val, expr)| format!("{val}: {expr}"))
180            .collect::<Vec<_>>();
181        let tmp_gen = self
182            .tmp_gen
183            .iter()
184            .map(|it| format!("{it}"))
185            .collect::<Vec<_>>();
186        let mut leaders = self.leaders.iter().collect::<Vec<_>>();
187        leaders.sort_by_key(|it| it.0);
188        let leaders = leaders
189            .into_iter()
190            .map(|(val, expr)| format!("{val}: {expr}"))
191            .collect::<Vec<_>>();
192        let mut antic_out = self.antic_out.iter().collect::<Vec<_>>();
193        antic_out.sort_by_key(|it| it.0);
194        let antic_out = antic_out
195            .into_iter()
196            .map(|(val, expr)| format!("{val}: {expr}"))
197            .collect::<Vec<_>>();
198        let mut antic_in = self.antic_in.iter().collect::<Vec<_>>();
199        antic_in.sort_by_key(|it| it.0);
200        let antic_in = antic_in
201            .into_iter()
202            .map(|(val, expr)| format!("{val}: {expr}"))
203            .collect::<Vec<_>>();
204
205        writeln!(f, "    exp_gen: [{}]", exp_gen.join(", "))?;
206        writeln!(f, "    phi_gen: [{}]", phi_gen.join(", "))?;
207        writeln!(f, "    tmp_gen: [{}]", tmp_gen.join(", "))?;
208        writeln!(f, "    leaders: [{}]", leaders.join(", "))?;
209        writeln!(f, "    antic_in: [{}]", antic_in.join(", "))?;
210        writeln!(f, "    antic_out: [{}]", antic_out.join(", "))
211    }
212}
213
214impl Display for ValueTable {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        let mut values = self.value_numbers.iter().collect::<Vec<_>>();
217        values.sort_by_key(|it| it.1);
218        writeln!(f, "values: [")?;
219        for (val, num) in values {
220            writeln!(f, "    {num}: {val},")?;
221        }
222        writeln!(f, "]")?;
223        writeln!(f, "expressions: [")?;
224        let mut expressions = self.expression_numbers.iter().collect::<Vec<_>>();
225        expressions.sort_by_key(|it| it.1);
226        for (expr, val) in expressions {
227            writeln!(f, "    {val}: {expr},")?;
228        }
229        writeln!(f, "]")
230    }
231}
232
233impl Display for Value {
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        match self {
236            Value::Constant(constant) => write!(f, "{constant}"),
237            Value::Local(local) => write!(f, "{local}"),
238            Value::Input(id, _) => write!(f, "input({id})"),
239            Value::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"),
240            Value::ConstArray(id, _, _) => write!(f, "const_array({id})"),
241            Value::Builtin(builtin) => write!(f, "{builtin:?}"),
242            Value::Output(id, _) => write!(f, "output({id})"),
243        }
244    }
245}
246
247impl Display for Local {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        match self.version {
250            0 => write!(f, "binding({})", self.id),
251            v => write!(f, "local({}).v{v}", self.id),
252        }
253    }
254}
255
256impl Display for Constant {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        match self {
259            Constant::Int(val, IntKind::I8) => write!(f, "{val}i8"),
260            Constant::Int(val, IntKind::I16) => write!(f, "{val}i16"),
261            Constant::Int(val, IntKind::I32) => write!(f, "{val}i32"),
262            Constant::Int(val, IntKind::I64) => write!(f, "{val}i64"),
263            Constant::Float(val, FloatKind::E2M1) => write!(f, "{}e2m1", val.0),
264            Constant::Float(val, FloatKind::E2M3) => write!(f, "{}e2m3", val.0),
265            Constant::Float(val, FloatKind::E3M2) => write!(f, "{}e3m2", val.0),
266            Constant::Float(val, FloatKind::E4M3) => write!(f, "{}e4m3", val.0),
267            Constant::Float(val, FloatKind::E5M2) => write!(f, "{}e5m2", val.0),
268            Constant::Float(val, FloatKind::UE8M0) => write!(f, "{}ue8m0", val.0),
269            Constant::Float(val, FloatKind::BF16) => write!(f, "{}bf16", val.0),
270            Constant::Float(val, FloatKind::F16) => write!(f, "{}f16", val.0),
271            Constant::Float(val, FloatKind::Flex32) => write!(f, "{}minf16", val.0),
272            Constant::Float(val, FloatKind::TF32) => write!(f, "{}tf32", val.0),
273            Constant::Float(val, FloatKind::F32) => write!(f, "{}f32", val.0),
274            Constant::Float(val, FloatKind::F64) => write!(f, "{}f64", val.0),
275            Constant::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
276            Constant::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
277            Constant::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
278            Constant::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
279            Constant::Bool(val) => write!(f, "{val}"),
280        }
281    }
282}
283
284impl Display for Expression {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        match self {
287            Expression::Instruction(instruction) => write!(f, "{instruction}"),
288            Expression::Copy(val, _) => write!(f, "copy({val})"),
289            Expression::Value(value) => write!(f, "{value}"),
290            Expression::Volatile(value) => write!(f, "volatile({value})"),
291            Expression::Phi(entries) => write!(
292                f,
293                "phi({})",
294                entries
295                    .iter()
296                    .map(|(val, b)| format!("{val}: bb{}", b.index()))
297                    .collect::<Vec<_>>()
298                    .join(", ")
299            ),
300        }
301    }
302}
303
304impl Display for Instruction {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        write!(f, "{:?}: [{:?}]", self.op, self.args)
307    }
308}
309
310impl Display for BasicBlock {
311    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312        for phi in self.phi_nodes.borrow().iter() {
313            write!(f, "    {} = phi ", phi.out)?;
314            for entry in &phi.entries {
315                write!(f, "[bb{}: ", entry.block.index())?;
316                write!(f, "{}]", entry.value)?;
317            }
318            writeln!(f, ";\n")?;
319        }
320        if !self.phi_nodes.borrow().is_empty() {
321            writeln!(f)?;
322        }
323
324        for op in self.ops.borrow_mut().values_mut() {
325            let op_fmt = op.to_string();
326            if op_fmt.is_empty() {
327                continue;
328            }
329
330            writeln!(f, "    {op_fmt};")?;
331        }
332        match &*self.control_flow.borrow() {
333            ControlFlow::IfElse {
334                cond,
335                then,
336                or_else,
337                merge,
338            } => {
339                writeln!(
340                    f,
341                    "    {cond} ? bb{} : bb{}; merge: {}",
342                    then.index(),
343                    or_else.index(),
344                    merge
345                        .as_ref()
346                        .map(|it| format!("bb{}", it.index()))
347                        .unwrap_or("None".to_string())
348                )?;
349            }
350            super::ControlFlow::Switch {
351                value,
352                default,
353                branches,
354                ..
355            } => {
356                write!(f, "    switch({value}) ")?;
357                for (val, block) in branches {
358                    write!(f, "[{val}: bb{}] ", block.index())?;
359                }
360                writeln!(f, "[default: bb{}];", default.index())?;
361            }
362            super::ControlFlow::Loop {
363                body,
364                continue_target,
365                merge,
366            } => {
367                writeln!(
368                    f,
369                    "    loop(continue: bb{}, merge: bb{})",
370                    continue_target.index(),
371                    merge.index()
372                )?;
373                writeln!(f, "    branch bb{};", body.index())?
374            }
375            super::ControlFlow::LoopBreak {
376                break_cond,
377                body,
378                continue_target,
379                merge,
380            } => {
381                writeln!(
382                    f,
383                    "    loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
384                    break_cond,
385                    body.index(),
386                    continue_target.index(),
387                    merge.index()
388                )?;
389            }
390            super::ControlFlow::Return => writeln!(f, "    return;")?,
391            super::ControlFlow::None => {
392                writeln!(f, "    branch;")?;
393            }
394        }
395        Ok(())
396    }
397}