1use std::{fmt::Display, rc::Rc};
2
3use cubecl_ir::{FloatKind, IntKind, UIntKind};
4use petgraph::visit::EdgeRef;
5
6use crate::{
7 BasicBlock, ControlFlow,
8 analyses::{liveness::Liveness, uniformity::Uniformity},
9 gvn::{BlockSets, Constant, Expression, GlobalValues, Instruction, Local, Value, ValueTable},
10};
11
12use super::Optimizer;
13
14const DEBUG_GVN: bool = false;
15
16impl Display for Optimizer {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 let global_nums = self
20 .analysis_cache
21 .try_get::<GlobalValues>()
22 .unwrap_or_default();
23 let liveness = self
24 .analysis_cache
25 .try_get::<Liveness>()
26 .unwrap_or_else(|| Rc::new(Liveness::empty(self)));
27 let uniformity = self
28 .analysis_cache
29 .try_get::<Uniformity>()
30 .unwrap_or_default();
31
32 if DEBUG_GVN {
33 writeln!(f, "# Value Table:")?;
34 writeln!(f, "{}", global_nums.borrow().values)?;
35 }
36
37 for node in self.program.node_indices() {
38 let id = node.index();
39 let bb = &self.program[node];
40 let uniform = match uniformity.is_block_uniform(node) {
41 true => "uniform ",
42 false => "",
43 };
44 writeln!(f, "{uniform}bb{id} {{")?;
45 if DEBUG_GVN {
46 let block_sets = &global_nums
47 .borrow()
48 .block_sets
49 .get(&node)
50 .cloned()
51 .unwrap_or_default();
52 writeln!(f, "{block_sets}")?;
53 }
54
55 if !bb.block_use.is_empty() {
56 writeln!(f, " Uses: {:?}", bb.block_use)?;
57 }
58 let live_vars = liveness.at_block(node).iter();
59 let live_vars = live_vars.map(|it| format!("local({it})"));
60 let live_vars = live_vars.collect::<Vec<_>>();
61 writeln!(f, " Live variables: [{}]\n", live_vars.join(", "))?;
62
63 for phi in bb.phi_nodes.borrow().iter() {
64 write!(f, " {} = phi ", phi.out)?;
65 for entry in &phi.entries {
66 write!(f, "[bb{}: ", entry.block.index())?;
67 write!(f, "{}]", entry.value)?;
68 }
69 let is_uniform = match uniformity.is_var_uniform(phi.out) {
70 true => " @ uniform",
71 false => "",
72 };
73 writeln!(f, ";{is_uniform}\n")?;
74 }
75 if !bb.phi_nodes.borrow().is_empty() {
76 writeln!(f)?;
77 }
78
79 for op in bb.ops.borrow_mut().values_mut() {
80 let op_fmt = op.to_string();
81 if op_fmt.is_empty() {
82 continue;
83 }
84
85 let is_uniform = match op
86 .out
87 .map(|out| uniformity.is_var_uniform(out))
88 .unwrap_or(false)
89 {
90 true => " @ uniform",
91 false => "",
92 };
93 writeln!(f, " {op_fmt};{is_uniform}")?;
94 }
95 match &*bb.control_flow.borrow() {
96 ControlFlow::IfElse {
97 cond,
98 then,
99 or_else,
100 merge,
101 } => {
102 writeln!(
103 f,
104 " {cond} ? bb{} : bb{}; merge: {}",
105 then.index(),
106 or_else.index(),
107 merge
108 .as_ref()
109 .map(|it| format!("bb{}", it.index()))
110 .unwrap_or("None".to_string())
111 )?;
112 }
113 super::ControlFlow::Switch {
114 value,
115 default,
116 branches,
117 ..
118 } => {
119 write!(f, " switch({value}) ")?;
120 for (val, block) in branches {
121 write!(f, "[{val}: bb{}] ", block.index())?;
122 }
123 writeln!(f, "[default: bb{}];", default.index())?;
124 }
125 super::ControlFlow::Loop {
126 body,
127 continue_target,
128 merge,
129 } => {
130 writeln!(
131 f,
132 " loop(continue: bb{}, merge: bb{})",
133 continue_target.index(),
134 merge.index()
135 )?;
136 writeln!(f, " branch bb{};", body.index())?
137 }
138 super::ControlFlow::LoopBreak {
139 break_cond,
140 body,
141 continue_target,
142 merge,
143 } => {
144 writeln!(
145 f,
146 " loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
147 break_cond,
148 body.index(),
149 continue_target.index(),
150 merge.index()
151 )?;
152 }
153 super::ControlFlow::Return => writeln!(f, " return;")?,
154 super::ControlFlow::None => {
155 let edge = self.program.edges(node).next();
156 let target = edge.map(|it| it.target().index()).unwrap_or(255);
157 writeln!(f, " branch bb{target};")?;
158 }
159 }
160 f.write_str("}\n\n")?;
161 }
162
163 Ok(())
164 }
165}
166
167impl Display for BlockSets {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 let mut exp_gen = self.exp_gen.iter().collect::<Vec<_>>();
170 exp_gen.sort_by_key(|it| it.0);
171 let exp_gen = exp_gen
172 .into_iter()
173 .map(|(val, expr)| format!("{val}: {expr}"))
174 .collect::<Vec<_>>();
175 let mut phi_gen = self.phi_gen.iter().collect::<Vec<_>>();
176 phi_gen.sort_by_key(|it| it.0);
177 let phi_gen = phi_gen
178 .into_iter()
179 .map(|(val, expr)| format!("{val}: {expr}"))
180 .collect::<Vec<_>>();
181 let tmp_gen = self
182 .tmp_gen
183 .iter()
184 .map(|it| format!("{it}"))
185 .collect::<Vec<_>>();
186 let mut leaders = self.leaders.iter().collect::<Vec<_>>();
187 leaders.sort_by_key(|it| it.0);
188 let leaders = leaders
189 .into_iter()
190 .map(|(val, expr)| format!("{val}: {expr}"))
191 .collect::<Vec<_>>();
192 let mut antic_out = self.antic_out.iter().collect::<Vec<_>>();
193 antic_out.sort_by_key(|it| it.0);
194 let antic_out = antic_out
195 .into_iter()
196 .map(|(val, expr)| format!("{val}: {expr}"))
197 .collect::<Vec<_>>();
198 let mut antic_in = self.antic_in.iter().collect::<Vec<_>>();
199 antic_in.sort_by_key(|it| it.0);
200 let antic_in = antic_in
201 .into_iter()
202 .map(|(val, expr)| format!("{val}: {expr}"))
203 .collect::<Vec<_>>();
204
205 writeln!(f, " exp_gen: [{}]", exp_gen.join(", "))?;
206 writeln!(f, " phi_gen: [{}]", phi_gen.join(", "))?;
207 writeln!(f, " tmp_gen: [{}]", tmp_gen.join(", "))?;
208 writeln!(f, " leaders: [{}]", leaders.join(", "))?;
209 writeln!(f, " antic_in: [{}]", antic_in.join(", "))?;
210 writeln!(f, " antic_out: [{}]", antic_out.join(", "))
211 }
212}
213
214impl Display for ValueTable {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 let mut values = self.value_numbers.iter().collect::<Vec<_>>();
217 values.sort_by_key(|it| it.1);
218 writeln!(f, "values: [")?;
219 for (val, num) in values {
220 writeln!(f, " {num}: {val},")?;
221 }
222 writeln!(f, "]")?;
223 writeln!(f, "expressions: [")?;
224 let mut expressions = self.expression_numbers.iter().collect::<Vec<_>>();
225 expressions.sort_by_key(|it| it.1);
226 for (expr, val) in expressions {
227 writeln!(f, " {val}: {expr},")?;
228 }
229 writeln!(f, "]")
230 }
231}
232
233impl Display for Value {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 match self {
236 Value::Constant(constant) => write!(f, "{constant}"),
237 Value::Local(local) => write!(f, "{local}"),
238 Value::Input(id, _) => write!(f, "input({id})"),
239 Value::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"),
240 Value::ConstArray(id, _, _) => write!(f, "const_array({id})"),
241 Value::Builtin(builtin) => write!(f, "{builtin:?}"),
242 Value::Output(id, _) => write!(f, "output({id})"),
243 }
244 }
245}
246
247impl Display for Local {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 match self.version {
250 0 => write!(f, "binding({})", self.id),
251 v => write!(f, "local({}).v{v}", self.id),
252 }
253 }
254}
255
256impl Display for Constant {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 match self {
259 Constant::Int(val, IntKind::I8) => write!(f, "{val}i8"),
260 Constant::Int(val, IntKind::I16) => write!(f, "{val}i16"),
261 Constant::Int(val, IntKind::I32) => write!(f, "{val}i32"),
262 Constant::Int(val, IntKind::I64) => write!(f, "{val}i64"),
263 Constant::Float(val, FloatKind::E2M1) => write!(f, "{}e2m1", val.0),
264 Constant::Float(val, FloatKind::E2M3) => write!(f, "{}e2m3", val.0),
265 Constant::Float(val, FloatKind::E3M2) => write!(f, "{}e3m2", val.0),
266 Constant::Float(val, FloatKind::E4M3) => write!(f, "{}e4m3", val.0),
267 Constant::Float(val, FloatKind::E5M2) => write!(f, "{}e5m2", val.0),
268 Constant::Float(val, FloatKind::UE8M0) => write!(f, "{}ue8m0", val.0),
269 Constant::Float(val, FloatKind::BF16) => write!(f, "{}bf16", val.0),
270 Constant::Float(val, FloatKind::F16) => write!(f, "{}f16", val.0),
271 Constant::Float(val, FloatKind::Flex32) => write!(f, "{}minf16", val.0),
272 Constant::Float(val, FloatKind::TF32) => write!(f, "{}tf32", val.0),
273 Constant::Float(val, FloatKind::F32) => write!(f, "{}f32", val.0),
274 Constant::Float(val, FloatKind::F64) => write!(f, "{}f64", val.0),
275 Constant::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
276 Constant::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
277 Constant::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
278 Constant::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
279 Constant::Bool(val) => write!(f, "{val}"),
280 }
281 }
282}
283
284impl Display for Expression {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 match self {
287 Expression::Instruction(instruction) => write!(f, "{instruction}"),
288 Expression::Copy(val, _) => write!(f, "copy({val})"),
289 Expression::Value(value) => write!(f, "{value}"),
290 Expression::Volatile(value) => write!(f, "volatile({value})"),
291 Expression::Phi(entries) => write!(
292 f,
293 "phi({})",
294 entries
295 .iter()
296 .map(|(val, b)| format!("{val}: bb{}", b.index()))
297 .collect::<Vec<_>>()
298 .join(", ")
299 ),
300 }
301 }
302}
303
304impl Display for Instruction {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 write!(f, "{:?}: [{:?}]", self.op, self.args)
307 }
308}
309
310impl Display for BasicBlock {
311 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312 for phi in self.phi_nodes.borrow().iter() {
313 write!(f, " {} = phi ", phi.out)?;
314 for entry in &phi.entries {
315 write!(f, "[bb{}: ", entry.block.index())?;
316 write!(f, "{}]", entry.value)?;
317 }
318 writeln!(f, ";\n")?;
319 }
320 if !self.phi_nodes.borrow().is_empty() {
321 writeln!(f)?;
322 }
323
324 for op in self.ops.borrow_mut().values_mut() {
325 let op_fmt = op.to_string();
326 if op_fmt.is_empty() {
327 continue;
328 }
329
330 writeln!(f, " {op_fmt};")?;
331 }
332 match &*self.control_flow.borrow() {
333 ControlFlow::IfElse {
334 cond,
335 then,
336 or_else,
337 merge,
338 } => {
339 writeln!(
340 f,
341 " {cond} ? bb{} : bb{}; merge: {}",
342 then.index(),
343 or_else.index(),
344 merge
345 .as_ref()
346 .map(|it| format!("bb{}", it.index()))
347 .unwrap_or("None".to_string())
348 )?;
349 }
350 super::ControlFlow::Switch {
351 value,
352 default,
353 branches,
354 ..
355 } => {
356 write!(f, " switch({value}) ")?;
357 for (val, block) in branches {
358 write!(f, "[{val}: bb{}] ", block.index())?;
359 }
360 writeln!(f, "[default: bb{}];", default.index())?;
361 }
362 super::ControlFlow::Loop {
363 body,
364 continue_target,
365 merge,
366 } => {
367 writeln!(
368 f,
369 " loop(continue: bb{}, merge: bb{})",
370 continue_target.index(),
371 merge.index()
372 )?;
373 writeln!(f, " branch bb{};", body.index())?
374 }
375 super::ControlFlow::LoopBreak {
376 break_cond,
377 body,
378 continue_target,
379 merge,
380 } => {
381 writeln!(
382 f,
383 " loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
384 break_cond,
385 body.index(),
386 continue_target.index(),
387 merge.index()
388 )?;
389 }
390 super::ControlFlow::Return => writeln!(f, " return;")?,
391 super::ControlFlow::None => {
392 writeln!(f, " branch;")?;
393 }
394 }
395 Ok(())
396 }
397}