1use alloc::{vec, vec::Vec};
2use cubecl_ir::{
3 Allocator, Arithmetic, BinaryOperator, Branch, CoopMma, CopyMemoryBulkOperator, ExpandElement,
4 IndexAssignOperator, IndexOperator, Instruction, LineSize, MatrixLayout, Metadata, Operation,
5 OperationReflect, Operator, Processor, ScopeProcessing, Type, Variable, VariableKind,
6};
7use hashbrown::HashMap;
8
9pub enum TransformAction {
11 Ignore,
13 Replace(Vec<Instruction>),
15}
16
17#[derive(new, Debug)]
18pub struct UnrollProcessor {
19 max_line_size: LineSize,
20}
21
22struct Mappings(HashMap<Variable, Vec<ExpandElement>>);
23
24impl Mappings {
25 fn get(
26 &mut self,
27 alloc: &Allocator,
28 var: Variable,
29 unroll_factor: usize,
30 line_size: LineSize,
31 ) -> Vec<Variable> {
32 self.0
33 .entry(var)
34 .or_insert_with(|| create_unrolled(alloc, &var, line_size, unroll_factor))
35 .iter()
36 .map(|it| **it)
37 .collect()
38 }
39}
40
41impl UnrollProcessor {
42 fn maybe_transform(
43 &self,
44 alloc: &Allocator,
45 inst: &Instruction,
46 mappings: &mut Mappings,
47 ) -> TransformAction {
48 if matches!(inst.operation, Operation::Marker(_)) {
49 return TransformAction::Ignore;
50 }
51
52 if inst.operation.args().is_none() {
53 match &inst.operation {
55 Operation::CoopMma(op) => match op {
56 CoopMma::Load {
58 value,
59 stride,
60 offset,
61 layout,
62 } if value.line_size() > self.max_line_size => {
63 return TransformAction::Replace(self.transform_cmma_load(
64 alloc,
65 inst.out(),
66 value,
67 stride,
68 offset,
69 layout,
70 ));
71 }
72 CoopMma::Store {
73 mat,
74 stride,
75 offset,
76 layout,
77 } if inst.out().line_size() > self.max_line_size => {
78 return TransformAction::Replace(self.transform_cmma_store(
79 alloc,
80 inst.out(),
81 mat,
82 stride,
83 offset,
84 layout,
85 ));
86 }
87 _ => return TransformAction::Ignore,
88 },
89 Operation::Branch(_) | Operation::NonSemantic(_) | Operation::Marker(_) => {
90 return TransformAction::Ignore;
91 }
92 _ => {
93 panic!("Need special handling for unrolling non-reflectable operations")
94 }
95 }
96 }
97
98 let args = inst.operation.args().unwrap_or_default();
99 if (inst.out.is_some() && inst.ty().line_size() > self.max_line_size)
100 || args.iter().any(|arg| arg.line_size() > self.max_line_size)
101 {
102 let line_size = max_line_size(&inst.out, &args);
103 let unroll_factor = line_size / self.max_line_size;
104
105 match &inst.operation {
106 Operation::Operator(Operator::CopyMemoryBulk(op)) => TransformAction::Replace(
107 self.transform_memcpy(alloc, op, inst.out(), unroll_factor),
108 ),
109 Operation::Operator(Operator::CopyMemory(op)) => {
110 TransformAction::Replace(self.transform_memcpy(
111 alloc,
112 &CopyMemoryBulkOperator {
113 out_index: op.out_index,
114 input: op.input,
115 in_index: op.in_index,
116 len: 1,
117 offset_input: 0.into(),
118 offset_out: 0.into(),
119 },
120 inst.out(),
121 unroll_factor,
122 ))
123 }
124 Operation::Operator(Operator::Index(op)) if op.list.is_array() => {
125 TransformAction::Replace(self.transform_array_index(
126 alloc,
127 inst.out(),
128 op,
129 Operator::Index,
130 unroll_factor,
131 mappings,
132 ))
133 }
134 Operation::Operator(Operator::UncheckedIndex(op)) if op.list.is_array() => {
135 TransformAction::Replace(self.transform_array_index(
136 alloc,
137 inst.out(),
138 op,
139 Operator::UncheckedIndex,
140 unroll_factor,
141 mappings,
142 ))
143 }
144 Operation::Operator(Operator::Index(op)) => {
145 TransformAction::Replace(self.transform_composite_index(
146 alloc,
147 inst.out(),
148 op,
149 Operator::Index,
150 unroll_factor,
151 mappings,
152 ))
153 }
154 Operation::Operator(Operator::UncheckedIndex(op)) => {
155 TransformAction::Replace(self.transform_composite_index(
156 alloc,
157 inst.out(),
158 op,
159 Operator::UncheckedIndex,
160 unroll_factor,
161 mappings,
162 ))
163 }
164 Operation::Operator(Operator::IndexAssign(op)) if inst.out().is_array() => {
165 TransformAction::Replace(self.transform_array_index_assign(
166 alloc,
167 inst.out(),
168 op,
169 Operator::IndexAssign,
170 unroll_factor,
171 mappings,
172 ))
173 }
174 Operation::Operator(Operator::UncheckedIndexAssign(op))
175 if inst.out().is_array() =>
176 {
177 TransformAction::Replace(self.transform_array_index_assign(
178 alloc,
179 inst.out(),
180 op,
181 Operator::UncheckedIndexAssign,
182 unroll_factor,
183 mappings,
184 ))
185 }
186 Operation::Operator(Operator::IndexAssign(op)) => {
187 TransformAction::Replace(self.transform_composite_index_assign(
188 alloc,
189 inst.out(),
190 op,
191 Operator::IndexAssign,
192 unroll_factor,
193 mappings,
194 ))
195 }
196 Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
197 TransformAction::Replace(self.transform_composite_index_assign(
198 alloc,
199 inst.out(),
200 op,
201 Operator::UncheckedIndexAssign,
202 unroll_factor,
203 mappings,
204 ))
205 }
206 Operation::Metadata(op) => {
207 TransformAction::Replace(self.transform_metadata(inst.out(), op, args))
208 }
209 _ => TransformAction::Replace(self.transform_basic(
210 alloc,
211 inst,
212 args,
213 unroll_factor,
214 mappings,
215 )),
216 }
217 } else {
218 TransformAction::Ignore
219 }
220 }
221
222 fn transform_cmma_load(
224 &self,
225 alloc: &Allocator,
226 out: Variable,
227 value: &Variable,
228 stride: &Variable,
229 offset: &Variable,
230 layout: &Option<MatrixLayout>,
231 ) -> Vec<Instruction> {
232 let line_size = value.line_size();
233 let unroll_factor = line_size / self.max_line_size;
234
235 let value = unroll_array(*value, self.max_line_size, unroll_factor);
236 let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
237 let load = Instruction::new(
238 Operation::CoopMma(CoopMma::Load {
239 value,
240 stride: *stride,
241 offset: *offset,
242 layout: *layout,
243 }),
244 out,
245 );
246 vec![mul, load]
247 }
248
249 fn transform_cmma_store(
251 &self,
252 alloc: &Allocator,
253 out: Variable,
254 mat: &Variable,
255 stride: &Variable,
256 offset: &Variable,
257 layout: &MatrixLayout,
258 ) -> Vec<Instruction> {
259 let line_size = out.line_size();
260 let unroll_factor = line_size / self.max_line_size;
261
262 let out = unroll_array(out, self.max_line_size, unroll_factor);
263 let (mul, offset) = mul_index(alloc, *offset, unroll_factor);
264 let store = Instruction::new(
265 Operation::CoopMma(CoopMma::Store {
266 mat: *mat,
267 stride: *stride,
268 offset: *offset,
269 layout: *layout,
270 }),
271 out,
272 );
273 vec![mul, store]
274 }
275
276 fn transform_memcpy(
278 &self,
279 alloc: &Allocator,
280 op: &CopyMemoryBulkOperator,
281 out: Variable,
282 unroll_factor: usize,
283 ) -> Vec<Instruction> {
284 let (mul1, in_index) = mul_index(alloc, op.in_index, unroll_factor);
285 let (mul2, offset_input) = mul_index(alloc, op.offset_input, unroll_factor);
286 let (mul3, out_index) = mul_index(alloc, op.out_index, unroll_factor);
287 let (mul4, offset_out) = mul_index(alloc, op.offset_out, unroll_factor);
288
289 let input = unroll_array(op.input, self.max_line_size, unroll_factor);
290 let out = unroll_array(out, self.max_line_size, unroll_factor);
291
292 vec![
293 mul1,
294 mul2,
295 mul3,
296 mul4,
297 Instruction::new(
298 Operator::CopyMemoryBulk(CopyMemoryBulkOperator {
299 input,
300 in_index: *in_index,
301 out_index: *out_index,
302 len: op.len * unroll_factor,
303 offset_input: *offset_input,
304 offset_out: *offset_out,
305 }),
306 out,
307 ),
308 ]
309 }
310
311 fn transform_array_index(
314 &self,
315 alloc: &Allocator,
316 out: Variable,
317 op: &IndexOperator,
318 operator: impl Fn(IndexOperator) -> Operator,
319 unroll_factor: usize,
320 mappings: &mut Mappings,
321 ) -> Vec<Instruction> {
322 let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
323 let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
324
325 let list = unroll_array(op.list, self.max_line_size, unroll_factor);
326
327 let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
328 let mut instructions = vec![mul];
329 instructions.extend((0..unroll_factor).flat_map(|i| {
330 let (add, idx) = indices.next().unwrap();
331 let index = Instruction::new(
332 operator(IndexOperator {
333 list,
334 index: *idx,
335 line_size: 0,
336 unroll_factor,
337 }),
338 out[i],
339 );
340 [add, index]
341 }));
342
343 instructions
344 }
345
346 fn transform_array_index_assign(
349 &self,
350 alloc: &Allocator,
351 out: Variable,
352 op: &IndexAssignOperator,
353 operator: impl Fn(IndexAssignOperator) -> Operator,
354 unroll_factor: usize,
355 mappings: &mut Mappings,
356 ) -> Vec<Instruction> {
357 let (mul, start_idx) = mul_index(alloc, op.index, unroll_factor);
358 let mut indices = (0..unroll_factor).map(|i| add_index(alloc, *start_idx, i));
359
360 let out = unroll_array(out, self.max_line_size, unroll_factor);
361
362 let value = mappings.get(alloc, op.value, unroll_factor, self.max_line_size);
363
364 let mut instructions = vec![mul];
365 instructions.extend((0..unroll_factor).flat_map(|i| {
366 let (add, idx) = indices.next().unwrap();
367 let index = Instruction::new(
368 operator(IndexAssignOperator {
369 index: *idx,
370 line_size: 0,
371 value: value[i],
372 unroll_factor,
373 }),
374 out,
375 );
376
377 [add, index]
378 }));
379
380 instructions
381 }
382
383 fn transform_composite_index(
388 &self,
389 alloc: &Allocator,
390 out: Variable,
391 op: &IndexOperator,
392 operator: impl Fn(IndexOperator) -> Operator,
393 unroll_factor: usize,
394 mappings: &mut Mappings,
395 ) -> Vec<Instruction> {
396 let index = op
397 .index
398 .as_const()
399 .expect("Can't unroll non-constant vector index")
400 .as_usize();
401
402 let unroll_idx = index / self.max_line_size;
403 let sub_idx = index % self.max_line_size;
404
405 let value = mappings.get(alloc, op.list, unroll_factor, self.max_line_size);
406
407 vec![Instruction::new(
408 operator(IndexOperator {
409 list: value[unroll_idx],
410 index: sub_idx.into(),
411 line_size: 1,
412 unroll_factor,
413 }),
414 out,
415 )]
416 }
417
418 fn transform_composite_index_assign(
423 &self,
424 alloc: &Allocator,
425 out: Variable,
426 op: &IndexAssignOperator,
427 operator: impl Fn(IndexAssignOperator) -> Operator,
428 unroll_factor: usize,
429 mappings: &mut Mappings,
430 ) -> Vec<Instruction> {
431 let index = op
432 .index
433 .as_const()
434 .expect("Can't unroll non-constant vector index")
435 .as_usize();
436
437 let unroll_idx = index / self.max_line_size;
438 let sub_idx = index % self.max_line_size;
439
440 let out = mappings.get(alloc, out, unroll_factor, self.max_line_size);
441
442 vec![Instruction::new(
443 operator(IndexAssignOperator {
444 index: sub_idx.into(),
445 line_size: 1,
446 value: op.value,
447 unroll_factor,
448 }),
449 out[unroll_idx],
450 )]
451 }
452
453 fn transform_metadata(
456 &self,
457 out: Variable,
458 op: &Metadata,
459 args: Vec<Variable>,
460 ) -> Vec<Instruction> {
461 let op_code = op.op_code();
462 let args = args
463 .into_iter()
464 .map(|mut var| {
465 if var.line_size() > self.max_line_size {
466 var.ty = var.ty.line(self.max_line_size);
467 }
468 var
469 })
470 .collect::<Vec<_>>();
471 let operation = Metadata::from_code_and_args(op_code, &args).unwrap();
472 vec![Instruction::new(operation, out)]
473 }
474
475 fn transform_basic(
478 &self,
479 alloc: &Allocator,
480 inst: &Instruction,
481 args: Vec<Variable>,
482 unroll_factor: usize,
483 mappings: &mut Mappings,
484 ) -> Vec<Instruction> {
485 let op_code = inst.operation.op_code();
486 let out = inst
487 .out
488 .map(|out| mappings.get(alloc, out, unroll_factor, self.max_line_size));
489 let args = args
490 .into_iter()
491 .map(|arg| {
492 if arg.line_size() > 1 {
493 mappings.get(alloc, arg, unroll_factor, self.max_line_size)
494 } else {
495 vec![arg]
497 }
498 })
499 .collect::<Vec<_>>();
500
501 (0..unroll_factor)
502 .map(|i| {
503 let out = out.as_ref().map(|out| out[i]);
504 let args = args
505 .iter()
506 .map(|arg| if arg.len() == 1 { arg[0] } else { arg[i] })
507 .collect::<Vec<_>>();
508 let operation = Operation::from_code_and_args(op_code, &args)
509 .expect("Failed to reconstruct operation");
510 Instruction {
511 out,
512 source_loc: inst.source_loc.clone(),
513 modes: inst.modes,
514 operation,
515 }
516 })
517 .collect()
518 }
519
520 fn transform_instructions(
521 &self,
522 allocator: &Allocator,
523 instructions: Vec<Instruction>,
524 mappings: &mut Mappings,
525 ) -> Vec<Instruction> {
526 let mut new_instructions = Vec::with_capacity(instructions.len());
527
528 for mut instruction in instructions {
529 if let Operation::Branch(branch) = &mut instruction.operation {
530 match branch {
531 Branch::If(op) => {
532 op.scope.instructions = self.transform_instructions(
533 allocator,
534 op.scope.instructions.drain(..).collect(),
535 mappings,
536 );
537 }
538 Branch::IfElse(op) => {
539 op.scope_if.instructions = self.transform_instructions(
540 allocator,
541 op.scope_if.instructions.drain(..).collect(),
542 mappings,
543 );
544 op.scope_else.instructions = self.transform_instructions(
545 allocator,
546 op.scope_else.instructions.drain(..).collect(),
547 mappings,
548 );
549 }
550 Branch::Switch(op) => {
551 for (_, case) in &mut op.cases {
552 case.instructions = self.transform_instructions(
553 allocator,
554 case.instructions.drain(..).collect(),
555 mappings,
556 );
557 }
558 op.scope_default.instructions = self.transform_instructions(
559 allocator,
560 op.scope_default.instructions.drain(..).collect(),
561 mappings,
562 );
563 }
564 Branch::RangeLoop(op) => {
565 op.scope.instructions = self.transform_instructions(
566 allocator,
567 op.scope.instructions.drain(..).collect(),
568 mappings,
569 );
570 }
571 Branch::Loop(op) => {
572 op.scope.instructions = self.transform_instructions(
573 allocator,
574 op.scope.instructions.drain(..).collect(),
575 mappings,
576 );
577 }
578 _ => {}
579 }
580 }
581 match self.maybe_transform(allocator, &instruction, mappings) {
582 TransformAction::Ignore => {
583 new_instructions.push(instruction);
584 }
585 TransformAction::Replace(replacement) => {
586 new_instructions.extend(replacement);
587 }
588 }
589 }
590
591 new_instructions
592 }
593}
594
595impl Processor for UnrollProcessor {
596 fn transform(&self, processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
597 let mut mappings = Mappings(Default::default());
598
599 let instructions =
600 self.transform_instructions(&allocator, processing.instructions, &mut mappings);
601
602 ScopeProcessing {
603 variables: processing.variables,
604 instructions,
605 typemap: processing.typemap.clone(),
606 }
607 }
608}
609
610fn max_line_size(out: &Option<Variable>, args: &[Variable]) -> LineSize {
611 let line_size = args.iter().map(|it| it.line_size()).max().unwrap();
612 line_size.max(out.map(|out| out.line_size()).unwrap_or(1))
613}
614
615fn create_unrolled(
616 allocator: &Allocator,
617 var: &Variable,
618 max_line_size: LineSize,
619 unroll_factor: usize,
620) -> Vec<ExpandElement> {
621 if var.line_size() == 1 {
623 return vec![ExpandElement::Plain(*var); unroll_factor];
624 }
625
626 let item = Type::new(var.storage_type()).line(max_line_size);
627 (0..unroll_factor)
628 .map(|_| match var.kind {
629 VariableKind::LocalMut { .. } | VariableKind::Versioned { .. } => {
630 allocator.create_local_mut(item)
631 }
632 VariableKind::Shared { .. } => {
633 let id = allocator.new_local_index();
634 let shared = VariableKind::Shared { id };
635 ExpandElement::Plain(Variable::new(shared, item))
636 }
637 VariableKind::LocalConst { .. } => allocator.create_local(item),
638 other => panic!("Out must be local, found {other:?}"),
639 })
640 .collect()
641}
642
643fn add_index(alloc: &Allocator, idx: Variable, i: usize) -> (Instruction, ExpandElement) {
644 let add_idx = alloc.create_local(idx.ty);
645 let add = Instruction::new(
646 Arithmetic::Add(BinaryOperator {
647 lhs: idx,
648 rhs: i.into(),
649 }),
650 *add_idx,
651 );
652 (add, add_idx)
653}
654
655fn mul_index(
656 alloc: &Allocator,
657 idx: Variable,
658 unroll_factor: usize,
659) -> (Instruction, ExpandElement) {
660 let mul_idx = alloc.create_local(idx.ty);
661 let mul = Instruction::new(
662 Arithmetic::Mul(BinaryOperator {
663 lhs: idx,
664 rhs: unroll_factor.into(),
665 }),
666 *mul_idx,
667 );
668 (mul, mul_idx)
669}
670
671fn unroll_array(mut var: Variable, max_line_size: LineSize, factor: usize) -> Variable {
672 var.ty = var.ty.line(max_line_size);
673
674 match &mut var.kind {
675 VariableKind::LocalArray { unroll_factor, .. }
676 | VariableKind::ConstantArray { unroll_factor, .. }
677 | VariableKind::SharedArray { unroll_factor, .. } => {
678 *unroll_factor = factor;
679 }
680 _ => {}
681 }
682
683 var
684}