cubecl_ir/
processing.rs

1use core::fmt::Display;
2
3use alloc::string::ToString;
4use alloc::vec::Vec;
5
6use crate::{Allocator, AtomicOp, Bitwise, Comparison, Operator};
7
8use super::{
9    Arithmetic, Branch, CoopMma, ElemType, Instruction, Metadata, Operation, UIntKind, Variable,
10    VariableKind,
11};
12
13pub trait Processor: core::fmt::Debug {
14    fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing;
15}
16
17/// Information necessary when compiling a scope.
18pub struct ScopeProcessing {
19    /// The variable declarations.
20    pub variables: Vec<Variable>,
21    /// The operations.
22    pub instructions: Vec<Instruction>,
23}
24
25impl Display for ScopeProcessing {
26    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27        writeln!(f, "{{")?;
28        for instruction in self.instructions.iter() {
29            let instruction_str = instruction.to_string();
30            if !instruction_str.is_empty() {
31                writeln!(f, "    {instruction_str}")?;
32            }
33        }
34        write!(f, "}}")?;
35        Ok(())
36    }
37}
38
39impl ScopeProcessing {
40    /// Optimize the [variables](Variable) and [operations](Operation).
41    ///
42    /// ## Notes:
43    ///
44    /// This should be called once right after the creation of the type.
45    /// If you built this type from the [scope process function](super::Scope::process), you don't have to
46    /// call it again.
47    pub fn optimize(self) -> Self {
48        self.sanitize_constant_scalars()
49    }
50
51    /// Make sure constant scalars are of the correct type so compilers don't have to do conversion
52    /// and handle edge cases such as indexing with a signed integer.
53    fn sanitize_constant_scalars(mut self) -> Self {
54        self.instructions
55            .iter_mut()
56            .for_each(|inst| match &mut inst.operation {
57                Operation::Copy(op) => {
58                    sanitize_constant_scalar_ref_var(op, &inst.out.unwrap());
59                }
60                Operation::Arithmetic(op) => match op {
61                    Arithmetic::Add(op) => {
62                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
63                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
64                    }
65                    Arithmetic::SaturatingAdd(op) => {
66                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
67                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
68                    }
69                    Arithmetic::Fma(op) => {
70                        sanitize_constant_scalar_ref_var(&mut op.a, &inst.out.unwrap());
71                        sanitize_constant_scalar_ref_var(&mut op.b, &inst.out.unwrap());
72                        sanitize_constant_scalar_ref_var(&mut op.c, &inst.out.unwrap());
73                    }
74                    Arithmetic::Sub(op) => {
75                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
76                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
77                    }
78                    Arithmetic::SaturatingSub(op) => {
79                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
80                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
81                    }
82                    Arithmetic::Mul(op) => {
83                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
84                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
85                    }
86                    Arithmetic::Div(op) => {
87                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
88                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
89                    }
90                    Arithmetic::MulHi(op) => {
91                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
92                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
93                    }
94                    Arithmetic::Abs(op) => {
95                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
96                    }
97                    Arithmetic::Exp(op) => {
98                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
99                    }
100                    Arithmetic::Log(op) => {
101                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
102                    }
103                    Arithmetic::Log1p(op) => {
104                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
105                    }
106                    Arithmetic::Cos(op) => {
107                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
108                    }
109                    Arithmetic::Sin(op) => {
110                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
111                    }
112                    Arithmetic::Tan(op) => {
113                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
114                    }
115                    Arithmetic::Tanh(op) => {
116                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
117                    }
118                    Arithmetic::Sinh(op) => {
119                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
120                    }
121                    Arithmetic::Cosh(op) => {
122                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
123                    }
124                    Arithmetic::ArcCos(op) => {
125                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
126                    }
127                    Arithmetic::ArcSin(op) => {
128                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
129                    }
130                    Arithmetic::ArcTan(op) => {
131                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
132                    }
133                    Arithmetic::ArcSinh(op) => {
134                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
135                    }
136                    Arithmetic::ArcCosh(op) => {
137                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
138                    }
139                    Arithmetic::ArcTanh(op) => {
140                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
141                    }
142                    Arithmetic::Degrees(op) => {
143                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
144                    }
145                    Arithmetic::Radians(op) => {
146                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
147                    }
148                    Arithmetic::ArcTan2(op) => {
149                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
150                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
151                    }
152                    Arithmetic::Powf(op) => {
153                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
154                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
155                    }
156                    Arithmetic::Powi(op) => {
157                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
158                    }
159                    Arithmetic::Hypot(op) => {
160                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
161                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
162                    }
163                    Arithmetic::Rhypot(op) => {
164                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
165                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
166                    }
167                    Arithmetic::Sqrt(op) => {
168                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
169                    }
170                    Arithmetic::InverseSqrt(op) => {
171                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
172                    }
173                    Arithmetic::Round(op) => {
174                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
175                    }
176                    Arithmetic::Floor(op) => {
177                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
178                    }
179                    Arithmetic::Ceil(op) => {
180                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
181                    }
182                    Arithmetic::Trunc(op) => {
183                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
184                    }
185                    Arithmetic::Erf(op) => {
186                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
187                    }
188                    Arithmetic::Recip(op) => {
189                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
190                    }
191                    Arithmetic::Clamp(op) => {
192                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
193                        sanitize_constant_scalar_ref_var(&mut op.min_value, &inst.out.unwrap());
194                        sanitize_constant_scalar_ref_var(&mut op.max_value, &inst.out.unwrap());
195                    }
196                    Arithmetic::Modulo(op) => {
197                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
198                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
199                    }
200                    Arithmetic::Neg(op) => {
201                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap())
202                    }
203                    Arithmetic::Max(op) => {
204                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
205                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
206                    }
207                    Arithmetic::Min(op) => {
208                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
209                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
210                    }
211                    Arithmetic::Remainder(op) => {
212                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
213                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
214                    }
215                    Arithmetic::Magnitude(op) => {
216                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
217                    }
218                    Arithmetic::Normalize(op) => {
219                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
220                    }
221                    Arithmetic::Dot(op) => {
222                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
223                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
224                    }
225                },
226                Operation::Comparison(op) => match op {
227                    Comparison::Greater(op) => {
228                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
229                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
230                    }
231                    Comparison::LowerEqual(op) => {
232                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
233                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
234                    }
235                    Comparison::GreaterEqual(op) => {
236                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
237                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
238                    }
239                    Comparison::Equal(op) => {
240                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
241                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
242                    }
243                    Comparison::NotEqual(op) => {
244                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
245                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
246                    }
247                    Comparison::Lower(op) => {
248                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
249                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
250                    }
251                    Comparison::IsNan(_op) | Comparison::IsInf(_op) => {
252                        // Nothing to do
253                    }
254                },
255                Operation::Bitwise(op) => match op {
256                    Bitwise::BitwiseAnd(op) => {
257                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
258                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
259                    }
260                    Bitwise::BitwiseOr(op) => {
261                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
262                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
263                    }
264                    Bitwise::BitwiseXor(op) => {
265                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
266                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
267                    }
268                    Bitwise::CountOnes(_) | Bitwise::LeadingZeros(_) | Bitwise::FindFirstSet(_) => {
269                        // Nothing to do
270                    }
271                    Bitwise::ReverseBits(op) => {
272                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
273                    }
274                    Bitwise::ShiftLeft(op) => {
275                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
276                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
277                    }
278                    Bitwise::ShiftRight(op) => {
279                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
280                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
281                    }
282                    Bitwise::BitwiseNot(op) => {
283                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
284                    }
285                },
286                Operation::Operator(op) => match op {
287                    Operator::Index(op) => {
288                        sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
289                        sanitize_constant_scalar_ref_elem(
290                            &mut op.index,
291                            ElemType::UInt(UIntKind::U32),
292                        );
293                    }
294                    Operator::UncheckedIndex(op) => {
295                        sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
296                        sanitize_constant_scalar_ref_elem(
297                            &mut op.index,
298                            ElemType::UInt(UIntKind::U32),
299                        );
300                    }
301                    Operator::IndexAssign(op) => {
302                        sanitize_constant_scalar_ref_elem(
303                            &mut op.index,
304                            ElemType::UInt(UIntKind::U32),
305                        );
306                        sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
307                    }
308                    Operator::UncheckedIndexAssign(op) => {
309                        sanitize_constant_scalar_ref_elem(
310                            &mut op.index,
311                            ElemType::UInt(UIntKind::U32),
312                        );
313                        sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
314                    }
315                    Operator::And(op) => {
316                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
317                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
318                    }
319                    Operator::Or(op) => {
320                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
321                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
322                    }
323                    Operator::Not(op) => {
324                        sanitize_constant_scalar_ref_elem(&mut op.input, ElemType::Bool);
325                    }
326                    Operator::InitLine(_) => {
327                        // TODO: Sanitize based on elem
328                    }
329                    Operator::CopyMemory(op) => {
330                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
331                        sanitize_constant_scalar_ref_elem(
332                            &mut op.in_index,
333                            ElemType::UInt(UIntKind::U32),
334                        );
335                        sanitize_constant_scalar_ref_elem(
336                            &mut op.out_index,
337                            ElemType::UInt(UIntKind::U32),
338                        );
339                    }
340                    Operator::CopyMemoryBulk(op) => {
341                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
342                        sanitize_constant_scalar_ref_elem(
343                            &mut op.in_index,
344                            ElemType::UInt(UIntKind::U32),
345                        );
346                        sanitize_constant_scalar_ref_elem(
347                            &mut op.out_index,
348                            ElemType::UInt(UIntKind::U32),
349                        );
350                    }
351                    Operator::Select(op) => {
352                        sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
353                        sanitize_constant_scalar_ref_var(&mut op.then, &inst.out.unwrap());
354                        sanitize_constant_scalar_ref_var(&mut op.or_else, &inst.out.unwrap());
355                    }
356                    Operator::Cast(_) => {}
357                    Operator::Reinterpret(_) => {}
358                },
359                Operation::Atomic(op) => match op {
360                    AtomicOp::Load(_) => {}
361                    AtomicOp::Store(_) => {}
362                    AtomicOp::Swap(op) => {
363                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
364                    }
365                    AtomicOp::CompareAndSwap(op) => {
366                        sanitize_constant_scalar_ref_var(&mut op.cmp, &inst.out.unwrap());
367                        sanitize_constant_scalar_ref_var(&mut op.val, &inst.out.unwrap());
368                    }
369                    AtomicOp::Add(op) => {
370                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
371                    }
372                    AtomicOp::Sub(op) => {
373                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
374                    }
375                    AtomicOp::Max(op) => {
376                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
377                    }
378                    AtomicOp::Min(op) => {
379                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
380                    }
381                    AtomicOp::And(op) => {
382                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
383                    }
384                    AtomicOp::Or(op) => {
385                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
386                    }
387                    AtomicOp::Xor(op) => {
388                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
389                    }
390                },
391                Operation::Metadata(op) => match op {
392                    Metadata::Stride { dim, .. } => {
393                        sanitize_constant_scalar_ref_elem(dim, ElemType::UInt(UIntKind::U32));
394                    }
395                    Metadata::Shape { dim, .. } => {
396                        sanitize_constant_scalar_ref_elem(dim, ElemType::UInt(UIntKind::U32));
397                    }
398                    Metadata::Length { .. }
399                    | Metadata::BufferLength { .. }
400                    | Metadata::Rank { .. } => {
401                        // Nothing to do
402                    }
403                },
404                Operation::Branch(op) => match op {
405                    Branch::If(op) => {
406                        sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
407                    }
408                    Branch::IfElse(op) => {
409                        sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
410                    }
411                    Branch::RangeLoop(op) => {
412                        sanitize_constant_scalar_ref_var(&mut op.end, &op.start);
413                        sanitize_constant_scalar_ref_var(&mut op.i, &op.start);
414                        if let Some(step) = &mut op.step {
415                            sanitize_constant_scalar_ref_elem(step, ElemType::UInt(UIntKind::U32));
416                        }
417                    }
418                    _ => {
419                        // Nothing to do.
420                    }
421                },
422                Operation::Synchronization(_) => {
423                    // Nothing to do.
424                }
425                Operation::Plane(_) => {
426                    // Nothing to do since no constant is possible.
427                }
428                Operation::CoopMma(op) => match op {
429                    CoopMma::Fill { value } => {
430                        sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
431                    }
432                    CoopMma::Load { value, stride, .. } => {
433                        sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
434                        sanitize_constant_scalar_ref_elem(stride, ElemType::UInt(UIntKind::U32));
435                    }
436                    CoopMma::Execute { .. }
437                    | CoopMma::ExecuteManual { .. }
438                    | CoopMma::ExecuteScaled { .. } => {
439                        // Nothing to do.
440                    }
441                    CoopMma::Store { stride, .. } => {
442                        sanitize_constant_scalar_ref_elem(stride, ElemType::UInt(UIntKind::U32));
443                    }
444                    CoopMma::Cast { .. } => {
445                        // Nothing to do.
446                    }
447                    CoopMma::RowIndex { lane_id, i, .. } => {
448                        sanitize_constant_scalar_ref_elem(lane_id, ElemType::UInt(UIntKind::U32));
449                        sanitize_constant_scalar_ref_elem(i, ElemType::UInt(UIntKind::U32));
450                    }
451                    CoopMma::ColIndex { lane_id, i, .. } => {
452                        sanitize_constant_scalar_ref_elem(lane_id, ElemType::UInt(UIntKind::U32));
453                        sanitize_constant_scalar_ref_elem(i, ElemType::UInt(UIntKind::U32));
454                    }
455                    CoopMma::LoadMatrix { .. } | CoopMma::StoreMatrix { .. } => {
456                        // Nothing to do
457                    }
458                },
459                Operation::NonSemantic(_) => {
460                    // Nothing to do.
461                }
462                Operation::Barrier(_) => {
463                    // Nothing to do
464                }
465                Operation::Tma(_) => {
466                    // Nothing to do
467                }
468                Operation::Marker(_) => {
469                    // Nothing to do
470                }
471            });
472        self
473    }
474}
475
476fn sanitize_constant_scalar_ref_var(var: &mut Variable, reference: &Variable) {
477    if !reference.ty.is_semantic() {
478        let elem = reference.ty.elem_type();
479        sanitize_constant_scalar_ref_elem(var, elem);
480    }
481}
482
483fn sanitize_constant_scalar_ref_elem(var: &mut Variable, elem: ElemType) {
484    if let VariableKind::ConstantScalar(scalar) = var.kind
485        && scalar.elem_type() != elem
486    {
487        *var = match scalar {
488            super::ConstantScalarValue::Int(val, _) => elem.constant_from_i64(val),
489            super::ConstantScalarValue::Float(val, _) => elem.constant_from_f64(val),
490            super::ConstantScalarValue::UInt(val, _) => elem.constant_from_u64(val),
491            super::ConstantScalarValue::Bool(val) => elem.constant_from_bool(val),
492        };
493    }
494}