cubecl_spirv/
cmma.rs

1use crate::{
2    SpirvCompiler, SpirvTarget,
3    item::{Elem, Item},
4    lookups::Matrix,
5    variable::Variable,
6};
7use cubecl_core::ir::{self as core, CoopMma, ElemType, Id, MatrixLayout};
8use rspirv::spirv::{
9    Capability, CooperativeMatrixLayout, CooperativeMatrixOperands, CooperativeMatrixUse,
10    StorageClass,
11};
12
13impl<T: SpirvTarget> SpirvCompiler<T> {
14    pub fn compile_cmma(&mut self, cmma: CoopMma, out: Option<core::Variable>) {
15        self.capabilities.insert(Capability::CooperativeMatrixKHR);
16        let out = out.unwrap();
17        match cmma {
18            CoopMma::Fill { value } => self.compile_fill(out, value),
19            CoopMma::Load {
20                value,
21                stride,
22                layout,
23                offset,
24            } => self.compile_load(out, value, stride, offset, layout),
25            CoopMma::Execute {
26                mat_a,
27                mat_b,
28                mat_c,
29            } => self.compile_execute(mat_a, mat_b, mat_c, out),
30            CoopMma::Store {
31                mat,
32                stride,
33                layout,
34                offset,
35            } => self.compile_store(mat, out, stride, offset, layout),
36            CoopMma::Cast { input } => self.compile_cast(input, out),
37            CoopMma::RowIndex { .. }
38            | CoopMma::ColIndex { .. }
39            | CoopMma::ExecuteManual { .. }
40            | CoopMma::ExecuteScaled { .. } => {
41                panic!("Manual register management not currently supported in SPIR-V")
42            }
43        }
44    }
45
46    fn compile_load(
47        &mut self,
48        mat: core::Variable,
49        value: core::Variable,
50        stride: core::Variable,
51        offset: core::Variable,
52        layout: Option<MatrixLayout>,
53    ) {
54        let mat = self.compile_variable(mat);
55        let mat = self.matrix_var(&mat).1;
56
57        let value = self.compile_variable(value);
58        let stride = self.compile_variable(stride);
59        let stride_item = stride.item();
60        let mut stride = self.read(&stride);
61
62        if let Item::Vector(_, line_size) = value.item() {
63            let shift = stride_item.const_u32(self, line_size.trailing_zeros());
64            let stride_ty = stride_item.id(self);
65            stride = self
66                .shift_right_logical(stride_ty, None, stride, shift)
67                .unwrap();
68        }
69
70        let layout = layout
71            .and_then(compile_layout)
72            .or(mat.layout)
73            .unwrap_or(CooperativeMatrixLayout::RowMajorKHR);
74        let memory_layout = self.const_u32(layout as u32);
75
76        let offset = self.compile_variable(offset);
77        let ptr = self.index_ptr(&value, &offset);
78        let out_ty = self.item(&mat);
79        let ty = out_ty.id(self);
80
81        let mat_id = self
82            .cooperative_matrix_load_khr(ty, None, ptr, memory_layout, Some(stride), None, vec![])
83            .unwrap();
84
85        self.store(mat.id, mat_id, None, vec![]).unwrap();
86    }
87
88    fn compile_fill(&mut self, mat: core::Variable, value: core::Variable) {
89        let mat = self.compile_variable(mat);
90        let value = self.compile_variable(value);
91        let mat = self.matrix_var(&mat).1;
92        let item = self.item(&mat);
93        let ty = item.id(self);
94        let mat_id = match value {
95            Variable::ConstantScalar(id, _, _) => self.constant_composite(ty, vec![id]),
96            var => {
97                let var = self.read(&var);
98                self.composite_construct(ty, None, vec![var]).unwrap()
99            }
100        };
101
102        self.store(mat.id, mat_id, None, vec![]).unwrap();
103    }
104
105    fn compile_store(
106        &mut self,
107        mat: core::Variable,
108        out: core::Variable,
109        stride: core::Variable,
110        offset: core::Variable,
111        layout: MatrixLayout,
112    ) {
113        let mat = self.compile_variable(mat);
114        let mat = self.matrix_var(&mat).1;
115        let item = self.item(&mat);
116        let ty = item.id(self);
117        let mat_obj = self.load(ty, None, mat.id, None, vec![]).unwrap();
118        //assert_ne!(mat_obj, 0, "Can't store uninitialized matrix");
119
120        let out = self.compile_variable(out);
121        let stride = self.compile_variable(stride);
122        let stride_item = stride.item();
123        let mut stride = self.read(&stride);
124        let layout = compile_layout(layout).unwrap_or(CooperativeMatrixLayout::RowMajorKHR);
125        let memory_layout = self.const_u32(layout as u32);
126        let offset = self.compile_variable(offset);
127        let ptr = self.index_ptr(&out, &offset);
128
129        if let Item::Vector(_, line_size) = out.item() {
130            let shift = stride_item.const_u32(self, line_size.trailing_zeros());
131            let stride_ty = stride_item.id(self);
132            stride = self
133                .shift_right_logical(stride_ty, None, stride, shift)
134                .unwrap();
135        }
136
137        self.cooperative_matrix_store_khr(ptr, mat_obj, memory_layout, Some(stride), None, vec![])
138            .unwrap();
139    }
140
141    fn compile_execute(
142        &mut self,
143        mat_a: core::Variable,
144        mat_b: core::Variable,
145        mat_c: core::Variable,
146        mat_d: core::Variable,
147    ) {
148        let mat_a = self.compile_variable(mat_a);
149        let mat_b = self.compile_variable(mat_b);
150        let mat_c = self.compile_variable(mat_c);
151        let mat_d = self.compile_variable(mat_d);
152
153        let mat_a = self.matrix_var(&mat_a).1;
154        let mat_b = self.matrix_var(&mat_b).1;
155        let mat_c = self.matrix_var(&mat_c).1;
156        let mat_d = self.matrix_var(&mat_d).1;
157
158        let mat_a_ty = self.item(&mat_a).id(self);
159        let mat_b_ty = self.item(&mat_b).id(self);
160        let mat_c_ty = self.item(&mat_c).id(self);
161
162        let mat_a_id = self.load(mat_a_ty, None, mat_a.id, None, vec![]).unwrap();
163        let mat_b_id = self.load(mat_b_ty, None, mat_b.id, None, vec![]).unwrap();
164        let mat_c_id = self.load(mat_c_ty, None, mat_c.id, None, vec![]).unwrap();
165
166        let ty = self.item(&mat_d).id(self);
167
168        let mut operands = CooperativeMatrixOperands::NONE_KHR;
169        if matches!(mat_a.elem, Elem::Int(_, true)) {
170            operands |= CooperativeMatrixOperands::MATRIX_A_SIGNED_COMPONENTS_KHR;
171        }
172        if matches!(mat_b.elem, Elem::Int(_, true)) {
173            operands |= CooperativeMatrixOperands::MATRIX_B_SIGNED_COMPONENTS_KHR;
174        }
175        if matches!(mat_c.elem, Elem::Int(_, true)) {
176            operands |= CooperativeMatrixOperands::MATRIX_C_SIGNED_COMPONENTS_KHR;
177        }
178        if matches!(mat_d.elem, Elem::Int(_, true)) {
179            operands |= CooperativeMatrixOperands::MATRIX_RESULT_SIGNED_COMPONENTS_KHR;
180        }
181
182        let mat_d_id = self
183            .cooperative_matrix_mul_add_khr(ty, None, mat_a_id, mat_b_id, mat_c_id, Some(operands))
184            .unwrap();
185
186        self.store(mat_d.id, mat_d_id, None, vec![]).unwrap();
187    }
188
189    fn compile_cast(&mut self, input: core::Variable, output: core::Variable) {
190        let input = self.compile_variable(input);
191        let output = self.compile_variable(output);
192
193        let input = self.matrix_var(&input).1;
194        let output = self.matrix_var(&output).1;
195
196        let input_ty = self.item(&input).id(self);
197        let output_ty = self.item(&output).id(self);
198
199        let fragment_id = self.load(input_ty, None, input.id, None, vec![]).unwrap();
200
201        let frag_new = self.f_convert(output_ty, None, fragment_id).unwrap();
202
203        self.store(output.id, frag_new, None, vec![]).unwrap();
204    }
205
206    fn matrix_var(&mut self, var: &Variable) -> (Id, Matrix) {
207        let id = match var {
208            Variable::CoopMatrix(id, _) => *id,
209            _ => unreachable!(),
210        };
211        let mat = self.state.matrices[&id];
212        (id, mat)
213    }
214
215    fn rows(&mut self, mat: &Matrix) -> u32 {
216        let rows = match mat.ident {
217            CooperativeMatrixUse::MatrixAKHR => mat.m,
218            CooperativeMatrixUse::MatrixBKHR => mat.k,
219            CooperativeMatrixUse::MatrixAccumulatorKHR => mat.m,
220        };
221        self.const_u32(rows)
222    }
223
224    fn columns(&mut self, mat: &Matrix) -> u32 {
225        let columns = match mat.ident {
226            CooperativeMatrixUse::MatrixAKHR => mat.k,
227            CooperativeMatrixUse::MatrixBKHR => mat.n,
228            CooperativeMatrixUse::MatrixAccumulatorKHR => mat.n,
229        };
230        self.const_u32(columns)
231    }
232
233    pub fn item(&mut self, mat: &Matrix) -> Item {
234        Item::CoopMatrix {
235            ty: mat.elem,
236            rows: self.rows(mat),
237            columns: self.columns(mat),
238            ident: mat.ident,
239        }
240    }
241
242    pub fn init_coop_matrix(&mut self, mat: core::Matrix, var: core::Variable) -> Matrix {
243        if mat.storage.elem_type() == ElemType::Float(core::FloatKind::BF16) {
244            self.capabilities
245                .insert(Capability::BFloat16CooperativeMatrixKHR);
246        }
247        if matches!(
248            mat.storage.elem_type(),
249            ElemType::Float(core::FloatKind::E5M2 | core::FloatKind::E4M3)
250        ) {
251            self.capabilities
252                .insert(Capability::Float8CooperativeMatrixEXT);
253        }
254
255        let elem = self.compile_type(core::Type::new(mat.storage)).elem();
256        let ident = match mat.ident {
257            core::MatrixIdent::A => CooperativeMatrixUse::MatrixAKHR,
258            core::MatrixIdent::B => CooperativeMatrixUse::MatrixBKHR,
259            core::MatrixIdent::Accumulator => CooperativeMatrixUse::MatrixAccumulatorKHR,
260        };
261        let layout = compile_layout(mat.layout);
262
263        let mut mat = Matrix {
264            id: 0,
265            ident,
266            m: mat.m,
267            n: mat.n,
268            k: mat.k,
269            elem,
270            layout,
271        };
272
273        let item = Item::Pointer(StorageClass::Function, Box::new(self.item(&mat)));
274        let ty = item.id(self);
275        mat.id = self.declare_function_variable(ty);
276        self.debug_var_name(mat.id, var);
277
278        mat
279    }
280}
281
282fn compile_layout(layout: MatrixLayout) -> Option<CooperativeMatrixLayout> {
283    match layout {
284        core::MatrixLayout::ColMajor => Some(CooperativeMatrixLayout::ColumnMajorKHR),
285        core::MatrixLayout::RowMajor => Some(CooperativeMatrixLayout::RowMajorKHR),
286        core::MatrixLayout::Undefined => None,
287    }
288}