cubecl_core/frontend/
cmma.rs

1//! This module exposes cooperative matrix-multiply and accumulate operations.
2//!
3//! Most of the functions are actually unsafe, since they mutate their input, even if they are
4//! passed as reference.
5//!
6//! # Example
7//!
8//! This is a basic 16x16x16 matrix multiplication example.
9//!
10//! ```rust, ignore
11//! #[cube(launch)]
12//! pub fn example(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>) {
13//!     let a = cmma::Matrix::<F16>::new(
14//!         cmma::MatrixIdent::A,
15//!         16,
16//!         16,
17//!         16,
18//!         cmma::MatrixLayout::RowMajor,
19//!     );
20//!     let b = cmma::Matrix::<F16>::new(
21//!         cmma::MatrixIdent::B,
22//!         16,
23//!         16,
24//!         16,
25//!         cmma::MatrixLayout::ColMajor,
26//!     );
27//!     let c = cmma::Matrix::<F32>::new(
28//!         cmma::MatrixIdent::Accumulator,
29//!         16,
30//!         16,
31//!         16,
32//!         cmma::MatrixLayout::Undefined,
33//!     );
34//!     cmma::fill::<F32>(&c, F32::new(0.0));
35//!     cmma::load::<F16>(&a, lhs.as_slice(), u32::new(16));
36//!     cmma::load::<F16>(&b, rhs.as_slice(), u32::new(16));
37//!
38//!     cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
39//!
40//!     cmma::store::<F32>(
41//!         out.as_slice_mut(),
42//!         &c,
43//!         u32::new(16),
44//!         cmma::MatrixLayout::RowMajor,
45//!     );
46//! }
47//! ```
48
49use std::marker::PhantomData;
50
51use crate::{
52    ir::{self, Instruction},
53    unexpanded,
54};
55
56use super::{CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, Init, Slice, SliceMut};
57
58use cubecl_ir::{ExpandElement, Scope};
59pub use ir::{MatrixIdent, MatrixLayout};
60
61/// A matrix represent a 2D grid of numbers.
62///
63/// They can either be in a [row major](MatrixLayout::RowMajor) or a
64/// [column major](MatrixLayout::ColMajor) format.
65#[derive(Copy, Clone)]
66pub struct Matrix<C: CubeType> {
67    _c: PhantomData<C>,
68}
69
70/// Expand type of [Matrix].
71pub struct MatrixExpand<C: CubeType> {
72    elem: ExpandElement,
73    ident: MatrixIdent,
74    _c: PhantomData<C>,
75}
76
77impl<C: CubeType> Clone for MatrixExpand<C> {
78    fn clone(&self) -> Self {
79        Self {
80            elem: self.elem.clone(),
81            ident: self.ident,
82            _c: self._c,
83        }
84    }
85}
86
87impl<C: CubeType> CubeType for Matrix<C> {
88    type ExpandType = MatrixExpand<C>;
89}
90
91impl<C: CubeType> Init for MatrixExpand<C> {
92    fn init(self, _scope: &mut Scope) -> Self {
93        self
94    }
95}
96
97impl<C: CubeType> CubeDebug for MatrixExpand<C> {
98    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
99        scope.update_variable_name(*self.elem, name);
100    }
101}
102
103impl<C: CubePrimitive> Matrix<C> {
104    /// Create a new uninitialized matrix that is going to be used in the
105    /// [matrix-multiply and accumulate](execute()) function.
106    ///
107    /// # Safety
108    /// Must be initialized with `load` or `fill` before use. Using it without initialization is
109    /// undefined behaviour on CUDA, and completely invalid on Vulkan.
110    ///
111    /// You have to declare the shape used for the execution.
112    /// The shape of the current matrix is determined using the [MatrixIdent].
113    ///
114    /// * [MatrixIdent::A] Shape => (M, K)
115    /// * [MatrixIdent::B] Shape => (K, N)
116    /// * [MatrixIdent::Accumulator] Shape => (M, N)
117    ///
118    /// Not all shapes are supported, and the permitted shapes depend on the element type.
119    ///
120    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
121    #[allow(unused_variables)]
122    pub unsafe fn uninitialized(
123        ident: MatrixIdent,
124        m: u32,
125        n: u32,
126        k: u32,
127        layout: MatrixLayout,
128    ) -> Self {
129        Matrix { _c: PhantomData }
130    }
131
132    /// Create a new matrix that is going to be used in the
133    /// [matrix-multiply and accumulate](execute()) function and is filled with `value`.
134    ///
135    /// You have to declare the shape used for the execution.
136    /// The shape of the current matrix is determined using the [MatrixIdent].
137    ///
138    /// * [MatrixIdent::A] Shape => (M, K)
139    /// * [MatrixIdent::B] Shape => (K, N)
140    /// * [MatrixIdent::Accumulator] Shape => (M, N)
141    ///
142    /// Not all shapes are supported, and the permitted shapes depend on the element type.
143    ///
144    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
145    #[allow(unused_variables)]
146    pub fn from_value(
147        ident: MatrixIdent,
148        m: u32,
149        n: u32,
150        k: u32,
151        layout: MatrixLayout,
152        value: C,
153    ) -> Self {
154        Matrix { _c: PhantomData }
155    }
156
157    /// Create a new matrix that is going to be used in the
158    /// [matrix-multiply and accumulate](execute()) function and is loaded from `value` with `stride`.
159    ///
160    /// You have to declare the shape used for the execution.
161    /// The shape of the current matrix is determined using the [MatrixIdent].
162    ///
163    /// * [MatrixIdent::A] Shape => (M, K)
164    /// * [MatrixIdent::B] Shape => (K, N)
165    /// * [MatrixIdent::Accumulator] Shape => (M, N)
166    ///
167    /// Not all shapes are supported, and the permitted shapes depend on the element type.
168    ///
169    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
170    #[allow(unused_variables)]
171    pub fn from_slice(
172        ident: MatrixIdent,
173        m: u32,
174        n: u32,
175        k: u32,
176        layout: MatrixLayout,
177        value: &Slice<C>,
178        stride: u32,
179    ) -> Self {
180        Matrix { _c: PhantomData }
181    }
182
183    pub fn __expand_uninitialized(
184        scope: &mut Scope,
185        ident: MatrixIdent,
186        m: ExpandElementTyped<u32>,
187        n: ExpandElementTyped<u32>,
188        k: ExpandElementTyped<u32>,
189        layout: MatrixLayout,
190    ) -> MatrixExpand<C> {
191        let elem = C::as_elem(scope);
192        let elem = scope.create_matrix(ir::Matrix {
193            ident,
194            m: m.constant().unwrap().as_u32() as u8,
195            n: n.constant().unwrap().as_u32() as u8,
196            k: k.constant().unwrap().as_u32() as u8,
197            elem,
198            layout,
199        });
200        MatrixExpand {
201            elem,
202            ident,
203            _c: PhantomData,
204        }
205    }
206
207    pub fn __expand_from_value(
208        scope: &mut Scope,
209        ident: MatrixIdent,
210        m: ExpandElementTyped<u32>,
211        n: ExpandElementTyped<u32>,
212        k: ExpandElementTyped<u32>,
213        layout: MatrixLayout,
214        value: ExpandElementTyped<C>,
215    ) -> MatrixExpand<C> {
216        let mat = Self::__expand_uninitialized(scope, ident, m, n, k, layout);
217        fill::expand(scope, mat.clone(), value);
218        mat
219    }
220
221    #[allow(clippy::too_many_arguments)]
222    pub fn __expand_from_slice(
223        scope: &mut Scope,
224        ident: MatrixIdent,
225        m: ExpandElementTyped<u32>,
226        n: ExpandElementTyped<u32>,
227        k: ExpandElementTyped<u32>,
228        layout: MatrixLayout,
229        value: ExpandElementTyped<Slice<C>>,
230        stride: ExpandElementTyped<u32>,
231    ) -> MatrixExpand<C> {
232        let mat = Self::__expand_uninitialized(scope, ident, m, n, k, layout);
233        load::expand(scope, mat.clone(), value, stride);
234        mat
235    }
236}
237
238/// Fill the matrix with the provided value.
239#[allow(unused_variables)]
240pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
241    unexpanded!()
242}
243
244/// Module containing the expand function for [fill()].
245pub mod fill {
246    use super::*;
247
248    /// Expand method of [fill()].
249    pub fn expand<C: CubeType>(
250        scope: &mut Scope,
251        mat: MatrixExpand<C>,
252        value: ExpandElementTyped<C>,
253    ) {
254        let value: ExpandElement = value.into();
255        scope.register(Instruction::new(
256            ir::CoopMma::Fill { value: *value },
257            *mat.elem,
258        ));
259    }
260}
261
262/// Load the matrix with the provided array using the stride.
263#[allow(unused_variables)]
264pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
265    unexpanded!()
266}
267
268/// Module containing the expand function for [load()].
269pub mod load {
270    use super::*;
271
272    /// Expand method of [load()].
273    #[allow(unused_variables)]
274    pub fn expand<C: CubePrimitive, V: CubePrimitive>(
275        scope: &mut Scope,
276        mat: MatrixExpand<C>,
277        value: ExpandElementTyped<Slice<V>>,
278        stride: ExpandElementTyped<u32>,
279    ) {
280        let stride: ExpandElement = stride.into();
281        assert_ne!(
282            mat.ident,
283            MatrixIdent::Accumulator,
284            "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
285        );
286
287        scope.register(Instruction::new(
288            ir::CoopMma::Load {
289                value: *value.expand,
290                stride: *stride,
291                layout: None,
292            },
293            *mat.elem,
294        ));
295    }
296}
297
298/// Load the matrix with the provided array using the stride with an explicit layout.
299/// Explicit layouts are required when loading accumulators.
300#[allow(unused_variables)]
301pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
302    mat: &Matrix<C>,
303    value: &Slice<V>,
304    stride: u32,
305    layout: MatrixLayout,
306) {
307    unexpanded!()
308}
309
310/// Module containing the expand function for [load_with_layout()].
311pub mod load_with_layout {
312    use super::*;
313
314    /// Expand method of [load_with_layout()].
315    #[allow(unused_variables)]
316    pub fn expand<C: CubeType, V: CubePrimitive>(
317        scope: &mut Scope,
318        mat: MatrixExpand<C>,
319        value: ExpandElementTyped<Slice<V>>,
320        stride: ExpandElementTyped<u32>,
321        layout: MatrixLayout,
322    ) {
323        let stride: ExpandElement = stride.into();
324
325        scope.register(Instruction::new(
326            ir::CoopMma::Load {
327                value: *value.expand,
328                stride: *stride,
329                layout: Some(layout),
330            },
331            *mat.elem,
332        ));
333    }
334}
335
336/// Store the matrix in the given array following the given stride and layout.
337#[allow(unused_variables)]
338pub fn store<C: CubePrimitive, O: CubePrimitive>(
339    output: &mut SliceMut<O>,
340    mat: &Matrix<C>,
341    stride: u32,
342    layout: MatrixLayout,
343) {
344    unexpanded!()
345}
346
347/// Module containing the expand function for [store()].
348pub mod store {
349    use super::*;
350
351    /// Expand method of [store()].
352    #[allow(unused_variables)]
353    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
354        scope: &mut Scope,
355        output: ExpandElementTyped<SliceMut<O>>,
356        mat: MatrixExpand<C>,
357        stride: ExpandElementTyped<u32>,
358        layout: MatrixLayout,
359    ) {
360        let stride: ExpandElement = stride.into();
361
362        scope.register(Instruction::new(
363            ir::CoopMma::Store {
364                mat: *mat.elem,
365                stride: *stride,
366                layout,
367            },
368            *output.expand,
369        ));
370    }
371}
372
373/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
374#[allow(unused_variables)]
375pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
376    mat_a: &Matrix<A>,
377    mat_b: &Matrix<B>,
378    mat_c: &Matrix<C>,
379    mat_d: &Matrix<D>,
380) {
381    unexpanded!()
382}
383
384/// Module containing the expand function for [execute()].
385pub mod execute {
386    use super::*;
387
388    /// Expand method of [execute()].
389    pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
390        scope: &mut Scope,
391        mat_a: MatrixExpand<A>,
392        mat_b: MatrixExpand<B>,
393        mat_c: MatrixExpand<C>,
394        mat_d: MatrixExpand<D>,
395    ) {
396        scope.register(Instruction::new(
397            ir::CoopMma::Execute {
398                mat_a: *mat_a.elem,
399                mat_b: *mat_b.elem,
400                mat_c: *mat_c.elem,
401            },
402            *mat_d.elem,
403        ));
404    }
405}
406
407/// Store the matrix in the given array following the given stride and layout.
408#[allow(unused_variables)]
409pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
410    unexpanded!()
411}
412
413/// Module containing the expand function for [store()].
414pub mod cast {
415    use super::*;
416
417    /// Expand method of [store()].
418    #[allow(unused_variables)]
419    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
420        scope: &mut Scope,
421        input: MatrixExpand<C>,
422    ) -> MatrixExpand<O> {
423        let ident = input.ident;
424
425        if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
426            return MatrixExpand {
427                elem: input.elem,
428                ident,
429                _c: PhantomData,
430            };
431        }
432        let input = *input.elem;
433        let input_mat = match input.kind {
434            ir::VariableKind::Matrix { mat, .. } => mat,
435            _ => unreachable!(),
436        };
437
438        let elem = O::as_elem(scope);
439        let elem = scope.create_matrix(ir::Matrix {
440            ident,
441            m: input_mat.m,
442            n: input_mat.n,
443            k: input_mat.k,
444            elem,
445            layout: MatrixLayout::Undefined,
446        });
447
448        let output = MatrixExpand {
449            ident,
450            elem,
451            _c: PhantomData,
452        };
453        scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
454
455        output
456    }
457}
458
459impl CubeType for MatrixLayout {
460    type ExpandType = Self;
461}
462
463impl Init for MatrixLayout {
464    fn init(self, _scope: &mut crate::ir::Scope) -> Self {
465        self
466    }
467}
468
469impl CubeDebug for MatrixLayout {}