cubecl_opt/
debug.rs

1use std::{fmt::Display, rc::Rc};
2
3use cubecl_ir::{FloatKind, IntKind, UIntKind};
4use petgraph::visit::EdgeRef;
5
6use crate::{
7    BasicBlock, ControlFlow,
8    analyses::{
9        liveness::{
10            Liveness,
11            shared::{SharedLiveness, SmemAllocation},
12        },
13        uniformity::Uniformity,
14    },
15    gvn::{BlockSets, Constant, Expression, GlobalValues, Instruction, Local, Value, ValueTable},
16};
17
18use super::Optimizer;
19
20const DEBUG_GVN: bool = false;
21
22/// Debug display for the program state.
23impl Display for Optimizer {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        let global_nums = self
26            .analysis_cache
27            .try_get::<GlobalValues>()
28            .unwrap_or_default();
29        let liveness = self
30            .analysis_cache
31            .try_get::<Liveness>()
32            .unwrap_or_else(|| Rc::new(Liveness::empty(self)));
33        let shared_liveness = self
34            .analysis_cache
35            .try_get::<SharedLiveness>()
36            .unwrap_or_else(|| Rc::new(SharedLiveness::empty(self)));
37        let uniformity = self
38            .analysis_cache
39            .try_get::<Uniformity>()
40            .unwrap_or_default();
41
42        if DEBUG_GVN {
43            writeln!(f, "# Value Table:")?;
44            writeln!(f, "{}", global_nums.borrow().values)?;
45        }
46
47        let smems = shared_liveness
48            .allocations
49            .values()
50            .map(|it| format!("    {it}"));
51        let smems = smems.collect::<Vec<_>>().join(",\n");
52        writeln!(f, "Shared memories: [\n{smems}\n]\n")?;
53
54        for node in self.program.node_indices() {
55            let id = node.index();
56            let bb = &self.program[node];
57            let uniform = match uniformity.is_block_uniform(node) {
58                true => "uniform ",
59                false => "",
60            };
61            writeln!(f, "{uniform}bb{id} {{")?;
62            if DEBUG_GVN {
63                let block_sets = &global_nums
64                    .borrow()
65                    .block_sets
66                    .get(&node)
67                    .cloned()
68                    .unwrap_or_default();
69                writeln!(f, "{block_sets}")?;
70            }
71
72            if !bb.block_use.is_empty() {
73                writeln!(f, "    Uses: {:?}", bb.block_use)?;
74            }
75            let live_vars = liveness.at_block(node).iter();
76            let live_vars = live_vars.map(|it| format!("local({it})"));
77            let live_vars = live_vars.collect::<Vec<_>>();
78            writeln!(f, "    Live variables: [{}]\n", live_vars.join(", "))?;
79            let live_shared = shared_liveness.at_block(node).iter();
80            let live_shared = live_shared.map(|it| format!("shared({it})"));
81            let live_shared = live_shared.collect::<Vec<_>>();
82            writeln!(
83                f,
84                "    Live shared memories: [{}]\n",
85                live_shared.join(", ")
86            )?;
87
88            for phi in bb.phi_nodes.borrow().iter() {
89                write!(f, "    {} = phi ", phi.out)?;
90                for entry in &phi.entries {
91                    write!(f, "[bb{}: ", entry.block.index())?;
92                    write!(f, "{}]", entry.value)?;
93                }
94                let is_uniform = match uniformity.is_var_uniform(phi.out) {
95                    true => " @ uniform",
96                    false => "",
97                };
98                writeln!(f, ";{is_uniform}\n")?;
99            }
100            if !bb.phi_nodes.borrow().is_empty() {
101                writeln!(f)?;
102            }
103
104            for op in bb.ops.borrow_mut().values_mut() {
105                let op_fmt = op.to_string();
106                if op_fmt.is_empty() {
107                    continue;
108                }
109
110                let is_uniform = match op
111                    .out
112                    .map(|out| uniformity.is_var_uniform(out))
113                    .unwrap_or(false)
114                {
115                    true => " @ uniform",
116                    false => "",
117                };
118                writeln!(f, "    {op_fmt};{is_uniform}")?;
119            }
120            match &*bb.control_flow.borrow() {
121                ControlFlow::IfElse {
122                    cond,
123                    then,
124                    or_else,
125                    merge,
126                } => {
127                    writeln!(
128                        f,
129                        "    {cond} ? bb{} : bb{}; merge: {}",
130                        then.index(),
131                        or_else.index(),
132                        merge
133                            .as_ref()
134                            .map(|it| format!("bb{}", it.index()))
135                            .unwrap_or("None".to_string())
136                    )?;
137                }
138                super::ControlFlow::Switch {
139                    value,
140                    default,
141                    branches,
142                    ..
143                } => {
144                    write!(f, "    switch({value}) ")?;
145                    for (val, block) in branches {
146                        write!(f, "[{val}: bb{}] ", block.index())?;
147                    }
148                    writeln!(f, "[default: bb{}];", default.index())?;
149                }
150                super::ControlFlow::Loop {
151                    body,
152                    continue_target,
153                    merge,
154                } => {
155                    writeln!(
156                        f,
157                        "    loop(continue: bb{}, merge: bb{})",
158                        continue_target.index(),
159                        merge.index()
160                    )?;
161                    writeln!(f, "    branch bb{};", body.index())?
162                }
163                super::ControlFlow::LoopBreak {
164                    break_cond,
165                    body,
166                    continue_target,
167                    merge,
168                } => {
169                    writeln!(
170                        f,
171                        "    loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
172                        break_cond,
173                        body.index(),
174                        continue_target.index(),
175                        merge.index()
176                    )?;
177                }
178                super::ControlFlow::Return => writeln!(f, "    return;")?,
179                super::ControlFlow::None => {
180                    let edge = self.program.edges(node).next();
181                    let target = edge.map(|it| it.target().index()).unwrap_or(255);
182                    writeln!(f, "    branch bb{target};")?;
183                }
184            }
185            f.write_str("}\n\n")?;
186        }
187
188        Ok(())
189    }
190}
191
192impl Display for BlockSets {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        let mut exp_gen = self.exp_gen.iter().collect::<Vec<_>>();
195        exp_gen.sort_by_key(|it| it.0);
196        let exp_gen = exp_gen
197            .into_iter()
198            .map(|(val, expr)| format!("{val}: {expr}"))
199            .collect::<Vec<_>>();
200        let mut phi_gen = self.phi_gen.iter().collect::<Vec<_>>();
201        phi_gen.sort_by_key(|it| it.0);
202        let phi_gen = phi_gen
203            .into_iter()
204            .map(|(val, expr)| format!("{val}: {expr}"))
205            .collect::<Vec<_>>();
206        let tmp_gen = self
207            .tmp_gen
208            .iter()
209            .map(|it| format!("{it}"))
210            .collect::<Vec<_>>();
211        let mut leaders = self.leaders.iter().collect::<Vec<_>>();
212        leaders.sort_by_key(|it| it.0);
213        let leaders = leaders
214            .into_iter()
215            .map(|(val, expr)| format!("{val}: {expr}"))
216            .collect::<Vec<_>>();
217        let mut antic_out = self.antic_out.iter().collect::<Vec<_>>();
218        antic_out.sort_by_key(|it| it.0);
219        let antic_out = antic_out
220            .into_iter()
221            .map(|(val, expr)| format!("{val}: {expr}"))
222            .collect::<Vec<_>>();
223        let mut antic_in = self.antic_in.iter().collect::<Vec<_>>();
224        antic_in.sort_by_key(|it| it.0);
225        let antic_in = antic_in
226            .into_iter()
227            .map(|(val, expr)| format!("{val}: {expr}"))
228            .collect::<Vec<_>>();
229
230        writeln!(f, "    exp_gen: [{}]", exp_gen.join(", "))?;
231        writeln!(f, "    phi_gen: [{}]", phi_gen.join(", "))?;
232        writeln!(f, "    tmp_gen: [{}]", tmp_gen.join(", "))?;
233        writeln!(f, "    leaders: [{}]", leaders.join(", "))?;
234        writeln!(f, "    antic_in: [{}]", antic_in.join(", "))?;
235        writeln!(f, "    antic_out: [{}]", antic_out.join(", "))
236    }
237}
238
239impl Display for ValueTable {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        let mut values = self.value_numbers.iter().collect::<Vec<_>>();
242        values.sort_by_key(|it| it.1);
243        writeln!(f, "values: [")?;
244        for (val, num) in values {
245            writeln!(f, "    {num}: {val},")?;
246        }
247        writeln!(f, "]")?;
248        writeln!(f, "expressions: [")?;
249        let mut expressions = self.expression_numbers.iter().collect::<Vec<_>>();
250        expressions.sort_by_key(|it| it.1);
251        for (expr, val) in expressions {
252            writeln!(f, "    {val}: {expr},")?;
253        }
254        writeln!(f, "]")
255    }
256}
257
258impl Display for Value {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        match self {
261            Value::Constant(constant) => write!(f, "{constant}"),
262            Value::Local(local) => write!(f, "{local}"),
263            Value::Input(id, _) => write!(f, "input({id})"),
264            Value::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"),
265            Value::ConstArray(id, _, _, _) => write!(f, "const_array({id})"),
266            Value::Builtin(builtin) => write!(f, "{builtin:?}"),
267            Value::Output(id, _) => write!(f, "output({id})"),
268        }
269    }
270}
271
272impl Display for Local {
273    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274        match self.version {
275            0 => write!(f, "binding({})", self.id),
276            v => write!(f, "local({}).v{v}", self.id),
277        }
278    }
279}
280
281impl Display for Constant {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        match self {
284            Constant::Int(val, IntKind::I8) => write!(f, "{val}i8"),
285            Constant::Int(val, IntKind::I16) => write!(f, "{val}i16"),
286            Constant::Int(val, IntKind::I32) => write!(f, "{val}i32"),
287            Constant::Int(val, IntKind::I64) => write!(f, "{val}i64"),
288            Constant::Float(val, FloatKind::E2M1) => write!(f, "{}e2m1", val.0),
289            Constant::Float(val, FloatKind::E2M3) => write!(f, "{}e2m3", val.0),
290            Constant::Float(val, FloatKind::E3M2) => write!(f, "{}e3m2", val.0),
291            Constant::Float(val, FloatKind::E4M3) => write!(f, "{}e4m3", val.0),
292            Constant::Float(val, FloatKind::E5M2) => write!(f, "{}e5m2", val.0),
293            Constant::Float(val, FloatKind::UE8M0) => write!(f, "{}ue8m0", val.0),
294            Constant::Float(val, FloatKind::BF16) => write!(f, "{}bf16", val.0),
295            Constant::Float(val, FloatKind::F16) => write!(f, "{}f16", val.0),
296            Constant::Float(val, FloatKind::Flex32) => write!(f, "{}minf16", val.0),
297            Constant::Float(val, FloatKind::TF32) => write!(f, "{}tf32", val.0),
298            Constant::Float(val, FloatKind::F32) => write!(f, "{}f32", val.0),
299            Constant::Float(val, FloatKind::F64) => write!(f, "{}f64", val.0),
300            Constant::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
301            Constant::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
302            Constant::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
303            Constant::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
304            Constant::Bool(val) => write!(f, "{val}"),
305        }
306    }
307}
308
309impl Display for Expression {
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        match self {
312            Expression::Instruction(instruction) => write!(f, "{instruction}"),
313            Expression::Copy(val, _) => write!(f, "copy({val})"),
314            Expression::Value(value) => write!(f, "{value}"),
315            Expression::Volatile(value) => write!(f, "volatile({value})"),
316            Expression::Phi(entries) => write!(
317                f,
318                "phi({})",
319                entries
320                    .iter()
321                    .map(|(val, b)| format!("{val}: bb{}", b.index()))
322                    .collect::<Vec<_>>()
323                    .join(", ")
324            ),
325        }
326    }
327}
328
329impl Display for Instruction {
330    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331        write!(f, "{:?}: [{:?}]", self.op, self.args)
332    }
333}
334
335impl Display for BasicBlock {
336    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        for phi in self.phi_nodes.borrow().iter() {
338            write!(f, "    {} = phi ", phi.out)?;
339            for entry in &phi.entries {
340                write!(f, "[bb{}: ", entry.block.index())?;
341                write!(f, "{}]", entry.value)?;
342            }
343            writeln!(f, ";\n")?;
344        }
345        if !self.phi_nodes.borrow().is_empty() {
346            writeln!(f)?;
347        }
348
349        for op in self.ops.borrow_mut().values_mut() {
350            let op_fmt = op.to_string();
351            if op_fmt.is_empty() {
352                continue;
353            }
354
355            writeln!(f, "    {op_fmt};")?;
356        }
357        match &*self.control_flow.borrow() {
358            ControlFlow::IfElse {
359                cond,
360                then,
361                or_else,
362                merge,
363            } => {
364                writeln!(
365                    f,
366                    "    {cond} ? bb{} : bb{}; merge: {}",
367                    then.index(),
368                    or_else.index(),
369                    merge
370                        .as_ref()
371                        .map(|it| format!("bb{}", it.index()))
372                        .unwrap_or("None".to_string())
373                )?;
374            }
375            super::ControlFlow::Switch {
376                value,
377                default,
378                branches,
379                ..
380            } => {
381                write!(f, "    switch({value}) ")?;
382                for (val, block) in branches {
383                    write!(f, "[{val}: bb{}] ", block.index())?;
384                }
385                writeln!(f, "[default: bb{}];", default.index())?;
386            }
387            super::ControlFlow::Loop {
388                body,
389                continue_target,
390                merge,
391            } => {
392                writeln!(
393                    f,
394                    "    loop(continue: bb{}, merge: bb{})",
395                    continue_target.index(),
396                    merge.index()
397                )?;
398                writeln!(f, "    branch bb{};", body.index())?
399            }
400            super::ControlFlow::LoopBreak {
401                break_cond,
402                body,
403                continue_target,
404                merge,
405            } => {
406                writeln!(
407                    f,
408                    "    loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
409                    break_cond,
410                    body.index(),
411                    continue_target.index(),
412                    merge.index()
413                )?;
414            }
415            super::ControlFlow::Return => writeln!(f, "    return;")?,
416            super::ControlFlow::None => {
417                writeln!(f, "    branch;")?;
418            }
419        }
420        Ok(())
421    }
422}
423
424impl Display for SmemAllocation {
425    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426        write!(
427            f,
428            "shared(id: {}, offset: {}, length: {}, align: {}, ty: {})",
429            self.smem.id, self.offset, self.smem.length, self.smem.align, self.smem.ty
430        )
431    }
432}