Skip to main content

cubecl_opt/
debug.rs

1use std::{fmt::Display, rc::Rc};
2
3use petgraph::visit::EdgeRef;
4
5use crate::{
6    BasicBlock, ControlFlow,
7    analyses::{
8        liveness::{
9            Liveness,
10            shared::{SharedLiveness, SmemAllocation},
11        },
12        uniformity::Uniformity,
13    },
14    gvn::{BlockSets, Expression, GlobalValues, Instruction, Local, Value, ValueTable},
15};
16
17use super::Optimizer;
18
19const DEBUG_GVN: bool = false;
20
21/// Debug display for the program state.
22impl Display for Optimizer {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        let global_nums = self
25            .analysis_cache
26            .try_get::<GlobalValues>()
27            .unwrap_or_default();
28        let liveness = self
29            .analysis_cache
30            .try_get::<Liveness>()
31            .unwrap_or_else(|| Rc::new(Liveness::empty(self)));
32        let shared_liveness = self
33            .analysis_cache
34            .try_get::<SharedLiveness>()
35            .unwrap_or_else(|| Rc::new(SharedLiveness::empty(self)));
36        let uniformity = self
37            .analysis_cache
38            .try_get::<Uniformity>()
39            .unwrap_or_default();
40
41        if DEBUG_GVN {
42            writeln!(f, "# Value Table:")?;
43            writeln!(f, "{}", global_nums.borrow().values)?;
44        }
45
46        let smems = shared_liveness
47            .allocations
48            .values()
49            .map(|it| format!("    {it}"));
50        let smems = smems.collect::<Vec<_>>().join(",\n");
51        writeln!(f, "Shared memories: [\n{smems}\n]\n")?;
52
53        for node in self.program.node_indices() {
54            let id = node.index();
55            let bb = &self.program[node];
56            let uniform = match uniformity.is_block_uniform(node) {
57                true => "uniform ",
58                false => "",
59            };
60            writeln!(f, "{uniform}bb{id} {{")?;
61            if DEBUG_GVN {
62                let block_sets = &global_nums
63                    .borrow()
64                    .block_sets
65                    .get(&node)
66                    .cloned()
67                    .unwrap_or_default();
68                writeln!(f, "{block_sets}")?;
69            }
70
71            if !bb.block_use.is_empty() {
72                writeln!(f, "    Uses: {:?}", bb.block_use)?;
73            }
74            let live_vars = liveness.at_block(node).iter();
75            let live_vars = live_vars.map(|it| format!("local({it})"));
76            let live_vars = live_vars.collect::<Vec<_>>();
77            writeln!(f, "    Live variables: [{}]\n", live_vars.join(", "))?;
78            let live_shared = shared_liveness.at_block(node).iter();
79            let live_shared = live_shared.map(|it| format!("shared({it})"));
80            let live_shared = live_shared.collect::<Vec<_>>();
81            writeln!(
82                f,
83                "    Live shared memories: [{}]\n",
84                live_shared.join(", ")
85            )?;
86
87            for phi in bb.phi_nodes.borrow().iter() {
88                write!(f, "    {} = phi ", phi.out)?;
89                for entry in &phi.entries {
90                    write!(f, "[bb{}: ", entry.block.index())?;
91                    write!(f, "{}]", entry.value)?;
92                }
93                let is_uniform = match uniformity.is_var_uniform(phi.out) {
94                    true => " @ uniform",
95                    false => "",
96                };
97                writeln!(f, ";{is_uniform}\n")?;
98            }
99            if !bb.phi_nodes.borrow().is_empty() {
100                writeln!(f)?;
101            }
102
103            for op in bb.ops.borrow_mut().values_mut() {
104                let op_fmt = op.to_string();
105                if op_fmt.is_empty() {
106                    continue;
107                }
108
109                let is_uniform = match op.out.is_some_and(|out| uniformity.is_var_uniform(out)) {
110                    true => " @ uniform",
111                    false => "",
112                };
113                writeln!(f, "    {op_fmt};{is_uniform}")?;
114            }
115            match &*bb.control_flow.borrow() {
116                ControlFlow::IfElse {
117                    cond,
118                    then,
119                    or_else,
120                    merge,
121                } => {
122                    writeln!(
123                        f,
124                        "    {cond} ? bb{} : bb{}; merge: {}",
125                        then.index(),
126                        or_else.index(),
127                        merge
128                            .as_ref()
129                            .map(|it| format!("bb{}", it.index()))
130                            .unwrap_or("None".to_string())
131                    )?;
132                }
133                super::ControlFlow::Switch {
134                    value,
135                    default,
136                    branches,
137                    ..
138                } => {
139                    write!(f, "    switch({value}) ")?;
140                    for (val, block) in branches {
141                        write!(f, "[{val}: bb{}] ", block.index())?;
142                    }
143                    writeln!(f, "[default: bb{}];", default.index())?;
144                }
145                super::ControlFlow::Loop {
146                    body,
147                    continue_target,
148                    merge,
149                } => {
150                    writeln!(
151                        f,
152                        "    loop(continue: bb{}, merge: bb{})",
153                        continue_target.index(),
154                        merge.index()
155                    )?;
156                    writeln!(f, "    branch bb{};", body.index())?
157                }
158                super::ControlFlow::LoopBreak {
159                    break_cond,
160                    body,
161                    continue_target,
162                    merge,
163                } => {
164                    writeln!(
165                        f,
166                        "    loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
167                        break_cond,
168                        body.index(),
169                        continue_target.index(),
170                        merge.index()
171                    )?;
172                }
173                super::ControlFlow::Return => writeln!(f, "    return;")?,
174                super::ControlFlow::Unreachable => writeln!(f, "    unreachable;")?,
175                super::ControlFlow::None => {
176                    let edge = self.program.edges(node).next();
177                    let target = edge.map(|it| it.target().index()).unwrap_or(255);
178                    writeln!(f, "    branch bb{target};")?;
179                }
180            }
181            f.write_str("}\n\n")?;
182        }
183
184        Ok(())
185    }
186}
187
188impl Display for BlockSets {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        let mut exp_gen = self.exp_gen.iter().collect::<Vec<_>>();
191        exp_gen.sort_by_key(|it| it.0);
192        let exp_gen = exp_gen
193            .into_iter()
194            .map(|(val, expr)| format!("{val}: {expr}"))
195            .collect::<Vec<_>>();
196        let mut phi_gen = self.phi_gen.iter().collect::<Vec<_>>();
197        phi_gen.sort_by_key(|it| it.0);
198        let phi_gen = phi_gen
199            .into_iter()
200            .map(|(val, expr)| format!("{val}: {expr}"))
201            .collect::<Vec<_>>();
202        let tmp_gen = self
203            .tmp_gen
204            .iter()
205            .map(|it| format!("{it}"))
206            .collect::<Vec<_>>();
207        let mut leaders = self.leaders.iter().collect::<Vec<_>>();
208        leaders.sort_by_key(|it| it.0);
209        let leaders = leaders
210            .into_iter()
211            .map(|(val, expr)| format!("{val}: {expr}"))
212            .collect::<Vec<_>>();
213        let mut antic_out = self.antic_out.iter().collect::<Vec<_>>();
214        antic_out.sort_by_key(|it| it.0);
215        let antic_out = antic_out
216            .into_iter()
217            .map(|(val, expr)| format!("{val}: {expr}"))
218            .collect::<Vec<_>>();
219        let mut antic_in = self.antic_in.iter().collect::<Vec<_>>();
220        antic_in.sort_by_key(|it| it.0);
221        let antic_in = antic_in
222            .into_iter()
223            .map(|(val, expr)| format!("{val}: {expr}"))
224            .collect::<Vec<_>>();
225
226        writeln!(f, "    exp_gen: [{}]", exp_gen.join(", "))?;
227        writeln!(f, "    phi_gen: [{}]", phi_gen.join(", "))?;
228        writeln!(f, "    tmp_gen: [{}]", tmp_gen.join(", "))?;
229        writeln!(f, "    leaders: [{}]", leaders.join(", "))?;
230        writeln!(f, "    antic_in: [{}]", antic_in.join(", "))?;
231        writeln!(f, "    antic_out: [{}]", antic_out.join(", "))
232    }
233}
234
235impl Display for ValueTable {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        let mut values = self.value_numbers.iter().collect::<Vec<_>>();
238        values.sort_by_key(|it| it.1);
239        writeln!(f, "values: [")?;
240        for (val, num) in values {
241            writeln!(f, "    {num}: {val},")?;
242        }
243        writeln!(f, "]")?;
244        writeln!(f, "expressions: [")?;
245        let mut expressions = self.expression_numbers.iter().collect::<Vec<_>>();
246        expressions.sort_by_key(|it| it.1);
247        for (expr, val) in expressions {
248            writeln!(f, "    {val}: {expr},")?;
249        }
250        writeln!(f, "]")
251    }
252}
253
254impl Display for Value {
255    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256        match self {
257            Value::Constant(constant, _) => write!(f, "{constant}"),
258            Value::Local(local) => write!(f, "{local}"),
259            Value::Input(id, _) => write!(f, "input({id})"),
260            Value::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"),
261            Value::ConstArray(id, _, _, _) => write!(f, "const_array({id})"),
262            Value::Builtin(builtin, _) => write!(f, "{builtin:?}"),
263            Value::Output(id, _) => write!(f, "output({id})"),
264        }
265    }
266}
267
268impl Display for Local {
269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270        match self.version {
271            0 => write!(f, "binding({})", self.id),
272            v => write!(f, "local({}).v{v}", self.id),
273        }
274    }
275}
276
277impl Display for Expression {
278    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279        match self {
280            Expression::Instruction(instruction) => write!(f, "{instruction}"),
281            Expression::Copy(val, _) => write!(f, "copy({val})"),
282            Expression::Value(value) => write!(f, "{value}"),
283            Expression::Volatile(value) => write!(f, "volatile({value})"),
284            Expression::Phi(entries) => write!(
285                f,
286                "phi({})",
287                entries
288                    .iter()
289                    .map(|(val, b)| format!("{val}: bb{}", b.index()))
290                    .collect::<Vec<_>>()
291                    .join(", ")
292            ),
293        }
294    }
295}
296
297impl Display for Instruction {
298    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299        write!(f, "{:?}: [{:?}]", self.op, self.args)
300    }
301}
302
303impl Display for BasicBlock {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        for phi in self.phi_nodes.borrow().iter() {
306            write!(f, "    {} = phi ", phi.out)?;
307            for entry in &phi.entries {
308                write!(f, "[bb{}: ", entry.block.index())?;
309                write!(f, "{}]", entry.value)?;
310            }
311            writeln!(f, ";\n")?;
312        }
313        if !self.phi_nodes.borrow().is_empty() {
314            writeln!(f)?;
315        }
316
317        for op in self.ops.borrow_mut().values_mut() {
318            let op_fmt = op.to_string();
319            if op_fmt.is_empty() {
320                continue;
321            }
322
323            writeln!(f, "    {op_fmt};")?;
324        }
325        match &*self.control_flow.borrow() {
326            ControlFlow::IfElse {
327                cond,
328                then,
329                or_else,
330                merge,
331            } => {
332                writeln!(
333                    f,
334                    "    {cond} ? bb{} : bb{}; merge: {}",
335                    then.index(),
336                    or_else.index(),
337                    merge
338                        .as_ref()
339                        .map(|it| format!("bb{}", it.index()))
340                        .unwrap_or("None".to_string())
341                )?;
342            }
343            super::ControlFlow::Switch {
344                value,
345                default,
346                branches,
347                ..
348            } => {
349                write!(f, "    switch({value}) ")?;
350                for (val, block) in branches {
351                    write!(f, "[{val}: bb{}] ", block.index())?;
352                }
353                writeln!(f, "[default: bb{}];", default.index())?;
354            }
355            super::ControlFlow::Loop {
356                body,
357                continue_target,
358                merge,
359            } => {
360                writeln!(
361                    f,
362                    "    loop(continue: bb{}, merge: bb{})",
363                    continue_target.index(),
364                    merge.index()
365                )?;
366                writeln!(f, "    branch bb{};", body.index())?
367            }
368            super::ControlFlow::LoopBreak {
369                break_cond,
370                body,
371                continue_target,
372                merge,
373            } => {
374                writeln!(
375                    f,
376                    "    loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
377                    break_cond,
378                    body.index(),
379                    continue_target.index(),
380                    merge.index()
381                )?;
382            }
383            super::ControlFlow::Return => writeln!(f, "    return;")?,
384            super::ControlFlow::Unreachable => writeln!(f, "    unreachable;")?,
385            super::ControlFlow::None => {
386                writeln!(f, "    branch;")?;
387            }
388        }
389        Ok(())
390    }
391}
392
393impl Display for SmemAllocation {
394    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395        match self.smem {
396            crate::SharedMemory::Array {
397                id,
398                length,
399                ty,
400                align,
401            } => {
402                write!(
403                    f,
404                    "shared_array(id: {id}, offset: {}, length: {length}, align: {align}, ty: {ty})",
405                    self.offset,
406                )
407            }
408            crate::SharedMemory::Value { id, ty, align } => {
409                write!(
410                    f,
411                    "shared(id: {id}, offset: {}, align: {align}, ty: {ty})",
412                    self.offset,
413                )
414            }
415        }
416    }
417}