Skip to main content

cubecl_core/post_processing/
unroll.rs

1use alloc::{vec, vec::Vec};
2use cubecl_ir::{
3    Allocator, Arithmetic, BinaryOperator, Branch, CoopMma, CopyMemoryBulkOperator,
4    IndexAssignOperator, IndexOperator, Instruction, ManagedVariable, MatrixLayout, Metadata,
5    Operation, OperationReflect, Operator, Processor, ScopeProcessing, Type, Variable,
6    VariableKind, VectorSize,
7};
8use hashbrown::HashMap;
9
10/// The action that should be performed on an instruction, returned by ``IrTransformer::maybe_transform``
11pub enum TransformAction {
12    /// The transformer doesn't apply to this instruction
13    Ignore,
14    /// Replace this instruction with one or more other instructions
15    Replace(Vec<Instruction>),
16}
17
18#[derive(new, Debug)]
19pub struct UnrollProcessor {
20    max_vector_size: VectorSize,
21}
22
23struct Mappings(HashMap<Variable, Vec<ManagedVariable>>);
24
25impl Mappings {
26    fn get(
27        &mut self,
28        alloc: &Allocator,
29        var: Variable,
30        unroll_factor: usize,
31        vector_size: VectorSize,
32    ) -> Vec<Variable> {
33        self.0
34            .entry(var)
35            .or_insert_with(|| create_unrolled(alloc, &var, vector_size, unroll_factor))
36            .iter()
37            .map(|it| **it)
38            .collect()
39    }
40}
41
42impl UnrollProcessor {
43    fn maybe_transform(
44        &self,
45        alloc: &Allocator,
46        inst: &Instruction,
47        mappings: &mut Mappings,
48    ) -> TransformAction {
49        if matches!(inst.operation, Operation::Marker(_)) {
50            return TransformAction::Ignore;
51        }
52
53        if inst.operation.args().is_none() {
54            // Detect unhandled ops that can't be reflected
55            match &inst.operation {
56                Operation::CoopMma(op) => match op {
57                    // Stride is in scalar elems
58                    CoopMma::Load {
59                        value,
60                        stride,
61                        offset,
62                        layout,
63                    } if value.vector_size() > self.max_vector_size => {
64                        return TransformAction::Replace(self.transform_cmma_load(
65                            alloc,
66                            inst.out(),
67                            value,
68                            stride,
69                            offset,
70                            layout,
71                        ));
72                    }
73                    CoopMma::Store {
74                        mat,
75                        stride,
76                        offset,
77                        layout,
78                    } if inst.out().vector_size() > self.max_vector_size => {
79                        return TransformAction::Replace(self.transform_cmma_store(
80                            alloc,
81                            inst.out(),
82                            mat,
83                            stride,
84                            offset,
85                            layout,
86                        ));
87                    }
88                    _ => return TransformAction::Ignore,
89                },
90                Operation::Branch(_) | Operation::NonSemantic(_) | Operation::Marker(_) => {
91                    return TransformAction::Ignore;
92                }
93                _ => {
94                    panic!("Need special handling for unrolling non-reflectable operations")
95                }
96            }
97        }
98
99        let args = inst.operation.args().unwrap_or_default();
100        if (inst.out.is_some() && inst.ty().vector_size() > self.max_vector_size)
101            || args
102                .iter()
103                .any(|arg| arg.vector_size() > self.max_vector_size)
104        {
105            let vector_size = max_vector_size(&inst.out, &args);
106            let unroll_factor = vector_size / self.max_vector_size;
107
108            match &inst.operation {
109                Operation::Operator(Operator::CopyMemoryBulk(op)) => TransformAction::Replace(
110                    self.transform_memcpy(alloc, op, inst.out(), unroll_factor),
111                ),
112                Operation::Operator(Operator::CopyMemory(op)) => {
113                    TransformAction::Replace(self.transform_memcpy(
114                        alloc,
115                        &CopyMemoryBulkOperator {
116                            out_index: op.out_index,
117                            input: op.input,
118                            in_index: op.in_index,
119                            len: 1,
120                            offset_input: 0.into(),
121                            offset_out: 0.into(),
122                        },
123                        inst.out(),
124                        unroll_factor,
125                    ))
126                }
127                Operation::Operator(Operator::Index(op)) if op.list.is_array() => {
128                    TransformAction::Replace(self.transform_array_index(
129                        alloc,
130                        inst.out(),
131                        op,
132                        Operator::Index,
133                        unroll_factor,
134                        mappings,
135                    ))
136                }
137                Operation::Operator(Operator::UncheckedIndex(op)) if op.list.is_array() => {
138                    TransformAction::Replace(self.transform_array_index(
139                        alloc,
140                        inst.out(),
141                        op,
142                        Operator::UncheckedIndex,
143                        unroll_factor,
144                        mappings,
145                    ))
146                }
147                Operation::Operator(Operator::Index(op)) => {
148                    TransformAction::Replace(self.transform_composite_index(
149                        alloc,
150                        inst.out(),
151                        op,
152                        Operator::Index,
153                        unroll_factor,
154                        mappings,
155                    ))
156                }
157                Operation::Operator(Operator::UncheckedIndex(op)) => {
158                    TransformAction::Replace(self.transform_composite_index(
159                        alloc,
160                        inst.out(),
161                        op,
162                        Operator::UncheckedIndex,
163                        unroll_factor,
164                        mappings,
165                    ))
166                }
167                Operation::Operator(Operator::IndexAssign(op)) if inst.out().is_array() => {
168                    TransformAction::Replace(self.transform_array_index_assign(
169                        alloc,
170                        inst.out(),
171                        op,
172                        Operator::IndexAssign,
173                        unroll_factor,
174                        mappings,
175                    ))
176                }
177                Operation::Operator(Operator::UncheckedIndexAssign(op))
178                    if inst.out().is_array() =>
179                {
180                    TransformAction::Replace(self.transform_array_index_assign(
181                        alloc,
182                        inst.out(),
183                        op,
184                        Operator::UncheckedIndexAssign,
185                        unroll_factor,
186                        mappings,
187                    ))
188                }
189                Operation::Operator(Operator::IndexAssign(op)) => {
190                    TransformAction::Replace(self.transform_composite_index_assign(
191                        alloc,
192                        inst.out(),
193                        op,
194                        Operator::IndexAssign,
195                        unroll_factor,
196                        mappings,
197                    ))
198                }
199                Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
200                    TransformAction::Replace(self.transform_composite_index_assign(
201                        alloc,
202                        inst.out(),
203                        op,
204                        Operator::UncheckedIndexAssign,
205                        unroll_factor,
206                        mappings,
207                    ))
208                }
209                Operation::Metadata(op) => {
210                    TransformAction::Replace(self.transform_metadata(inst.out(), op, args))
211                }
212                _ => TransformAction::Replace(self.transform_basic(
213                    alloc,
214                    inst,
215                    args,
216                    unroll_factor,
217                    mappings,
218                )),
219            }
220        } else {
221            TransformAction::Ignore
222        }
223    }
224
225    /// Transform CMMA load offset and array
226    fn transform_cmma_load(
227        &self,
228        alloc: &Allocator,
229        out: Variable,
230        value: &Variable,
231        stride: &Variable,
232        offset: &Variable,
233        layout: &Option<MatrixLayout>,
234    ) -> Vec<Instruction> {
235        let vector_size = value.vector_size();
236        let unroll_factor = vector_size / self.max_vector_size;
237
238        let value = unroll_array(*value, self.max_vector_size, unroll_factor);
239        let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
240        let load = Instruction::new(
241            Operation::CoopMma(CoopMma::Load {
242                value,
243                stride: *stride,
244                offset: *offset,
245                layout: *layout,
246            }),
247            out,
248        );
249        vec![mul, load]
250    }
251
252    /// Transform CMMA store offset and array
253    fn transform_cmma_store(
254        &self,
255        alloc: &Allocator,
256        out: Variable,
257        mat: &Variable,
258        stride: &Variable,
259        offset: &Variable,
260        layout: &MatrixLayout,
261    ) -> Vec<Instruction> {
262        let vector_size = out.vector_size();
263        let unroll_factor = vector_size / self.max_vector_size;
264
265        let out = unroll_array(out, self.max_vector_size, unroll_factor);
266        let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
267        let store = Instruction::new(
268            Operation::CoopMma(CoopMma::Store {
269                mat: *mat,
270                stride: *stride,
271                offset: *offset,
272                layout: *layout,
273            }),
274            out,
275        );
276        vec![mul, store]
277    }
278
279    /// Transforms memcpy into one with higher length and adjusted indices/offsets
280    fn transform_memcpy(
281        &self,
282        alloc: &Allocator,
283        op: &CopyMemoryBulkOperator,
284        out: Variable,
285        unroll_factor: usize,
286    ) -> Vec<Instruction> {
287        let (mul1, in_index) = mul_index(alloc, op.in_index, unroll_factor);
288        let (mul2, offset_input) = mul_index(alloc, op.offset_input, unroll_factor);
289        let (mul3, out_index) = mul_index(alloc, op.out_index, unroll_factor);
290        let (mul4, offset_out) = mul_index(alloc, op.offset_out, unroll_factor);
291
292        let input = unroll_array(op.input, self.max_vector_size, unroll_factor);
293        let out = unroll_array(out, self.max_vector_size, unroll_factor);
294
295        vec![
296            mul1,
297            mul2,
298            mul3,
299            mul4,
300            Instruction::new(
301                Operator::CopyMemoryBulk(CopyMemoryBulkOperator {
302                    input,
303                    in_index: *in_index,
304                    out_index: *out_index,
305                    len: op.len * unroll_factor,
306                    offset_input: *offset_input,
307                    offset_out: *offset_out,
308                }),
309                out,
310            ),
311        ]
312    }
313
314    /// Transforms indexing into multiple index operations, each offset by 1 from the base. The base
315    /// is also multiplied by the unroll factor to compensate for the lower actual vectorization.
316    fn transform_array_index(
317        &self,
318        alloc: &Allocator,
319        out: Variable,
320        op: &IndexOperator,
321        operator: impl Fn(IndexOperator) -> Operator,
322        unroll_factor: usize,
323        mappings: &mut Mappings,
324    ) -> Vec<Instruction> {
325        let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
326        let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
327
328        let list = unroll_array(op.list, self.max_vector_size, unroll_factor);
329
330        let out = mappings.get(alloc, out, unroll_factor, self.max_vector_size);
331        let mut instructions = vec![mul];
332        instructions.extend((0..unroll_factor).flat_map(|i| {
333            let (add, idx) = indices.next().unwrap();
334            let index = Instruction::new(
335                operator(IndexOperator {
336                    list,
337                    index: *idx,
338                    vector_size: 0,
339                    unroll_factor,
340                }),
341                out[i],
342            );
343            [add, index]
344        }));
345
346        instructions
347    }
348
349    /// Transforms index assign into multiple index assign operations, each offset by 1 from the base.
350    /// The base is also multiplied by the unroll factor to compensate for the lower actual vectorization.
351    fn transform_array_index_assign(
352        &self,
353        alloc: &Allocator,
354        out: Variable,
355        op: &IndexAssignOperator,
356        operator: impl Fn(IndexAssignOperator) -> Operator,
357        unroll_factor: usize,
358        mappings: &mut Mappings,
359    ) -> Vec<Instruction> {
360        let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
361        let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
362
363        let out = unroll_array(out, self.max_vector_size, unroll_factor);
364
365        let value = mappings.get(alloc, op.value, unroll_factor, self.max_vector_size);
366
367        let mut instructions = vec![mul];
368        instructions.extend((0..unroll_factor).flat_map(|i| {
369            let (add, idx) = indices.next().unwrap();
370            let index = Instruction::new(
371                operator(IndexAssignOperator {
372                    index: *idx,
373                    vector_size: 0,
374                    value: value[i],
375                    unroll_factor,
376                }),
377                out,
378            );
379
380            [add, index]
381        }));
382
383        instructions
384    }
385
386    /// Transforms a composite index (i.e. `Vector`) that always returns a scalar. Translates the index
387    /// to a local index and an unroll index, then indexes the proper variable. Note that this requires
388    /// the index to be constant - it needs to be decomposed at compile time, otherwise it wouldn't
389    /// work.
390    fn transform_composite_index(
391        &self,
392        alloc: &Allocator,
393        out: Variable,
394        op: &IndexOperator,
395        operator: impl Fn(IndexOperator) -> Operator,
396        unroll_factor: usize,
397        mappings: &mut Mappings,
398    ) -> Vec<Instruction> {
399        let index = op
400            .index
401            .as_const()
402            .expect("Can't unroll non-constant vector index")
403            .as_usize();
404
405        let unroll_idx = index / self.max_vector_size;
406        let sub_idx = index % self.max_vector_size;
407
408        let value = mappings.get(alloc, op.list, unroll_factor, self.max_vector_size);
409
410        vec![Instruction::new(
411            operator(IndexOperator {
412                list: value[unroll_idx],
413                index: sub_idx.into(),
414                vector_size: 1,
415                unroll_factor,
416            }),
417            out,
418        )]
419    }
420
421    /// Transforms a composite index assign (i.e. `Vector`) that always takes a scalar. Translates the index
422    /// to a local index and an unroll index, then indexes the proper variable. Note that this requires
423    /// the index to be constant - it needs to be decomposed at compile time, otherwise it wouldn't
424    /// work.
425    fn transform_composite_index_assign(
426        &self,
427        alloc: &Allocator,
428        out: Variable,
429        op: &IndexAssignOperator,
430        operator: impl Fn(IndexAssignOperator) -> Operator,
431        unroll_factor: usize,
432        mappings: &mut Mappings,
433    ) -> Vec<Instruction> {
434        let index = op
435            .index
436            .as_const()
437            .expect("Can't unroll non-constant vector index")
438            .as_usize();
439
440        let unroll_idx = index / self.max_vector_size;
441        let sub_idx = index % self.max_vector_size;
442
443        let out = mappings.get(alloc, out, unroll_factor, self.max_vector_size);
444
445        vec![Instruction::new(
446            operator(IndexAssignOperator {
447                index: sub_idx.into(),
448                vector_size: 1,
449                value: op.value,
450                unroll_factor,
451            }),
452            out[unroll_idx],
453        )]
454    }
455
456    /// Transforms metadata by just replacing the type of the buffer. The values are already
457    /// properly calculated on the CPU.
458    fn transform_metadata(
459        &self,
460        out: Variable,
461        op: &Metadata,
462        args: Vec<Variable>,
463    ) -> Vec<Instruction> {
464        let op_code = op.op_code();
465        let args = args
466            .into_iter()
467            .map(|mut var| {
468                if var.vector_size() > self.max_vector_size {
469                    var.ty = var.ty.with_vector_size(self.max_vector_size);
470                }
471                var
472            })
473            .collect::<Vec<_>>();
474        let operation = Metadata::from_code_and_args(op_code, &args).unwrap();
475        vec![Instruction::new(operation, out)]
476    }
477
478    /// Transforms generic instructions, i.e. comparison, arithmetic. Unrolls each vectorized variable
479    /// to `unroll_factor` replacements, and executes the operation `unroll_factor` times.
480    fn transform_basic(
481        &self,
482        alloc: &Allocator,
483        inst: &Instruction,
484        args: Vec<Variable>,
485        unroll_factor: usize,
486        mappings: &mut Mappings,
487    ) -> Vec<Instruction> {
488        let op_code = inst.operation.op_code();
489        let out = inst
490            .out
491            .map(|out| mappings.get(alloc, out, unroll_factor, self.max_vector_size));
492        let args = args
493            .into_iter()
494            .map(|arg| {
495                if arg.vector_size() > 1 {
496                    mappings.get(alloc, arg, unroll_factor, self.max_vector_size)
497                } else {
498                    // Preserve scalars
499                    vec![arg]
500                }
501            })
502            .collect::<Vec<_>>();
503
504        (0..unroll_factor)
505            .map(|i| {
506                let out = out.as_ref().map(|out| out[i]);
507                let args = args
508                    .iter()
509                    .map(|arg| if arg.len() == 1 { arg[0] } else { arg[i] })
510                    .collect::<Vec<_>>();
511                let operation = Operation::from_code_and_args(op_code, &args)
512                    .expect("Failed to reconstruct operation");
513                Instruction {
514                    out,
515                    source_loc: inst.source_loc.clone(),
516                    modes: inst.modes,
517                    operation,
518                }
519            })
520            .collect()
521    }
522
523    fn transform_instructions(
524        &self,
525        allocator: &Allocator,
526        instructions: Vec<Instruction>,
527        mappings: &mut Mappings,
528    ) -> Vec<Instruction> {
529        let mut new_instructions = Vec::with_capacity(instructions.len());
530
531        for mut instruction in instructions {
532            if let Operation::Branch(branch) = &mut instruction.operation {
533                match branch {
534                    Branch::If(op) => {
535                        op.scope.instructions = self.transform_instructions(
536                            allocator,
537                            op.scope.instructions.drain(..).collect(),
538                            mappings,
539                        );
540                    }
541                    Branch::IfElse(op) => {
542                        op.scope_if.instructions = self.transform_instructions(
543                            allocator,
544                            op.scope_if.instructions.drain(..).collect(),
545                            mappings,
546                        );
547                        op.scope_else.instructions = self.transform_instructions(
548                            allocator,
549                            op.scope_else.instructions.drain(..).collect(),
550                            mappings,
551                        );
552                    }
553                    Branch::Switch(op) => {
554                        for (_, case) in &mut op.cases {
555                            case.instructions = self.transform_instructions(
556                                allocator,
557                                case.instructions.drain(..).collect(),
558                                mappings,
559                            );
560                        }
561                        op.scope_default.instructions = self.transform_instructions(
562                            allocator,
563                            op.scope_default.instructions.drain(..).collect(),
564                            mappings,
565                        );
566                    }
567                    Branch::RangeLoop(op) => {
568                        op.scope.instructions = self.transform_instructions(
569                            allocator,
570                            op.scope.instructions.drain(..).collect(),
571                            mappings,
572                        );
573                    }
574                    Branch::Loop(op) => {
575                        op.scope.instructions = self.transform_instructions(
576                            allocator,
577                            op.scope.instructions.drain(..).collect(),
578                            mappings,
579                        );
580                    }
581                    _ => {}
582                }
583            }
584            match self.maybe_transform(allocator, &instruction, mappings) {
585                TransformAction::Ignore => {
586                    new_instructions.push(instruction);
587                }
588                TransformAction::Replace(replacement) => {
589                    new_instructions.extend(replacement);
590                }
591            }
592        }
593
594        new_instructions
595    }
596}
597
598impl Processor for UnrollProcessor {
599    fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
600        let mut mappings = Mappings(Default::default());
601
602        let instructions =
603            self.transform_instructions(&allocator, processing.instructions, &mut mappings);
604
605        ScopeProcessing {
606            variables: processing.variables,
607            instructions,
608            typemap: processing.typemap.clone(),
609        }
610    }
611}
612
613fn max_vector_size(out: &Option<Variable>, args: &[Variable]) -> VectorSize {
614    let vector_size = args.iter().map(|it| it.vector_size()).max().unwrap();
615    vector_size.max(out.map(|out| out.vector_size()).unwrap_or(1))
616}
617
618fn create_unrolled(
619    allocator: &Allocator,
620    var: &Variable,
621    max_vector_size: VectorSize,
622    unroll_factor: usize,
623) -> Vec<ManagedVariable> {
624    // Preserve scalars
625    if var.vector_size() == 1 {
626        return vec![ManagedVariable::Plain(*var); unroll_factor];
627    }
628
629    let item = Type::new(var.storage_type()).with_vector_size(max_vector_size);
630    (0..unroll_factor)
631        .map(|_| match var.kind {
632            VariableKind::LocalMut { .. } | VariableKind::Versioned { .. } => {
633                allocator.create_local_mut(item)
634            }
635            VariableKind::Shared { .. } => {
636                let id = allocator.new_local_index();
637                let shared = VariableKind::Shared { id };
638                ManagedVariable::Plain(Variable::new(shared, item))
639            }
640            VariableKind::LocalConst { .. } => allocator.create_local(item),
641            other => panic!("Out must be local, found {other:?}"),
642        })
643        .collect()
644}
645
646fn add_index(alloc: &Allocator, idx: Variable, i: usize) -> (Instruction, ManagedVariable) {
647    let add_idx = alloc.create_local(idx.ty);
648    let add = Instruction::new(
649        Arithmetic::Add(BinaryOperator {
650            lhs: idx,
651            rhs: i.into(),
652        }),
653        *add_idx,
654    );
655    (add, add_idx)
656}
657
658fn mul_index(
659    alloc: &Allocator,
660    idx: Variable,
661    unroll_factor: usize,
662) -> (Instruction, ManagedVariable) {
663    let mul_idx = alloc.create_local(idx.ty);
664    let mul = Instruction::new(
665        Arithmetic::Mul(BinaryOperator {
666            lhs: idx,
667            rhs: unroll_factor.into(),
668        }),
669        *mul_idx,
670    );
671    (mul, mul_idx)
672}
673
674fn unroll_array(mut var: Variable, max_vector_size: VectorSize, factor: usize) -> Variable {
675    var.ty = var.ty.with_vector_size(max_vector_size);
676
677    match &mut var.kind {
678        VariableKind::LocalArray { unroll_factor, .. }
679        | VariableKind::ConstantArray { unroll_factor, .. }
680        | VariableKind::SharedArray { unroll_factor, .. } => {
681            *unroll_factor = factor;
682        }
683        _ => {}
684    }
685
686    var
687}