cubecl_core/post_processing/
unroll.rs

1use cubecl_ir::{
2    Allocator, Arithmetic, BinaryOperator, Branch, CoopMma, CopyMemoryBulkOperator, ExpandElement,
3    IndexAssignOperator, IndexOperator, Instruction, MatrixLayout, Metadata, Operation,
4    OperationReflect, Operator, Processor, ScopeProcessing, Type, Variable, VariableKind,
5};
6use hashbrown::HashMap;
7
8/// The action that should be performed on an instruction, returned by [`IrTransformer::maybe_transform`]
9pub enum TransformAction {
10    /// The transformer doesn't apply to this instruction
11    Ignore,
12    /// Replace this instruction with one or more other instructions
13    Replace(Vec<Instruction>),
14}
15
16#[derive(new, Debug)]
17pub struct UnrollProcessor {
18    max_line_size: u32,
19}
20
21struct Mappings(HashMap<Variable, Vec<ExpandElement>>);
22
23impl Mappings {
24    fn get(
25        &mut self,
26        alloc: &Allocator,
27        var: Variable,
28        unroll_factor: u32,
29        line_size: u32,
30    ) -> Vec<Variable> {
31        self.0
32            .entry(var)
33            .or_insert_with(|| create_unrolled(alloc, &var, line_size, unroll_factor))
34            .iter()
35            .map(|it| **it)
36            .collect()
37    }
38}
39
40impl UnrollProcessor {
41    fn maybe_transform(
42        &self,
43        alloc: &Allocator,
44        inst: &Instruction,
45        mappings: &mut Mappings,
46    ) -> TransformAction {
47        if matches!(inst.operation, Operation::Free(_)) {
48            return TransformAction::Ignore;
49        }
50
51        if inst.operation.args().is_none() {
52            // Detect unhandled ops that can't be reflected
53            match &inst.operation {
54                Operation::CoopMma(op) => match op {
55                    // Stride is in scalar elems
56                    CoopMma::Load {
57                        value,
58                        stride,
59                        offset,
60                        layout,
61                    } if value.line_size() > self.max_line_size => {
62                        return TransformAction::Replace(self.transform_cmma_load(
63                            alloc,
64                            inst.out(),
65                            value,
66                            stride,
67                            offset,
68                            layout,
69                        ));
70                    }
71                    CoopMma::Store {
72                        mat,
73                        stride,
74                        offset,
75                        layout,
76                    } if inst.out().line_size() > self.max_line_size => {
77                        return TransformAction::Replace(self.transform_cmma_store(
78                            alloc,
79                            inst.out(),
80                            mat,
81                            stride,
82                            offset,
83                            layout,
84                        ));
85                    }
86                    _ => return TransformAction::Ignore,
87                },
88                Operation::Branch(_) | Operation::NonSemantic(_) => {
89                    return TransformAction::Ignore;
90                }
91                _ => {
92                    panic!("Need special handling for unrolling non-reflectable operations")
93                }
94            }
95        }
96
97        let args = inst.operation.args().unwrap_or_default();
98        if (inst.out.is_some() && inst.ty().line_size() > self.max_line_size)
99            || args.iter().any(|arg| arg.line_size() > self.max_line_size)
100        {
101            let line_size = max_line_size(&inst.out, &args);
102            let unroll_factor = line_size / self.max_line_size;
103
104            match &inst.operation {
105                Operation::Operator(Operator::CopyMemoryBulk(op)) => TransformAction::Replace(
106                    self.transform_memcpy(alloc, op, inst.out(), unroll_factor),
107                ),
108                Operation::Operator(Operator::CopyMemory(op)) => {
109                    TransformAction::Replace(self.transform_memcpy(
110                        alloc,
111                        &CopyMemoryBulkOperator {
112                            out_index: op.out_index,
113                            input: op.input,
114                            in_index: op.in_index,
115                            len: 1,
116                            offset_input: 0.into(),
117                            offset_out: 0.into(),
118                        },
119                        inst.out(),
120                        unroll_factor,
121                    ))
122                }
123                Operation::Operator(Operator::Index(op)) if op.list.is_array() => {
124                    TransformAction::Replace(self.transform_array_index(
125                        alloc,
126                        inst.out(),
127                        op,
128                        Operator::Index,
129                        unroll_factor,
130                        mappings,
131                    ))
132                }
133                Operation::Operator(Operator::UncheckedIndex(op)) if op.list.is_array() => {
134                    TransformAction::Replace(self.transform_array_index(
135                        alloc,
136                        inst.out(),
137                        op,
138                        Operator::UncheckedIndex,
139                        unroll_factor,
140                        mappings,
141                    ))
142                }
143                Operation::Operator(Operator::Index(op)) => {
144                    TransformAction::Replace(self.transform_composite_index(
145                        alloc,
146                        inst.out(),
147                        op,
148                        Operator::Index,
149                        unroll_factor,
150                        mappings,
151                    ))
152                }
153                Operation::Operator(Operator::UncheckedIndex(op)) => {
154                    TransformAction::Replace(self.transform_composite_index(
155                        alloc,
156                        inst.out(),
157                        op,
158                        Operator::UncheckedIndex,
159                        unroll_factor,
160                        mappings,
161                    ))
162                }
163                Operation::Operator(Operator::IndexAssign(op)) if inst.out().is_array() => {
164                    TransformAction::Replace(self.transform_array_index_assign(
165                        alloc,
166                        inst.out(),
167                        op,
168                        Operator::IndexAssign,
169                        unroll_factor,
170                        mappings,
171                    ))
172                }
173                Operation::Operator(Operator::UncheckedIndexAssign(op))
174                    if inst.out().is_array() =>
175                {
176                    TransformAction::Replace(self.transform_array_index_assign(
177                        alloc,
178                        inst.out(),
179                        op,
180                        Operator::UncheckedIndexAssign,
181                        unroll_factor,
182                        mappings,
183                    ))
184                }
185                Operation::Operator(Operator::IndexAssign(op)) => {
186                    TransformAction::Replace(self.transform_composite_index_assign(
187                        alloc,
188                        inst.out(),
189                        op,
190                        Operator::IndexAssign,
191                        unroll_factor,
192                        mappings,
193                    ))
194                }
195                Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
196                    TransformAction::Replace(self.transform_composite_index_assign(
197                        alloc,
198                        inst.out(),
199                        op,
200                        Operator::UncheckedIndexAssign,
201                        unroll_factor,
202                        mappings,
203                    ))
204                }
205                Operation::Metadata(op) => {
206                    TransformAction::Replace(self.transform_metadata(inst.out(), op, args))
207                }
208                _ => TransformAction::Replace(self.transform_basic(
209                    alloc,
210                    inst,
211                    args,
212                    unroll_factor,
213                    mappings,
214                )),
215            }
216        } else {
217            TransformAction::Ignore
218        }
219    }
220
221    /// Transform CMMA load offset and array
222    fn transform_cmma_load(
223        &self,
224        alloc: &Allocator,
225        out: Variable,
226        value: &Variable,
227        stride: &Variable,
228        offset: &Variable,
229        layout: &Option<MatrixLayout>,
230    ) -> Vec<Instruction> {
231        let line_size = value.line_size();
232        let unroll_factor = line_size / self.max_line_size;
233
234        let value = unroll_array(*value, self.max_line_size, unroll_factor);
235        let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
236        let load = Instruction::new(
237            Operation::CoopMma(CoopMma::Load {
238                value,
239                stride: *stride,
240                offset: *offset,
241                layout: *layout,
242            }),
243            out,
244        );
245        vec![mul, load]
246    }
247
248    /// Transform CMMA store offset and array
249    fn transform_cmma_store(
250        &self,
251        alloc: &Allocator,
252        out: Variable,
253        mat: &Variable,
254        stride: &Variable,
255        offset: &Variable,
256        layout: &MatrixLayout,
257    ) -> Vec<Instruction> {
258        let line_size = out.line_size();
259        let unroll_factor = line_size / self.max_line_size;
260
261        let out = unroll_array(out, self.max_line_size, unroll_factor);
262        let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
263        let store = Instruction::new(
264            Operation::CoopMma(CoopMma::Store {
265                mat: *mat,
266                stride: *stride,
267                offset: *offset,
268                layout: *layout,
269            }),
270            out,
271        );
272        vec![mul, store]
273    }
274
275    /// Transforms memcpy into one with higher length and adjusted indices/offsets
276    fn transform_memcpy(
277        &self,
278        alloc: &Allocator,
279        op: &CopyMemoryBulkOperator,
280        out: Variable,
281        unroll_factor: u32,
282    ) -> Vec<Instruction> {
283        let (mul1, in_index) = mul_index(alloc, op.in_index, unroll_factor);
284        let (mul2, offset_input) = mul_index(alloc, op.offset_input, unroll_factor);
285        let (mul3, out_index) = mul_index(alloc, op.out_index, unroll_factor);
286        let (mul4, offset_out) = mul_index(alloc, op.offset_out, unroll_factor);
287
288        let input = unroll_array(op.input, self.max_line_size, unroll_factor);
289        let out = unroll_array(out, self.max_line_size, unroll_factor);
290
291        vec![
292            mul1,
293            mul2,
294            mul3,
295            mul4,
296            Instruction::new(
297                Operator::CopyMemoryBulk(CopyMemoryBulkOperator {
298                    input,
299                    in_index: *in_index,
300                    out_index: *out_index,
301                    len: op.len * unroll_factor,
302                    offset_input: *offset_input,
303                    offset_out: *offset_out,
304                }),
305                out,
306            ),
307        ]
308    }
309
310    /// Transforms indexing into multiple index operations, each offset by 1 from the base. The base
311    /// is also multiplied by the unroll factor to compensate for the lower actual vectorization.
312    fn transform_array_index(
313        &self,
314        alloc: &Allocator,
315        out: Variable,
316        op: &IndexOperator,
317        operator: impl Fn(IndexOperator) -> Operator,
318        unroll_factor: u32,
319        mappings: &mut Mappings,
320    ) -> Vec<Instruction> {
321        let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
322        let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
323
324        let list = unroll_array(op.list, self.max_line_size, unroll_factor);
325
326        let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
327        let mut instructions = vec![mul];
328        instructions.extend((0..unroll_factor as usize).flat_map(|i| {
329            let (add, idx) = indices.next().unwrap();
330            let index = Instruction::new(
331                operator(IndexOperator {
332                    list,
333                    index: *idx,
334                    line_size: 0,
335                    unroll_factor,
336                }),
337                out[i],
338            );
339            [add, index]
340        }));
341
342        instructions
343    }
344
345    /// Transforms index assign into multiple index assign operations, each offset by 1 from the base.
346    /// The base is also multiplied by the unroll factor to compensate for the lower actual vectorization.
347    fn transform_array_index_assign(
348        &self,
349        alloc: &Allocator,
350        out: Variable,
351        op: &IndexAssignOperator,
352        operator: impl Fn(IndexAssignOperator) -> Operator,
353        unroll_factor: u32,
354        mappings: &mut Mappings,
355    ) -> Vec<Instruction> {
356        let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
357        let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
358
359        let out = unroll_array(out, self.max_line_size, unroll_factor);
360
361        let value = mappings.get(alloc, op.value, unroll_factor, self.max_line_size);
362
363        let mut instructions = vec![mul];
364        instructions.extend((0..unroll_factor as usize).flat_map(|i| {
365            let (add, idx) = indices.next().unwrap();
366            let index = Instruction::new(
367                operator(IndexAssignOperator {
368                    index: *idx,
369                    line_size: 0,
370                    value: value[i],
371                    unroll_factor,
372                }),
373                out,
374            );
375
376            [add, index]
377        }));
378
379        instructions
380    }
381
382    /// Transforms a composite index (i.e. `Line`) that always returns a scalar. Translates the index
383    /// to a local index and an unroll index, then indexes the proper variable. Note that this requires
384    /// the index to be constant - it needs to be decomposed at compile time, otherwise it wouldn't
385    /// work.
386    fn transform_composite_index(
387        &self,
388        alloc: &Allocator,
389        out: Variable,
390        op: &IndexOperator,
391        operator: impl Fn(IndexOperator) -> Operator,
392        unroll_factor: u32,
393        mappings: &mut Mappings,
394    ) -> Vec<Instruction> {
395        let index = op
396            .index
397            .as_const()
398            .expect("Can't unroll non-constant vector index")
399            .as_u32();
400
401        let unroll_idx = index / self.max_line_size;
402        let sub_idx = index % self.max_line_size;
403
404        let value = mappings.get(alloc, op.list, unroll_factor, self.max_line_size);
405
406        vec![Instruction::new(
407            operator(IndexOperator {
408                list: value[unroll_idx as usize],
409                index: sub_idx.into(),
410                line_size: 1,
411                unroll_factor,
412            }),
413            out,
414        )]
415    }
416
417    /// Transforms a composite index assign (i.e. `Line`) that always takes a scalar. Translates the index
418    /// to a local index and an unroll index, then indexes the proper variable. Note that this requires
419    /// the index to be constant - it needs to be decomposed at compile time, otherwise it wouldn't
420    /// work.
421    fn transform_composite_index_assign(
422        &self,
423        alloc: &Allocator,
424        out: Variable,
425        op: &IndexAssignOperator,
426        operator: impl Fn(IndexAssignOperator) -> Operator,
427        unroll_factor: u32,
428        mappings: &mut Mappings,
429    ) -> Vec<Instruction> {
430        let index = op
431            .index
432            .as_const()
433            .expect("Can't unroll non-constant vector index")
434            .as_u32();
435
436        let unroll_idx = index / self.max_line_size;
437        let sub_idx = index % self.max_line_size;
438
439        let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
440
441        vec![Instruction::new(
442            operator(IndexAssignOperator {
443                index: sub_idx.into(),
444                line_size: 1,
445                value: op.value,
446                unroll_factor,
447            }),
448            out[unroll_idx as usize],
449        )]
450    }
451
452    /// Transforms metadata by just replacing the type of the buffer. The values are already
453    /// properly calculated on the CPU.
454    fn transform_metadata(
455        &self,
456        out: Variable,
457        op: &Metadata,
458        args: Vec<Variable>,
459    ) -> Vec<Instruction> {
460        let op_code = op.op_code();
461        let args = args
462            .into_iter()
463            .map(|mut var| {
464                if var.line_size() > self.max_line_size {
465                    var.ty = var.ty.line(self.max_line_size);
466                }
467                var
468            })
469            .collect::<Vec<_>>();
470        let operation = Metadata::from_code_and_args(op_code, &args).unwrap();
471        vec![Instruction::new(operation, out)]
472    }
473
474    /// Transforms generic instructions, i.e. comparison, arithmetic. Unrolls each vectorized variable
475    /// to `unroll_factor` replacements, and executes the operation `unroll_factor` times.
476    fn transform_basic(
477        &self,
478        alloc: &Allocator,
479        inst: &Instruction,
480        args: Vec<Variable>,
481        unroll_factor: u32,
482        mappings: &mut Mappings,
483    ) -> Vec<Instruction> {
484        let op_code = inst.operation.op_code();
485        let out = inst
486            .out
487            .map(|out| mappings.get(alloc, out, unroll_factor, self.max_line_size));
488        let args = args
489            .into_iter()
490            .map(|arg| {
491                if arg.line_size() > 1 {
492                    mappings.get(alloc, arg, unroll_factor, self.max_line_size)
493                } else {
494                    // Preserve scalars
495                    vec![arg]
496                }
497            })
498            .collect::<Vec<_>>();
499
500        (0..unroll_factor as usize)
501            .map(|i| {
502                let out = out.as_ref().map(|out| out[i]);
503                let args = args
504                    .iter()
505                    .map(|arg| if arg.len() == 1 { arg[0] } else { arg[i] })
506                    .collect::<Vec<_>>();
507                let operation = Operation::from_code_and_args(op_code, &args)
508                    .expect("Failed to reconstruct operation");
509                Instruction {
510                    out,
511                    source_loc: inst.source_loc.clone(),
512                    operation,
513                }
514            })
515            .collect()
516    }
517
518    fn transform_instructions(
519        &self,
520        allocator: &Allocator,
521        instructions: Vec<Instruction>,
522        mappings: &mut Mappings,
523    ) -> Vec<Instruction> {
524        let mut new_instructions = Vec::with_capacity(instructions.len());
525
526        for mut instruction in instructions {
527            if let Operation::Branch(branch) = &mut instruction.operation {
528                match branch {
529                    Branch::If(op) => {
530                        op.scope.instructions = self.transform_instructions(
531                            allocator,
532                            op.scope.instructions.drain(..).collect(),
533                            mappings,
534                        );
535                    }
536                    Branch::IfElse(op) => {
537                        op.scope_if.instructions = self.transform_instructions(
538                            allocator,
539                            op.scope_if.instructions.drain(..).collect(),
540                            mappings,
541                        );
542                        op.scope_else.instructions = self.transform_instructions(
543                            allocator,
544                            op.scope_else.instructions.drain(..).collect(),
545                            mappings,
546                        );
547                    }
548                    Branch::Switch(op) => {
549                        for (_, case) in &mut op.cases {
550                            case.instructions = self.transform_instructions(
551                                allocator,
552                                case.instructions.drain(..).collect(),
553                                mappings,
554                            );
555                        }
556                        op.scope_default.instructions = self.transform_instructions(
557                            allocator,
558                            op.scope_default.instructions.drain(..).collect(),
559                            mappings,
560                        );
561                    }
562                    Branch::RangeLoop(op) => {
563                        op.scope.instructions = self.transform_instructions(
564                            allocator,
565                            op.scope.instructions.drain(..).collect(),
566                            mappings,
567                        );
568                    }
569                    Branch::Loop(op) => {
570                        op.scope.instructions = self.transform_instructions(
571                            allocator,
572                            op.scope.instructions.drain(..).collect(),
573                            mappings,
574                        );
575                    }
576                    _ => {}
577                }
578            }
579            match self.maybe_transform(allocator, &instruction, mappings) {
580                TransformAction::Ignore => {
581                    new_instructions.push(instruction);
582                }
583                TransformAction::Replace(replacement) => {
584                    new_instructions.extend(replacement);
585                }
586            }
587        }
588
589        new_instructions
590    }
591}
592
593impl Processor for UnrollProcessor {
594    fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
595        let mut mappings = Mappings(Default::default());
596
597        let instructions =
598            self.transform_instructions(&allocator, processing.instructions, &mut mappings);
599
600        ScopeProcessing {
601            variables: processing.variables,
602            instructions,
603        }
604    }
605}
606
607fn max_line_size(out: &Option<Variable>, args: &[Variable]) -> u32 {
608    let line_size = args.iter().map(|it| it.line_size()).max().unwrap();
609    line_size.max(out.map(|out| out.line_size()).unwrap_or(1))
610}
611
612fn create_unrolled(
613    allocator: &Allocator,
614    var: &Variable,
615    max_line_size: u32,
616    unroll_factor: u32,
617) -> Vec<ExpandElement> {
618    // Preserve scalars
619    if var.line_size() == 1 {
620        return vec![ExpandElement::Plain(*var); unroll_factor as usize];
621    }
622
623    let item = Type::new(var.storage_type()).line(max_line_size);
624    (0..unroll_factor as usize)
625        .map(|_| match var.kind {
626            VariableKind::LocalMut { .. } | VariableKind::Versioned { .. } => {
627                allocator.create_local_mut(item)
628            }
629            VariableKind::LocalConst { .. } => allocator.create_local(item),
630            other => panic!("Out must be local, found {other:?}"),
631        })
632        .collect()
633}
634
635fn add_index(alloc: &Allocator, idx: Variable, i: u32) -> (Instruction, ExpandElement) {
636    let add_idx = alloc.create_local(idx.ty);
637    let add = Instruction::new(
638        Arithmetic::Add(BinaryOperator {
639            lhs: idx,
640            rhs: i.into(),
641        }),
642        *add_idx,
643    );
644    (add, add_idx)
645}
646
647fn mul_index(alloc: &Allocator, idx: Variable, unroll_factor: u32) -> (Instruction, ExpandElement) {
648    let mul_idx = alloc.create_local(idx.ty);
649    let mul = Instruction::new(
650        Arithmetic::Mul(BinaryOperator {
651            lhs: idx,
652            rhs: unroll_factor.into(),
653        }),
654        *mul_idx,
655    );
656    (mul, mul_idx)
657}
658
659fn unroll_array(mut var: Variable, max_line_size: u32, factor: u32) -> Variable {
660    var.ty = var.ty.line(max_line_size);
661
662    match &mut var.kind {
663        VariableKind::LocalArray { unroll_factor, .. }
664        | VariableKind::ConstantArray { unroll_factor, .. }
665        | VariableKind::SharedMemory { unroll_factor, .. } => {
666            *unroll_factor = factor;
667        }
668        _ => {}
669    }
670
671    var
672}