1use alloc::{vec, vec::Vec};
2use cubecl_ir::{
3 Allocator, Arithmetic, BinaryOperator, Branch, CoopMma, CopyMemoryBulkOperator,
4 IndexAssignOperator, IndexOperator, Instruction, ManagedVariable, MatrixLayout, Metadata,
5 Operation, OperationReflect, Operator, Processor, ScopeProcessing, Type, Variable,
6 VariableKind, VectorSize,
7};
8use hashbrown::HashMap;
9
10pub enum TransformAction {
12 Ignore,
14 Replace(Vec<Instruction>),
16}
17
18#[derive(new, Debug)]
19pub struct UnrollProcessor {
20 max_vector_size: VectorSize,
21}
22
23struct Mappings(HashMap<Variable, Vec<ManagedVariable>>);
24
25impl Mappings {
26 fn get(
27 &mut self,
28 alloc: &Allocator,
29 var: Variable,
30 unroll_factor: usize,
31 vector_size: VectorSize,
32 ) -> Vec<Variable> {
33 self.0
34 .entry(var)
35 .or_insert_with(|| create_unrolled(alloc, &var, vector_size, unroll_factor))
36 .iter()
37 .map(|it| **it)
38 .collect()
39 }
40}
41
42impl UnrollProcessor {
43 fn maybe_transform(
44 &self,
45 alloc: &Allocator,
46 inst: &Instruction,
47 mappings: &mut Mappings,
48 ) -> TransformAction {
49 if matches!(inst.operation, Operation::Marker(_)) {
50 return TransformAction::Ignore;
51 }
52
53 if inst.operation.args().is_none() {
54 match &inst.operation {
56 Operation::CoopMma(op) => match op {
57 CoopMma::Load {
59 value,
60 stride,
61 offset,
62 layout,
63 } if value.vector_size() > self.max_vector_size => {
64 return TransformAction::Replace(self.transform_cmma_load(
65 alloc,
66 inst.out(),
67 value,
68 stride,
69 offset,
70 layout,
71 ));
72 }
73 CoopMma::Store {
74 mat,
75 stride,
76 offset,
77 layout,
78 } if inst.out().vector_size() > self.max_vector_size => {
79 return TransformAction::Replace(self.transform_cmma_store(
80 alloc,
81 inst.out(),
82 mat,
83 stride,
84 offset,
85 layout,
86 ));
87 }
88 _ => return TransformAction::Ignore,
89 },
90 Operation::Branch(_) | Operation::NonSemantic(_) | Operation::Marker(_) => {
91 return TransformAction::Ignore;
92 }
93 _ => {
94 panic!("Need special handling for unrolling non-reflectable operations")
95 }
96 }
97 }
98
99 let args = inst.operation.args().unwrap_or_default();
100 if (inst.out.is_some() && inst.ty().vector_size() > self.max_vector_size)
101 || args
102 .iter()
103 .any(|arg| arg.vector_size() > self.max_vector_size)
104 {
105 let vector_size = max_vector_size(&inst.out, &args);
106 let unroll_factor = vector_size / self.max_vector_size;
107
108 match &inst.operation {
109 Operation::Operator(Operator::CopyMemoryBulk(op)) => TransformAction::Replace(
110 self.transform_memcpy(alloc, op, inst.out(), unroll_factor),
111 ),
112 Operation::Operator(Operator::CopyMemory(op)) => {
113 TransformAction::Replace(self.transform_memcpy(
114 alloc,
115 &CopyMemoryBulkOperator {
116 out_index: op.out_index,
117 input: op.input,
118 in_index: op.in_index,
119 len: 1,
120 offset_input: 0.into(),
121 offset_out: 0.into(),
122 },
123 inst.out(),
124 unroll_factor,
125 ))
126 }
127 Operation::Operator(Operator::Index(op)) if op.list.is_array() => {
128 TransformAction::Replace(self.transform_array_index(
129 alloc,
130 inst.out(),
131 op,
132 Operator::Index,
133 unroll_factor,
134 mappings,
135 ))
136 }
137 Operation::Operator(Operator::UncheckedIndex(op)) if op.list.is_array() => {
138 TransformAction::Replace(self.transform_array_index(
139 alloc,
140 inst.out(),
141 op,
142 Operator::UncheckedIndex,
143 unroll_factor,
144 mappings,
145 ))
146 }
147 Operation::Operator(Operator::Index(op)) => {
148 TransformAction::Replace(self.transform_composite_index(
149 alloc,
150 inst.out(),
151 op,
152 Operator::Index,
153 unroll_factor,
154 mappings,
155 ))
156 }
157 Operation::Operator(Operator::UncheckedIndex(op)) => {
158 TransformAction::Replace(self.transform_composite_index(
159 alloc,
160 inst.out(),
161 op,
162 Operator::UncheckedIndex,
163 unroll_factor,
164 mappings,
165 ))
166 }
167 Operation::Operator(Operator::IndexAssign(op)) if inst.out().is_array() => {
168 TransformAction::Replace(self.transform_array_index_assign(
169 alloc,
170 inst.out(),
171 op,
172 Operator::IndexAssign,
173 unroll_factor,
174 mappings,
175 ))
176 }
177 Operation::Operator(Operator::UncheckedIndexAssign(op))
178 if inst.out().is_array() =>
179 {
180 TransformAction::Replace(self.transform_array_index_assign(
181 alloc,
182 inst.out(),
183 op,
184 Operator::UncheckedIndexAssign,
185 unroll_factor,
186 mappings,
187 ))
188 }
189 Operation::Operator(Operator::IndexAssign(op)) => {
190 TransformAction::Replace(self.transform_composite_index_assign(
191 alloc,
192 inst.out(),
193 op,
194 Operator::IndexAssign,
195 unroll_factor,
196 mappings,
197 ))
198 }
199 Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
200 TransformAction::Replace(self.transform_composite_index_assign(
201 alloc,
202 inst.out(),
203 op,
204 Operator::UncheckedIndexAssign,
205 unroll_factor,
206 mappings,
207 ))
208 }
209 Operation::Metadata(op) => {
210 TransformAction::Replace(self.transform_metadata(inst.out(), op, args))
211 }
212 _ => TransformAction::Replace(self.transform_basic(
213 alloc,
214 inst,
215 args,
216 unroll_factor,
217 mappings,
218 )),
219 }
220 } else {
221 TransformAction::Ignore
222 }
223 }
224
225 fn transform_cmma_load(
227 &self,
228 alloc: &Allocator,
229 out: Variable,
230 value: &Variable,
231 stride: &Variable,
232 offset: &Variable,
233 layout: &Option<MatrixLayout>,
234 ) -> Vec<Instruction> {
235 let vector_size = value.vector_size();
236 let unroll_factor = vector_size / self.max_vector_size;
237
238 let value = unroll_array(*value, self.max_vector_size, unroll_factor);
239 let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
240 let load = Instruction::new(
241 Operation::CoopMma(CoopMma::Load {
242 value,
243 stride: *stride,
244 offset: *offset,
245 layout: *layout,
246 }),
247 out,
248 );
249 vec![mul, load]
250 }
251
252 fn transform_cmma_store(
254 &self,
255 alloc: &Allocator,
256 out: Variable,
257 mat: &Variable,
258 stride: &Variable,
259 offset: &Variable,
260 layout: &MatrixLayout,
261 ) -> Vec<Instruction> {
262 let vector_size = out.vector_size();
263 let unroll_factor = vector_size / self.max_vector_size;
264
265 let out = unroll_array(out, self.max_vector_size, unroll_factor);
266 let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
267 let store = Instruction::new(
268 Operation::CoopMma(CoopMma::Store {
269 mat: *mat,
270 stride: *stride,
271 offset: *offset,
272 layout: *layout,
273 }),
274 out,
275 );
276 vec![mul, store]
277 }
278
279 fn transform_memcpy(
281 &self,
282 alloc: &Allocator,
283 op: &CopyMemoryBulkOperator,
284 out: Variable,
285 unroll_factor: usize,
286 ) -> Vec<Instruction> {
287 let (mul1, in_index) = mul_index(alloc, op.in_index, unroll_factor);
288 let (mul2, offset_input) = mul_index(alloc, op.offset_input, unroll_factor);
289 let (mul3, out_index) = mul_index(alloc, op.out_index, unroll_factor);
290 let (mul4, offset_out) = mul_index(alloc, op.offset_out, unroll_factor);
291
292 let input = unroll_array(op.input, self.max_vector_size, unroll_factor);
293 let out = unroll_array(out, self.max_vector_size, unroll_factor);
294
295 vec![
296 mul1,
297 mul2,
298 mul3,
299 mul4,
300 Instruction::new(
301 Operator::CopyMemoryBulk(CopyMemoryBulkOperator {
302 input,
303 in_index: *in_index,
304 out_index: *out_index,
305 len: op.len * unroll_factor,
306 offset_input: *offset_input,
307 offset_out: *offset_out,
308 }),
309 out,
310 ),
311 ]
312 }
313
314 fn transform_array_index(
317 &self,
318 alloc: &Allocator,
319 out: Variable,
320 op: &IndexOperator,
321 operator: impl Fn(IndexOperator) -> Operator,
322 unroll_factor: usize,
323 mappings: &mut Mappings,
324 ) -> Vec<Instruction> {
325 let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
326 let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
327
328 let list = unroll_array(op.list, self.max_vector_size, unroll_factor);
329
330 let out = mappings.get(alloc, out, unroll_factor, self.max_vector_size);
331 let mut instructions = vec![mul];
332 instructions.extend((0..unroll_factor).flat_map(|i| {
333 let (add, idx) = indices.next().unwrap();
334 let index = Instruction::new(
335 operator(IndexOperator {
336 list,
337 index: *idx,
338 vector_size: 0,
339 unroll_factor,
340 }),
341 out[i],
342 );
343 [add, index]
344 }));
345
346 instructions
347 }
348
349 fn transform_array_index_assign(
352 &self,
353 alloc: &Allocator,
354 out: Variable,
355 op: &IndexAssignOperator,
356 operator: impl Fn(IndexAssignOperator) -> Operator,
357 unroll_factor: usize,
358 mappings: &mut Mappings,
359 ) -> Vec<Instruction> {
360 let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
361 let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
362
363 let out = unroll_array(out, self.max_vector_size, unroll_factor);
364
365 let value = mappings.get(alloc, op.value, unroll_factor, self.max_vector_size);
366
367 let mut instructions = vec![mul];
368 instructions.extend((0..unroll_factor).flat_map(|i| {
369 let (add, idx) = indices.next().unwrap();
370 let index = Instruction::new(
371 operator(IndexAssignOperator {
372 index: *idx,
373 vector_size: 0,
374 value: value[i],
375 unroll_factor,
376 }),
377 out,
378 );
379
380 [add, index]
381 }));
382
383 instructions
384 }
385
386 fn transform_composite_index(
391 &self,
392 alloc: &Allocator,
393 out: Variable,
394 op: &IndexOperator,
395 operator: impl Fn(IndexOperator) -> Operator,
396 unroll_factor: usize,
397 mappings: &mut Mappings,
398 ) -> Vec<Instruction> {
399 let index = op
400 .index
401 .as_const()
402 .expect("Can't unroll non-constant vector index")
403 .as_usize();
404
405 let unroll_idx = index / self.max_vector_size;
406 let sub_idx = index % self.max_vector_size;
407
408 let value = mappings.get(alloc, op.list, unroll_factor, self.max_vector_size);
409
410 vec![Instruction::new(
411 operator(IndexOperator {
412 list: value[unroll_idx],
413 index: sub_idx.into(),
414 vector_size: 1,
415 unroll_factor,
416 }),
417 out,
418 )]
419 }
420
421 fn transform_composite_index_assign(
426 &self,
427 alloc: &Allocator,
428 out: Variable,
429 op: &IndexAssignOperator,
430 operator: impl Fn(IndexAssignOperator) -> Operator,
431 unroll_factor: usize,
432 mappings: &mut Mappings,
433 ) -> Vec<Instruction> {
434 let index = op
435 .index
436 .as_const()
437 .expect("Can't unroll non-constant vector index")
438 .as_usize();
439
440 let unroll_idx = index / self.max_vector_size;
441 let sub_idx = index % self.max_vector_size;
442
443 let out = mappings.get(alloc, out, unroll_factor, self.max_vector_size);
444
445 vec![Instruction::new(
446 operator(IndexAssignOperator {
447 index: sub_idx.into(),
448 vector_size: 1,
449 value: op.value,
450 unroll_factor,
451 }),
452 out[unroll_idx],
453 )]
454 }
455
456 fn transform_metadata(
459 &self,
460 out: Variable,
461 op: &Metadata,
462 args: Vec<Variable>,
463 ) -> Vec<Instruction> {
464 let op_code = op.op_code();
465 let args = args
466 .into_iter()
467 .map(|mut var| {
468 if var.vector_size() > self.max_vector_size {
469 var.ty = var.ty.with_vector_size(self.max_vector_size);
470 }
471 var
472 })
473 .collect::<Vec<_>>();
474 let operation = Metadata::from_code_and_args(op_code, &args).unwrap();
475 vec![Instruction::new(operation, out)]
476 }
477
478 fn transform_basic(
481 &self,
482 alloc: &Allocator,
483 inst: &Instruction,
484 args: Vec<Variable>,
485 unroll_factor: usize,
486 mappings: &mut Mappings,
487 ) -> Vec<Instruction> {
488 let op_code = inst.operation.op_code();
489 let out = inst
490 .out
491 .map(|out| mappings.get(alloc, out, unroll_factor, self.max_vector_size));
492 let args = args
493 .into_iter()
494 .map(|arg| {
495 if arg.vector_size() > 1 {
496 mappings.get(alloc, arg, unroll_factor, self.max_vector_size)
497 } else {
498 vec![arg]
500 }
501 })
502 .collect::<Vec<_>>();
503
504 (0..unroll_factor)
505 .map(|i| {
506 let out = out.as_ref().map(|out| out[i]);
507 let args = args
508 .iter()
509 .map(|arg| if arg.len() == 1 { arg[0] } else { arg[i] })
510 .collect::<Vec<_>>();
511 let operation = Operation::from_code_and_args(op_code, &args)
512 .expect("Failed to reconstruct operation");
513 Instruction {
514 out,
515 source_loc: inst.source_loc.clone(),
516 modes: inst.modes,
517 operation,
518 }
519 })
520 .collect()
521 }
522
523 fn transform_instructions(
524 &self,
525 allocator: &Allocator,
526 instructions: Vec<Instruction>,
527 mappings: &mut Mappings,
528 ) -> Vec<Instruction> {
529 let mut new_instructions = Vec::with_capacity(instructions.len());
530
531 for mut instruction in instructions {
532 if let Operation::Branch(branch) = &mut instruction.operation {
533 match branch {
534 Branch::If(op) => {
535 op.scope.instructions = self.transform_instructions(
536 allocator,
537 op.scope.instructions.drain(..).collect(),
538 mappings,
539 );
540 }
541 Branch::IfElse(op) => {
542 op.scope_if.instructions = self.transform_instructions(
543 allocator,
544 op.scope_if.instructions.drain(..).collect(),
545 mappings,
546 );
547 op.scope_else.instructions = self.transform_instructions(
548 allocator,
549 op.scope_else.instructions.drain(..).collect(),
550 mappings,
551 );
552 }
553 Branch::Switch(op) => {
554 for (_, case) in &mut op.cases {
555 case.instructions = self.transform_instructions(
556 allocator,
557 case.instructions.drain(..).collect(),
558 mappings,
559 );
560 }
561 op.scope_default.instructions = self.transform_instructions(
562 allocator,
563 op.scope_default.instructions.drain(..).collect(),
564 mappings,
565 );
566 }
567 Branch::RangeLoop(op) => {
568 op.scope.instructions = self.transform_instructions(
569 allocator,
570 op.scope.instructions.drain(..).collect(),
571 mappings,
572 );
573 }
574 Branch::Loop(op) => {
575 op.scope.instructions = self.transform_instructions(
576 allocator,
577 op.scope.instructions.drain(..).collect(),
578 mappings,
579 );
580 }
581 _ => {}
582 }
583 }
584 match self.maybe_transform(allocator, &instruction, mappings) {
585 TransformAction::Ignore => {
586 new_instructions.push(instruction);
587 }
588 TransformAction::Replace(replacement) => {
589 new_instructions.extend(replacement);
590 }
591 }
592 }
593
594 new_instructions
595 }
596}
597
598impl Processor for UnrollProcessor {
599 fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
600 let mut mappings = Mappings(Default::default());
601
602 let instructions =
603 self.transform_instructions(&allocator, processing.instructions, &mut mappings);
604
605 ScopeProcessing {
606 variables: processing.variables,
607 instructions,
608 typemap: processing.typemap.clone(),
609 }
610 }
611}
612
613fn max_vector_size(out: &Option<Variable>, args: &[Variable]) -> VectorSize {
614 let vector_size = args.iter().map(|it| it.vector_size()).max().unwrap();
615 vector_size.max(out.map(|out| out.vector_size()).unwrap_or(1))
616}
617
618fn create_unrolled(
619 allocator: &Allocator,
620 var: &Variable,
621 max_vector_size: VectorSize,
622 unroll_factor: usize,
623) -> Vec<ManagedVariable> {
624 if var.vector_size() == 1 {
626 return vec![ManagedVariable::Plain(*var); unroll_factor];
627 }
628
629 let item = Type::new(var.storage_type()).with_vector_size(max_vector_size);
630 (0..unroll_factor)
631 .map(|_| match var.kind {
632 VariableKind::LocalMut { .. } | VariableKind::Versioned { .. } => {
633 allocator.create_local_mut(item)
634 }
635 VariableKind::Shared { .. } => {
636 let id = allocator.new_local_index();
637 let shared = VariableKind::Shared { id };
638 ManagedVariable::Plain(Variable::new(shared, item))
639 }
640 VariableKind::LocalConst { .. } => allocator.create_local(item),
641 other => panic!("Out must be local, found {other:?}"),
642 })
643 .collect()
644}
645
646fn add_index(alloc: &Allocator, idx: Variable, i: usize) -> (Instruction, ManagedVariable) {
647 let add_idx = alloc.create_local(idx.ty);
648 let add = Instruction::new(
649 Arithmetic::Add(BinaryOperator {
650 lhs: idx,
651 rhs: i.into(),
652 }),
653 *add_idx,
654 );
655 (add, add_idx)
656}
657
658fn mul_index(
659 alloc: &Allocator,
660 idx: Variable,
661 unroll_factor: usize,
662) -> (Instruction, ManagedVariable) {
663 let mul_idx = alloc.create_local(idx.ty);
664 let mul = Instruction::new(
665 Arithmetic::Mul(BinaryOperator {
666 lhs: idx,
667 rhs: unroll_factor.into(),
668 }),
669 *mul_idx,
670 );
671 (mul, mul_idx)
672}
673
674fn unroll_array(mut var: Variable, max_vector_size: VectorSize, factor: usize) -> Variable {
675 var.ty = var.ty.with_vector_size(max_vector_size);
676
677 match &mut var.kind {
678 VariableKind::LocalArray { unroll_factor, .. }
679 | VariableKind::ConstantArray { unroll_factor, .. }
680 | VariableKind::SharedArray { unroll_factor, .. } => {
681 *unroll_factor = factor;
682 }
683 _ => {}
684 }
685
686 var
687}