1use std::{fmt::Display, rc::Rc};
2
3use cubecl_ir::{FloatKind, IntKind, UIntKind};
4use petgraph::visit::EdgeRef;
5
6use crate::{
7 BasicBlock, ControlFlow,
8 analyses::{
9 liveness::{
10 Liveness,
11 shared::{SharedLiveness, SmemAllocation},
12 },
13 uniformity::Uniformity,
14 },
15 gvn::{BlockSets, Constant, Expression, GlobalValues, Instruction, Local, Value, ValueTable},
16};
17
18use super::Optimizer;
19
20const DEBUG_GVN: bool = false;
21
22impl Display for Optimizer {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 let global_nums = self
26 .analysis_cache
27 .try_get::<GlobalValues>()
28 .unwrap_or_default();
29 let liveness = self
30 .analysis_cache
31 .try_get::<Liveness>()
32 .unwrap_or_else(|| Rc::new(Liveness::empty(self)));
33 let shared_liveness = self
34 .analysis_cache
35 .try_get::<SharedLiveness>()
36 .unwrap_or_else(|| Rc::new(SharedLiveness::empty(self)));
37 let uniformity = self
38 .analysis_cache
39 .try_get::<Uniformity>()
40 .unwrap_or_default();
41
42 if DEBUG_GVN {
43 writeln!(f, "# Value Table:")?;
44 writeln!(f, "{}", global_nums.borrow().values)?;
45 }
46
47 let smems = shared_liveness
48 .allocations
49 .values()
50 .map(|it| format!(" {it}"));
51 let smems = smems.collect::<Vec<_>>().join(",\n");
52 writeln!(f, "Shared memories: [\n{smems}\n]\n")?;
53
54 for node in self.program.node_indices() {
55 let id = node.index();
56 let bb = &self.program[node];
57 let uniform = match uniformity.is_block_uniform(node) {
58 true => "uniform ",
59 false => "",
60 };
61 writeln!(f, "{uniform}bb{id} {{")?;
62 if DEBUG_GVN {
63 let block_sets = &global_nums
64 .borrow()
65 .block_sets
66 .get(&node)
67 .cloned()
68 .unwrap_or_default();
69 writeln!(f, "{block_sets}")?;
70 }
71
72 if !bb.block_use.is_empty() {
73 writeln!(f, " Uses: {:?}", bb.block_use)?;
74 }
75 let live_vars = liveness.at_block(node).iter();
76 let live_vars = live_vars.map(|it| format!("local({it})"));
77 let live_vars = live_vars.collect::<Vec<_>>();
78 writeln!(f, " Live variables: [{}]\n", live_vars.join(", "))?;
79 let live_shared = shared_liveness.at_block(node).iter();
80 let live_shared = live_shared.map(|it| format!("shared({it})"));
81 let live_shared = live_shared.collect::<Vec<_>>();
82 writeln!(
83 f,
84 " Live shared memories: [{}]\n",
85 live_shared.join(", ")
86 )?;
87
88 for phi in bb.phi_nodes.borrow().iter() {
89 write!(f, " {} = phi ", phi.out)?;
90 for entry in &phi.entries {
91 write!(f, "[bb{}: ", entry.block.index())?;
92 write!(f, "{}]", entry.value)?;
93 }
94 let is_uniform = match uniformity.is_var_uniform(phi.out) {
95 true => " @ uniform",
96 false => "",
97 };
98 writeln!(f, ";{is_uniform}\n")?;
99 }
100 if !bb.phi_nodes.borrow().is_empty() {
101 writeln!(f)?;
102 }
103
104 for op in bb.ops.borrow_mut().values_mut() {
105 let op_fmt = op.to_string();
106 if op_fmt.is_empty() {
107 continue;
108 }
109
110 let is_uniform = match op
111 .out
112 .map(|out| uniformity.is_var_uniform(out))
113 .unwrap_or(false)
114 {
115 true => " @ uniform",
116 false => "",
117 };
118 writeln!(f, " {op_fmt};{is_uniform}")?;
119 }
120 match &*bb.control_flow.borrow() {
121 ControlFlow::IfElse {
122 cond,
123 then,
124 or_else,
125 merge,
126 } => {
127 writeln!(
128 f,
129 " {cond} ? bb{} : bb{}; merge: {}",
130 then.index(),
131 or_else.index(),
132 merge
133 .as_ref()
134 .map(|it| format!("bb{}", it.index()))
135 .unwrap_or("None".to_string())
136 )?;
137 }
138 super::ControlFlow::Switch {
139 value,
140 default,
141 branches,
142 ..
143 } => {
144 write!(f, " switch({value}) ")?;
145 for (val, block) in branches {
146 write!(f, "[{val}: bb{}] ", block.index())?;
147 }
148 writeln!(f, "[default: bb{}];", default.index())?;
149 }
150 super::ControlFlow::Loop {
151 body,
152 continue_target,
153 merge,
154 } => {
155 writeln!(
156 f,
157 " loop(continue: bb{}, merge: bb{})",
158 continue_target.index(),
159 merge.index()
160 )?;
161 writeln!(f, " branch bb{};", body.index())?
162 }
163 super::ControlFlow::LoopBreak {
164 break_cond,
165 body,
166 continue_target,
167 merge,
168 } => {
169 writeln!(
170 f,
171 " loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
172 break_cond,
173 body.index(),
174 continue_target.index(),
175 merge.index()
176 )?;
177 }
178 super::ControlFlow::Return => writeln!(f, " return;")?,
179 super::ControlFlow::None => {
180 let edge = self.program.edges(node).next();
181 let target = edge.map(|it| it.target().index()).unwrap_or(255);
182 writeln!(f, " branch bb{target};")?;
183 }
184 }
185 f.write_str("}\n\n")?;
186 }
187
188 Ok(())
189 }
190}
191
192impl Display for BlockSets {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 let mut exp_gen = self.exp_gen.iter().collect::<Vec<_>>();
195 exp_gen.sort_by_key(|it| it.0);
196 let exp_gen = exp_gen
197 .into_iter()
198 .map(|(val, expr)| format!("{val}: {expr}"))
199 .collect::<Vec<_>>();
200 let mut phi_gen = self.phi_gen.iter().collect::<Vec<_>>();
201 phi_gen.sort_by_key(|it| it.0);
202 let phi_gen = phi_gen
203 .into_iter()
204 .map(|(val, expr)| format!("{val}: {expr}"))
205 .collect::<Vec<_>>();
206 let tmp_gen = self
207 .tmp_gen
208 .iter()
209 .map(|it| format!("{it}"))
210 .collect::<Vec<_>>();
211 let mut leaders = self.leaders.iter().collect::<Vec<_>>();
212 leaders.sort_by_key(|it| it.0);
213 let leaders = leaders
214 .into_iter()
215 .map(|(val, expr)| format!("{val}: {expr}"))
216 .collect::<Vec<_>>();
217 let mut antic_out = self.antic_out.iter().collect::<Vec<_>>();
218 antic_out.sort_by_key(|it| it.0);
219 let antic_out = antic_out
220 .into_iter()
221 .map(|(val, expr)| format!("{val}: {expr}"))
222 .collect::<Vec<_>>();
223 let mut antic_in = self.antic_in.iter().collect::<Vec<_>>();
224 antic_in.sort_by_key(|it| it.0);
225 let antic_in = antic_in
226 .into_iter()
227 .map(|(val, expr)| format!("{val}: {expr}"))
228 .collect::<Vec<_>>();
229
230 writeln!(f, " exp_gen: [{}]", exp_gen.join(", "))?;
231 writeln!(f, " phi_gen: [{}]", phi_gen.join(", "))?;
232 writeln!(f, " tmp_gen: [{}]", tmp_gen.join(", "))?;
233 writeln!(f, " leaders: [{}]", leaders.join(", "))?;
234 writeln!(f, " antic_in: [{}]", antic_in.join(", "))?;
235 writeln!(f, " antic_out: [{}]", antic_out.join(", "))
236 }
237}
238
239impl Display for ValueTable {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 let mut values = self.value_numbers.iter().collect::<Vec<_>>();
242 values.sort_by_key(|it| it.1);
243 writeln!(f, "values: [")?;
244 for (val, num) in values {
245 writeln!(f, " {num}: {val},")?;
246 }
247 writeln!(f, "]")?;
248 writeln!(f, "expressions: [")?;
249 let mut expressions = self.expression_numbers.iter().collect::<Vec<_>>();
250 expressions.sort_by_key(|it| it.1);
251 for (expr, val) in expressions {
252 writeln!(f, " {val}: {expr},")?;
253 }
254 writeln!(f, "]")
255 }
256}
257
258impl Display for Value {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 match self {
261 Value::Constant(constant) => write!(f, "{constant}"),
262 Value::Local(local) => write!(f, "{local}"),
263 Value::Input(id, _) => write!(f, "input({id})"),
264 Value::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"),
265 Value::ConstArray(id, _, _, _) => write!(f, "const_array({id})"),
266 Value::Builtin(builtin) => write!(f, "{builtin:?}"),
267 Value::Output(id, _) => write!(f, "output({id})"),
268 }
269 }
270}
271
272impl Display for Local {
273 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274 match self.version {
275 0 => write!(f, "binding({})", self.id),
276 v => write!(f, "local({}).v{v}", self.id),
277 }
278 }
279}
280
281impl Display for Constant {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 match self {
284 Constant::Int(val, IntKind::I8) => write!(f, "{val}i8"),
285 Constant::Int(val, IntKind::I16) => write!(f, "{val}i16"),
286 Constant::Int(val, IntKind::I32) => write!(f, "{val}i32"),
287 Constant::Int(val, IntKind::I64) => write!(f, "{val}i64"),
288 Constant::Float(val, FloatKind::E2M1) => write!(f, "{}e2m1", val.0),
289 Constant::Float(val, FloatKind::E2M3) => write!(f, "{}e2m3", val.0),
290 Constant::Float(val, FloatKind::E3M2) => write!(f, "{}e3m2", val.0),
291 Constant::Float(val, FloatKind::E4M3) => write!(f, "{}e4m3", val.0),
292 Constant::Float(val, FloatKind::E5M2) => write!(f, "{}e5m2", val.0),
293 Constant::Float(val, FloatKind::UE8M0) => write!(f, "{}ue8m0", val.0),
294 Constant::Float(val, FloatKind::BF16) => write!(f, "{}bf16", val.0),
295 Constant::Float(val, FloatKind::F16) => write!(f, "{}f16", val.0),
296 Constant::Float(val, FloatKind::Flex32) => write!(f, "{}minf16", val.0),
297 Constant::Float(val, FloatKind::TF32) => write!(f, "{}tf32", val.0),
298 Constant::Float(val, FloatKind::F32) => write!(f, "{}f32", val.0),
299 Constant::Float(val, FloatKind::F64) => write!(f, "{}f64", val.0),
300 Constant::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
301 Constant::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
302 Constant::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
303 Constant::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
304 Constant::Bool(val) => write!(f, "{val}"),
305 }
306 }
307}
308
309impl Display for Expression {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 match self {
312 Expression::Instruction(instruction) => write!(f, "{instruction}"),
313 Expression::Copy(val, _) => write!(f, "copy({val})"),
314 Expression::Value(value) => write!(f, "{value}"),
315 Expression::Volatile(value) => write!(f, "volatile({value})"),
316 Expression::Phi(entries) => write!(
317 f,
318 "phi({})",
319 entries
320 .iter()
321 .map(|(val, b)| format!("{val}: bb{}", b.index()))
322 .collect::<Vec<_>>()
323 .join(", ")
324 ),
325 }
326 }
327}
328
329impl Display for Instruction {
330 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 write!(f, "{:?}: [{:?}]", self.op, self.args)
332 }
333}
334
335impl Display for BasicBlock {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 for phi in self.phi_nodes.borrow().iter() {
338 write!(f, " {} = phi ", phi.out)?;
339 for entry in &phi.entries {
340 write!(f, "[bb{}: ", entry.block.index())?;
341 write!(f, "{}]", entry.value)?;
342 }
343 writeln!(f, ";\n")?;
344 }
345 if !self.phi_nodes.borrow().is_empty() {
346 writeln!(f)?;
347 }
348
349 for op in self.ops.borrow_mut().values_mut() {
350 let op_fmt = op.to_string();
351 if op_fmt.is_empty() {
352 continue;
353 }
354
355 writeln!(f, " {op_fmt};")?;
356 }
357 match &*self.control_flow.borrow() {
358 ControlFlow::IfElse {
359 cond,
360 then,
361 or_else,
362 merge,
363 } => {
364 writeln!(
365 f,
366 " {cond} ? bb{} : bb{}; merge: {}",
367 then.index(),
368 or_else.index(),
369 merge
370 .as_ref()
371 .map(|it| format!("bb{}", it.index()))
372 .unwrap_or("None".to_string())
373 )?;
374 }
375 super::ControlFlow::Switch {
376 value,
377 default,
378 branches,
379 ..
380 } => {
381 write!(f, " switch({value}) ")?;
382 for (val, block) in branches {
383 write!(f, "[{val}: bb{}] ", block.index())?;
384 }
385 writeln!(f, "[default: bb{}];", default.index())?;
386 }
387 super::ControlFlow::Loop {
388 body,
389 continue_target,
390 merge,
391 } => {
392 writeln!(
393 f,
394 " loop(continue: bb{}, merge: bb{})",
395 continue_target.index(),
396 merge.index()
397 )?;
398 writeln!(f, " branch bb{};", body.index())?
399 }
400 super::ControlFlow::LoopBreak {
401 break_cond,
402 body,
403 continue_target,
404 merge,
405 } => {
406 writeln!(
407 f,
408 " loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
409 break_cond,
410 body.index(),
411 continue_target.index(),
412 merge.index()
413 )?;
414 }
415 super::ControlFlow::Return => writeln!(f, " return;")?,
416 super::ControlFlow::None => {
417 writeln!(f, " branch;")?;
418 }
419 }
420 Ok(())
421 }
422}
423
424impl Display for SmemAllocation {
425 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426 write!(
427 f,
428 "shared(id: {}, offset: {}, length: {}, align: {}, ty: {})",
429 self.smem.id, self.offset, self.smem.length, self.smem.align, self.smem.ty
430 )
431 }
432}