1use std::{fmt::Display, rc::Rc};
2
3use cubecl_core::ir::{FloatKind, IntKind, UIntKind};
4use petgraph::visit::EdgeRef;
5
6use crate::{
7 analyses::{const_len::Slices, integer_range::Ranges, liveness::Liveness},
8 gvn::{BlockSets, Constant, Expression, GvnState, Instruction, Local, OpId, Value, ValueTable},
9 ControlFlow,
10};
11
12use super::Optimizer;
13
14const DEBUG_GVN: bool = false;
15
16impl Display for Optimizer {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 let slices = self.analysis_cache.try_get::<Slices>().unwrap_or_default();
20 let ranges = self.analysis_cache.try_get::<Ranges>().unwrap_or_default();
21
22 f.write_str("Slices:\n")?;
23 for (var_id, slice) in slices.iter() {
24 let end_op = slice.end_op.as_ref().map(|it| format!("{it}"));
25 writeln!(
26 f,
27 "slice{var_id:?}: {{ start: {}, end: {}, end_op: {}, const_len: {:?} }}",
28 slice.start,
29 slice.end,
30 end_op.unwrap_or("None".to_string()),
31 slice.const_len
32 )?;
33 }
34 f.write_str("\n\n")?;
35
36 let global_nums = self
37 .analysis_cache
38 .try_get::<GvnState>()
39 .unwrap_or_default();
40 let liveness = self
41 .analysis_cache
42 .try_get::<Liveness>()
43 .unwrap_or_else(|| Rc::new(Liveness::empty(self)));
44
45 if DEBUG_GVN {
46 writeln!(f, "# Value Table:")?;
47 writeln!(f, "{}", global_nums.values)?;
48 }
49
50 for node in self.program.node_indices() {
51 let id = node.index();
52 let bb = &self.program[node];
53 writeln!(f, "bb{id} {{")?;
54 if DEBUG_GVN {
55 let block_sets = &global_nums
56 .block_sets
57 .get(&node)
58 .cloned()
59 .unwrap_or_default();
60 writeln!(f, "{block_sets}")?;
61 }
62
63 if !bb.block_use.is_empty() {
64 writeln!(f, " Uses: {:?}", bb.block_use)?;
65 }
66 let live_vars = liveness.at_block(node).iter();
67 let live_vars = live_vars.map(|it| format!("local({})", it));
68 let live_vars = live_vars.collect::<Vec<_>>();
69 writeln!(f, " Live variables: [{}]\n", live_vars.join(", "))?;
70
71 for phi in bb.phi_nodes.borrow().iter() {
72 write!(f, " {} = phi ", phi.out)?;
73 for entry in &phi.entries {
74 write!(f, "[bb{}: ", entry.block.index())?;
75 write!(f, "{}]", entry.value)?;
76 }
77 f.write_str(";\n")?;
78 }
79 if !bb.phi_nodes.borrow().is_empty() {
80 writeln!(f)?;
81 }
82
83 for op in bb.ops.borrow_mut().values_mut() {
84 let range = op.out.map(|var| ranges.range_of(self, &var));
85 let range = range.map(|it| format!(" range: {it};")).unwrap_or_default();
86
87 writeln!(f, " {op};{range}")?;
88 }
89 match &*bb.control_flow.borrow() {
90 ControlFlow::IfElse {
91 cond,
92 then,
93 or_else,
94 merge,
95 } => {
96 writeln!(
97 f,
98 " {cond} ? bb{} : bb{}; merge: {}",
99 then.index(),
100 or_else.index(),
101 merge
102 .as_ref()
103 .map(|it| format!("bb{}", it.index()))
104 .unwrap_or("None".to_string())
105 )?;
106 }
107 super::ControlFlow::Switch {
108 value,
109 default,
110 branches,
111 ..
112 } => {
113 write!(f, " switch({value}) ")?;
114 for (val, block) in branches {
115 write!(f, "[{val}: bb{}] ", block.index())?;
116 }
117 writeln!(f, "[default: bb{}];", default.index())?;
118 }
119 super::ControlFlow::Loop {
120 body,
121 continue_target,
122 merge,
123 } => {
124 writeln!(
125 f,
126 " loop(continue: bb{}, merge: bb{})",
127 continue_target.index(),
128 merge.index()
129 )?;
130 writeln!(f, " branch bb{};", body.index())?
131 }
132 super::ControlFlow::LoopBreak {
133 break_cond,
134 body,
135 continue_target,
136 merge,
137 } => {
138 writeln!(
139 f,
140 " loop(cond: {}, body: bb{} continue: bb{}, break: bb{})",
141 break_cond,
142 body.index(),
143 continue_target.index(),
144 merge.index()
145 )?;
146 }
147 super::ControlFlow::Return => writeln!(f, " return;")?,
148 super::ControlFlow::None => {
149 let edge = self.program.edges(node).next();
150 let target = edge.map(|it| it.target().index()).unwrap_or(255);
151 writeln!(f, " branch bb{target};")?;
152 }
153 }
154 f.write_str("}\n\n")?;
155 }
156
157 Ok(())
158 }
159}
160
161impl Display for BlockSets {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 let mut exp_gen = self.exp_gen.iter().collect::<Vec<_>>();
164 exp_gen.sort_by_key(|it| it.0);
165 let exp_gen = exp_gen
166 .into_iter()
167 .map(|(val, expr)| format!("{val}: {expr}"))
168 .collect::<Vec<_>>();
169 let mut phi_gen = self.phi_gen.iter().collect::<Vec<_>>();
170 phi_gen.sort_by_key(|it| it.0);
171 let phi_gen = phi_gen
172 .into_iter()
173 .map(|(val, expr)| format!("{val}: {expr}"))
174 .collect::<Vec<_>>();
175 let tmp_gen = self
176 .tmp_gen
177 .iter()
178 .map(|it| format!("{it}"))
179 .collect::<Vec<_>>();
180 let mut leaders = self.leaders.iter().collect::<Vec<_>>();
181 leaders.sort_by_key(|it| it.0);
182 let leaders = leaders
183 .into_iter()
184 .map(|(val, expr)| format!("{val}: {expr}"))
185 .collect::<Vec<_>>();
186 let mut antic_out = self.antic_out.iter().collect::<Vec<_>>();
187 antic_out.sort_by_key(|it| it.0);
188 let antic_out = antic_out
189 .into_iter()
190 .map(|(val, expr)| format!("{val}: {expr}"))
191 .collect::<Vec<_>>();
192 let mut antic_in = self.antic_in.iter().collect::<Vec<_>>();
193 antic_in.sort_by_key(|it| it.0);
194 let antic_in = antic_in
195 .into_iter()
196 .map(|(val, expr)| format!("{val}: {expr}"))
197 .collect::<Vec<_>>();
198
199 writeln!(f, " exp_gen: [{}]", exp_gen.join(", "))?;
200 writeln!(f, " phi_gen: [{}]", phi_gen.join(", "))?;
201 writeln!(f, " tmp_gen: [{}]", tmp_gen.join(", "))?;
202 writeln!(f, " leaders: [{}]", leaders.join(", "))?;
203 writeln!(f, " antic_in: [{}]", antic_in.join(", "))?;
204 writeln!(f, " antic_out: [{}]", antic_out.join(", "))
205 }
206}
207
208impl Display for ValueTable {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 let mut values = self.value_numbers.iter().collect::<Vec<_>>();
211 values.sort_by_key(|it| it.1);
212 writeln!(f, "values: [")?;
213 for (val, num) in values {
214 writeln!(f, " {num}: {val},")?;
215 }
216 writeln!(f, "]")?;
217 writeln!(f, "expressions: [")?;
218 let mut expressions = self.expression_numbers.iter().collect::<Vec<_>>();
219 expressions.sort_by_key(|it| it.1);
220 for (expr, val) in expressions {
221 writeln!(f, " {val}: {expr},")?;
222 }
223 writeln!(f, "]")
224 }
225}
226
227impl Display for Value {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 match self {
230 Value::Constant(constant) => write!(f, "{constant}"),
231 Value::Local(local) => write!(f, "{local}"),
232 Value::Input(id, _) => write!(f, "input({id})"),
233 Value::Scalar(id, elem) => write!(f, "scalar({elem}, {id})"),
234 Value::ConstArray(id, _, _) => write!(f, "const_array({id})"),
235 Value::Builtin(builtin) => write!(f, "{builtin:?}"),
236 Value::Output(id, _) => write!(f, "output({id})"),
237 Value::Slice(id, _) => write!(f, "slice({id})"),
238 }
239 }
240}
241
242impl Display for Local {
243 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 match self.version {
245 0 => write!(f, "binding({})", self.id),
246 v => write!(f, "local({}).v{v}", self.id),
247 }
248 }
249}
250
251impl Display for Constant {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 match self {
254 Constant::Int(val, IntKind::I8) => write!(f, "{val}i8"),
255 Constant::Int(val, IntKind::I16) => write!(f, "{val}i16"),
256 Constant::Int(val, IntKind::I32) => write!(f, "{val}i32"),
257 Constant::Int(val, IntKind::I64) => write!(f, "{val}i64"),
258 Constant::Float(val, FloatKind::BF16) => write!(f, "{}bf16", val.0),
259 Constant::Float(val, FloatKind::F16) => write!(f, "{}f16", val.0),
260 Constant::Float(val, FloatKind::Flex32) => write!(f, "{}minf16", val.0),
261 Constant::Float(val, FloatKind::TF32) => write!(f, "{}tf32", val.0),
262 Constant::Float(val, FloatKind::F32) => write!(f, "{}f32", val.0),
263 Constant::Float(val, FloatKind::F64) => write!(f, "{}f64", val.0),
264 Constant::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
265 Constant::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
266 Constant::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
267 Constant::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
268 Constant::Bool(val) => write!(f, "{val}"),
269 }
270 }
271}
272
273impl Display for Expression {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 match self {
276 Expression::Instruction(instruction) => write!(f, "{instruction}"),
277 Expression::Copy(val, _) => write!(f, "copy({val})"),
278 Expression::Value(value) => write!(f, "{value}"),
279 Expression::Volatile(value) => write!(f, "volatile({value})"),
280 Expression::Phi(entries) => write!(
281 f,
282 "phi({})",
283 entries
284 .iter()
285 .map(|(val, b)| format!("{val}: bb{}", b.index()))
286 .collect::<Vec<_>>()
287 .join(", ")
288 ),
289 }
290 }
291}
292
293impl Display for Instruction {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 let args = &self.args;
296 match self.op {
297 OpId::Add => write!(f, "{} + {}", args[0], args[1]),
298 OpId::Fma => write!(f, "fma({}, {}, {})", args[0], args[1], args[2]),
299 OpId::Sub => write!(f, "{} - {}", args[0], args[1]),
300 OpId::Mul => write!(f, "{} * {}", args[0], args[1]),
301 OpId::Div => write!(f, "{} / {}", args[0], args[1]),
302 OpId::Abs => write!(f, "{}.abs()", args[0]),
303 OpId::Exp => write!(f, "{}.exp()", args[0]),
304 OpId::Log => write!(f, "{}.log()", args[0]),
305 OpId::Log1p => write!(f, "{}.log1p()", args[0]),
306 OpId::Cos => write!(f, "{}.cos()", args[0]),
307 OpId::Sin => write!(f, "{}.sin()", args[0]),
308 OpId::Tanh => write!(f, "{}.tanh()", args[0]),
309 OpId::Powf => write!(f, "{}.powf()", args[0]),
310 OpId::Sqrt => write!(f, "{}.sqrt()", args[0]),
311 OpId::Round => write!(f, "{}.round()", args[0]),
312 OpId::Floor => write!(f, "{}.floor()", args[0]),
313 OpId::Ceil => write!(f, "{}.ceil()", args[0]),
314 OpId::Erf => write!(f, "{}.erf()", args[0]),
315 OpId::Recip => write!(f, "1.0 / {}", args[0]),
316 OpId::Equal => write!(f, "{} == {}", args[0], args[1]),
317 OpId::NotEqual => write!(f, "{} != {}", args[0], args[1]),
318 OpId::Lower => write!(f, "{} < {}", args[0], args[1]),
319 OpId::Clamp => write!(f, "clamp({}, {}, {})", args[0], args[1], args[2]),
320 OpId::Greater => write!(f, "{} > {}", args[0], args[1]),
321 OpId::LowerEqual => write!(f, "{} <= {}", args[0], args[1]),
322 OpId::GreaterEqual => write!(f, "{} >= {}", args[0], args[1]),
323 OpId::Modulo => write!(f, "{} % {}", args[0], args[1]),
324 OpId::Index => write!(f, "{}[{}]", args[0], args[1]),
325 OpId::InitLine => write!(
326 f,
327 "vec{}({})",
328 args.len(),
329 args.iter()
330 .map(|it| it.to_string())
331 .collect::<Vec<_>>()
332 .join(", ")
333 ),
334 OpId::And => write!(f, "{} && {}", args[0], args[1]),
335 OpId::Or => write!(f, "{} || {}", args[0], args[1]),
336 OpId::Not => write!(f, "!{}", args[0]),
337 OpId::Neg => write!(f, "-{}", args[0]),
338 OpId::Max => write!(f, "max({}, {})", args[0], args[1]),
339 OpId::Min => write!(f, "min({}, {})", args[0], args[1]),
340 OpId::BitwiseAnd => write!(f, "{} & {}", args[0], args[1]),
341 OpId::BitwiseOr => write!(f, "{} | {}", args[0], args[1]),
342 OpId::BitwiseXor => write!(f, "{} ^ {}", args[0], args[1]),
343 OpId::CountOnes => write!(f, "{}.count_ones()", args[0]),
344 OpId::ReverseBits => write!(f, "{}.reverse_bits()", args[0]),
345 OpId::ShiftLeft => write!(f, "{} << {}", args[0], args[1]),
346 OpId::ShiftRight => write!(f, "{} >> {}", args[0], args[1]),
347 OpId::Remainder => write!(f, "{} % {}", args[0], args[1]),
348 OpId::Magnitude => write!(f, "{}.length()", args[0]),
349 OpId::Normalize => write!(f, "{}.normalize()", args[0]),
350 OpId::Dot => write!(f, "dot({}, {})", args[0], args[1]),
351 OpId::Select => write!(f, "select({}, {}, {})", args[0], args[1], args[2]),
352 OpId::Bitcast => write!(f, "bitcast<{}>({})", self.item, args[0]),
353 OpId::Rank => write!(f, "{}.rank()", args[0]),
354 OpId::Length => write!(f, "{}.len()", args[0]),
355 OpId::BufferLength => write!(f, "buffer_len({})", args[0]),
356 OpId::Shape => write!(f, "{}.shape[{}]", args[0], args[1]),
357 OpId::Stride => write!(f, "{}.stride[{}]", args[0], args[1]),
358 OpId::Cast => write!(f, "cast<{}>({})", self.item, args[0]),
359 }
360 }
361}