1use std::{fmt::Display, rc::Rc};
2
3use cubecl_ir::{FloatKind, IntKind, UIntKind};
4use petgraph::visit::EdgeRef;
5
6use crate::{
7 ControlFlow,
8 analyses::{const_len::Slices, liveness::Liveness, uniformity::Uniformity},
9 gvn::{BlockSets, Constant, Expression, GlobalValues, Instruction, Local, Value, ValueTable},
10};
11
12use super::Optimizer;
13
14const DEBUG_GVN: bool = false;
15
16impl Display for Optimizer {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 let slices = self.analysis_cache.try_get::<Slices>().unwrap_or_default();
20
21 f.write_str("Slices:\n")?;
22 for (var_id, slice) in slices.iter() {
23 let end_op = slice.end_op.as_ref().map(|it| format!("{it}"));
24 writeln!(
25 f,
26 "slice{var_id:?}: {{ start: {}, end: {}, end_op: {}, const_len: {:?} }}",
27 slice.start,
28 slice.end,
29 end_op.unwrap_or("None".to_string()),
30 slice.const_len
31 )?;
32 }
33 f.write_str("\n\n")?;
34
35 let global_nums = self
36 .analysis_cache
37 .try_get::<GlobalValues>()
38 .unwrap_or_default();
39 let liveness = self
40 .analysis_cache
41 .try_get::<Liveness>()
42 .unwrap_or_else(|| Rc::new(Liveness::empty(self)));
43 let uniformity = self
44 .analysis_cache
45 .try_get::<Uniformity>()
46 .unwrap_or_default();
47
48 if DEBUG_GVN {
49 writeln!(f, "# Value Table:")?;
50 writeln!(f, "{}", global_nums.borrow().values)?;
51 }
52
53 for node in self.program.node_indices() {
54 let id = node.index();
55 let bb = &self.program[node];
56 let uniform = match uniformity.is_block_uniform(node) {
57 true => "uniform ",
58 false => "",
59 };
60 writeln!(f, "{uniform}bb{id} {{")?;
61 if DEBUG_GVN {
62 let block_sets = &global_nums
63 .borrow()
64 .block_sets
65 .get(&node)
66 .cloned()
67 .unwrap_or_default();
68 writeln!(f, "{block_sets}")?;
69 }
70
71 if !bb.block_use.is_empty() {
72 writeln!(f, " Uses: {:?}", bb.block_use)?;
73 }
74 let live_vars = liveness.at_block(node).iter();
75 let live_vars = live_vars.map(|it| format!("local({})", it));
76 let live_vars = live_vars.collect::<Vec<_>>();
77 writeln!(f, " Live variables: [{}]\n", live_vars.join(", "))?;
78
79 for phi in bb.phi_nodes.borrow().iter() {
80 write!(f, " {} = phi ", phi.out)?;
81 for entry in &phi.entries {
82 write!(f, "[bb{}: ", entry.block.index())?;
83 write!(f, "{}]", entry.value)?;
84 }
85 let is_uniform = match uniformity.is_var_uniform(phi.out) {
86 true => " @ uniform",
87 false => "",
88 };
89 writeln!(f, ";{is_uniform}\n")?;
90 }
91 if !bb.phi_nodes.borrow().is_empty() {
92 writeln!(f)?;
93 }
94
95 for op in bb.ops.borrow_mut().values_mut() {
96 let op_fmt = op.to_string();
97 if op_fmt.is_empty() {
98 continue;
99 }
100
101 let is_uniform = match op
102 .out
103 .map(|out| uniformity.is_var_uniform(out))
104 .unwrap_or(false)
105 {
106 true => " @ uniform",
107 false => "",
108 };
109 writeln!(f, " {op_fmt};{is_uniform}")?;
110 }
111 match &*bb.control_flow.borrow() {
112 ControlFlow::IfElse {
113 cond,
114 then,
115 or_else,
116 merge,
117 } => {
118 writeln!(
119 f,
120 " {cond} ? bb{} : bb{}; merge: {}",
121 then.index(),
122 or_else.index(),
123 merge
124 .as_ref()
125 .map(|it| format!("bb{}", it.index()))
126 .unwrap_or("None".to_string())
127 )?;
128 }
129 super::ControlFlow::Switch {
130 value,
131 default,
132 branches,
133 ..
134 } => {
135 write!(f, " switch({value}) ")?;
136 for (val, block) in branches {
137 write!(f, "[{val}: bb{}] ", block.index())?;
138 }
139 writeln!(f, "[default: bb{}];", default.index())?;
140 }
141 super::ControlFlow::Loop {
142 body,
143 continue_target,
144 merge,
145 } => {
146 writeln!(
147 f,
148 " loop(continue: bb{}, merge: bb{})",
149 continue_target.index(),
150 merge.index()
151 )?;
152 writeln!(f, " branch bb{};", body.index())?
153 }
154 super::ControlFlow::LoopBreak {
155 break_cond,
156 body,
157 continue_target,
158 merge,
159 } => {
160 writeln!(
161 f,
162 " loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
163 break_cond,
164 body.index(),
165 continue_target.index(),
166 merge.index()
167 )?;
168 }
169 super::ControlFlow::Return => writeln!(f, " return;")?,
170 super::ControlFlow::None => {
171 let edge = self.program.edges(node).next();
172 let target = edge.map(|it| it.target().index()).unwrap_or(255);
173 writeln!(f, " branch bb{target};")?;
174 }
175 }
176 f.write_str("}\n\n")?;
177 }
178
179 Ok(())
180 }
181}
182
183impl Display for BlockSets {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 let mut exp_gen = self.exp_gen.iter().collect::<Vec<_>>();
186 exp_gen.sort_by_key(|it| it.0);
187 let exp_gen = exp_gen
188 .into_iter()
189 .map(|(val, expr)| format!("{val}: {expr}"))
190 .collect::<Vec<_>>();
191 let mut phi_gen = self.phi_gen.iter().collect::<Vec<_>>();
192 phi_gen.sort_by_key(|it| it.0);
193 let phi_gen = phi_gen
194 .into_iter()
195 .map(|(val, expr)| format!("{val}: {expr}"))
196 .collect::<Vec<_>>();
197 let tmp_gen = self
198 .tmp_gen
199 .iter()
200 .map(|it| format!("{it}"))
201 .collect::<Vec<_>>();
202 let mut leaders = self.leaders.iter().collect::<Vec<_>>();
203 leaders.sort_by_key(|it| it.0);
204 let leaders = leaders
205 .into_iter()
206 .map(|(val, expr)| format!("{val}: {expr}"))
207 .collect::<Vec<_>>();
208 let mut antic_out = self.antic_out.iter().collect::<Vec<_>>();
209 antic_out.sort_by_key(|it| it.0);
210 let antic_out = antic_out
211 .into_iter()
212 .map(|(val, expr)| format!("{val}: {expr}"))
213 .collect::<Vec<_>>();
214 let mut antic_in = self.antic_in.iter().collect::<Vec<_>>();
215 antic_in.sort_by_key(|it| it.0);
216 let antic_in = antic_in
217 .into_iter()
218 .map(|(val, expr)| format!("{val}: {expr}"))
219 .collect::<Vec<_>>();
220
221 writeln!(f, " exp_gen: [{}]", exp_gen.join(", "))?;
222 writeln!(f, " phi_gen: [{}]", phi_gen.join(", "))?;
223 writeln!(f, " tmp_gen: [{}]", tmp_gen.join(", "))?;
224 writeln!(f, " leaders: [{}]", leaders.join(", "))?;
225 writeln!(f, " antic_in: [{}]", antic_in.join(", "))?;
226 writeln!(f, " antic_out: [{}]", antic_out.join(", "))
227 }
228}
229
230impl Display for ValueTable {
231 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232 let mut values = self.value_numbers.iter().collect::<Vec<_>>();
233 values.sort_by_key(|it| it.1);
234 writeln!(f, "values: [")?;
235 for (val, num) in values {
236 writeln!(f, " {num}: {val},")?;
237 }
238 writeln!(f, "]")?;
239 writeln!(f, "expressions: [")?;
240 let mut expressions = self.expression_numbers.iter().collect::<Vec<_>>();
241 expressions.sort_by_key(|it| it.1);
242 for (expr, val) in expressions {
243 writeln!(f, " {val}: {expr},")?;
244 }
245 writeln!(f, "]")
246 }
247}
248
249impl Display for Value {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 match self {
252 Value::Constant(constant) => write!(f, "{constant}"),
253 Value::Local(local) => write!(f, "{local}"),
254 Value::Input(id, _) => write!(f, "input({id})"),
255 Value::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"),
256 Value::ConstArray(id, _, _) => write!(f, "const_array({id})"),
257 Value::Builtin(builtin) => write!(f, "{builtin:?}"),
258 Value::Output(id, _) => write!(f, "output({id})"),
259 Value::Slice(id, _) => write!(f, "slice({id})"),
260 }
261 }
262}
263
264impl Display for Local {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 match self.version {
267 0 => write!(f, "binding({})", self.id),
268 v => write!(f, "local({}).v{v}", self.id),
269 }
270 }
271}
272
273impl Display for Constant {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 match self {
276 Constant::Int(val, IntKind::I8) => write!(f, "{val}i8"),
277 Constant::Int(val, IntKind::I16) => write!(f, "{val}i16"),
278 Constant::Int(val, IntKind::I32) => write!(f, "{val}i32"),
279 Constant::Int(val, IntKind::I64) => write!(f, "{val}i64"),
280 Constant::Float(val, FloatKind::BF16) => write!(f, "{}bf16", val.0),
281 Constant::Float(val, FloatKind::F16) => write!(f, "{}f16", val.0),
282 Constant::Float(val, FloatKind::Flex32) => write!(f, "{}minf16", val.0),
283 Constant::Float(val, FloatKind::TF32) => write!(f, "{}tf32", val.0),
284 Constant::Float(val, FloatKind::F32) => write!(f, "{}f32", val.0),
285 Constant::Float(val, FloatKind::F64) => write!(f, "{}f64", val.0),
286 Constant::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
287 Constant::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
288 Constant::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
289 Constant::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
290 Constant::Bool(val) => write!(f, "{val}"),
291 }
292 }
293}
294
295impl Display for Expression {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 match self {
298 Expression::Instruction(instruction) => write!(f, "{instruction}"),
299 Expression::Copy(val, _) => write!(f, "copy({val})"),
300 Expression::Value(value) => write!(f, "{value}"),
301 Expression::Volatile(value) => write!(f, "volatile({value})"),
302 Expression::Phi(entries) => write!(
303 f,
304 "phi({})",
305 entries
306 .iter()
307 .map(|(val, b)| format!("{val}: bb{}", b.index()))
308 .collect::<Vec<_>>()
309 .join(", ")
310 ),
311 }
312 }
313}
314
315impl Display for Instruction {
316 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317 write!(f, "{:?}: [{:?}]", self.op, self.args)
318 }
319}