cubecl_ir/
processing.rs

1use core::fmt::Display;
2
3use alloc::string::ToString;
4use alloc::vec::Vec;
5
6use crate::{Allocator, AtomicOp, Bitwise, Comparison, Operator};
7
8use super::{
9    Arithmetic, Branch, CoopMma, ElemType, Instruction, Metadata, Operation, UIntKind, Variable,
10    VariableKind,
11};
12
13pub trait Processor: core::fmt::Debug {
14    fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing;
15}
16
17/// Information necessary when compiling a scope.
18pub struct ScopeProcessing {
19    /// The variable declarations.
20    pub variables: Vec<Variable>,
21    /// The operations.
22    pub instructions: Vec<Instruction>,
23}
24
25impl Display for ScopeProcessing {
26    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27        writeln!(f, "{{")?;
28        for instruction in self.instructions.iter() {
29            let instruction_str = instruction.to_string();
30            if !instruction_str.is_empty() {
31                writeln!(f, "    {instruction_str}")?;
32            }
33        }
34        write!(f, "}}")?;
35        Ok(())
36    }
37}
38
39impl ScopeProcessing {
40    /// Optimize the [variables](Variable) and [operations](Operation).
41    ///
42    /// ## Notes:
43    ///
44    /// This should be called once right after the creation of the type.
45    /// If you built this type from the [scope process function](super::Scope::process), you don't have to
46    /// call it again.
47    pub fn optimize(self) -> Self {
48        self.sanitize_constant_scalars()
49    }
50
51    /// Make sure constant scalars are of the correct type so compilers don't have to do conversion
52    /// and handle edge cases such as indexing with a signed integer.
53    fn sanitize_constant_scalars(mut self) -> Self {
54        self.instructions
55            .iter_mut()
56            .for_each(|inst| match &mut inst.operation {
57                Operation::Copy(op) => {
58                    sanitize_constant_scalar_ref_var(op, &inst.out.unwrap());
59                }
60                Operation::Arithmetic(op) => match op {
61                    Arithmetic::Add(op) => {
62                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
63                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
64                    }
65                    Arithmetic::SaturatingAdd(op) => {
66                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
67                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
68                    }
69                    Arithmetic::Fma(op) => {
70                        sanitize_constant_scalar_ref_var(&mut op.a, &inst.out.unwrap());
71                        sanitize_constant_scalar_ref_var(&mut op.b, &inst.out.unwrap());
72                        sanitize_constant_scalar_ref_var(&mut op.c, &inst.out.unwrap());
73                    }
74                    Arithmetic::Sub(op) => {
75                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
76                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
77                    }
78                    Arithmetic::SaturatingSub(op) => {
79                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
80                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
81                    }
82                    Arithmetic::Mul(op) => {
83                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
84                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
85                    }
86                    Arithmetic::Div(op) => {
87                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
88                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
89                    }
90                    Arithmetic::MulHi(op) => {
91                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
92                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
93                    }
94                    Arithmetic::Abs(op) => {
95                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
96                    }
97                    Arithmetic::Exp(op) => {
98                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
99                    }
100                    Arithmetic::Log(op) => {
101                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
102                    }
103                    Arithmetic::Log1p(op) => {
104                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
105                    }
106                    Arithmetic::Cos(op) => {
107                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
108                    }
109                    Arithmetic::Sin(op) => {
110                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
111                    }
112                    Arithmetic::Tan(op) => {
113                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
114                    }
115                    Arithmetic::Tanh(op) => {
116                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
117                    }
118                    Arithmetic::Sinh(op) => {
119                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
120                    }
121                    Arithmetic::Cosh(op) => {
122                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
123                    }
124                    Arithmetic::ArcCos(op) => {
125                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
126                    }
127                    Arithmetic::ArcSin(op) => {
128                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
129                    }
130                    Arithmetic::ArcTan(op) => {
131                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
132                    }
133                    Arithmetic::ArcSinh(op) => {
134                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
135                    }
136                    Arithmetic::ArcCosh(op) => {
137                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
138                    }
139                    Arithmetic::ArcTanh(op) => {
140                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
141                    }
142                    Arithmetic::Degrees(op) => {
143                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
144                    }
145                    Arithmetic::Radians(op) => {
146                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
147                    }
148                    Arithmetic::ArcTan2(op) => {
149                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
150                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
151                    }
152                    Arithmetic::Powf(op) => {
153                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
154                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
155                    }
156                    Arithmetic::Powi(op) => {
157                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
158                    }
159                    Arithmetic::Sqrt(op) => {
160                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
161                    }
162                    Arithmetic::InverseSqrt(op) => {
163                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
164                    }
165                    Arithmetic::Round(op) => {
166                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
167                    }
168                    Arithmetic::Floor(op) => {
169                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
170                    }
171                    Arithmetic::Ceil(op) => {
172                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
173                    }
174                    Arithmetic::Trunc(op) => {
175                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
176                    }
177                    Arithmetic::Erf(op) => {
178                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
179                    }
180                    Arithmetic::Recip(op) => {
181                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
182                    }
183                    Arithmetic::Clamp(op) => {
184                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
185                        sanitize_constant_scalar_ref_var(&mut op.min_value, &inst.out.unwrap());
186                        sanitize_constant_scalar_ref_var(&mut op.max_value, &inst.out.unwrap());
187                    }
188                    Arithmetic::Modulo(op) => {
189                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
190                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
191                    }
192                    Arithmetic::Neg(op) => {
193                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap())
194                    }
195                    Arithmetic::Max(op) => {
196                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
197                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
198                    }
199                    Arithmetic::Min(op) => {
200                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
201                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
202                    }
203                    Arithmetic::Remainder(op) => {
204                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
205                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
206                    }
207                    Arithmetic::Magnitude(op) => {
208                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
209                    }
210                    Arithmetic::Normalize(op) => {
211                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
212                    }
213                    Arithmetic::Dot(op) => {
214                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
215                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
216                    }
217                },
218                Operation::Comparison(op) => match op {
219                    Comparison::Greater(op) => {
220                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
221                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
222                    }
223                    Comparison::LowerEqual(op) => {
224                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
225                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
226                    }
227                    Comparison::GreaterEqual(op) => {
228                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
229                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
230                    }
231                    Comparison::Equal(op) => {
232                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
233                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
234                    }
235                    Comparison::NotEqual(op) => {
236                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
237                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
238                    }
239                    Comparison::Lower(op) => {
240                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
241                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
242                    }
243                    Comparison::IsNan(_op) | Comparison::IsInf(_op) => {
244                        // Nothing to do
245                    }
246                },
247                Operation::Bitwise(op) => match op {
248                    Bitwise::BitwiseAnd(op) => {
249                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
250                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
251                    }
252                    Bitwise::BitwiseOr(op) => {
253                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
254                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
255                    }
256                    Bitwise::BitwiseXor(op) => {
257                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
258                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
259                    }
260                    Bitwise::CountOnes(_) | Bitwise::LeadingZeros(_) | Bitwise::FindFirstSet(_) => {
261                        // Nothing to do
262                    }
263                    Bitwise::ReverseBits(op) => {
264                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
265                    }
266                    Bitwise::ShiftLeft(op) => {
267                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
268                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
269                    }
270                    Bitwise::ShiftRight(op) => {
271                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
272                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
273                    }
274                    Bitwise::BitwiseNot(op) => {
275                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
276                    }
277                },
278                Operation::Operator(op) => match op {
279                    Operator::Index(op) => {
280                        sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
281                        sanitize_constant_scalar_ref_elem(
282                            &mut op.index,
283                            ElemType::UInt(UIntKind::U32),
284                        );
285                    }
286                    Operator::UncheckedIndex(op) => {
287                        sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
288                        sanitize_constant_scalar_ref_elem(
289                            &mut op.index,
290                            ElemType::UInt(UIntKind::U32),
291                        );
292                    }
293                    Operator::IndexAssign(op) => {
294                        sanitize_constant_scalar_ref_elem(
295                            &mut op.index,
296                            ElemType::UInt(UIntKind::U32),
297                        );
298                        sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
299                    }
300                    Operator::UncheckedIndexAssign(op) => {
301                        sanitize_constant_scalar_ref_elem(
302                            &mut op.index,
303                            ElemType::UInt(UIntKind::U32),
304                        );
305                        sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
306                    }
307                    Operator::And(op) => {
308                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
309                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
310                    }
311                    Operator::Or(op) => {
312                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
313                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
314                    }
315                    Operator::Not(op) => {
316                        sanitize_constant_scalar_ref_elem(&mut op.input, ElemType::Bool);
317                    }
318                    Operator::InitLine(_) => {
319                        // TODO: Sanitize based on elem
320                    }
321                    Operator::CopyMemory(op) => {
322                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
323                        sanitize_constant_scalar_ref_elem(
324                            &mut op.in_index,
325                            ElemType::UInt(UIntKind::U32),
326                        );
327                        sanitize_constant_scalar_ref_elem(
328                            &mut op.out_index,
329                            ElemType::UInt(UIntKind::U32),
330                        );
331                    }
332                    Operator::CopyMemoryBulk(op) => {
333                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
334                        sanitize_constant_scalar_ref_elem(
335                            &mut op.in_index,
336                            ElemType::UInt(UIntKind::U32),
337                        );
338                        sanitize_constant_scalar_ref_elem(
339                            &mut op.out_index,
340                            ElemType::UInt(UIntKind::U32),
341                        );
342                    }
343                    Operator::Select(op) => {
344                        sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
345                        sanitize_constant_scalar_ref_var(&mut op.then, &inst.out.unwrap());
346                        sanitize_constant_scalar_ref_var(&mut op.or_else, &inst.out.unwrap());
347                    }
348                    Operator::Cast(_) => {}
349                    Operator::Reinterpret(_) => {}
350                },
351                Operation::Atomic(op) => match op {
352                    AtomicOp::Load(_) => {}
353                    AtomicOp::Store(_) => {}
354                    AtomicOp::Swap(op) => {
355                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
356                    }
357                    AtomicOp::CompareAndSwap(op) => {
358                        sanitize_constant_scalar_ref_var(&mut op.cmp, &inst.out.unwrap());
359                        sanitize_constant_scalar_ref_var(&mut op.val, &inst.out.unwrap());
360                    }
361                    AtomicOp::Add(op) => {
362                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
363                    }
364                    AtomicOp::Sub(op) => {
365                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
366                    }
367                    AtomicOp::Max(op) => {
368                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
369                    }
370                    AtomicOp::Min(op) => {
371                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
372                    }
373                    AtomicOp::And(op) => {
374                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
375                    }
376                    AtomicOp::Or(op) => {
377                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
378                    }
379                    AtomicOp::Xor(op) => {
380                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
381                    }
382                },
383                Operation::Metadata(op) => match op {
384                    Metadata::Stride { dim, .. } => {
385                        sanitize_constant_scalar_ref_elem(dim, ElemType::UInt(UIntKind::U32));
386                    }
387                    Metadata::Shape { dim, .. } => {
388                        sanitize_constant_scalar_ref_elem(dim, ElemType::UInt(UIntKind::U32));
389                    }
390                    Metadata::Length { .. }
391                    | Metadata::BufferLength { .. }
392                    | Metadata::Rank { .. } => {
393                        // Nothing to do
394                    }
395                },
396                Operation::Branch(op) => match op {
397                    Branch::If(op) => {
398                        sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
399                    }
400                    Branch::IfElse(op) => {
401                        sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
402                    }
403                    Branch::RangeLoop(op) => {
404                        sanitize_constant_scalar_ref_var(&mut op.end, &op.start);
405                        sanitize_constant_scalar_ref_var(&mut op.i, &op.start);
406                        if let Some(step) = &mut op.step {
407                            sanitize_constant_scalar_ref_elem(step, ElemType::UInt(UIntKind::U32));
408                        }
409                    }
410                    _ => {
411                        // Nothing to do.
412                    }
413                },
414                Operation::Synchronization(_) => {
415                    // Nothing to do.
416                }
417                Operation::Plane(_) => {
418                    // Nothing to do since no constant is possible.
419                }
420                Operation::CoopMma(op) => match op {
421                    CoopMma::Fill { value } => {
422                        sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
423                    }
424                    CoopMma::Load { value, stride, .. } => {
425                        sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
426                        sanitize_constant_scalar_ref_elem(stride, ElemType::UInt(UIntKind::U32));
427                    }
428                    CoopMma::Execute { .. }
429                    | CoopMma::ExecuteManual { .. }
430                    | CoopMma::ExecuteScaled { .. } => {
431                        // Nothing to do.
432                    }
433                    CoopMma::Store { stride, .. } => {
434                        sanitize_constant_scalar_ref_elem(stride, ElemType::UInt(UIntKind::U32));
435                    }
436                    CoopMma::Cast { .. } => {
437                        // Nothing to do.
438                    }
439                    CoopMma::RowIndex { lane_id, i, .. } => {
440                        sanitize_constant_scalar_ref_elem(lane_id, ElemType::UInt(UIntKind::U32));
441                        sanitize_constant_scalar_ref_elem(i, ElemType::UInt(UIntKind::U32));
442                    }
443                    CoopMma::ColIndex { lane_id, i, .. } => {
444                        sanitize_constant_scalar_ref_elem(lane_id, ElemType::UInt(UIntKind::U32));
445                        sanitize_constant_scalar_ref_elem(i, ElemType::UInt(UIntKind::U32));
446                    }
447                    CoopMma::LoadMatrix { .. } | CoopMma::StoreMatrix { .. } => {
448                        // Nothing to do
449                    }
450                },
451                Operation::NonSemantic(_) => {
452                    // Nothing to do.
453                }
454                Operation::Barrier(_) => {
455                    // Nothing to do
456                }
457                Operation::Tma(_) => {
458                    // Nothing to do
459                }
460                Operation::Marker(_) => {
461                    // Nothing to do
462                }
463            });
464        self
465    }
466}
467
468fn sanitize_constant_scalar_ref_var(var: &mut Variable, reference: &Variable) {
469    if !reference.ty.is_semantic() {
470        let elem = reference.ty.elem_type();
471        sanitize_constant_scalar_ref_elem(var, elem);
472    }
473}
474
475fn sanitize_constant_scalar_ref_elem(var: &mut Variable, elem: ElemType) {
476    if let VariableKind::ConstantScalar(scalar) = var.kind
477        && scalar.elem_type() != elem
478    {
479        *var = match scalar {
480            super::ConstantScalarValue::Int(val, _) => elem.constant_from_i64(val),
481            super::ConstantScalarValue::Float(val, _) => elem.constant_from_f64(val),
482            super::ConstantScalarValue::UInt(val, _) => elem.constant_from_u64(val),
483            super::ConstantScalarValue::Bool(val) => elem.constant_from_bool(val),
484        };
485    }
486}