cubecl_ir/
processing.rs

1use core::fmt::Display;
2
3use alloc::string::ToString;
4use alloc::vec::Vec;
5
6use crate::{Allocator, AtomicOp, Bitwise, Comparison, Operator};
7
8use super::{
9    Arithmetic, Branch, CoopMma, ElemType, Instruction, Metadata, Operation, UIntKind, Variable,
10    VariableKind,
11};
12
13pub trait Processor: core::fmt::Debug {
14    fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing;
15}
16
17/// Information necessary when compiling a scope.
18pub struct ScopeProcessing {
19    /// The variable declarations.
20    pub variables: Vec<Variable>,
21    /// The operations.
22    pub instructions: Vec<Instruction>,
23}
24
25impl Display for ScopeProcessing {
26    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27        writeln!(f, "{{")?;
28        for instruction in self.instructions.iter() {
29            let instruction_str = instruction.to_string();
30            if !instruction_str.is_empty() {
31                writeln!(f, "    {instruction_str}")?;
32            }
33        }
34        write!(f, "}}")?;
35        Ok(())
36    }
37}
38
39impl ScopeProcessing {
40    /// Optimize the [variables](Variable) and [operations](Operation).
41    ///
42    /// ## Notes:
43    ///
44    /// This should be called once right after the creation of the type.
45    /// If you built this type from the [scope process function](super::Scope::process), you don't have to
46    /// call it again.
47    pub fn optimize(self) -> Self {
48        self.sanitize_constant_scalars()
49    }
50
51    /// Make sure constant scalars are of the correct type so compilers don't have to do conversion
52    /// and handle edge cases such as indexing with a signed integer.
53    fn sanitize_constant_scalars(mut self) -> Self {
54        self.instructions
55            .iter_mut()
56            .for_each(|inst| match &mut inst.operation {
57                Operation::Copy(op) => {
58                    sanitize_constant_scalar_ref_var(op, &inst.out.unwrap());
59                }
60                Operation::Arithmetic(op) => match op {
61                    Arithmetic::Add(op) => {
62                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
63                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
64                    }
65                    Arithmetic::SaturatingAdd(op) => {
66                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
67                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
68                    }
69                    Arithmetic::Fma(op) => {
70                        sanitize_constant_scalar_ref_var(&mut op.a, &inst.out.unwrap());
71                        sanitize_constant_scalar_ref_var(&mut op.b, &inst.out.unwrap());
72                        sanitize_constant_scalar_ref_var(&mut op.c, &inst.out.unwrap());
73                    }
74                    Arithmetic::Sub(op) => {
75                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
76                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
77                    }
78                    Arithmetic::SaturatingSub(op) => {
79                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
80                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
81                    }
82                    Arithmetic::Mul(op) => {
83                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
84                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
85                    }
86                    Arithmetic::Div(op) => {
87                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
88                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
89                    }
90                    Arithmetic::MulHi(op) => {
91                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
92                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
93                    }
94                    Arithmetic::Abs(op) => {
95                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
96                    }
97                    Arithmetic::Exp(op) => {
98                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
99                    }
100                    Arithmetic::Log(op) => {
101                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
102                    }
103                    Arithmetic::Log1p(op) => {
104                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
105                    }
106                    Arithmetic::Cos(op) => {
107                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
108                    }
109                    Arithmetic::Sin(op) => {
110                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
111                    }
112                    Arithmetic::Tanh(op) => {
113                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
114                    }
115                    Arithmetic::Powf(op) => {
116                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
117                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
118                    }
119                    Arithmetic::Powi(op) => {
120                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
121                    }
122                    Arithmetic::Sqrt(op) => {
123                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
124                    }
125                    Arithmetic::InverseSqrt(op) => {
126                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
127                    }
128                    Arithmetic::Round(op) => {
129                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
130                    }
131                    Arithmetic::Floor(op) => {
132                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
133                    }
134                    Arithmetic::Ceil(op) => {
135                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
136                    }
137                    Arithmetic::Trunc(op) => {
138                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
139                    }
140                    Arithmetic::Erf(op) => {
141                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
142                    }
143                    Arithmetic::Recip(op) => {
144                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
145                    }
146                    Arithmetic::Clamp(op) => {
147                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
148                        sanitize_constant_scalar_ref_var(&mut op.min_value, &inst.out.unwrap());
149                        sanitize_constant_scalar_ref_var(&mut op.max_value, &inst.out.unwrap());
150                    }
151                    Arithmetic::Modulo(op) => {
152                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
153                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
154                    }
155                    Arithmetic::Neg(op) => {
156                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap())
157                    }
158                    Arithmetic::Max(op) => {
159                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
160                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
161                    }
162                    Arithmetic::Min(op) => {
163                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
164                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
165                    }
166                    Arithmetic::Remainder(op) => {
167                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
168                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
169                    }
170                    Arithmetic::Magnitude(op) => {
171                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
172                    }
173                    Arithmetic::Normalize(op) => {
174                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
175                    }
176                    Arithmetic::Dot(op) => {
177                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
178                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
179                    }
180                },
181                Operation::Comparison(op) => match op {
182                    Comparison::Greater(op) => {
183                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
184                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
185                    }
186                    Comparison::LowerEqual(op) => {
187                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
188                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
189                    }
190                    Comparison::GreaterEqual(op) => {
191                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
192                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
193                    }
194                    Comparison::Equal(op) => {
195                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
196                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
197                    }
198                    Comparison::NotEqual(op) => {
199                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
200                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
201                    }
202                    Comparison::Lower(op) => {
203                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
204                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
205                    }
206                    Comparison::IsNan(_op) | Comparison::IsInf(_op) => {
207                        // Nothing to do
208                    }
209                },
210                Operation::Bitwise(op) => match op {
211                    Bitwise::BitwiseAnd(op) => {
212                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
213                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
214                    }
215                    Bitwise::BitwiseOr(op) => {
216                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
217                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
218                    }
219                    Bitwise::BitwiseXor(op) => {
220                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
221                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
222                    }
223                    Bitwise::CountOnes(_) | Bitwise::LeadingZeros(_) | Bitwise::FindFirstSet(_) => {
224                        // Nothing to do
225                    }
226                    Bitwise::ReverseBits(op) => {
227                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
228                    }
229                    Bitwise::ShiftLeft(op) => {
230                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
231                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
232                    }
233                    Bitwise::ShiftRight(op) => {
234                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
235                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
236                    }
237                    Bitwise::BitwiseNot(op) => {
238                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
239                    }
240                },
241                Operation::Operator(op) => match op {
242                    Operator::Index(op) => {
243                        sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
244                        sanitize_constant_scalar_ref_elem(
245                            &mut op.index,
246                            ElemType::UInt(UIntKind::U32),
247                        );
248                    }
249                    Operator::UncheckedIndex(op) => {
250                        sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
251                        sanitize_constant_scalar_ref_elem(
252                            &mut op.index,
253                            ElemType::UInt(UIntKind::U32),
254                        );
255                    }
256                    Operator::IndexAssign(op) => {
257                        sanitize_constant_scalar_ref_elem(
258                            &mut op.index,
259                            ElemType::UInt(UIntKind::U32),
260                        );
261                        sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
262                    }
263                    Operator::UncheckedIndexAssign(op) => {
264                        sanitize_constant_scalar_ref_elem(
265                            &mut op.index,
266                            ElemType::UInt(UIntKind::U32),
267                        );
268                        sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
269                    }
270                    Operator::And(op) => {
271                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
272                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
273                    }
274                    Operator::Or(op) => {
275                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
276                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
277                    }
278                    Operator::Not(op) => {
279                        sanitize_constant_scalar_ref_elem(&mut op.input, ElemType::Bool);
280                    }
281                    Operator::InitLine(_) => {
282                        // TODO: Sanitize based on elem
283                    }
284                    Operator::CopyMemory(op) => {
285                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
286                        sanitize_constant_scalar_ref_elem(
287                            &mut op.in_index,
288                            ElemType::UInt(UIntKind::U32),
289                        );
290                        sanitize_constant_scalar_ref_elem(
291                            &mut op.out_index,
292                            ElemType::UInt(UIntKind::U32),
293                        );
294                    }
295                    Operator::CopyMemoryBulk(op) => {
296                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
297                        sanitize_constant_scalar_ref_elem(
298                            &mut op.in_index,
299                            ElemType::UInt(UIntKind::U32),
300                        );
301                        sanitize_constant_scalar_ref_elem(
302                            &mut op.out_index,
303                            ElemType::UInt(UIntKind::U32),
304                        );
305                    }
306                    Operator::Select(op) => {
307                        sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
308                        sanitize_constant_scalar_ref_var(&mut op.then, &inst.out.unwrap());
309                        sanitize_constant_scalar_ref_var(&mut op.or_else, &inst.out.unwrap());
310                    }
311                    Operator::Cast(_) => {}
312                    Operator::Reinterpret(_) => {}
313                },
314                Operation::Atomic(op) => match op {
315                    AtomicOp::Load(_) => {}
316                    AtomicOp::Store(_) => {}
317                    AtomicOp::Swap(op) => {
318                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
319                    }
320                    AtomicOp::CompareAndSwap(op) => {
321                        sanitize_constant_scalar_ref_var(&mut op.cmp, &inst.out.unwrap());
322                        sanitize_constant_scalar_ref_var(&mut op.val, &inst.out.unwrap());
323                    }
324                    AtomicOp::Add(op) => {
325                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
326                    }
327                    AtomicOp::Sub(op) => {
328                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
329                    }
330                    AtomicOp::Max(op) => {
331                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
332                    }
333                    AtomicOp::Min(op) => {
334                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
335                    }
336                    AtomicOp::And(op) => {
337                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
338                    }
339                    AtomicOp::Or(op) => {
340                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
341                    }
342                    AtomicOp::Xor(op) => {
343                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
344                    }
345                },
346                Operation::Metadata(op) => match op {
347                    Metadata::Stride { dim, .. } => {
348                        sanitize_constant_scalar_ref_elem(dim, ElemType::UInt(UIntKind::U32));
349                    }
350                    Metadata::Shape { dim, .. } => {
351                        sanitize_constant_scalar_ref_elem(dim, ElemType::UInt(UIntKind::U32));
352                    }
353                    Metadata::Length { .. }
354                    | Metadata::BufferLength { .. }
355                    | Metadata::Rank { .. } => {
356                        // Nothing to do
357                    }
358                },
359                Operation::Branch(op) => match op {
360                    Branch::If(op) => {
361                        sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
362                    }
363                    Branch::IfElse(op) => {
364                        sanitize_constant_scalar_ref_elem(&mut op.cond, ElemType::Bool);
365                    }
366                    Branch::RangeLoop(op) => {
367                        sanitize_constant_scalar_ref_var(&mut op.end, &op.start);
368                        sanitize_constant_scalar_ref_var(&mut op.i, &op.start);
369                        if let Some(step) = &mut op.step {
370                            sanitize_constant_scalar_ref_elem(step, ElemType::UInt(UIntKind::U32));
371                        }
372                    }
373                    _ => {
374                        // Nothing to do.
375                    }
376                },
377                Operation::Synchronization(_) => {
378                    // Nothing to do.
379                }
380                Operation::Plane(_) => {
381                    // Nothing to do since no constant is possible.
382                }
383                Operation::CoopMma(op) => match op {
384                    CoopMma::Fill { value } => {
385                        sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
386                    }
387                    CoopMma::Load { value, stride, .. } => {
388                        sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
389                        sanitize_constant_scalar_ref_elem(stride, ElemType::UInt(UIntKind::U32));
390                    }
391                    CoopMma::Execute { .. }
392                    | CoopMma::ExecuteManual { .. }
393                    | CoopMma::ExecuteScaled { .. } => {
394                        // Nothing to do.
395                    }
396                    CoopMma::Store { stride, .. } => {
397                        sanitize_constant_scalar_ref_elem(stride, ElemType::UInt(UIntKind::U32));
398                    }
399                    CoopMma::Cast { .. } => {
400                        // Nothing to do.
401                    }
402                    CoopMma::RowIndex { lane_id, i, .. } => {
403                        sanitize_constant_scalar_ref_elem(lane_id, ElemType::UInt(UIntKind::U32));
404                        sanitize_constant_scalar_ref_elem(i, ElemType::UInt(UIntKind::U32));
405                    }
406                    CoopMma::ColIndex { lane_id, i, .. } => {
407                        sanitize_constant_scalar_ref_elem(lane_id, ElemType::UInt(UIntKind::U32));
408                        sanitize_constant_scalar_ref_elem(i, ElemType::UInt(UIntKind::U32));
409                    }
410                },
411                Operation::NonSemantic(_) => {
412                    // Nothing to do.
413                }
414                Operation::Barrier(_) => {
415                    // Nothing to do
416                }
417                Operation::Tma(_) => {
418                    // Nothing to do
419                }
420                Operation::Marker(_) => {
421                    // Nothing to do
422                }
423            });
424        self
425    }
426}
427
428fn sanitize_constant_scalar_ref_var(var: &mut Variable, reference: &Variable) {
429    if !reference.ty.is_semantic() {
430        let elem = reference.ty.elem_type();
431        sanitize_constant_scalar_ref_elem(var, elem);
432    }
433}
434
435fn sanitize_constant_scalar_ref_elem(var: &mut Variable, elem: ElemType) {
436    if let VariableKind::ConstantScalar(scalar) = var.kind
437        && scalar.elem_type() != elem
438    {
439        *var = match scalar {
440            super::ConstantScalarValue::Int(val, _) => elem.constant_from_i64(val),
441            super::ConstantScalarValue::Float(val, _) => elem.constant_from_f64(val),
442            super::ConstantScalarValue::UInt(val, _) => elem.constant_from_u64(val),
443            super::ConstantScalarValue::Bool(val) => elem.constant_from_bool(val),
444        };
445    }
446}