1use cubecl_ir::{
2 Allocator, Arithmetic, BinaryOperator, Branch, CoopMma, CopyMemoryBulkOperator, ExpandElement,
3 IndexAssignOperator, IndexOperator, Instruction, MatrixLayout, Metadata, Operation,
4 OperationReflect, Operator, Processor, ScopeProcessing, Type, Variable, VariableKind,
5};
6use hashbrown::HashMap;
7
8pub enum TransformAction {
10 Ignore,
12 Replace(Vec<Instruction>),
14}
15
16#[derive(new, Debug)]
17pub struct UnrollProcessor {
18 max_line_size: u32,
19}
20
21struct Mappings(HashMap<Variable, Vec<ExpandElement>>);
22
23impl Mappings {
24 fn get(
25 &mut self,
26 alloc: &Allocator,
27 var: Variable,
28 unroll_factor: u32,
29 line_size: u32,
30 ) -> Vec<Variable> {
31 self.0
32 .entry(var)
33 .or_insert_with(|| create_unrolled(alloc, &var, line_size, unroll_factor))
34 .iter()
35 .map(|it| **it)
36 .collect()
37 }
38}
39
40impl UnrollProcessor {
41 fn maybe_transform(
42 &self,
43 alloc: &Allocator,
44 inst: &Instruction,
45 mappings: &mut Mappings,
46 ) -> TransformAction {
47 if matches!(inst.operation, Operation::Marker(_)) {
48 return TransformAction::Ignore;
49 }
50
51 if inst.operation.args().is_none() {
52 match &inst.operation {
54 Operation::CoopMma(op) => match op {
55 CoopMma::Load {
57 value,
58 stride,
59 offset,
60 layout,
61 } if value.line_size() > self.max_line_size => {
62 return TransformAction::Replace(self.transform_cmma_load(
63 alloc,
64 inst.out(),
65 value,
66 stride,
67 offset,
68 layout,
69 ));
70 }
71 CoopMma::Store {
72 mat,
73 stride,
74 offset,
75 layout,
76 } if inst.out().line_size() > self.max_line_size => {
77 return TransformAction::Replace(self.transform_cmma_store(
78 alloc,
79 inst.out(),
80 mat,
81 stride,
82 offset,
83 layout,
84 ));
85 }
86 _ => return TransformAction::Ignore,
87 },
88 Operation::Branch(_) | Operation::NonSemantic(_) | Operation::Marker(_) => {
89 return TransformAction::Ignore;
90 }
91 _ => {
92 panic!("Need special handling for unrolling non-reflectable operations")
93 }
94 }
95 }
96
97 let args = inst.operation.args().unwrap_or_default();
98 if (inst.out.is_some() && inst.ty().line_size() > self.max_line_size)
99 || args.iter().any(|arg| arg.line_size() > self.max_line_size)
100 {
101 let line_size = max_line_size(&inst.out, &args);
102 let unroll_factor = line_size / self.max_line_size;
103
104 match &inst.operation {
105 Operation::Operator(Operator::CopyMemoryBulk(op)) => TransformAction::Replace(
106 self.transform_memcpy(alloc, op, inst.out(), unroll_factor),
107 ),
108 Operation::Operator(Operator::CopyMemory(op)) => {
109 TransformAction::Replace(self.transform_memcpy(
110 alloc,
111 &CopyMemoryBulkOperator {
112 out_index: op.out_index,
113 input: op.input,
114 in_index: op.in_index,
115 len: 1,
116 offset_input: 0.into(),
117 offset_out: 0.into(),
118 },
119 inst.out(),
120 unroll_factor,
121 ))
122 }
123 Operation::Operator(Operator::Index(op)) if op.list.is_array() => {
124 TransformAction::Replace(self.transform_array_index(
125 alloc,
126 inst.out(),
127 op,
128 Operator::Index,
129 unroll_factor,
130 mappings,
131 ))
132 }
133 Operation::Operator(Operator::UncheckedIndex(op)) if op.list.is_array() => {
134 TransformAction::Replace(self.transform_array_index(
135 alloc,
136 inst.out(),
137 op,
138 Operator::UncheckedIndex,
139 unroll_factor,
140 mappings,
141 ))
142 }
143 Operation::Operator(Operator::Index(op)) => {
144 TransformAction::Replace(self.transform_composite_index(
145 alloc,
146 inst.out(),
147 op,
148 Operator::Index,
149 unroll_factor,
150 mappings,
151 ))
152 }
153 Operation::Operator(Operator::UncheckedIndex(op)) => {
154 TransformAction::Replace(self.transform_composite_index(
155 alloc,
156 inst.out(),
157 op,
158 Operator::UncheckedIndex,
159 unroll_factor,
160 mappings,
161 ))
162 }
163 Operation::Operator(Operator::IndexAssign(op)) if inst.out().is_array() => {
164 TransformAction::Replace(self.transform_array_index_assign(
165 alloc,
166 inst.out(),
167 op,
168 Operator::IndexAssign,
169 unroll_factor,
170 mappings,
171 ))
172 }
173 Operation::Operator(Operator::UncheckedIndexAssign(op))
174 if inst.out().is_array() =>
175 {
176 TransformAction::Replace(self.transform_array_index_assign(
177 alloc,
178 inst.out(),
179 op,
180 Operator::UncheckedIndexAssign,
181 unroll_factor,
182 mappings,
183 ))
184 }
185 Operation::Operator(Operator::IndexAssign(op)) => {
186 TransformAction::Replace(self.transform_composite_index_assign(
187 alloc,
188 inst.out(),
189 op,
190 Operator::IndexAssign,
191 unroll_factor,
192 mappings,
193 ))
194 }
195 Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
196 TransformAction::Replace(self.transform_composite_index_assign(
197 alloc,
198 inst.out(),
199 op,
200 Operator::UncheckedIndexAssign,
201 unroll_factor,
202 mappings,
203 ))
204 }
205 Operation::Metadata(op) => {
206 TransformAction::Replace(self.transform_metadata(inst.out(), op, args))
207 }
208 _ => TransformAction::Replace(self.transform_basic(
209 alloc,
210 inst,
211 args,
212 unroll_factor,
213 mappings,
214 )),
215 }
216 } else {
217 TransformAction::Ignore
218 }
219 }
220
221 fn transform_cmma_load(
223 &self,
224 alloc: &Allocator,
225 out: Variable,
226 value: &Variable,
227 stride: &Variable,
228 offset: &Variable,
229 layout: &Option<MatrixLayout>,
230 ) -> Vec<Instruction> {
231 let line_size = value.line_size();
232 let unroll_factor = line_size / self.max_line_size;
233
234 let value = unroll_array(*value, self.max_line_size, unroll_factor);
235 let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
236 let load = Instruction::new(
237 Operation::CoopMma(CoopMma::Load {
238 value,
239 stride: *stride,
240 offset: *offset,
241 layout: *layout,
242 }),
243 out,
244 );
245 vec![mul, load]
246 }
247
248 fn transform_cmma_store(
250 &self,
251 alloc: &Allocator,
252 out: Variable,
253 mat: &Variable,
254 stride: &Variable,
255 offset: &Variable,
256 layout: &MatrixLayout,
257 ) -> Vec<Instruction> {
258 let line_size = out.line_size();
259 let unroll_factor = line_size / self.max_line_size;
260
261 let out = unroll_array(out, self.max_line_size, unroll_factor);
262 let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
263 let store = Instruction::new(
264 Operation::CoopMma(CoopMma::Store {
265 mat: *mat,
266 stride: *stride,
267 offset: *offset,
268 layout: *layout,
269 }),
270 out,
271 );
272 vec![mul, store]
273 }
274
275 fn transform_memcpy(
277 &self,
278 alloc: &Allocator,
279 op: &CopyMemoryBulkOperator,
280 out: Variable,
281 unroll_factor: u32,
282 ) -> Vec<Instruction> {
283 let (mul1, in_index) = mul_index(alloc, op.in_index, unroll_factor);
284 let (mul2, offset_input) = mul_index(alloc, op.offset_input, unroll_factor);
285 let (mul3, out_index) = mul_index(alloc, op.out_index, unroll_factor);
286 let (mul4, offset_out) = mul_index(alloc, op.offset_out, unroll_factor);
287
288 let input = unroll_array(op.input, self.max_line_size, unroll_factor);
289 let out = unroll_array(out, self.max_line_size, unroll_factor);
290
291 vec![
292 mul1,
293 mul2,
294 mul3,
295 mul4,
296 Instruction::new(
297 Operator::CopyMemoryBulk(CopyMemoryBulkOperator {
298 input,
299 in_index: *in_index,
300 out_index: *out_index,
301 len: op.len * unroll_factor,
302 offset_input: *offset_input,
303 offset_out: *offset_out,
304 }),
305 out,
306 ),
307 ]
308 }
309
310 fn transform_array_index(
313 &self,
314 alloc: &Allocator,
315 out: Variable,
316 op: &IndexOperator,
317 operator: impl Fn(IndexOperator) -> Operator,
318 unroll_factor: u32,
319 mappings: &mut Mappings,
320 ) -> Vec<Instruction> {
321 let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
322 let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
323
324 let list = unroll_array(op.list, self.max_line_size, unroll_factor);
325
326 let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
327 let mut instructions = vec![mul];
328 instructions.extend((0..unroll_factor as usize).flat_map(|i| {
329 let (add, idx) = indices.next().unwrap();
330 let index = Instruction::new(
331 operator(IndexOperator {
332 list,
333 index: *idx,
334 line_size: 0,
335 unroll_factor,
336 }),
337 out[i],
338 );
339 [add, index]
340 }));
341
342 instructions
343 }
344
345 fn transform_array_index_assign(
348 &self,
349 alloc: &Allocator,
350 out: Variable,
351 op: &IndexAssignOperator,
352 operator: impl Fn(IndexAssignOperator) -> Operator,
353 unroll_factor: u32,
354 mappings: &mut Mappings,
355 ) -> Vec<Instruction> {
356 let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
357 let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
358
359 let out = unroll_array(out, self.max_line_size, unroll_factor);
360
361 let value = mappings.get(alloc, op.value, unroll_factor, self.max_line_size);
362
363 let mut instructions = vec![mul];
364 instructions.extend((0..unroll_factor as usize).flat_map(|i| {
365 let (add, idx) = indices.next().unwrap();
366 let index = Instruction::new(
367 operator(IndexAssignOperator {
368 index: *idx,
369 line_size: 0,
370 value: value[i],
371 unroll_factor,
372 }),
373 out,
374 );
375
376 [add, index]
377 }));
378
379 instructions
380 }
381
382 fn transform_composite_index(
387 &self,
388 alloc: &Allocator,
389 out: Variable,
390 op: &IndexOperator,
391 operator: impl Fn(IndexOperator) -> Operator,
392 unroll_factor: u32,
393 mappings: &mut Mappings,
394 ) -> Vec<Instruction> {
395 let index = op
396 .index
397 .as_const()
398 .expect("Can't unroll non-constant vector index")
399 .as_u32();
400
401 let unroll_idx = index / self.max_line_size;
402 let sub_idx = index % self.max_line_size;
403
404 let value = mappings.get(alloc, op.list, unroll_factor, self.max_line_size);
405
406 vec![Instruction::new(
407 operator(IndexOperator {
408 list: value[unroll_idx as usize],
409 index: sub_idx.into(),
410 line_size: 1,
411 unroll_factor,
412 }),
413 out,
414 )]
415 }
416
417 fn transform_composite_index_assign(
422 &self,
423 alloc: &Allocator,
424 out: Variable,
425 op: &IndexAssignOperator,
426 operator: impl Fn(IndexAssignOperator) -> Operator,
427 unroll_factor: u32,
428 mappings: &mut Mappings,
429 ) -> Vec<Instruction> {
430 let index = op
431 .index
432 .as_const()
433 .expect("Can't unroll non-constant vector index")
434 .as_u32();
435
436 let unroll_idx = index / self.max_line_size;
437 let sub_idx = index % self.max_line_size;
438
439 let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
440
441 vec![Instruction::new(
442 operator(IndexAssignOperator {
443 index: sub_idx.into(),
444 line_size: 1,
445 value: op.value,
446 unroll_factor,
447 }),
448 out[unroll_idx as usize],
449 )]
450 }
451
452 fn transform_metadata(
455 &self,
456 out: Variable,
457 op: &Metadata,
458 args: Vec<Variable>,
459 ) -> Vec<Instruction> {
460 let op_code = op.op_code();
461 let args = args
462 .into_iter()
463 .map(|mut var| {
464 if var.line_size() > self.max_line_size {
465 var.ty = var.ty.line(self.max_line_size);
466 }
467 var
468 })
469 .collect::<Vec<_>>();
470 let operation = Metadata::from_code_and_args(op_code, &args).unwrap();
471 vec![Instruction::new(operation, out)]
472 }
473
474 fn transform_basic(
477 &self,
478 alloc: &Allocator,
479 inst: &Instruction,
480 args: Vec<Variable>,
481 unroll_factor: u32,
482 mappings: &mut Mappings,
483 ) -> Vec<Instruction> {
484 let op_code = inst.operation.op_code();
485 let out = inst
486 .out
487 .map(|out| mappings.get(alloc, out, unroll_factor, self.max_line_size));
488 let args = args
489 .into_iter()
490 .map(|arg| {
491 if arg.line_size() > 1 {
492 mappings.get(alloc, arg, unroll_factor, self.max_line_size)
493 } else {
494 vec![arg]
496 }
497 })
498 .collect::<Vec<_>>();
499
500 (0..unroll_factor as usize)
501 .map(|i| {
502 let out = out.as_ref().map(|out| out[i]);
503 let args = args
504 .iter()
505 .map(|arg| if arg.len() == 1 { arg[0] } else { arg[i] })
506 .collect::<Vec<_>>();
507 let operation = Operation::from_code_and_args(op_code, &args)
508 .expect("Failed to reconstruct operation");
509 Instruction {
510 out,
511 source_loc: inst.source_loc.clone(),
512 modes: inst.modes,
513 operation,
514 }
515 })
516 .collect()
517 }
518
519 fn transform_instructions(
520 &self,
521 allocator: &Allocator,
522 instructions: Vec<Instruction>,
523 mappings: &mut Mappings,
524 ) -> Vec<Instruction> {
525 let mut new_instructions = Vec::with_capacity(instructions.len());
526
527 for mut instruction in instructions {
528 if let Operation::Branch(branch) = &mut instruction.operation {
529 match branch {
530 Branch::If(op) => {
531 op.scope.instructions = self.transform_instructions(
532 allocator,
533 op.scope.instructions.drain(..).collect(),
534 mappings,
535 );
536 }
537 Branch::IfElse(op) => {
538 op.scope_if.instructions = self.transform_instructions(
539 allocator,
540 op.scope_if.instructions.drain(..).collect(),
541 mappings,
542 );
543 op.scope_else.instructions = self.transform_instructions(
544 allocator,
545 op.scope_else.instructions.drain(..).collect(),
546 mappings,
547 );
548 }
549 Branch::Switch(op) => {
550 for (_, case) in &mut op.cases {
551 case.instructions = self.transform_instructions(
552 allocator,
553 case.instructions.drain(..).collect(),
554 mappings,
555 );
556 }
557 op.scope_default.instructions = self.transform_instructions(
558 allocator,
559 op.scope_default.instructions.drain(..).collect(),
560 mappings,
561 );
562 }
563 Branch::RangeLoop(op) => {
564 op.scope.instructions = self.transform_instructions(
565 allocator,
566 op.scope.instructions.drain(..).collect(),
567 mappings,
568 );
569 }
570 Branch::Loop(op) => {
571 op.scope.instructions = self.transform_instructions(
572 allocator,
573 op.scope.instructions.drain(..).collect(),
574 mappings,
575 );
576 }
577 _ => {}
578 }
579 }
580 match self.maybe_transform(allocator, &instruction, mappings) {
581 TransformAction::Ignore => {
582 new_instructions.push(instruction);
583 }
584 TransformAction::Replace(replacement) => {
585 new_instructions.extend(replacement);
586 }
587 }
588 }
589
590 new_instructions
591 }
592}
593
594impl Processor for UnrollProcessor {
595 fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
596 let mut mappings = Mappings(Default::default());
597
598 let instructions =
599 self.transform_instructions(&allocator, processing.instructions, &mut mappings);
600
601 ScopeProcessing {
602 variables: processing.variables,
603 instructions,
604 }
605 }
606}
607
608fn max_line_size(out: &Option<Variable>, args: &[Variable]) -> u32 {
609 let line_size = args.iter().map(|it| it.line_size()).max().unwrap();
610 line_size.max(out.map(|out| out.line_size()).unwrap_or(1))
611}
612
613fn create_unrolled(
614 allocator: &Allocator,
615 var: &Variable,
616 max_line_size: u32,
617 unroll_factor: u32,
618) -> Vec<ExpandElement> {
619 if var.line_size() == 1 {
621 return vec![ExpandElement::Plain(*var); unroll_factor as usize];
622 }
623
624 let item = Type::new(var.storage_type()).line(max_line_size);
625 (0..unroll_factor as usize)
626 .map(|_| match var.kind {
627 VariableKind::LocalMut { .. } | VariableKind::Versioned { .. } => {
628 allocator.create_local_mut(item)
629 }
630 VariableKind::Shared { .. } => {
631 let id = allocator.new_local_index();
632 let shared = VariableKind::Shared { id };
633 ExpandElement::Plain(Variable::new(shared, item))
634 }
635 VariableKind::LocalConst { .. } => allocator.create_local(item),
636 other => panic!("Out must be local, found {other:?}"),
637 })
638 .collect()
639}
640
641fn add_index(alloc: &Allocator, idx: Variable, i: u32) -> (Instruction, ExpandElement) {
642 let add_idx = alloc.create_local(idx.ty);
643 let add = Instruction::new(
644 Arithmetic::Add(BinaryOperator {
645 lhs: idx,
646 rhs: i.into(),
647 }),
648 *add_idx,
649 );
650 (add, add_idx)
651}
652
653fn mul_index(alloc: &Allocator, idx: Variable, unroll_factor: u32) -> (Instruction, ExpandElement) {
654 let mul_idx = alloc.create_local(idx.ty);
655 let mul = Instruction::new(
656 Arithmetic::Mul(BinaryOperator {
657 lhs: idx,
658 rhs: unroll_factor.into(),
659 }),
660 *mul_idx,
661 );
662 (mul, mul_idx)
663}
664
665fn unroll_array(mut var: Variable, max_line_size: u32, factor: u32) -> Variable {
666 var.ty = var.ty.line(max_line_size);
667
668 match &mut var.kind {
669 VariableKind::LocalArray { unroll_factor, .. }
670 | VariableKind::ConstantArray { unroll_factor, .. }
671 | VariableKind::SharedArray { unroll_factor, .. } => {
672 *unroll_factor = factor;
673 }
674 _ => {}
675 }
676
677 var
678}