1use cubecl_ir::{
2 Allocator, Arithmetic, BinaryOperator, Branch, CoopMma, CopyMemoryBulkOperator, ExpandElement,
3 IndexAssignOperator, IndexOperator, Instruction, MatrixLayout, Metadata, Operation,
4 OperationReflect, Operator, Processor, ScopeProcessing, Type, Variable, VariableKind,
5};
6use hashbrown::HashMap;
7
8pub enum TransformAction {
10 Ignore,
12 Replace(Vec<Instruction>),
14}
15
16#[derive(new, Debug)]
17pub struct UnrollProcessor {
18 max_line_size: u32,
19}
20
21struct Mappings(HashMap<Variable, Vec<ExpandElement>>);
22
23impl Mappings {
24 fn get(
25 &mut self,
26 alloc: &Allocator,
27 var: Variable,
28 unroll_factor: u32,
29 line_size: u32,
30 ) -> Vec<Variable> {
31 self.0
32 .entry(var)
33 .or_insert_with(|| create_unrolled(alloc, &var, line_size, unroll_factor))
34 .iter()
35 .map(|it| **it)
36 .collect()
37 }
38}
39
40impl UnrollProcessor {
41 fn maybe_transform(
42 &self,
43 alloc: &Allocator,
44 inst: &Instruction,
45 mappings: &mut Mappings,
46 ) -> TransformAction {
47 if matches!(inst.operation, Operation::Free(_)) {
48 return TransformAction::Ignore;
49 }
50
51 if inst.operation.args().is_none() {
52 match &inst.operation {
54 Operation::CoopMma(op) => match op {
55 CoopMma::Load {
57 value,
58 stride,
59 offset,
60 layout,
61 } if value.line_size() > self.max_line_size => {
62 return TransformAction::Replace(self.transform_cmma_load(
63 alloc,
64 inst.out(),
65 value,
66 stride,
67 offset,
68 layout,
69 ));
70 }
71 CoopMma::Store {
72 mat,
73 stride,
74 offset,
75 layout,
76 } if inst.out().line_size() > self.max_line_size => {
77 return TransformAction::Replace(self.transform_cmma_store(
78 alloc,
79 inst.out(),
80 mat,
81 stride,
82 offset,
83 layout,
84 ));
85 }
86 _ => return TransformAction::Ignore,
87 },
88 Operation::Branch(_) | Operation::NonSemantic(_) => {
89 return TransformAction::Ignore;
90 }
91 _ => {
92 panic!("Need special handling for unrolling non-reflectable operations")
93 }
94 }
95 }
96
97 let args = inst.operation.args().unwrap_or_default();
98 if (inst.out.is_some() && inst.ty().line_size() > self.max_line_size)
99 || args.iter().any(|arg| arg.line_size() > self.max_line_size)
100 {
101 let line_size = max_line_size(&inst.out, &args);
102 let unroll_factor = line_size / self.max_line_size;
103
104 match &inst.operation {
105 Operation::Operator(Operator::CopyMemoryBulk(op)) => TransformAction::Replace(
106 self.transform_memcpy(alloc, op, inst.out(), unroll_factor),
107 ),
108 Operation::Operator(Operator::CopyMemory(op)) => {
109 TransformAction::Replace(self.transform_memcpy(
110 alloc,
111 &CopyMemoryBulkOperator {
112 out_index: op.out_index,
113 input: op.input,
114 in_index: op.in_index,
115 len: 1,
116 offset_input: 0.into(),
117 offset_out: 0.into(),
118 },
119 inst.out(),
120 unroll_factor,
121 ))
122 }
123 Operation::Operator(Operator::Index(op)) if op.list.is_array() => {
124 TransformAction::Replace(self.transform_array_index(
125 alloc,
126 inst.out(),
127 op,
128 Operator::Index,
129 unroll_factor,
130 mappings,
131 ))
132 }
133 Operation::Operator(Operator::UncheckedIndex(op)) if op.list.is_array() => {
134 TransformAction::Replace(self.transform_array_index(
135 alloc,
136 inst.out(),
137 op,
138 Operator::UncheckedIndex,
139 unroll_factor,
140 mappings,
141 ))
142 }
143 Operation::Operator(Operator::Index(op)) => {
144 TransformAction::Replace(self.transform_composite_index(
145 alloc,
146 inst.out(),
147 op,
148 Operator::Index,
149 unroll_factor,
150 mappings,
151 ))
152 }
153 Operation::Operator(Operator::UncheckedIndex(op)) => {
154 TransformAction::Replace(self.transform_composite_index(
155 alloc,
156 inst.out(),
157 op,
158 Operator::UncheckedIndex,
159 unroll_factor,
160 mappings,
161 ))
162 }
163 Operation::Operator(Operator::IndexAssign(op)) if inst.out().is_array() => {
164 TransformAction::Replace(self.transform_array_index_assign(
165 alloc,
166 inst.out(),
167 op,
168 Operator::IndexAssign,
169 unroll_factor,
170 mappings,
171 ))
172 }
173 Operation::Operator(Operator::UncheckedIndexAssign(op))
174 if inst.out().is_array() =>
175 {
176 TransformAction::Replace(self.transform_array_index_assign(
177 alloc,
178 inst.out(),
179 op,
180 Operator::UncheckedIndexAssign,
181 unroll_factor,
182 mappings,
183 ))
184 }
185 Operation::Operator(Operator::IndexAssign(op)) => {
186 TransformAction::Replace(self.transform_composite_index_assign(
187 alloc,
188 inst.out(),
189 op,
190 Operator::IndexAssign,
191 unroll_factor,
192 mappings,
193 ))
194 }
195 Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
196 TransformAction::Replace(self.transform_composite_index_assign(
197 alloc,
198 inst.out(),
199 op,
200 Operator::UncheckedIndexAssign,
201 unroll_factor,
202 mappings,
203 ))
204 }
205 Operation::Metadata(op) => {
206 TransformAction::Replace(self.transform_metadata(inst.out(), op, args))
207 }
208 _ => TransformAction::Replace(self.transform_basic(
209 alloc,
210 inst,
211 args,
212 unroll_factor,
213 mappings,
214 )),
215 }
216 } else {
217 TransformAction::Ignore
218 }
219 }
220
221 fn transform_cmma_load(
223 &self,
224 alloc: &Allocator,
225 out: Variable,
226 value: &Variable,
227 stride: &Variable,
228 offset: &Variable,
229 layout: &Option<MatrixLayout>,
230 ) -> Vec<Instruction> {
231 let line_size = value.line_size();
232 let unroll_factor = line_size / self.max_line_size;
233
234 let value = unroll_array(*value, self.max_line_size, unroll_factor);
235 let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
236 let load = Instruction::new(
237 Operation::CoopMma(CoopMma::Load {
238 value,
239 stride: *stride,
240 offset: *offset,
241 layout: *layout,
242 }),
243 out,
244 );
245 vec![mul, load]
246 }
247
248 fn transform_cmma_store(
250 &self,
251 alloc: &Allocator,
252 out: Variable,
253 mat: &Variable,
254 stride: &Variable,
255 offset: &Variable,
256 layout: &MatrixLayout,
257 ) -> Vec<Instruction> {
258 let line_size = out.line_size();
259 let unroll_factor = line_size / self.max_line_size;
260
261 let out = unroll_array(out, self.max_line_size, unroll_factor);
262 let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
263 let store = Instruction::new(
264 Operation::CoopMma(CoopMma::Store {
265 mat: *mat,
266 stride: *stride,
267 offset: *offset,
268 layout: *layout,
269 }),
270 out,
271 );
272 vec![mul, store]
273 }
274
275 fn transform_memcpy(
277 &self,
278 alloc: &Allocator,
279 op: &CopyMemoryBulkOperator,
280 out: Variable,
281 unroll_factor: u32,
282 ) -> Vec<Instruction> {
283 let (mul1, in_index) = mul_index(alloc, op.in_index, unroll_factor);
284 let (mul2, offset_input) = mul_index(alloc, op.offset_input, unroll_factor);
285 let (mul3, out_index) = mul_index(alloc, op.out_index, unroll_factor);
286 let (mul4, offset_out) = mul_index(alloc, op.offset_out, unroll_factor);
287
288 let input = unroll_array(op.input, self.max_line_size, unroll_factor);
289 let out = unroll_array(out, self.max_line_size, unroll_factor);
290
291 vec![
292 mul1,
293 mul2,
294 mul3,
295 mul4,
296 Instruction::new(
297 Operator::CopyMemoryBulk(CopyMemoryBulkOperator {
298 input,
299 in_index: *in_index,
300 out_index: *out_index,
301 len: op.len * unroll_factor,
302 offset_input: *offset_input,
303 offset_out: *offset_out,
304 }),
305 out,
306 ),
307 ]
308 }
309
310 fn transform_array_index(
313 &self,
314 alloc: &Allocator,
315 out: Variable,
316 op: &IndexOperator,
317 operator: impl Fn(IndexOperator) -> Operator,
318 unroll_factor: u32,
319 mappings: &mut Mappings,
320 ) -> Vec<Instruction> {
321 let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
322 let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
323
324 let list = unroll_array(op.list, self.max_line_size, unroll_factor);
325
326 let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
327 let mut instructions = vec![mul];
328 instructions.extend((0..unroll_factor as usize).flat_map(|i| {
329 let (add, idx) = indices.next().unwrap();
330 let index = Instruction::new(
331 operator(IndexOperator {
332 list,
333 index: *idx,
334 line_size: 0,
335 unroll_factor,
336 }),
337 out[i],
338 );
339 [add, index]
340 }));
341
342 instructions
343 }
344
345 fn transform_array_index_assign(
348 &self,
349 alloc: &Allocator,
350 out: Variable,
351 op: &IndexAssignOperator,
352 operator: impl Fn(IndexAssignOperator) -> Operator,
353 unroll_factor: u32,
354 mappings: &mut Mappings,
355 ) -> Vec<Instruction> {
356 let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
357 let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
358
359 let out = unroll_array(out, self.max_line_size, unroll_factor);
360
361 let value = mappings.get(alloc, op.value, unroll_factor, self.max_line_size);
362
363 let mut instructions = vec![mul];
364 instructions.extend((0..unroll_factor as usize).flat_map(|i| {
365 let (add, idx) = indices.next().unwrap();
366 let index = Instruction::new(
367 operator(IndexAssignOperator {
368 index: *idx,
369 line_size: 0,
370 value: value[i],
371 unroll_factor,
372 }),
373 out,
374 );
375
376 [add, index]
377 }));
378
379 instructions
380 }
381
382 fn transform_composite_index(
387 &self,
388 alloc: &Allocator,
389 out: Variable,
390 op: &IndexOperator,
391 operator: impl Fn(IndexOperator) -> Operator,
392 unroll_factor: u32,
393 mappings: &mut Mappings,
394 ) -> Vec<Instruction> {
395 let index = op
396 .index
397 .as_const()
398 .expect("Can't unroll non-constant vector index")
399 .as_u32();
400
401 let unroll_idx = index / self.max_line_size;
402 let sub_idx = index % self.max_line_size;
403
404 let value = mappings.get(alloc, op.list, unroll_factor, self.max_line_size);
405
406 vec![Instruction::new(
407 operator(IndexOperator {
408 list: value[unroll_idx as usize],
409 index: sub_idx.into(),
410 line_size: 1,
411 unroll_factor,
412 }),
413 out,
414 )]
415 }
416
417 fn transform_composite_index_assign(
422 &self,
423 alloc: &Allocator,
424 out: Variable,
425 op: &IndexAssignOperator,
426 operator: impl Fn(IndexAssignOperator) -> Operator,
427 unroll_factor: u32,
428 mappings: &mut Mappings,
429 ) -> Vec<Instruction> {
430 let index = op
431 .index
432 .as_const()
433 .expect("Can't unroll non-constant vector index")
434 .as_u32();
435
436 let unroll_idx = index / self.max_line_size;
437 let sub_idx = index % self.max_line_size;
438
439 let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
440
441 vec![Instruction::new(
442 operator(IndexAssignOperator {
443 index: sub_idx.into(),
444 line_size: 1,
445 value: op.value,
446 unroll_factor,
447 }),
448 out[unroll_idx as usize],
449 )]
450 }
451
452 fn transform_metadata(
455 &self,
456 out: Variable,
457 op: &Metadata,
458 args: Vec<Variable>,
459 ) -> Vec<Instruction> {
460 let op_code = op.op_code();
461 let args = args
462 .into_iter()
463 .map(|mut var| {
464 if var.line_size() > self.max_line_size {
465 var.ty = var.ty.line(self.max_line_size);
466 }
467 var
468 })
469 .collect::<Vec<_>>();
470 let operation = Metadata::from_code_and_args(op_code, &args).unwrap();
471 vec![Instruction::new(operation, out)]
472 }
473
474 fn transform_basic(
477 &self,
478 alloc: &Allocator,
479 inst: &Instruction,
480 args: Vec<Variable>,
481 unroll_factor: u32,
482 mappings: &mut Mappings,
483 ) -> Vec<Instruction> {
484 let op_code = inst.operation.op_code();
485 let out = inst
486 .out
487 .map(|out| mappings.get(alloc, out, unroll_factor, self.max_line_size));
488 let args = args
489 .into_iter()
490 .map(|arg| {
491 if arg.line_size() > 1 {
492 mappings.get(alloc, arg, unroll_factor, self.max_line_size)
493 } else {
494 vec![arg]
496 }
497 })
498 .collect::<Vec<_>>();
499
500 (0..unroll_factor as usize)
501 .map(|i| {
502 let out = out.as_ref().map(|out| out[i]);
503 let args = args
504 .iter()
505 .map(|arg| if arg.len() == 1 { arg[0] } else { arg[i] })
506 .collect::<Vec<_>>();
507 let operation = Operation::from_code_and_args(op_code, &args)
508 .expect("Failed to reconstruct operation");
509 Instruction {
510 out,
511 source_loc: inst.source_loc.clone(),
512 operation,
513 }
514 })
515 .collect()
516 }
517
518 fn transform_instructions(
519 &self,
520 allocator: &Allocator,
521 instructions: Vec<Instruction>,
522 mappings: &mut Mappings,
523 ) -> Vec<Instruction> {
524 let mut new_instructions = Vec::with_capacity(instructions.len());
525
526 for mut instruction in instructions {
527 if let Operation::Branch(branch) = &mut instruction.operation {
528 match branch {
529 Branch::If(op) => {
530 op.scope.instructions = self.transform_instructions(
531 allocator,
532 op.scope.instructions.drain(..).collect(),
533 mappings,
534 );
535 }
536 Branch::IfElse(op) => {
537 op.scope_if.instructions = self.transform_instructions(
538 allocator,
539 op.scope_if.instructions.drain(..).collect(),
540 mappings,
541 );
542 op.scope_else.instructions = self.transform_instructions(
543 allocator,
544 op.scope_else.instructions.drain(..).collect(),
545 mappings,
546 );
547 }
548 Branch::Switch(op) => {
549 for (_, case) in &mut op.cases {
550 case.instructions = self.transform_instructions(
551 allocator,
552 case.instructions.drain(..).collect(),
553 mappings,
554 );
555 }
556 op.scope_default.instructions = self.transform_instructions(
557 allocator,
558 op.scope_default.instructions.drain(..).collect(),
559 mappings,
560 );
561 }
562 Branch::RangeLoop(op) => {
563 op.scope.instructions = self.transform_instructions(
564 allocator,
565 op.scope.instructions.drain(..).collect(),
566 mappings,
567 );
568 }
569 Branch::Loop(op) => {
570 op.scope.instructions = self.transform_instructions(
571 allocator,
572 op.scope.instructions.drain(..).collect(),
573 mappings,
574 );
575 }
576 _ => {}
577 }
578 }
579 match self.maybe_transform(allocator, &instruction, mappings) {
580 TransformAction::Ignore => {
581 new_instructions.push(instruction);
582 }
583 TransformAction::Replace(replacement) => {
584 new_instructions.extend(replacement);
585 }
586 }
587 }
588
589 new_instructions
590 }
591}
592
593impl Processor for UnrollProcessor {
594 fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
595 let mut mappings = Mappings(Default::default());
596
597 let instructions =
598 self.transform_instructions(&allocator, processing.instructions, &mut mappings);
599
600 ScopeProcessing {
601 variables: processing.variables,
602 instructions,
603 }
604 }
605}
606
607fn max_line_size(out: &Option<Variable>, args: &[Variable]) -> u32 {
608 let line_size = args.iter().map(|it| it.line_size()).max().unwrap();
609 line_size.max(out.map(|out| out.line_size()).unwrap_or(1))
610}
611
612fn create_unrolled(
613 allocator: &Allocator,
614 var: &Variable,
615 max_line_size: u32,
616 unroll_factor: u32,
617) -> Vec<ExpandElement> {
618 if var.line_size() == 1 {
620 return vec![ExpandElement::Plain(*var); unroll_factor as usize];
621 }
622
623 let item = Type::new(var.storage_type()).line(max_line_size);
624 (0..unroll_factor as usize)
625 .map(|_| match var.kind {
626 VariableKind::LocalMut { .. } | VariableKind::Versioned { .. } => {
627 allocator.create_local_mut(item)
628 }
629 VariableKind::LocalConst { .. } => allocator.create_local(item),
630 other => panic!("Out must be local, found {other:?}"),
631 })
632 .collect()
633}
634
635fn add_index(alloc: &Allocator, idx: Variable, i: u32) -> (Instruction, ExpandElement) {
636 let add_idx = alloc.create_local(idx.ty);
637 let add = Instruction::new(
638 Arithmetic::Add(BinaryOperator {
639 lhs: idx,
640 rhs: i.into(),
641 }),
642 *add_idx,
643 );
644 (add, add_idx)
645}
646
647fn mul_index(alloc: &Allocator, idx: Variable, unroll_factor: u32) -> (Instruction, ExpandElement) {
648 let mul_idx = alloc.create_local(idx.ty);
649 let mul = Instruction::new(
650 Arithmetic::Mul(BinaryOperator {
651 lhs: idx,
652 rhs: unroll_factor.into(),
653 }),
654 *mul_idx,
655 );
656 (mul, mul_idx)
657}
658
659fn unroll_array(mut var: Variable, max_line_size: u32, factor: u32) -> Variable {
660 var.ty = var.ty.line(max_line_size);
661
662 match &mut var.kind {
663 VariableKind::LocalArray { unroll_factor, .. }
664 | VariableKind::ConstantArray { unroll_factor, .. }
665 | VariableKind::SharedMemory { unroll_factor, .. } => {
666 *unroll_factor = factor;
667 }
668 _ => {}
669 }
670
671 var
672}