cubecl_opt/
debug.rs

1use std::{fmt::Display, rc::Rc};
2
3use petgraph::visit::EdgeRef;
4
5use crate::{
6    BasicBlock, ControlFlow,
7    analyses::{
8        liveness::{
9            Liveness,
10            shared::{SharedLiveness, SmemAllocation},
11        },
12        uniformity::Uniformity,
13    },
14    gvn::{BlockSets, Expression, GlobalValues, Instruction, Local, Value, ValueTable},
15};
16
17use super::Optimizer;
18
19const DEBUG_GVN: bool = false;
20
21/// Debug display for the program state.
22impl Display for Optimizer {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        let global_nums = self
25            .analysis_cache
26            .try_get::<GlobalValues>()
27            .unwrap_or_default();
28        let liveness = self
29            .analysis_cache
30            .try_get::<Liveness>()
31            .unwrap_or_else(|| Rc::new(Liveness::empty(self)));
32        let shared_liveness = self
33            .analysis_cache
34            .try_get::<SharedLiveness>()
35            .unwrap_or_else(|| Rc::new(SharedLiveness::empty(self)));
36        let uniformity = self
37            .analysis_cache
38            .try_get::<Uniformity>()
39            .unwrap_or_default();
40
41        if DEBUG_GVN {
42            writeln!(f, "# Value Table:")?;
43            writeln!(f, "{}", global_nums.borrow().values)?;
44        }
45
46        let smems = shared_liveness
47            .allocations
48            .values()
49            .map(|it| format!("    {it}"));
50        let smems = smems.collect::<Vec<_>>().join(",\n");
51        writeln!(f, "Shared memories: [\n{smems}\n]\n")?;
52
53        for node in self.program.node_indices() {
54            let id = node.index();
55            let bb = &self.program[node];
56            let uniform = match uniformity.is_block_uniform(node) {
57                true => "uniform ",
58                false => "",
59            };
60            writeln!(f, "{uniform}bb{id} {{")?;
61            if DEBUG_GVN {
62                let block_sets = &global_nums
63                    .borrow()
64                    .block_sets
65                    .get(&node)
66                    .cloned()
67                    .unwrap_or_default();
68                writeln!(f, "{block_sets}")?;
69            }
70
71            if !bb.block_use.is_empty() {
72                writeln!(f, "    Uses: {:?}", bb.block_use)?;
73            }
74            let live_vars = liveness.at_block(node).iter();
75            let live_vars = live_vars.map(|it| format!("local({it})"));
76            let live_vars = live_vars.collect::<Vec<_>>();
77            writeln!(f, "    Live variables: [{}]\n", live_vars.join(", "))?;
78            let live_shared = shared_liveness.at_block(node).iter();
79            let live_shared = live_shared.map(|it| format!("shared({it})"));
80            let live_shared = live_shared.collect::<Vec<_>>();
81            writeln!(
82                f,
83                "    Live shared memories: [{}]\n",
84                live_shared.join(", ")
85            )?;
86
87            for phi in bb.phi_nodes.borrow().iter() {
88                write!(f, "    {} = phi ", phi.out)?;
89                for entry in &phi.entries {
90                    write!(f, "[bb{}: ", entry.block.index())?;
91                    write!(f, "{}]", entry.value)?;
92                }
93                let is_uniform = match uniformity.is_var_uniform(phi.out) {
94                    true => " @ uniform",
95                    false => "",
96                };
97                writeln!(f, ";{is_uniform}\n")?;
98            }
99            if !bb.phi_nodes.borrow().is_empty() {
100                writeln!(f)?;
101            }
102
103            for op in bb.ops.borrow_mut().values_mut() {
104                let op_fmt = op.to_string();
105                if op_fmt.is_empty() {
106                    continue;
107                }
108
109                let is_uniform = match op
110                    .out
111                    .map(|out| uniformity.is_var_uniform(out))
112                    .unwrap_or(false)
113                {
114                    true => " @ uniform",
115                    false => "",
116                };
117                writeln!(f, "    {op_fmt};{is_uniform}")?;
118            }
119            match &*bb.control_flow.borrow() {
120                ControlFlow::IfElse {
121                    cond,
122                    then,
123                    or_else,
124                    merge,
125                } => {
126                    writeln!(
127                        f,
128                        "    {cond} ? bb{} : bb{}; merge: {}",
129                        then.index(),
130                        or_else.index(),
131                        merge
132                            .as_ref()
133                            .map(|it| format!("bb{}", it.index()))
134                            .unwrap_or("None".to_string())
135                    )?;
136                }
137                super::ControlFlow::Switch {
138                    value,
139                    default,
140                    branches,
141                    ..
142                } => {
143                    write!(f, "    switch({value}) ")?;
144                    for (val, block) in branches {
145                        write!(f, "[{val}: bb{}] ", block.index())?;
146                    }
147                    writeln!(f, "[default: bb{}];", default.index())?;
148                }
149                super::ControlFlow::Loop {
150                    body,
151                    continue_target,
152                    merge,
153                } => {
154                    writeln!(
155                        f,
156                        "    loop(continue: bb{}, merge: bb{})",
157                        continue_target.index(),
158                        merge.index()
159                    )?;
160                    writeln!(f, "    branch bb{};", body.index())?
161                }
162                super::ControlFlow::LoopBreak {
163                    break_cond,
164                    body,
165                    continue_target,
166                    merge,
167                } => {
168                    writeln!(
169                        f,
170                        "    loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
171                        break_cond,
172                        body.index(),
173                        continue_target.index(),
174                        merge.index()
175                    )?;
176                }
177                super::ControlFlow::Return => writeln!(f, "    return;")?,
178                super::ControlFlow::None => {
179                    let edge = self.program.edges(node).next();
180                    let target = edge.map(|it| it.target().index()).unwrap_or(255);
181                    writeln!(f, "    branch bb{target};")?;
182                }
183            }
184            f.write_str("}\n\n")?;
185        }
186
187        Ok(())
188    }
189}
190
191impl Display for BlockSets {
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        let mut exp_gen = self.exp_gen.iter().collect::<Vec<_>>();
194        exp_gen.sort_by_key(|it| it.0);
195        let exp_gen = exp_gen
196            .into_iter()
197            .map(|(val, expr)| format!("{val}: {expr}"))
198            .collect::<Vec<_>>();
199        let mut phi_gen = self.phi_gen.iter().collect::<Vec<_>>();
200        phi_gen.sort_by_key(|it| it.0);
201        let phi_gen = phi_gen
202            .into_iter()
203            .map(|(val, expr)| format!("{val}: {expr}"))
204            .collect::<Vec<_>>();
205        let tmp_gen = self
206            .tmp_gen
207            .iter()
208            .map(|it| format!("{it}"))
209            .collect::<Vec<_>>();
210        let mut leaders = self.leaders.iter().collect::<Vec<_>>();
211        leaders.sort_by_key(|it| it.0);
212        let leaders = leaders
213            .into_iter()
214            .map(|(val, expr)| format!("{val}: {expr}"))
215            .collect::<Vec<_>>();
216        let mut antic_out = self.antic_out.iter().collect::<Vec<_>>();
217        antic_out.sort_by_key(|it| it.0);
218        let antic_out = antic_out
219            .into_iter()
220            .map(|(val, expr)| format!("{val}: {expr}"))
221            .collect::<Vec<_>>();
222        let mut antic_in = self.antic_in.iter().collect::<Vec<_>>();
223        antic_in.sort_by_key(|it| it.0);
224        let antic_in = antic_in
225            .into_iter()
226            .map(|(val, expr)| format!("{val}: {expr}"))
227            .collect::<Vec<_>>();
228
229        writeln!(f, "    exp_gen: [{}]", exp_gen.join(", "))?;
230        writeln!(f, "    phi_gen: [{}]", phi_gen.join(", "))?;
231        writeln!(f, "    tmp_gen: [{}]", tmp_gen.join(", "))?;
232        writeln!(f, "    leaders: [{}]", leaders.join(", "))?;
233        writeln!(f, "    antic_in: [{}]", antic_in.join(", "))?;
234        writeln!(f, "    antic_out: [{}]", antic_out.join(", "))
235    }
236}
237
238impl Display for ValueTable {
239    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        let mut values = self.value_numbers.iter().collect::<Vec<_>>();
241        values.sort_by_key(|it| it.1);
242        writeln!(f, "values: [")?;
243        for (val, num) in values {
244            writeln!(f, "    {num}: {val},")?;
245        }
246        writeln!(f, "]")?;
247        writeln!(f, "expressions: [")?;
248        let mut expressions = self.expression_numbers.iter().collect::<Vec<_>>();
249        expressions.sort_by_key(|it| it.1);
250        for (expr, val) in expressions {
251            writeln!(f, "    {val}: {expr},")?;
252        }
253        writeln!(f, "]")
254    }
255}
256
257impl Display for Value {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        match self {
260            Value::Constant(constant, _) => write!(f, "{constant}"),
261            Value::Local(local) => write!(f, "{local}"),
262            Value::Input(id, _) => write!(f, "input({id})"),
263            Value::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"),
264            Value::ConstArray(id, _, _, _) => write!(f, "const_array({id})"),
265            Value::Builtin(builtin, _) => write!(f, "{builtin:?}"),
266            Value::Output(id, _) => write!(f, "output({id})"),
267        }
268    }
269}
270
271impl Display for Local {
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        match self.version {
274            0 => write!(f, "binding({})", self.id),
275            v => write!(f, "local({}).v{v}", self.id),
276        }
277    }
278}
279
280impl Display for Expression {
281    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282        match self {
283            Expression::Instruction(instruction) => write!(f, "{instruction}"),
284            Expression::Copy(val, _) => write!(f, "copy({val})"),
285            Expression::Value(value) => write!(f, "{value}"),
286            Expression::Volatile(value) => write!(f, "volatile({value})"),
287            Expression::Phi(entries) => write!(
288                f,
289                "phi({})",
290                entries
291                    .iter()
292                    .map(|(val, b)| format!("{val}: bb{}", b.index()))
293                    .collect::<Vec<_>>()
294                    .join(", ")
295            ),
296        }
297    }
298}
299
300impl Display for Instruction {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        write!(f, "{:?}: [{:?}]", self.op, self.args)
303    }
304}
305
306impl Display for BasicBlock {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        for phi in self.phi_nodes.borrow().iter() {
309            write!(f, "    {} = phi ", phi.out)?;
310            for entry in &phi.entries {
311                write!(f, "[bb{}: ", entry.block.index())?;
312                write!(f, "{}]", entry.value)?;
313            }
314            writeln!(f, ";\n")?;
315        }
316        if !self.phi_nodes.borrow().is_empty() {
317            writeln!(f)?;
318        }
319
320        for op in self.ops.borrow_mut().values_mut() {
321            let op_fmt = op.to_string();
322            if op_fmt.is_empty() {
323                continue;
324            }
325
326            writeln!(f, "    {op_fmt};")?;
327        }
328        match &*self.control_flow.borrow() {
329            ControlFlow::IfElse {
330                cond,
331                then,
332                or_else,
333                merge,
334            } => {
335                writeln!(
336                    f,
337                    "    {cond} ? bb{} : bb{}; merge: {}",
338                    then.index(),
339                    or_else.index(),
340                    merge
341                        .as_ref()
342                        .map(|it| format!("bb{}", it.index()))
343                        .unwrap_or("None".to_string())
344                )?;
345            }
346            super::ControlFlow::Switch {
347                value,
348                default,
349                branches,
350                ..
351            } => {
352                write!(f, "    switch({value}) ")?;
353                for (val, block) in branches {
354                    write!(f, "[{val}: bb{}] ", block.index())?;
355                }
356                writeln!(f, "[default: bb{}];", default.index())?;
357            }
358            super::ControlFlow::Loop {
359                body,
360                continue_target,
361                merge,
362            } => {
363                writeln!(
364                    f,
365                    "    loop(continue: bb{}, merge: bb{})",
366                    continue_target.index(),
367                    merge.index()
368                )?;
369                writeln!(f, "    branch bb{};", body.index())?
370            }
371            super::ControlFlow::LoopBreak {
372                break_cond,
373                body,
374                continue_target,
375                merge,
376            } => {
377                writeln!(
378                    f,
379                    "    loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
380                    break_cond,
381                    body.index(),
382                    continue_target.index(),
383                    merge.index()
384                )?;
385            }
386            super::ControlFlow::Return => writeln!(f, "    return;")?,
387            super::ControlFlow::None => {
388                writeln!(f, "    branch;")?;
389            }
390        }
391        Ok(())
392    }
393}
394
395impl Display for SmemAllocation {
396    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397        match self.smem {
398            crate::SharedMemory::Array {
399                id,
400                length,
401                ty,
402                align,
403            } => {
404                write!(
405                    f,
406                    "shared_array(id: {id}, offset: {}, length: {length}, align: {align}, ty: {ty})",
407                    self.offset,
408                )
409            }
410            crate::SharedMemory::Value { id, ty, align } => {
411                write!(
412                    f,
413                    "shared(id: {id}, offset: {}, align: {align}, ty: {ty})",
414                    self.offset,
415                )
416            }
417        }
418    }
419}