cubecl_spirv/
cmma.rs

1use crate::{
2    SpirvCompiler, SpirvTarget,
3    item::{Elem, Item},
4    lookups::Matrix,
5    variable::Variable,
6};
7use cubecl_core::ir::{self as core, CoopMma, ElemType, Id, MatrixLayout};
8use rspirv::spirv::{
9    Capability, CooperativeMatrixLayout, CooperativeMatrixOperands, CooperativeMatrixUse,
10    StorageClass,
11};
12
13impl<T: SpirvTarget> SpirvCompiler<T> {
14    pub fn compile_cmma(&mut self, cmma: CoopMma, out: Option<core::Variable>) {
15        self.capabilities.insert(Capability::CooperativeMatrixKHR);
16        let out = out.unwrap();
17        match cmma {
18            CoopMma::Fill { value } => self.compile_fill(out, value),
19            CoopMma::Load {
20                value,
21                stride,
22                layout,
23                offset,
24            } => self.compile_load(out, value, stride, offset, layout),
25            CoopMma::Execute {
26                mat_a,
27                mat_b,
28                mat_c,
29            } => self.compile_execute(mat_a, mat_b, mat_c, out),
30            CoopMma::Store {
31                mat,
32                stride,
33                layout,
34                offset,
35            } => self.compile_store(mat, out, stride, offset, layout),
36            CoopMma::Cast { input } => self.compile_cast(input, out),
37            CoopMma::RowIndex { .. }
38            | CoopMma::ColIndex { .. }
39            | CoopMma::LoadMatrix { .. }
40            | CoopMma::StoreMatrix { .. }
41            | CoopMma::ExecuteManual { .. }
42            | CoopMma::ExecuteScaled { .. } => {
43                panic!("Manual register management not currently supported in SPIR-V")
44            }
45        }
46    }
47
48    fn compile_load(
49        &mut self,
50        mat: core::Variable,
51        value: core::Variable,
52        stride: core::Variable,
53        offset: core::Variable,
54        layout: Option<MatrixLayout>,
55    ) {
56        let mat = self.compile_variable(mat);
57        let mat = self.matrix_var(&mat).1;
58
59        let value = self.compile_variable(value);
60        let stride = self.compile_variable(stride);
61        let stride_item = stride.item();
62        let mut stride = self.read(&stride);
63
64        if let Item::Vector(_, line_size) = value.item() {
65            let shift = stride_item.const_u32(self, line_size.trailing_zeros());
66            let stride_ty = stride_item.id(self);
67            stride = self
68                .shift_right_logical(stride_ty, None, stride, shift)
69                .unwrap();
70        }
71
72        let layout = layout
73            .and_then(compile_layout)
74            .or(mat.layout)
75            .unwrap_or(CooperativeMatrixLayout::RowMajorKHR);
76        let memory_layout = self.const_u32(layout as u32);
77
78        let offset = self.compile_variable(offset);
79        let ptr = self.index_ptr(&value, &offset);
80        let out_ty = self.item(&mat);
81        let ty = out_ty.id(self);
82
83        let mat_id = self
84            .cooperative_matrix_load_khr(ty, None, ptr, memory_layout, Some(stride), None, vec![])
85            .unwrap();
86
87        self.store(mat.id, mat_id, None, vec![]).unwrap();
88    }
89
90    fn compile_fill(&mut self, mat: core::Variable, value: core::Variable) {
91        let mat = self.compile_variable(mat);
92        let value = self.compile_variable(value);
93        let mat = self.matrix_var(&mat).1;
94        let item = self.item(&mat);
95        let ty = item.id(self);
96        let mat_id = match value {
97            Variable::ConstantScalar(id, _, _) => self.constant_composite(ty, vec![id]),
98            var => {
99                let var = self.read(&var);
100                self.composite_construct(ty, None, vec![var]).unwrap()
101            }
102        };
103
104        self.store(mat.id, mat_id, None, vec![]).unwrap();
105    }
106
107    fn compile_store(
108        &mut self,
109        mat: core::Variable,
110        out: core::Variable,
111        stride: core::Variable,
112        offset: core::Variable,
113        layout: MatrixLayout,
114    ) {
115        let mat = self.compile_variable(mat);
116        let mat = self.matrix_var(&mat).1;
117        let item = self.item(&mat);
118        let ty = item.id(self);
119        let mat_obj = self.load(ty, None, mat.id, None, vec![]).unwrap();
120        //assert_ne!(mat_obj, 0, "Can't store uninitialized matrix");
121
122        let out = self.compile_variable(out);
123        let stride = self.compile_variable(stride);
124        let stride_item = stride.item();
125        let mut stride = self.read(&stride);
126        let layout = compile_layout(layout).unwrap_or(CooperativeMatrixLayout::RowMajorKHR);
127        let memory_layout = self.const_u32(layout as u32);
128        let offset = self.compile_variable(offset);
129        let ptr = self.index_ptr(&out, &offset);
130
131        if let Item::Vector(_, line_size) = out.item() {
132            let shift = stride_item.const_u32(self, line_size.trailing_zeros());
133            let stride_ty = stride_item.id(self);
134            stride = self
135                .shift_right_logical(stride_ty, None, stride, shift)
136                .unwrap();
137        }
138
139        self.cooperative_matrix_store_khr(ptr, mat_obj, memory_layout, Some(stride), None, vec![])
140            .unwrap();
141    }
142
143    fn compile_execute(
144        &mut self,
145        mat_a: core::Variable,
146        mat_b: core::Variable,
147        mat_c: core::Variable,
148        mat_d: core::Variable,
149    ) {
150        let mat_a = self.compile_variable(mat_a);
151        let mat_b = self.compile_variable(mat_b);
152        let mat_c = self.compile_variable(mat_c);
153        let mat_d = self.compile_variable(mat_d);
154
155        let mat_a = self.matrix_var(&mat_a).1;
156        let mat_b = self.matrix_var(&mat_b).1;
157        let mat_c = self.matrix_var(&mat_c).1;
158        let mat_d = self.matrix_var(&mat_d).1;
159
160        let mat_a_ty = self.item(&mat_a).id(self);
161        let mat_b_ty = self.item(&mat_b).id(self);
162        let mat_c_ty = self.item(&mat_c).id(self);
163
164        let mat_a_id = self.load(mat_a_ty, None, mat_a.id, None, vec![]).unwrap();
165        let mat_b_id = self.load(mat_b_ty, None, mat_b.id, None, vec![]).unwrap();
166        let mat_c_id = self.load(mat_c_ty, None, mat_c.id, None, vec![]).unwrap();
167
168        let ty = self.item(&mat_d).id(self);
169
170        let mut operands = CooperativeMatrixOperands::NONE_KHR;
171        if matches!(mat_a.elem, Elem::Int(_, true)) {
172            operands |= CooperativeMatrixOperands::MATRIX_A_SIGNED_COMPONENTS_KHR;
173        }
174        if matches!(mat_b.elem, Elem::Int(_, true)) {
175            operands |= CooperativeMatrixOperands::MATRIX_B_SIGNED_COMPONENTS_KHR;
176        }
177        if matches!(mat_c.elem, Elem::Int(_, true)) {
178            operands |= CooperativeMatrixOperands::MATRIX_C_SIGNED_COMPONENTS_KHR;
179        }
180        if matches!(mat_d.elem, Elem::Int(_, true)) {
181            operands |= CooperativeMatrixOperands::MATRIX_RESULT_SIGNED_COMPONENTS_KHR;
182        }
183
184        let mat_d_id = self
185            .cooperative_matrix_mul_add_khr(ty, None, mat_a_id, mat_b_id, mat_c_id, Some(operands))
186            .unwrap();
187
188        self.store(mat_d.id, mat_d_id, None, vec![]).unwrap();
189    }
190
191    fn compile_cast(&mut self, input: core::Variable, output: core::Variable) {
192        let input = self.compile_variable(input);
193        let output = self.compile_variable(output);
194
195        let input = self.matrix_var(&input).1;
196        let output = self.matrix_var(&output).1;
197
198        let input_ty = self.item(&input).id(self);
199        let output_ty = self.item(&output).id(self);
200
201        let fragment_id = self.load(input_ty, None, input.id, None, vec![]).unwrap();
202
203        let frag_new = self.f_convert(output_ty, None, fragment_id).unwrap();
204
205        self.store(output.id, frag_new, None, vec![]).unwrap();
206    }
207
208    fn matrix_var(&mut self, var: &Variable) -> (Id, Matrix) {
209        let id = match var {
210            Variable::CoopMatrix(id, _) => *id,
211            _ => unreachable!(),
212        };
213        let mat = self.state.matrices[&id];
214        (id, mat)
215    }
216
217    fn rows(&mut self, mat: &Matrix) -> u32 {
218        let rows = match mat.ident {
219            CooperativeMatrixUse::MatrixAKHR => mat.m,
220            CooperativeMatrixUse::MatrixBKHR => mat.k,
221            CooperativeMatrixUse::MatrixAccumulatorKHR => mat.m,
222        };
223        self.const_u32(rows)
224    }
225
226    fn columns(&mut self, mat: &Matrix) -> u32 {
227        let columns = match mat.ident {
228            CooperativeMatrixUse::MatrixAKHR => mat.k,
229            CooperativeMatrixUse::MatrixBKHR => mat.n,
230            CooperativeMatrixUse::MatrixAccumulatorKHR => mat.n,
231        };
232        self.const_u32(columns)
233    }
234
235    pub fn item(&mut self, mat: &Matrix) -> Item {
236        Item::CoopMatrix {
237            ty: mat.elem,
238            rows: self.rows(mat),
239            columns: self.columns(mat),
240            ident: mat.ident,
241        }
242    }
243
244    pub fn init_coop_matrix(&mut self, mat: core::Matrix, var: core::Variable) -> Matrix {
245        if mat.storage.elem_type() == ElemType::Float(core::FloatKind::BF16) {
246            self.capabilities
247                .insert(Capability::BFloat16CooperativeMatrixKHR);
248        }
249        if matches!(
250            mat.storage.elem_type(),
251            ElemType::Float(core::FloatKind::E5M2 | core::FloatKind::E4M3)
252        ) {
253            self.capabilities
254                .insert(Capability::Float8CooperativeMatrixEXT);
255        }
256
257        let elem = self.compile_type(core::Type::new(mat.storage)).elem();
258        let ident = match mat.ident {
259            core::MatrixIdent::A => CooperativeMatrixUse::MatrixAKHR,
260            core::MatrixIdent::B => CooperativeMatrixUse::MatrixBKHR,
261            core::MatrixIdent::Accumulator => CooperativeMatrixUse::MatrixAccumulatorKHR,
262        };
263        let layout = compile_layout(mat.layout);
264
265        let mut mat = Matrix {
266            id: 0,
267            ident,
268            m: mat.m,
269            n: mat.n,
270            k: mat.k,
271            elem,
272            layout,
273        };
274
275        let item = Item::Pointer(StorageClass::Function, Box::new(self.item(&mat)));
276        let ty = item.id(self);
277        mat.id = self.declare_function_variable(ty);
278        self.debug_var_name(mat.id, var);
279
280        mat
281    }
282}
283
284fn compile_layout(layout: MatrixLayout) -> Option<CooperativeMatrixLayout> {
285    match layout {
286        core::MatrixLayout::ColMajor => Some(CooperativeMatrixLayout::ColumnMajorKHR),
287        core::MatrixLayout::RowMajor => Some(CooperativeMatrixLayout::RowMajorKHR),
288        core::MatrixLayout::Undefined => None,
289    }
290}