cubecl_core/post_processing/
unroll.rs

1use cubecl_ir::{
2    Allocator, Arithmetic, BinaryOperator, Branch, CoopMma, CopyMemoryBulkOperator, ExpandElement,
3    IndexAssignOperator, IndexOperator, Instruction, MatrixLayout, Metadata, Operation,
4    OperationReflect, Operator, Processor, ScopeProcessing, Type, Variable, VariableKind,
5};
6use hashbrown::HashMap;
7
8/// The action that should be performed on an instruction, returned by [`IrTransformer::maybe_transform`]
9pub enum TransformAction {
10    /// The transformer doesn't apply to this instruction
11    Ignore,
12    /// Replace this instruction with one or more other instructions
13    Replace(Vec<Instruction>),
14}
15
16#[derive(new, Debug)]
17pub struct UnrollProcessor {
18    max_line_size: u32,
19}
20
21struct Mappings(HashMap<Variable, Vec<ExpandElement>>);
22
23impl Mappings {
24    fn get(
25        &mut self,
26        alloc: &Allocator,
27        var: Variable,
28        unroll_factor: u32,
29        line_size: u32,
30    ) -> Vec<Variable> {
31        self.0
32            .entry(var)
33            .or_insert_with(|| create_unrolled(alloc, &var, line_size, unroll_factor))
34            .iter()
35            .map(|it| **it)
36            .collect()
37    }
38}
39
40impl UnrollProcessor {
41    fn maybe_transform(
42        &self,
43        alloc: &Allocator,
44        inst: &Instruction,
45        mappings: &mut Mappings,
46    ) -> TransformAction {
47        if matches!(inst.operation, Operation::Marker(_)) {
48            return TransformAction::Ignore;
49        }
50
51        if inst.operation.args().is_none() {
52            // Detect unhandled ops that can't be reflected
53            match &inst.operation {
54                Operation::CoopMma(op) => match op {
55                    // Stride is in scalar elems
56                    CoopMma::Load {
57                        value,
58                        stride,
59                        offset,
60                        layout,
61                    } if value.line_size() > self.max_line_size => {
62                        return TransformAction::Replace(self.transform_cmma_load(
63                            alloc,
64                            inst.out(),
65                            value,
66                            stride,
67                            offset,
68                            layout,
69                        ));
70                    }
71                    CoopMma::Store {
72                        mat,
73                        stride,
74                        offset,
75                        layout,
76                    } if inst.out().line_size() > self.max_line_size => {
77                        return TransformAction::Replace(self.transform_cmma_store(
78                            alloc,
79                            inst.out(),
80                            mat,
81                            stride,
82                            offset,
83                            layout,
84                        ));
85                    }
86                    _ => return TransformAction::Ignore,
87                },
88                Operation::Branch(_) | Operation::NonSemantic(_) | Operation::Marker(_) => {
89                    return TransformAction::Ignore;
90                }
91                _ => {
92                    panic!("Need special handling for unrolling non-reflectable operations")
93                }
94            }
95        }
96
97        let args = inst.operation.args().unwrap_or_default();
98        if (inst.out.is_some() && inst.ty().line_size() > self.max_line_size)
99            || args.iter().any(|arg| arg.line_size() > self.max_line_size)
100        {
101            let line_size = max_line_size(&inst.out, &args);
102            let unroll_factor = line_size / self.max_line_size;
103
104            match &inst.operation {
105                Operation::Operator(Operator::CopyMemoryBulk(op)) => TransformAction::Replace(
106                    self.transform_memcpy(alloc, op, inst.out(), unroll_factor),
107                ),
108                Operation::Operator(Operator::CopyMemory(op)) => {
109                    TransformAction::Replace(self.transform_memcpy(
110                        alloc,
111                        &CopyMemoryBulkOperator {
112                            out_index: op.out_index,
113                            input: op.input,
114                            in_index: op.in_index,
115                            len: 1,
116                            offset_input: 0.into(),
117                            offset_out: 0.into(),
118                        },
119                        inst.out(),
120                        unroll_factor,
121                    ))
122                }
123                Operation::Operator(Operator::Index(op)) if op.list.is_array() => {
124                    TransformAction::Replace(self.transform_array_index(
125                        alloc,
126                        inst.out(),
127                        op,
128                        Operator::Index,
129                        unroll_factor,
130                        mappings,
131                    ))
132                }
133                Operation::Operator(Operator::UncheckedIndex(op)) if op.list.is_array() => {
134                    TransformAction::Replace(self.transform_array_index(
135                        alloc,
136                        inst.out(),
137                        op,
138                        Operator::UncheckedIndex,
139                        unroll_factor,
140                        mappings,
141                    ))
142                }
143                Operation::Operator(Operator::Index(op)) => {
144                    TransformAction::Replace(self.transform_composite_index(
145                        alloc,
146                        inst.out(),
147                        op,
148                        Operator::Index,
149                        unroll_factor,
150                        mappings,
151                    ))
152                }
153                Operation::Operator(Operator::UncheckedIndex(op)) => {
154                    TransformAction::Replace(self.transform_composite_index(
155                        alloc,
156                        inst.out(),
157                        op,
158                        Operator::UncheckedIndex,
159                        unroll_factor,
160                        mappings,
161                    ))
162                }
163                Operation::Operator(Operator::IndexAssign(op)) if inst.out().is_array() => {
164                    TransformAction::Replace(self.transform_array_index_assign(
165                        alloc,
166                        inst.out(),
167                        op,
168                        Operator::IndexAssign,
169                        unroll_factor,
170                        mappings,
171                    ))
172                }
173                Operation::Operator(Operator::UncheckedIndexAssign(op))
174                    if inst.out().is_array() =>
175                {
176                    TransformAction::Replace(self.transform_array_index_assign(
177                        alloc,
178                        inst.out(),
179                        op,
180                        Operator::UncheckedIndexAssign,
181                        unroll_factor,
182                        mappings,
183                    ))
184                }
185                Operation::Operator(Operator::IndexAssign(op)) => {
186                    TransformAction::Replace(self.transform_composite_index_assign(
187                        alloc,
188                        inst.out(),
189                        op,
190                        Operator::IndexAssign,
191                        unroll_factor,
192                        mappings,
193                    ))
194                }
195                Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
196                    TransformAction::Replace(self.transform_composite_index_assign(
197                        alloc,
198                        inst.out(),
199                        op,
200                        Operator::UncheckedIndexAssign,
201                        unroll_factor,
202                        mappings,
203                    ))
204                }
205                Operation::Metadata(op) => {
206                    TransformAction::Replace(self.transform_metadata(inst.out(), op, args))
207                }
208                _ => TransformAction::Replace(self.transform_basic(
209                    alloc,
210                    inst,
211                    args,
212                    unroll_factor,
213                    mappings,
214                )),
215            }
216        } else {
217            TransformAction::Ignore
218        }
219    }
220
221    /// Transform CMMA load offset and array
222    fn transform_cmma_load(
223        &self,
224        alloc: &Allocator,
225        out: Variable,
226        value: &Variable,
227        stride: &Variable,
228        offset: &Variable,
229        layout: &Option<MatrixLayout>,
230    ) -> Vec<Instruction> {
231        let line_size = value.line_size();
232        let unroll_factor = line_size / self.max_line_size;
233
234        let value = unroll_array(*value, self.max_line_size, unroll_factor);
235        let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
236        let load = Instruction::new(
237            Operation::CoopMma(CoopMma::Load {
238                value,
239                stride: *stride,
240                offset: *offset,
241                layout: *layout,
242            }),
243            out,
244        );
245        vec![mul, load]
246    }
247
248    /// Transform CMMA store offset and array
249    fn transform_cmma_store(
250        &self,
251        alloc: &Allocator,
252        out: Variable,
253        mat: &Variable,
254        stride: &Variable,
255        offset: &Variable,
256        layout: &MatrixLayout,
257    ) -> Vec<Instruction> {
258        let line_size = out.line_size();
259        let unroll_factor = line_size / self.max_line_size;
260
261        let out = unroll_array(out, self.max_line_size, unroll_factor);
262        let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
263        let store = Instruction::new(
264            Operation::CoopMma(CoopMma::Store {
265                mat: *mat,
266                stride: *stride,
267                offset: *offset,
268                layout: *layout,
269            }),
270            out,
271        );
272        vec![mul, store]
273    }
274
275    /// Transforms memcpy into one with higher length and adjusted indices/offsets
276    fn transform_memcpy(
277        &self,
278        alloc: &Allocator,
279        op: &CopyMemoryBulkOperator,
280        out: Variable,
281        unroll_factor: u32,
282    ) -> Vec<Instruction> {
283        let (mul1, in_index) = mul_index(alloc, op.in_index, unroll_factor);
284        let (mul2, offset_input) = mul_index(alloc, op.offset_input, unroll_factor);
285        let (mul3, out_index) = mul_index(alloc, op.out_index, unroll_factor);
286        let (mul4, offset_out) = mul_index(alloc, op.offset_out, unroll_factor);
287
288        let input = unroll_array(op.input, self.max_line_size, unroll_factor);
289        let out = unroll_array(out, self.max_line_size, unroll_factor);
290
291        vec![
292            mul1,
293            mul2,
294            mul3,
295            mul4,
296            Instruction::new(
297                Operator::CopyMemoryBulk(CopyMemoryBulkOperator {
298                    input,
299                    in_index: *in_index,
300                    out_index: *out_index,
301                    len: op.len * unroll_factor,
302                    offset_input: *offset_input,
303                    offset_out: *offset_out,
304                }),
305                out,
306            ),
307        ]
308    }
309
310    /// Transforms indexing into multiple index operations, each offset by 1 from the base. The base
311    /// is also multiplied by the unroll factor to compensate for the lower actual vectorization.
312    fn transform_array_index(
313        &self,
314        alloc: &Allocator,
315        out: Variable,
316        op: &IndexOperator,
317        operator: impl Fn(IndexOperator) -> Operator,
318        unroll_factor: u32,
319        mappings: &mut Mappings,
320    ) -> Vec<Instruction> {
321        let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
322        let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
323
324        let list = unroll_array(op.list, self.max_line_size, unroll_factor);
325
326        let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
327        let mut instructions = vec![mul];
328        instructions.extend((0..unroll_factor as usize).flat_map(|i| {
329            let (add, idx) = indices.next().unwrap();
330            let index = Instruction::new(
331                operator(IndexOperator {
332                    list,
333                    index: *idx,
334                    line_size: 0,
335                    unroll_factor,
336                }),
337                out[i],
338            );
339            [add, index]
340        }));
341
342        instructions
343    }
344
345    /// Transforms index assign into multiple index assign operations, each offset by 1 from the base.
346    /// The base is also multiplied by the unroll factor to compensate for the lower actual vectorization.
347    fn transform_array_index_assign(
348        &self,
349        alloc: &Allocator,
350        out: Variable,
351        op: &IndexAssignOperator,
352        operator: impl Fn(IndexAssignOperator) -> Operator,
353        unroll_factor: u32,
354        mappings: &mut Mappings,
355    ) -> Vec<Instruction> {
356        let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
357        let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
358
359        let out = unroll_array(out, self.max_line_size, unroll_factor);
360
361        let value = mappings.get(alloc, op.value, unroll_factor, self.max_line_size);
362
363        let mut instructions = vec![mul];
364        instructions.extend((0..unroll_factor as usize).flat_map(|i| {
365            let (add, idx) = indices.next().unwrap();
366            let index = Instruction::new(
367                operator(IndexAssignOperator {
368                    index: *idx,
369                    line_size: 0,
370                    value: value[i],
371                    unroll_factor,
372                }),
373                out,
374            );
375
376            [add, index]
377        }));
378
379        instructions
380    }
381
382    /// Transforms a composite index (i.e. `Line`) that always returns a scalar. Translates the index
383    /// to a local index and an unroll index, then indexes the proper variable. Note that this requires
384    /// the index to be constant - it needs to be decomposed at compile time, otherwise it wouldn't
385    /// work.
386    fn transform_composite_index(
387        &self,
388        alloc: &Allocator,
389        out: Variable,
390        op: &IndexOperator,
391        operator: impl Fn(IndexOperator) -> Operator,
392        unroll_factor: u32,
393        mappings: &mut Mappings,
394    ) -> Vec<Instruction> {
395        let index = op
396            .index
397            .as_const()
398            .expect("Can't unroll non-constant vector index")
399            .as_u32();
400
401        let unroll_idx = index / self.max_line_size;
402        let sub_idx = index % self.max_line_size;
403
404        let value = mappings.get(alloc, op.list, unroll_factor, self.max_line_size);
405
406        vec![Instruction::new(
407            operator(IndexOperator {
408                list: value[unroll_idx as usize],
409                index: sub_idx.into(),
410                line_size: 1,
411                unroll_factor,
412            }),
413            out,
414        )]
415    }
416
417    /// Transforms a composite index assign (i.e. `Line`) that always takes a scalar. Translates the index
418    /// to a local index and an unroll index, then indexes the proper variable. Note that this requires
419    /// the index to be constant - it needs to be decomposed at compile time, otherwise it wouldn't
420    /// work.
421    fn transform_composite_index_assign(
422        &self,
423        alloc: &Allocator,
424        out: Variable,
425        op: &IndexAssignOperator,
426        operator: impl Fn(IndexAssignOperator) -> Operator,
427        unroll_factor: u32,
428        mappings: &mut Mappings,
429    ) -> Vec<Instruction> {
430        let index = op
431            .index
432            .as_const()
433            .expect("Can't unroll non-constant vector index")
434            .as_u32();
435
436        let unroll_idx = index / self.max_line_size;
437        let sub_idx = index % self.max_line_size;
438
439        let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
440
441        vec![Instruction::new(
442            operator(IndexAssignOperator {
443                index: sub_idx.into(),
444                line_size: 1,
445                value: op.value,
446                unroll_factor,
447            }),
448            out[unroll_idx as usize],
449        )]
450    }
451
452    /// Transforms metadata by just replacing the type of the buffer. The values are already
453    /// properly calculated on the CPU.
454    fn transform_metadata(
455        &self,
456        out: Variable,
457        op: &Metadata,
458        args: Vec<Variable>,
459    ) -> Vec<Instruction> {
460        let op_code = op.op_code();
461        let args = args
462            .into_iter()
463            .map(|mut var| {
464                if var.line_size() > self.max_line_size {
465                    var.ty = var.ty.line(self.max_line_size);
466                }
467                var
468            })
469            .collect::<Vec<_>>();
470        let operation = Metadata::from_code_and_args(op_code, &args).unwrap();
471        vec![Instruction::new(operation, out)]
472    }
473
474    /// Transforms generic instructions, i.e. comparison, arithmetic. Unrolls each vectorized variable
475    /// to `unroll_factor` replacements, and executes the operation `unroll_factor` times.
476    fn transform_basic(
477        &self,
478        alloc: &Allocator,
479        inst: &Instruction,
480        args: Vec<Variable>,
481        unroll_factor: u32,
482        mappings: &mut Mappings,
483    ) -> Vec<Instruction> {
484        let op_code = inst.operation.op_code();
485        let out = inst
486            .out
487            .map(|out| mappings.get(alloc, out, unroll_factor, self.max_line_size));
488        let args = args
489            .into_iter()
490            .map(|arg| {
491                if arg.line_size() > 1 {
492                    mappings.get(alloc, arg, unroll_factor, self.max_line_size)
493                } else {
494                    // Preserve scalars
495                    vec![arg]
496                }
497            })
498            .collect::<Vec<_>>();
499
500        (0..unroll_factor as usize)
501            .map(|i| {
502                let out = out.as_ref().map(|out| out[i]);
503                let args = args
504                    .iter()
505                    .map(|arg| if arg.len() == 1 { arg[0] } else { arg[i] })
506                    .collect::<Vec<_>>();
507                let operation = Operation::from_code_and_args(op_code, &args)
508                    .expect("Failed to reconstruct operation");
509                Instruction {
510                    out,
511                    source_loc: inst.source_loc.clone(),
512                    modes: inst.modes,
513                    operation,
514                }
515            })
516            .collect()
517    }
518
519    fn transform_instructions(
520        &self,
521        allocator: &Allocator,
522        instructions: Vec<Instruction>,
523        mappings: &mut Mappings,
524    ) -> Vec<Instruction> {
525        let mut new_instructions = Vec::with_capacity(instructions.len());
526
527        for mut instruction in instructions {
528            if let Operation::Branch(branch) = &mut instruction.operation {
529                match branch {
530                    Branch::If(op) => {
531                        op.scope.instructions = self.transform_instructions(
532                            allocator,
533                            op.scope.instructions.drain(..).collect(),
534                            mappings,
535                        );
536                    }
537                    Branch::IfElse(op) => {
538                        op.scope_if.instructions = self.transform_instructions(
539                            allocator,
540                            op.scope_if.instructions.drain(..).collect(),
541                            mappings,
542                        );
543                        op.scope_else.instructions = self.transform_instructions(
544                            allocator,
545                            op.scope_else.instructions.drain(..).collect(),
546                            mappings,
547                        );
548                    }
549                    Branch::Switch(op) => {
550                        for (_, case) in &mut op.cases {
551                            case.instructions = self.transform_instructions(
552                                allocator,
553                                case.instructions.drain(..).collect(),
554                                mappings,
555                            );
556                        }
557                        op.scope_default.instructions = self.transform_instructions(
558                            allocator,
559                            op.scope_default.instructions.drain(..).collect(),
560                            mappings,
561                        );
562                    }
563                    Branch::RangeLoop(op) => {
564                        op.scope.instructions = self.transform_instructions(
565                            allocator,
566                            op.scope.instructions.drain(..).collect(),
567                            mappings,
568                        );
569                    }
570                    Branch::Loop(op) => {
571                        op.scope.instructions = self.transform_instructions(
572                            allocator,
573                            op.scope.instructions.drain(..).collect(),
574                            mappings,
575                        );
576                    }
577                    _ => {}
578                }
579            }
580            match self.maybe_transform(allocator, &instruction, mappings) {
581                TransformAction::Ignore => {
582                    new_instructions.push(instruction);
583                }
584                TransformAction::Replace(replacement) => {
585                    new_instructions.extend(replacement);
586                }
587            }
588        }
589
590        new_instructions
591    }
592}
593
594impl Processor for UnrollProcessor {
595    fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
596        let mut mappings = Mappings(Default::default());
597
598        let instructions =
599            self.transform_instructions(&allocator, processing.instructions, &mut mappings);
600
601        ScopeProcessing {
602            variables: processing.variables,
603            instructions,
604        }
605    }
606}
607
608fn max_line_size(out: &Option<Variable>, args: &[Variable]) -> u32 {
609    let line_size = args.iter().map(|it| it.line_size()).max().unwrap();
610    line_size.max(out.map(|out| out.line_size()).unwrap_or(1))
611}
612
613fn create_unrolled(
614    allocator: &Allocator,
615    var: &Variable,
616    max_line_size: u32,
617    unroll_factor: u32,
618) -> Vec<ExpandElement> {
619    // Preserve scalars
620    if var.line_size() == 1 {
621        return vec![ExpandElement::Plain(*var); unroll_factor as usize];
622    }
623
624    let item = Type::new(var.storage_type()).line(max_line_size);
625    (0..unroll_factor as usize)
626        .map(|_| match var.kind {
627            VariableKind::LocalMut { .. } | VariableKind::Versioned { .. } => {
628                allocator.create_local_mut(item)
629            }
630            VariableKind::Shared { .. } => {
631                let id = allocator.new_local_index();
632                let shared = VariableKind::Shared { id };
633                ExpandElement::Plain(Variable::new(shared, item))
634            }
635            VariableKind::LocalConst { .. } => allocator.create_local(item),
636            other => panic!("Out must be local, found {other:?}"),
637        })
638        .collect()
639}
640
641fn add_index(alloc: &Allocator, idx: Variable, i: u32) -> (Instruction, ExpandElement) {
642    let add_idx = alloc.create_local(idx.ty);
643    let add = Instruction::new(
644        Arithmetic::Add(BinaryOperator {
645            lhs: idx,
646            rhs: i.into(),
647        }),
648        *add_idx,
649    );
650    (add, add_idx)
651}
652
653fn mul_index(alloc: &Allocator, idx: Variable, unroll_factor: u32) -> (Instruction, ExpandElement) {
654    let mul_idx = alloc.create_local(idx.ty);
655    let mul = Instruction::new(
656        Arithmetic::Mul(BinaryOperator {
657            lhs: idx,
658            rhs: unroll_factor.into(),
659        }),
660        *mul_idx,
661    );
662    (mul, mul_idx)
663}
664
665fn unroll_array(mut var: Variable, max_line_size: u32, factor: u32) -> Variable {
666    var.ty = var.ty.line(max_line_size);
667
668    match &mut var.kind {
669        VariableKind::LocalArray { unroll_factor, .. }
670        | VariableKind::ConstantArray { unroll_factor, .. }
671        | VariableKind::SharedArray { unroll_factor, .. } => {
672            *unroll_factor = factor;
673        }
674        _ => {}
675    }
676
677    var
678}