cubecl_cpp/cuda/
processors.rs

1use cubecl_core::{
2    self as cubecl,
3    ir::{ExpandElement, Instruction, Scope},
4};
5use cubecl_core::{
6    cube,
7    ir::{Allocator, CoopMma, MatrixIdent, Operation, Processor, ScopeProcessing},
8};
9
10#[derive(new, Debug)]
11pub struct CudaMmaProcessor;
12
13impl Processor for CudaMmaProcessor {
14    fn transform(&self, mut processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
15        let mut instructions = Vec::new();
16        core::mem::swap(&mut processing.instructions, &mut instructions);
17
18        for instruction in instructions {
19            match instruction.operation {
20                Operation::CoopMma(CoopMma::RowIndex { lane_id, i, matrix }) => {
21                    let lane_id = ExpandElement::Plain(lane_id);
22                    let i = ExpandElement::Plain(i);
23                    let elems_per_reg = 32 / matrix.storage.elem_type().size_bits();
24                    let mut scope = Scope::root(false).with_allocator(allocator.clone());
25                    let row_idx: ExpandElement = row_index::expand(
26                        &mut scope,
27                        lane_id.into(),
28                        i.into(),
29                        elems_per_reg as u32,
30                        matrix.ident,
31                    )
32                    .into();
33                    let tmp_processing = scope.process([]);
34                    for inst in tmp_processing.instructions {
35                        processing.instructions.push(inst);
36                    }
37                    for var in tmp_processing.variables {
38                        processing.variables.push(var);
39                    }
40
41                    processing.instructions.push(Instruction::new(
42                        Operation::Copy(*row_idx),
43                        instruction.out(),
44                    ));
45                }
46                Operation::CoopMma(CoopMma::ColIndex { lane_id, i, matrix }) => {
47                    let lane_id = ExpandElement::Plain(lane_id);
48                    let i = ExpandElement::Plain(i);
49                    let elems_per_reg = 32 / matrix.storage.elem_type().size_bits();
50                    let mut scope = Scope::root(false).with_allocator(allocator.clone());
51                    let col_idx: ExpandElement = col_index::expand(
52                        &mut scope,
53                        lane_id.into(),
54                        i.into(),
55                        elems_per_reg as u32,
56                        matrix.ident,
57                    )
58                    .into();
59                    let tmp_processing = scope.process([]);
60                    for inst in tmp_processing.instructions {
61                        processing.instructions.push(inst);
62                    }
63                    for var in tmp_processing.variables {
64                        processing.variables.push(var);
65                    }
66
67                    processing.instructions.push(Instruction::new(
68                        Operation::Copy(*col_idx),
69                        instruction.out(),
70                    ));
71                }
72                _ => {
73                    processing.instructions.push(instruction);
74                }
75            }
76        }
77
78        processing
79    }
80}
81
82/// Derived from PTX shape documentation
83/// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-mma
84#[cube]
85fn row_index(
86    lane_id: u32,
87    i: u32,
88    #[comptime] elems_per_reg: u32,
89    #[comptime] ident: MatrixIdent,
90) -> u32 {
91    match ident {
92        MatrixIdent::A => {
93            let group_id = lane_id / 4;
94            let odd_register = (i / elems_per_reg) & 1;
95            group_id + odd_register * 8
96        }
97        MatrixIdent::B => {
98            let thread_id_in_group = lane_id % 4;
99            let offset = thread_id_in_group * elems_per_reg + (i % elems_per_reg);
100            let reg = i / elems_per_reg;
101            offset + elems_per_reg * 4 * reg
102        }
103        MatrixIdent::Accumulator => {
104            let group_id = lane_id / 4;
105            let offset = (i << 2) & 8;
106            group_id + offset
107        }
108    }
109}
110
111/// Derived from PTX shape documentation
112/// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-mma
113#[cube]
114fn col_index(
115    lane_id: u32,
116    i: u32,
117    #[comptime] elems_per_reg: u32,
118    #[comptime] ident: MatrixIdent,
119) -> u32 {
120    match ident {
121        MatrixIdent::A => {
122            let thread_id_in_group = lane_id % 4;
123            let offset = thread_id_in_group * elems_per_reg + (i % elems_per_reg);
124            let group_2 = (i / (2 * elems_per_reg)) & 1;
125            offset + 4 * elems_per_reg * group_2
126        }
127        MatrixIdent::B => lane_id >> 2,
128        MatrixIdent::Accumulator => {
129            let thread_id_in_group = lane_id % 4;
130            (thread_id_in_group * 2) + (i % 2)
131        }
132    }
133}