cubecl_cpp/cuda/
processors.rs1use cubecl_core::{
2 self as cubecl,
3 ir::{ExpandElement, Instruction, Scope},
4};
5use cubecl_core::{
6 cube,
7 ir::{Allocator, CoopMma, MatrixIdent, Operation, Processor, ScopeProcessing},
8};
9
10#[derive(new, Debug)]
11pub struct CudaMmaProcessor;
12
13impl Processor for CudaMmaProcessor {
14 fn transform(&self, mut processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
15 let mut instructions = Vec::new();
16 core::mem::swap(&mut processing.instructions, &mut instructions);
17
18 for instruction in instructions {
19 match instruction.operation {
20 Operation::CoopMma(CoopMma::RowIndex { lane_id, i, matrix }) => {
21 let lane_id = ExpandElement::Plain(lane_id);
22 let i = ExpandElement::Plain(i);
23 let elems_per_reg = 32 / matrix.storage.elem_type().size_bits();
24 let mut scope = Scope::root(false)
25 .with_allocator(allocator.clone())
26 .with_types(processing.typemap.clone());
27 let row_idx: ExpandElement = row_index::expand(
28 &mut scope,
29 lane_id.into(),
30 i.into(),
31 elems_per_reg as u32,
32 matrix.ident,
33 )
34 .into();
35 let tmp_processing = scope.process([]);
36 for inst in tmp_processing.instructions {
37 processing.instructions.push(inst);
38 }
39 for var in tmp_processing.variables {
40 processing.variables.push(var);
41 }
42
43 processing.instructions.push(Instruction::new(
44 Operation::Copy(*row_idx),
45 instruction.out(),
46 ));
47 }
48 Operation::CoopMma(CoopMma::ColIndex { lane_id, i, matrix }) => {
49 let lane_id = ExpandElement::Plain(lane_id);
50 let i = ExpandElement::Plain(i);
51 let elems_per_reg = 32 / matrix.storage.elem_type().size_bits();
52 let mut scope = Scope::root(false)
53 .with_allocator(allocator.clone())
54 .with_types(processing.typemap.clone());
55 let col_idx: ExpandElement = col_index::expand(
56 &mut scope,
57 lane_id.into(),
58 i.into(),
59 elems_per_reg as u32,
60 matrix.ident,
61 )
62 .into();
63 let tmp_processing = scope.process([]);
64 for inst in tmp_processing.instructions {
65 processing.instructions.push(inst);
66 }
67 for var in tmp_processing.variables {
68 processing.variables.push(var);
69 }
70
71 processing.instructions.push(Instruction::new(
72 Operation::Copy(*col_idx),
73 instruction.out(),
74 ));
75 }
76 _ => {
77 processing.instructions.push(instruction);
78 }
79 }
80 }
81
82 processing
83 }
84}
85
86#[cube]
89fn row_index(
90 lane_id: u32,
91 i: u32,
92 #[comptime] elems_per_reg: u32,
93 #[comptime] ident: MatrixIdent,
94) -> u32 {
95 match ident {
96 MatrixIdent::A => {
97 let group_id = lane_id / 4;
98 let odd_register = (i / elems_per_reg) & 1;
99 group_id + odd_register * 8
100 }
101 MatrixIdent::B => {
102 let thread_id_in_group = lane_id % 4;
103 let offset = thread_id_in_group * elems_per_reg + (i % elems_per_reg);
104 let reg = i / elems_per_reg;
105 offset + elems_per_reg * 4 * reg
106 }
107 MatrixIdent::Accumulator => {
108 let group_id = lane_id / 4;
109 let offset = (i << 2) & 8;
110 group_id + offset
111 }
112 }
113}
114
115#[cube]
118fn col_index(
119 lane_id: u32,
120 i: u32,
121 #[comptime] elems_per_reg: u32,
122 #[comptime] ident: MatrixIdent,
123) -> u32 {
124 match ident {
125 MatrixIdent::A => {
126 let thread_id_in_group = lane_id % 4;
127 let offset = thread_id_in_group * elems_per_reg + (i % elems_per_reg);
128 let group_2 = (i / (2 * elems_per_reg)) & 1;
129 offset + 4 * elems_per_reg * group_2
130 }
131 MatrixIdent::B => lane_id >> 2,
132 MatrixIdent::Accumulator => {
133 let thread_id_in_group = lane_id % 4;
134 (thread_id_in_group * 2) + (i % 2)
135 }
136 }
137}