cubecl_ir/
processing.rs

1use core::fmt::Display;
2
3use alloc::vec::Vec;
4
5use crate::{Allocator, AtomicOp, Bitwise, Comparison, Operator};
6
7use super::{
8    Arithmetic, Branch, CoopMma, Elem, Instruction, Metadata, Operation, UIntKind, Variable,
9    VariableKind,
10};
11
12pub trait Processor {
13    fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing;
14}
15
16/// Information necessary when compiling a scope.
17pub struct ScopeProcessing {
18    /// The variable declarations.
19    pub variables: Vec<Variable>,
20    /// The operations.
21    pub instructions: Vec<Instruction>,
22}
23
24impl Display for ScopeProcessing {
25    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
26        for inst in self.instructions.iter() {
27            f.write_fmt(format_args!("{inst}\n"))?;
28        }
29
30        Ok(())
31    }
32}
33
34impl ScopeProcessing {
35    /// Optimize the [variables](Variable) and [operations](Operation).
36    ///
37    /// ## Notes:
38    ///
39    /// This should be called once right after the creation of the type.
40    /// If you built this type from the [scope process function](super::Scope::process), you don't have to
41    /// call it again.
42    pub fn optimize(self) -> Self {
43        self.sanitize_constant_scalars()
44    }
45
46    /// Make sure constant scalars are of the correct type so compilers don't have to do conversion
47    /// and handle edge cases such as indexing with a signed integer.
48    fn sanitize_constant_scalars(mut self) -> Self {
49        self.instructions
50            .iter_mut()
51            .for_each(|inst| match &mut inst.operation {
52                Operation::Copy(op) => {
53                    sanitize_constant_scalar_ref_var(op, &inst.out.unwrap());
54                }
55                Operation::Arithmetic(op) => match op {
56                    Arithmetic::Add(op) => {
57                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
58                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
59                    }
60                    Arithmetic::Fma(op) => {
61                        sanitize_constant_scalar_ref_var(&mut op.a, &inst.out.unwrap());
62                        sanitize_constant_scalar_ref_var(&mut op.b, &inst.out.unwrap());
63                        sanitize_constant_scalar_ref_var(&mut op.c, &inst.out.unwrap());
64                    }
65                    Arithmetic::Sub(op) => {
66                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
67                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
68                    }
69                    Arithmetic::Mul(op) => {
70                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
71                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
72                    }
73                    Arithmetic::Div(op) => {
74                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
75                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
76                    }
77                    Arithmetic::MulHi(op) => {
78                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
79                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
80                    }
81                    Arithmetic::Abs(op) => {
82                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
83                    }
84                    Arithmetic::Exp(op) => {
85                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
86                    }
87                    Arithmetic::Log(op) => {
88                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
89                    }
90                    Arithmetic::Log1p(op) => {
91                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
92                    }
93                    Arithmetic::Cos(op) => {
94                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
95                    }
96                    Arithmetic::Sin(op) => {
97                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
98                    }
99                    Arithmetic::Tanh(op) => {
100                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
101                    }
102                    Arithmetic::Powf(op) => {
103                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
104                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
105                    }
106                    Arithmetic::Sqrt(op) => {
107                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
108                    }
109                    Arithmetic::Round(op) => {
110                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
111                    }
112                    Arithmetic::Floor(op) => {
113                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
114                    }
115                    Arithmetic::Ceil(op) => {
116                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
117                    }
118                    Arithmetic::Erf(op) => {
119                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
120                    }
121                    Arithmetic::Recip(op) => {
122                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
123                    }
124                    Arithmetic::Clamp(op) => {
125                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
126                        sanitize_constant_scalar_ref_var(&mut op.min_value, &inst.out.unwrap());
127                        sanitize_constant_scalar_ref_var(&mut op.max_value, &inst.out.unwrap());
128                    }
129                    Arithmetic::Modulo(op) => {
130                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
131                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
132                    }
133                    Arithmetic::Neg(op) => {
134                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap())
135                    }
136                    Arithmetic::Max(op) => {
137                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
138                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
139                    }
140                    Arithmetic::Min(op) => {
141                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
142                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
143                    }
144                    Arithmetic::Remainder(op) => {
145                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
146                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
147                    }
148                    Arithmetic::Magnitude(op) => {
149                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
150                    }
151                    Arithmetic::Normalize(op) => {
152                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
153                    }
154                    Arithmetic::Dot(op) => {
155                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
156                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
157                    }
158                },
159                Operation::Comparison(op) => match op {
160                    Comparison::Greater(op) => {
161                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
162                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
163                    }
164                    Comparison::LowerEqual(op) => {
165                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
166                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
167                    }
168                    Comparison::GreaterEqual(op) => {
169                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
170                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
171                    }
172                    Comparison::Equal(op) => {
173                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
174                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
175                    }
176                    Comparison::NotEqual(op) => {
177                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
178                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
179                    }
180                    Comparison::Lower(op) => {
181                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
182                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
183                    }
184                },
185                Operation::Bitwise(op) => match op {
186                    Bitwise::BitwiseAnd(op) => {
187                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
188                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
189                    }
190                    Bitwise::BitwiseOr(op) => {
191                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
192                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
193                    }
194                    Bitwise::BitwiseXor(op) => {
195                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
196                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
197                    }
198                    Bitwise::CountOnes(_) | Bitwise::LeadingZeros(_) | Bitwise::FindFirstSet(_) => {
199                        // Nothing to do
200                    }
201                    Bitwise::ReverseBits(op) => {
202                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
203                    }
204                    Bitwise::ShiftLeft(op) => {
205                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
206                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
207                    }
208                    Bitwise::ShiftRight(op) => {
209                        sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
210                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
211                    }
212                    Bitwise::BitwiseNot(op) => {
213                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
214                    }
215                },
216                Operation::Operator(op) => match op {
217                    Operator::Index(op) => {
218                        sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
219                        sanitize_constant_scalar_ref_elem(&mut op.index, Elem::UInt(UIntKind::U32));
220                    }
221                    Operator::UncheckedIndex(op) => {
222                        sanitize_constant_scalar_ref_var(&mut op.list, &inst.out.unwrap());
223                        sanitize_constant_scalar_ref_elem(&mut op.index, Elem::UInt(UIntKind::U32));
224                    }
225                    Operator::IndexAssign(op) => {
226                        sanitize_constant_scalar_ref_elem(&mut op.index, Elem::UInt(UIntKind::U32));
227                        sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
228                    }
229                    Operator::UncheckedIndexAssign(op) => {
230                        sanitize_constant_scalar_ref_elem(&mut op.index, Elem::UInt(UIntKind::U32));
231                        sanitize_constant_scalar_ref_var(&mut op.value, &inst.out.unwrap());
232                    }
233                    Operator::And(op) => {
234                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
235                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
236                    }
237                    Operator::Or(op) => {
238                        sanitize_constant_scalar_ref_var(&mut op.lhs, &op.rhs);
239                        sanitize_constant_scalar_ref_var(&mut op.rhs, &op.lhs);
240                    }
241                    Operator::Not(op) => {
242                        sanitize_constant_scalar_ref_elem(&mut op.input, Elem::Bool);
243                    }
244                    Operator::InitLine(_) => {
245                        // TODO: Sanitize based on elem
246                    }
247                    Operator::CopyMemory(op) => {
248                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
249                        sanitize_constant_scalar_ref_elem(
250                            &mut op.in_index,
251                            Elem::UInt(UIntKind::U32),
252                        );
253                        sanitize_constant_scalar_ref_elem(
254                            &mut op.out_index,
255                            Elem::UInt(UIntKind::U32),
256                        );
257                    }
258                    Operator::CopyMemoryBulk(op) => {
259                        sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
260                        sanitize_constant_scalar_ref_elem(
261                            &mut op.in_index,
262                            Elem::UInt(UIntKind::U32),
263                        );
264                        sanitize_constant_scalar_ref_elem(
265                            &mut op.out_index,
266                            Elem::UInt(UIntKind::U32),
267                        );
268                    }
269                    Operator::Select(op) => {
270                        sanitize_constant_scalar_ref_elem(&mut op.cond, Elem::Bool);
271                        sanitize_constant_scalar_ref_var(&mut op.then, &inst.out.unwrap());
272                        sanitize_constant_scalar_ref_var(&mut op.or_else, &inst.out.unwrap());
273                    }
274                    Operator::Cast(_) => {}
275                    Operator::Reinterpret(_) => {}
276                },
277                Operation::Atomic(op) => match op {
278                    AtomicOp::Load(_) => {}
279                    AtomicOp::Store(_) => {}
280                    AtomicOp::Swap(op) => {
281                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
282                    }
283                    AtomicOp::CompareAndSwap(op) => {
284                        sanitize_constant_scalar_ref_var(&mut op.cmp, &inst.out.unwrap());
285                        sanitize_constant_scalar_ref_var(&mut op.val, &inst.out.unwrap());
286                    }
287                    AtomicOp::Add(op) => {
288                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
289                    }
290                    AtomicOp::Sub(op) => {
291                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
292                    }
293                    AtomicOp::Max(op) => {
294                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
295                    }
296                    AtomicOp::Min(op) => {
297                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
298                    }
299                    AtomicOp::And(op) => {
300                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
301                    }
302                    AtomicOp::Or(op) => {
303                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
304                    }
305                    AtomicOp::Xor(op) => {
306                        sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
307                    }
308                },
309                Operation::Metadata(op) => match op {
310                    Metadata::Stride { dim, .. } => {
311                        sanitize_constant_scalar_ref_elem(dim, Elem::UInt(UIntKind::U32));
312                    }
313                    Metadata::Shape { dim, .. } => {
314                        sanitize_constant_scalar_ref_elem(dim, Elem::UInt(UIntKind::U32));
315                    }
316                    Metadata::Length { .. }
317                    | Metadata::BufferLength { .. }
318                    | Metadata::Rank { .. } => {
319                        // Nothing to do
320                    }
321                },
322                Operation::Branch(op) => match op {
323                    Branch::If(op) => {
324                        sanitize_constant_scalar_ref_elem(&mut op.cond, Elem::Bool);
325                    }
326                    Branch::IfElse(op) => {
327                        sanitize_constant_scalar_ref_elem(&mut op.cond, Elem::Bool);
328                    }
329                    Branch::RangeLoop(op) => {
330                        sanitize_constant_scalar_ref_var(&mut op.end, &op.start);
331                        sanitize_constant_scalar_ref_var(&mut op.i, &op.start);
332                        if let Some(step) = &mut op.step {
333                            sanitize_constant_scalar_ref_elem(step, Elem::UInt(UIntKind::U32));
334                        }
335                    }
336                    _ => {
337                        // Nothing to do.
338                    }
339                },
340                Operation::Synchronization(_) => {
341                    // Nothing to do.
342                }
343                Operation::Plane(_) => {
344                    // Nothing to do since no constant is possible.
345                }
346                Operation::CoopMma(op) => match op {
347                    CoopMma::Fill { value } => {
348                        sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
349                    }
350                    CoopMma::Load { value, stride, .. } => {
351                        sanitize_constant_scalar_ref_var(value, &inst.out.unwrap());
352                        sanitize_constant_scalar_ref_elem(stride, Elem::UInt(UIntKind::U32));
353                    }
354                    CoopMma::Execute { .. } => {
355                        // Nothing to do.
356                    }
357                    CoopMma::Store { stride, .. } => {
358                        sanitize_constant_scalar_ref_elem(stride, Elem::UInt(UIntKind::U32));
359                    }
360                    CoopMma::Cast { .. } => {
361                        // Nothing to do.
362                    }
363                },
364                Operation::NonSemantic(_) => {
365                    // Nothing to do.
366                }
367                Operation::Barrier(_) => {
368                    // Nothing to do
369                }
370                Operation::Tma(_) => {
371                    // Nothing to do
372                }
373            });
374        self
375    }
376}
377
378fn sanitize_constant_scalar_ref_var(var: &mut Variable, reference: &Variable) {
379    let elem = reference.item.elem();
380    sanitize_constant_scalar_ref_elem(var, elem);
381}
382
383fn sanitize_constant_scalar_ref_elem(var: &mut Variable, elem: Elem) {
384    if let VariableKind::ConstantScalar(scalar) = var.kind {
385        if scalar.elem() != elem {
386            *var = match scalar {
387                super::ConstantScalarValue::Int(val, _) => elem.constant_from_i64(val),
388                super::ConstantScalarValue::Float(val, _) => elem.constant_from_f64(val),
389                super::ConstantScalarValue::UInt(val, _) => elem.constant_from_u64(val),
390                super::ConstantScalarValue::Bool(val) => elem.constant_from_bool(val),
391            };
392        }
393    }
394}