cubecl_cpp/cuda/
processors.rs1use cubecl_core::{
2 self as cubecl,
3 ir::{ExpandElement, Instruction, Scope},
4};
5use cubecl_core::{
6 cube,
7 ir::{Allocator, CoopMma, MatrixIdent, Operation, Processor, ScopeProcessing},
8};
9
10#[derive(new, Debug)]
11pub struct CudaMmaProcessor;
12
13impl Processor for CudaMmaProcessor {
14 fn transform(&self, mut processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
15 let mut instructions = Vec::new();
16 core::mem::swap(&mut processing.instructions, &mut instructions);
17
18 for instruction in instructions {
19 match instruction.operation {
20 Operation::CoopMma(CoopMma::RowIndex { lane_id, i, matrix }) => {
21 let lane_id = ExpandElement::Plain(lane_id);
22 let i = ExpandElement::Plain(i);
23 let elems_per_reg = 32 / matrix.storage.elem_type().size_bits();
24 let mut scope = Scope::root(false).with_allocator(allocator.clone());
25 let row_idx: ExpandElement = row_index::expand(
26 &mut scope,
27 lane_id.into(),
28 i.into(),
29 elems_per_reg as u32,
30 matrix.ident,
31 )
32 .into();
33 let tmp_processing = scope.process([]);
34 for inst in tmp_processing.instructions {
35 processing.instructions.push(inst);
36 }
37 for var in tmp_processing.variables {
38 processing.variables.push(var);
39 }
40
41 processing.instructions.push(Instruction::new(
42 Operation::Copy(*row_idx),
43 instruction.out(),
44 ));
45 }
46 Operation::CoopMma(CoopMma::ColIndex { lane_id, i, matrix }) => {
47 let lane_id = ExpandElement::Plain(lane_id);
48 let i = ExpandElement::Plain(i);
49 let elems_per_reg = 32 / matrix.storage.elem_type().size_bits();
50 let mut scope = Scope::root(false).with_allocator(allocator.clone());
51 let col_idx: ExpandElement = col_index::expand(
52 &mut scope,
53 lane_id.into(),
54 i.into(),
55 elems_per_reg as u32,
56 matrix.ident,
57 )
58 .into();
59 let tmp_processing = scope.process([]);
60 for inst in tmp_processing.instructions {
61 processing.instructions.push(inst);
62 }
63 for var in tmp_processing.variables {
64 processing.variables.push(var);
65 }
66
67 processing.instructions.push(Instruction::new(
68 Operation::Copy(*col_idx),
69 instruction.out(),
70 ));
71 }
72 _ => {
73 processing.instructions.push(instruction);
74 }
75 }
76 }
77
78 processing
79 }
80}
81
82#[cube]
85fn row_index(
86 lane_id: u32,
87 i: u32,
88 #[comptime] elems_per_reg: u32,
89 #[comptime] ident: MatrixIdent,
90) -> u32 {
91 match ident {
92 MatrixIdent::A => {
93 let group_id = lane_id / 4;
94 let odd_register = (i / elems_per_reg) & 1;
95 group_id + odd_register * 8
96 }
97 MatrixIdent::B => {
98 let thread_id_in_group = lane_id % 4;
99 let offset = thread_id_in_group * elems_per_reg + (i % elems_per_reg);
100 let reg = i / elems_per_reg;
101 offset + elems_per_reg * 4 * reg
102 }
103 MatrixIdent::Accumulator => {
104 let group_id = lane_id / 4;
105 let offset = (i << 2) & 8;
106 group_id + offset
107 }
108 }
109}
110
111#[cube]
114fn col_index(
115 lane_id: u32,
116 i: u32,
117 #[comptime] elems_per_reg: u32,
118 #[comptime] ident: MatrixIdent,
119) -> u32 {
120 match ident {
121 MatrixIdent::A => {
122 let thread_id_in_group = lane_id % 4;
123 let offset = thread_id_in_group * elems_per_reg + (i % elems_per_reg);
124 let group_2 = (i / (2 * elems_per_reg)) & 1;
125 offset + 4 * elems_per_reg * group_2
126 }
127 MatrixIdent::B => lane_id >> 2,
128 MatrixIdent::Accumulator => {
129 let thread_id_in_group = lane_id % 4;
130 (thread_id_in_group * 2) + (i % 2)
131 }
132 }
133}