Skip to main content

cubecl_core/post_processing/
unroll.rs

1use alloc::{vec, vec::Vec};
2use cubecl_ir::{
3    Allocator, Arithmetic, BinaryOperator, Branch, CoopMma, CopyMemoryBulkOperator, ExpandElement,
4    IndexAssignOperator, IndexOperator, Instruction, LineSize, MatrixLayout, Metadata, Operation,
5    OperationReflect, Operator, Processor, ScopeProcessing, Type, Variable, VariableKind,
6};
7use hashbrown::HashMap;
8
9/// The action that should be performed on an instruction, returned by ``IrTransformer::maybe_transform``
10pub enum TransformAction {
11    /// The transformer doesn't apply to this instruction
12    Ignore,
13    /// Replace this instruction with one or more other instructions
14    Replace(Vec<Instruction>),
15}
16
17#[derive(new, Debug)]
18pub struct UnrollProcessor {
19    max_line_size: LineSize,
20}
21
22struct Mappings(HashMap<Variable, Vec<ExpandElement>>);
23
24impl Mappings {
25    fn get(
26        &mut self,
27        alloc: &Allocator,
28        var: Variable,
29        unroll_factor: usize,
30        line_size: LineSize,
31    ) -> Vec<Variable> {
32        self.0
33            .entry(var)
34            .or_insert_with(|| create_unrolled(alloc, &var, line_size, unroll_factor))
35            .iter()
36            .map(|it| **it)
37            .collect()
38    }
39}
40
41impl UnrollProcessor {
42    fn maybe_transform(
43        &self,
44        alloc: &Allocator,
45        inst: &Instruction,
46        mappings: &mut Mappings,
47    ) -> TransformAction {
48        if matches!(inst.operation, Operation::Marker(_)) {
49            return TransformAction::Ignore;
50        }
51
52        if inst.operation.args().is_none() {
53            // Detect unhandled ops that can't be reflected
54            match &inst.operation {
55                Operation::CoopMma(op) => match op {
56                    // Stride is in scalar elems
57                    CoopMma::Load {
58                        value,
59                        stride,
60                        offset,
61                        layout,
62                    } if value.line_size() > self.max_line_size => {
63                        return TransformAction::Replace(self.transform_cmma_load(
64                            alloc,
65                            inst.out(),
66                            value,
67                            stride,
68                            offset,
69                            layout,
70                        ));
71                    }
72                    CoopMma::Store {
73                        mat,
74                        stride,
75                        offset,
76                        layout,
77                    } if inst.out().line_size() > self.max_line_size => {
78                        return TransformAction::Replace(self.transform_cmma_store(
79                            alloc,
80                            inst.out(),
81                            mat,
82                            stride,
83                            offset,
84                            layout,
85                        ));
86                    }
87                    _ => return TransformAction::Ignore,
88                },
89                Operation::Branch(_) | Operation::NonSemantic(_) | Operation::Marker(_) => {
90                    return TransformAction::Ignore;
91                }
92                _ => {
93                    panic!("Need special handling for unrolling non-reflectable operations")
94                }
95            }
96        }
97
98        let args = inst.operation.args().unwrap_or_default();
99        if (inst.out.is_some() && inst.ty().line_size() > self.max_line_size)
100            || args.iter().any(|arg| arg.line_size() > self.max_line_size)
101        {
102            let line_size = max_line_size(&inst.out, &args);
103            let unroll_factor = line_size / self.max_line_size;
104
105            match &inst.operation {
106                Operation::Operator(Operator::CopyMemoryBulk(op)) => TransformAction::Replace(
107                    self.transform_memcpy(alloc, op, inst.out(), unroll_factor),
108                ),
109                Operation::Operator(Operator::CopyMemory(op)) => {
110                    TransformAction::Replace(self.transform_memcpy(
111                        alloc,
112                        &CopyMemoryBulkOperator {
113                            out_index: op.out_index,
114                            input: op.input,
115                            in_index: op.in_index,
116                            len: 1,
117                            offset_input: 0.into(),
118                            offset_out: 0.into(),
119                        },
120                        inst.out(),
121                        unroll_factor,
122                    ))
123                }
124                Operation::Operator(Operator::Index(op)) if op.list.is_array() => {
125                    TransformAction::Replace(self.transform_array_index(
126                        alloc,
127                        inst.out(),
128                        op,
129                        Operator::Index,
130                        unroll_factor,
131                        mappings,
132                    ))
133                }
134                Operation::Operator(Operator::UncheckedIndex(op)) if op.list.is_array() => {
135                    TransformAction::Replace(self.transform_array_index(
136                        alloc,
137                        inst.out(),
138                        op,
139                        Operator::UncheckedIndex,
140                        unroll_factor,
141                        mappings,
142                    ))
143                }
144                Operation::Operator(Operator::Index(op)) => {
145                    TransformAction::Replace(self.transform_composite_index(
146                        alloc,
147                        inst.out(),
148                        op,
149                        Operator::Index,
150                        unroll_factor,
151                        mappings,
152                    ))
153                }
154                Operation::Operator(Operator::UncheckedIndex(op)) => {
155                    TransformAction::Replace(self.transform_composite_index(
156                        alloc,
157                        inst.out(),
158                        op,
159                        Operator::UncheckedIndex,
160                        unroll_factor,
161                        mappings,
162                    ))
163                }
164                Operation::Operator(Operator::IndexAssign(op)) if inst.out().is_array() => {
165                    TransformAction::Replace(self.transform_array_index_assign(
166                        alloc,
167                        inst.out(),
168                        op,
169                        Operator::IndexAssign,
170                        unroll_factor,
171                        mappings,
172                    ))
173                }
174                Operation::Operator(Operator::UncheckedIndexAssign(op))
175                    if inst.out().is_array() =>
176                {
177                    TransformAction::Replace(self.transform_array_index_assign(
178                        alloc,
179                        inst.out(),
180                        op,
181                        Operator::UncheckedIndexAssign,
182                        unroll_factor,
183                        mappings,
184                    ))
185                }
186                Operation::Operator(Operator::IndexAssign(op)) => {
187                    TransformAction::Replace(self.transform_composite_index_assign(
188                        alloc,
189                        inst.out(),
190                        op,
191                        Operator::IndexAssign,
192                        unroll_factor,
193                        mappings,
194                    ))
195                }
196                Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
197                    TransformAction::Replace(self.transform_composite_index_assign(
198                        alloc,
199                        inst.out(),
200                        op,
201                        Operator::UncheckedIndexAssign,
202                        unroll_factor,
203                        mappings,
204                    ))
205                }
206                Operation::Metadata(op) => {
207                    TransformAction::Replace(self.transform_metadata(inst.out(), op, args))
208                }
209                _ => TransformAction::Replace(self.transform_basic(
210                    alloc,
211                    inst,
212                    args,
213                    unroll_factor,
214                    mappings,
215                )),
216            }
217        } else {
218            TransformAction::Ignore
219        }
220    }
221
222    /// Transform CMMA load offset and array
223    fn transform_cmma_load(
224        &self,
225        alloc: &Allocator,
226        out: Variable,
227        value: &Variable,
228        stride: &Variable,
229        offset: &Variable,
230        layout: &Option<MatrixLayout>,
231    ) -> Vec<Instruction> {
232        let line_size = value.line_size();
233        let unroll_factor = line_size / self.max_line_size;
234
235        let value = unroll_array(*value, self.max_line_size, unroll_factor);
236        let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
237        let load = Instruction::new(
238            Operation::CoopMma(CoopMma::Load {
239                value,
240                stride: *stride,
241                offset: *offset,
242                layout: *layout,
243            }),
244            out,
245        );
246        vec![mul, load]
247    }
248
249    /// Transform CMMA store offset and array
250    fn transform_cmma_store(
251        &self,
252        alloc: &Allocator,
253        out: Variable,
254        mat: &Variable,
255        stride: &Variable,
256        offset: &Variable,
257        layout: &MatrixLayout,
258    ) -> Vec<Instruction> {
259        let line_size = out.line_size();
260        let unroll_factor = line_size / self.max_line_size;
261
262        let out = unroll_array(out, self.max_line_size, unroll_factor);
263        let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
264        let store = Instruction::new(
265            Operation::CoopMma(CoopMma::Store {
266                mat: *mat,
267                stride: *stride,
268                offset: *offset,
269                layout: *layout,
270            }),
271            out,
272        );
273        vec![mul, store]
274    }
275
276    /// Transforms memcpy into one with higher length and adjusted indices/offsets
277    fn transform_memcpy(
278        &self,
279        alloc: &Allocator,
280        op: &CopyMemoryBulkOperator,
281        out: Variable,
282        unroll_factor: usize,
283    ) -> Vec<Instruction> {
284        let (mul1, in_index) = mul_index(alloc, op.in_index, unroll_factor);
285        let (mul2, offset_input) = mul_index(alloc, op.offset_input, unroll_factor);
286        let (mul3, out_index) = mul_index(alloc, op.out_index, unroll_factor);
287        let (mul4, offset_out) = mul_index(alloc, op.offset_out, unroll_factor);
288
289        let input = unroll_array(op.input, self.max_line_size, unroll_factor);
290        let out = unroll_array(out, self.max_line_size, unroll_factor);
291
292        vec![
293            mul1,
294            mul2,
295            mul3,
296            mul4,
297            Instruction::new(
298                Operator::CopyMemoryBulk(CopyMemoryBulkOperator {
299                    input,
300                    in_index: *in_index,
301                    out_index: *out_index,
302                    len: op.len * unroll_factor,
303                    offset_input: *offset_input,
304                    offset_out: *offset_out,
305                }),
306                out,
307            ),
308        ]
309    }
310
311    /// Transforms indexing into multiple index operations, each offset by 1 from the base. The base
312    /// is also multiplied by the unroll factor to compensate for the lower actual vectorization.
313    fn transform_array_index(
314        &self,
315        alloc: &Allocator,
316        out: Variable,
317        op: &IndexOperator,
318        operator: impl Fn(IndexOperator) -> Operator,
319        unroll_factor: usize,
320        mappings: &mut Mappings,
321    ) -> Vec<Instruction> {
322        let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
323        let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
324
325        let list = unroll_array(op.list, self.max_line_size, unroll_factor);
326
327        let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
328        let mut instructions = vec![mul];
329        instructions.extend((0..unroll_factor).flat_map(|i| {
330            let (add, idx) = indices.next().unwrap();
331            let index = Instruction::new(
332                operator(IndexOperator {
333                    list,
334                    index: *idx,
335                    line_size: 0,
336                    unroll_factor,
337                }),
338                out[i],
339            );
340            [add, index]
341        }));
342
343        instructions
344    }
345
346    /// Transforms index assign into multiple index assign operations, each offset by 1 from the base.
347    /// The base is also multiplied by the unroll factor to compensate for the lower actual vectorization.
348    fn transform_array_index_assign(
349        &self,
350        alloc: &Allocator,
351        out: Variable,
352        op: &IndexAssignOperator,
353        operator: impl Fn(IndexAssignOperator) -> Operator,
354        unroll_factor: usize,
355        mappings: &mut Mappings,
356    ) -> Vec<Instruction> {
357        let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
358        let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
359
360        let out = unroll_array(out, self.max_line_size, unroll_factor);
361
362        let value = mappings.get(alloc, op.value, unroll_factor, self.max_line_size);
363
364        let mut instructions = vec![mul];
365        instructions.extend((0..unroll_factor).flat_map(|i| {
366            let (add, idx) = indices.next().unwrap();
367            let index = Instruction::new(
368                operator(IndexAssignOperator {
369                    index: *idx,
370                    line_size: 0,
371                    value: value[i],
372                    unroll_factor,
373                }),
374                out,
375            );
376
377            [add, index]
378        }));
379
380        instructions
381    }
382
383    /// Transforms a composite index (i.e. `Line`) that always returns a scalar. Translates the index
384    /// to a local index and an unroll index, then indexes the proper variable. Note that this requires
385    /// the index to be constant - it needs to be decomposed at compile time, otherwise it wouldn't
386    /// work.
387    fn transform_composite_index(
388        &self,
389        alloc: &Allocator,
390        out: Variable,
391        op: &IndexOperator,
392        operator: impl Fn(IndexOperator) -> Operator,
393        unroll_factor: usize,
394        mappings: &mut Mappings,
395    ) -> Vec<Instruction> {
396        let index = op
397            .index
398            .as_const()
399            .expect("Can't unroll non-constant vector index")
400            .as_usize();
401
402        let unroll_idx = index / self.max_line_size;
403        let sub_idx = index % self.max_line_size;
404
405        let value = mappings.get(alloc, op.list, unroll_factor, self.max_line_size);
406
407        vec![Instruction::new(
408            operator(IndexOperator {
409                list: value[unroll_idx],
410                index: sub_idx.into(),
411                line_size: 1,
412                unroll_factor,
413            }),
414            out,
415        )]
416    }
417
418    /// Transforms a composite index assign (i.e. `Line`) that always takes a scalar. Translates the index
419    /// to a local index and an unroll index, then indexes the proper variable. Note that this requires
420    /// the index to be constant - it needs to be decomposed at compile time, otherwise it wouldn't
421    /// work.
422    fn transform_composite_index_assign(
423        &self,
424        alloc: &Allocator,
425        out: Variable,
426        op: &IndexAssignOperator,
427        operator: impl Fn(IndexAssignOperator) -> Operator,
428        unroll_factor: usize,
429        mappings: &mut Mappings,
430    ) -> Vec<Instruction> {
431        let index = op
432            .index
433            .as_const()
434            .expect("Can't unroll non-constant vector index")
435            .as_usize();
436
437        let unroll_idx = index / self.max_line_size;
438        let sub_idx = index % self.max_line_size;
439
440        let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
441
442        vec![Instruction::new(
443            operator(IndexAssignOperator {
444                index: sub_idx.into(),
445                line_size: 1,
446                value: op.value,
447                unroll_factor,
448            }),
449            out[unroll_idx],
450        )]
451    }
452
453    /// Transforms metadata by just replacing the type of the buffer. The values are already
454    /// properly calculated on the CPU.
455    fn transform_metadata(
456        &self,
457        out: Variable,
458        op: &Metadata,
459        args: Vec<Variable>,
460    ) -> Vec<Instruction> {
461        let op_code = op.op_code();
462        let args = args
463            .into_iter()
464            .map(|mut var| {
465                if var.line_size() > self.max_line_size {
466                    var.ty = var.ty.line(self.max_line_size);
467                }
468                var
469            })
470            .collect::<Vec<_>>();
471        let operation = Metadata::from_code_and_args(op_code, &args).unwrap();
472        vec![Instruction::new(operation, out)]
473    }
474
475    /// Transforms generic instructions, i.e. comparison, arithmetic. Unrolls each vectorized variable
476    /// to `unroll_factor` replacements, and executes the operation `unroll_factor` times.
477    fn transform_basic(
478        &self,
479        alloc: &Allocator,
480        inst: &Instruction,
481        args: Vec<Variable>,
482        unroll_factor: usize,
483        mappings: &mut Mappings,
484    ) -> Vec<Instruction> {
485        let op_code = inst.operation.op_code();
486        let out = inst
487            .out
488            .map(|out| mappings.get(alloc, out, unroll_factor, self.max_line_size));
489        let args = args
490            .into_iter()
491            .map(|arg| {
492                if arg.line_size() > 1 {
493                    mappings.get(alloc, arg, unroll_factor, self.max_line_size)
494                } else {
495                    // Preserve scalars
496                    vec![arg]
497                }
498            })
499            .collect::<Vec<_>>();
500
501        (0..unroll_factor)
502            .map(|i| {
503                let out = out.as_ref().map(|out| out[i]);
504                let args = args
505                    .iter()
506                    .map(|arg| if arg.len() == 1 { arg[0] } else { arg[i] })
507                    .collect::<Vec<_>>();
508                let operation = Operation::from_code_and_args(op_code, &args)
509                    .expect("Failed to reconstruct operation");
510                Instruction {
511                    out,
512                    source_loc: inst.source_loc.clone(),
513                    modes: inst.modes,
514                    operation,
515                }
516            })
517            .collect()
518    }
519
520    fn transform_instructions(
521        &self,
522        allocator: &Allocator,
523        instructions: Vec<Instruction>,
524        mappings: &mut Mappings,
525    ) -> Vec<Instruction> {
526        let mut new_instructions = Vec::with_capacity(instructions.len());
527
528        for mut instruction in instructions {
529            if let Operation::Branch(branch) = &mut instruction.operation {
530                match branch {
531                    Branch::If(op) => {
532                        op.scope.instructions = self.transform_instructions(
533                            allocator,
534                            op.scope.instructions.drain(..).collect(),
535                            mappings,
536                        );
537                    }
538                    Branch::IfElse(op) => {
539                        op.scope_if.instructions = self.transform_instructions(
540                            allocator,
541                            op.scope_if.instructions.drain(..).collect(),
542                            mappings,
543                        );
544                        op.scope_else.instructions = self.transform_instructions(
545                            allocator,
546                            op.scope_else.instructions.drain(..).collect(),
547                            mappings,
548                        );
549                    }
550                    Branch::Switch(op) => {
551                        for (_, case) in &mut op.cases {
552                            case.instructions = self.transform_instructions(
553                                allocator,
554                                case.instructions.drain(..).collect(),
555                                mappings,
556                            );
557                        }
558                        op.scope_default.instructions = self.transform_instructions(
559                            allocator,
560                            op.scope_default.instructions.drain(..).collect(),
561                            mappings,
562                        );
563                    }
564                    Branch::RangeLoop(op) => {
565                        op.scope.instructions = self.transform_instructions(
566                            allocator,
567                            op.scope.instructions.drain(..).collect(),
568                            mappings,
569                        );
570                    }
571                    Branch::Loop(op) => {
572                        op.scope.instructions = self.transform_instructions(
573                            allocator,
574                            op.scope.instructions.drain(..).collect(),
575                            mappings,
576                        );
577                    }
578                    _ => {}
579                }
580            }
581            match self.maybe_transform(allocator, &instruction, mappings) {
582                TransformAction::Ignore => {
583                    new_instructions.push(instruction);
584                }
585                TransformAction::Replace(replacement) => {
586                    new_instructions.extend(replacement);
587                }
588            }
589        }
590
591        new_instructions
592    }
593}
594
595impl Processor for UnrollProcessor {
596    fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
597        let mut mappings = Mappings(Default::default());
598
599        let instructions =
600            self.transform_instructions(&allocator, processing.instructions, &mut mappings);
601
602        ScopeProcessing {
603            variables: processing.variables,
604            instructions,
605            typemap: processing.typemap.clone(),
606        }
607    }
608}
609
610fn max_line_size(out: &Option<Variable>, args: &[Variable]) -> LineSize {
611    let line_size = args.iter().map(|it| it.line_size()).max().unwrap();
612    line_size.max(out.map(|out| out.line_size()).unwrap_or(1))
613}
614
615fn create_unrolled(
616    allocator: &Allocator,
617    var: &Variable,
618    max_line_size: LineSize,
619    unroll_factor: usize,
620) -> Vec<ExpandElement> {
621    // Preserve scalars
622    if var.line_size() == 1 {
623        return vec![ExpandElement::Plain(*var); unroll_factor];
624    }
625
626    let item = Type::new(var.storage_type()).line(max_line_size);
627    (0..unroll_factor)
628        .map(|_| match var.kind {
629            VariableKind::LocalMut { .. } | VariableKind::Versioned { .. } => {
630                allocator.create_local_mut(item)
631            }
632            VariableKind::Shared { .. } => {
633                let id = allocator.new_local_index();
634                let shared = VariableKind::Shared { id };
635                ExpandElement::Plain(Variable::new(shared, item))
636            }
637            VariableKind::LocalConst { .. } => allocator.create_local(item),
638            other => panic!("Out must be local, found {other:?}"),
639        })
640        .collect()
641}
642
643fn add_index(alloc: &Allocator, idx: Variable, i: usize) -> (Instruction, ExpandElement) {
644    let add_idx = alloc.create_local(idx.ty);
645    let add = Instruction::new(
646        Arithmetic::Add(BinaryOperator {
647            lhs: idx,
648            rhs: i.into(),
649        }),
650        *add_idx,
651    );
652    (add, add_idx)
653}
654
655fn mul_index(
656    alloc: &Allocator,
657    idx: Variable,
658    unroll_factor: usize,
659) -> (Instruction, ExpandElement) {
660    let mul_idx = alloc.create_local(idx.ty);
661    let mul = Instruction::new(
662        Arithmetic::Mul(BinaryOperator {
663            lhs: idx,
664            rhs: unroll_factor.into(),
665        }),
666        *mul_idx,
667    );
668    (mul, mul_idx)
669}
670
671fn unroll_array(mut var: Variable, max_line_size: LineSize, factor: usize) -> Variable {
672    var.ty = var.ty.line(max_line_size);
673
674    match &mut var.kind {
675        VariableKind::LocalArray { unroll_factor, .. }
676        | VariableKind::ConstantArray { unroll_factor, .. }
677        | VariableKind::SharedArray { unroll_factor, .. } => {
678            *unroll_factor = factor;
679        }
680        _ => {}
681    }
682
683    var
684}