cubecl_cpp/cuda/
processors.rs

1use cubecl_core::{
2    self as cubecl,
3    ir::{ExpandElement, Instruction, Scope},
4};
5use cubecl_core::{
6    cube,
7    ir::{Allocator, CoopMma, MatrixIdent, Operation, Processor, ScopeProcessing},
8};
9
10#[derive(new, Debug)]
11pub struct CudaMmaProcessor;
12
13impl Processor for CudaMmaProcessor {
14    fn transform(&self, mut processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
15        let mut instructions = Vec::new();
16        core::mem::swap(&mut processing.instructions, &mut instructions);
17
18        for instruction in instructions {
19            match instruction.operation {
20                Operation::CoopMma(CoopMma::RowIndex { lane_id, i, matrix }) => {
21                    let lane_id = ExpandElement::Plain(lane_id);
22                    let i = ExpandElement::Plain(i);
23                    let elems_per_reg = 32 / matrix.storage.elem_type().size_bits();
24                    let mut scope = Scope::root(false)
25                        .with_allocator(allocator.clone())
26                        .with_types(processing.typemap.clone());
27                    let row_idx: ExpandElement = row_index::expand(
28                        &mut scope,
29                        lane_id.into(),
30                        i.into(),
31                        elems_per_reg as u32,
32                        matrix.ident,
33                    )
34                    .into();
35                    let tmp_processing = scope.process([]);
36                    for inst in tmp_processing.instructions {
37                        processing.instructions.push(inst);
38                    }
39                    for var in tmp_processing.variables {
40                        processing.variables.push(var);
41                    }
42
43                    processing.instructions.push(Instruction::new(
44                        Operation::Copy(*row_idx),
45                        instruction.out(),
46                    ));
47                }
48                Operation::CoopMma(CoopMma::ColIndex { lane_id, i, matrix }) => {
49                    let lane_id = ExpandElement::Plain(lane_id);
50                    let i = ExpandElement::Plain(i);
51                    let elems_per_reg = 32 / matrix.storage.elem_type().size_bits();
52                    let mut scope = Scope::root(false)
53                        .with_allocator(allocator.clone())
54                        .with_types(processing.typemap.clone());
55                    let col_idx: ExpandElement = col_index::expand(
56                        &mut scope,
57                        lane_id.into(),
58                        i.into(),
59                        elems_per_reg as u32,
60                        matrix.ident,
61                    )
62                    .into();
63                    let tmp_processing = scope.process([]);
64                    for inst in tmp_processing.instructions {
65                        processing.instructions.push(inst);
66                    }
67                    for var in tmp_processing.variables {
68                        processing.variables.push(var);
69                    }
70
71                    processing.instructions.push(Instruction::new(
72                        Operation::Copy(*col_idx),
73                        instruction.out(),
74                    ));
75                }
76                _ => {
77                    processing.instructions.push(instruction);
78                }
79            }
80        }
81
82        processing
83    }
84}
85
86/// Derived from PTX shape documentation
87/// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-mma
88#[cube]
89fn row_index(
90    lane_id: u32,
91    i: u32,
92    #[comptime] elems_per_reg: u32,
93    #[comptime] ident: MatrixIdent,
94) -> u32 {
95    match ident {
96        MatrixIdent::A => {
97            let group_id = lane_id / 4;
98            let odd_register = (i / elems_per_reg) & 1;
99            group_id + odd_register * 8
100        }
101        MatrixIdent::B => {
102            let thread_id_in_group = lane_id % 4;
103            let offset = thread_id_in_group * elems_per_reg + (i % elems_per_reg);
104            let reg = i / elems_per_reg;
105            offset + elems_per_reg * 4 * reg
106        }
107        MatrixIdent::Accumulator => {
108            let group_id = lane_id / 4;
109            let offset = (i << 2) & 8;
110            group_id + offset
111        }
112    }
113}
114
115/// Derived from PTX shape documentation
116/// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-mma
117#[cube]
118fn col_index(
119    lane_id: u32,
120    i: u32,
121    #[comptime] elems_per_reg: u32,
122    #[comptime] ident: MatrixIdent,
123) -> u32 {
124    match ident {
125        MatrixIdent::A => {
126            let thread_id_in_group = lane_id % 4;
127            let offset = thread_id_in_group * elems_per_reg + (i % elems_per_reg);
128            let group_2 = (i / (2 * elems_per_reg)) & 1;
129            offset + 4 * elems_per_reg * group_2
130        }
131        MatrixIdent::B => lane_id >> 2,
132        MatrixIdent::Accumulator => {
133            let thread_id_in_group = lane_id % 4;
134            (thread_id_in_group * 2) + (i % 2)
135        }
136    }
137}