cubecl_core/frontend/
cmma.rs

1//! This module exposes cooperative matrix-multiply and accumulate operations.
2//!
3//! Most of the functions are actually unsafe, since they mutate their input, even if they are
4//! passed as reference.
5//!
6//! # Example
7//!
8//! This is a basic 16x16x16 matrix multiplication example.
9//!
10//! ```rust, ignore
11//! #[cube(launch)]
12//! pub fn example(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>) {
13//!     let a = cmma::Matrix::<F16>::new(
14//!         cmma::MatrixIdent::A,
15//!         16,
16//!         16,
17//!         16,
18//!         cmma::MatrixLayout::RowMajor,
19//!     );
20//!     let b = cmma::Matrix::<F16>::new(
21//!         cmma::MatrixIdent::B,
22//!         16,
23//!         16,
24//!         16,
25//!         cmma::MatrixLayout::ColMajor,
26//!     );
27//!     let c = cmma::Matrix::<F32>::new(
28//!         cmma::MatrixIdent::Accumulator,
29//!         16,
30//!         16,
31//!         16,
32//!         cmma::MatrixLayout::Undefined,
33//!     );
34//!     cmma::fill::<F32>(&c, F32::new(0.0));
35//!     cmma::load::<F16>(&a, lhs.as_slice(), u32::new(16));
36//!     cmma::load::<F16>(&b, rhs.as_slice(), u32::new(16));
37//!
38//!     cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
39//!
40//!     cmma::store::<F32>(
41//!         out.as_slice_mut(),
42//!         &c,
43//!         u32::new(16),
44//!         cmma::MatrixLayout::RowMajor,
45//!     );
46//! }
47//! ```
48
49use std::marker::PhantomData;
50
51use crate::{
52    ir::{self, Instruction},
53    unexpanded,
54};
55
56use super::{
57    CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
58    SliceMut,
59};
60
61use cubecl_ir::{ExpandElement, Scope};
62pub use ir::{MatrixIdent, MatrixLayout};
63
64/// A matrix represent a 2D grid of numbers.
65///
66/// They can either be in a [row major](MatrixLayout::RowMajor) or a
67/// [column major](MatrixLayout::ColMajor) format.
68#[derive(Copy, Clone)]
69pub struct Matrix<C: CubeType> {
70    _c: PhantomData<C>,
71}
72
73/// Expand type of [Matrix].
74pub struct MatrixExpand<C: CubeType> {
75    elem: ExpandElement,
76    ident: MatrixIdent,
77    _c: PhantomData<C>,
78}
79
80impl<C: CubeType> Clone for MatrixExpand<C> {
81    fn clone(&self) -> Self {
82        Self {
83            elem: self.elem.clone(),
84            ident: self.ident,
85            _c: self._c,
86        }
87    }
88}
89
90impl<C: CubeType> CubeType for Matrix<C> {
91    type ExpandType = MatrixExpand<C>;
92}
93
94impl<C: CubeType> IntoMut for MatrixExpand<C> {
95    fn into_mut(self, _scope: &mut Scope) -> Self {
96        self
97    }
98}
99
100impl<C: CubeType> CubeDebug for MatrixExpand<C> {
101    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
102        scope.update_variable_name(*self.elem, name);
103    }
104}
105
106impl<C: CubePrimitive> Matrix<C> {
107    /// Create a new uninitialized matrix that is going to be used in the
108    /// [matrix-multiply and accumulate](execute()) function.
109    ///
110    /// # Safety
111    /// Must be initialized with `load` or `fill` before use. Using it without initialization is
112    /// undefined behaviour on CUDA, and completely invalid on Vulkan.
113    ///
114    /// You have to declare the shape used for the execution.
115    /// The shape of the current matrix is determined using the [MatrixIdent].
116    ///
117    /// * [MatrixIdent::A] Shape => (M, K)
118    /// * [MatrixIdent::B] Shape => (K, N)
119    /// * [MatrixIdent::Accumulator] Shape => (M, N)
120    ///
121    /// Not all shapes are supported, and the permitted shapes depend on the element type.
122    ///
123    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
124    #[allow(unused_variables)]
125    pub unsafe fn uninitialized(
126        ident: MatrixIdent,
127        m: u32,
128        n: u32,
129        k: u32,
130        layout: MatrixLayout,
131    ) -> Self {
132        Matrix { _c: PhantomData }
133    }
134
135    /// Create a new matrix that is going to be used in the
136    /// [matrix-multiply and accumulate](execute()) function and is filled with `value`.
137    ///
138    /// You have to declare the shape used for the execution.
139    /// The shape of the current matrix is determined using the [MatrixIdent].
140    ///
141    /// * [MatrixIdent::A] Shape => (M, K)
142    /// * [MatrixIdent::B] Shape => (K, N)
143    /// * [MatrixIdent::Accumulator] Shape => (M, N)
144    ///
145    /// Not all shapes are supported, and the permitted shapes depend on the element type.
146    ///
147    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
148    #[allow(unused_variables)]
149    pub fn from_value(
150        ident: MatrixIdent,
151        m: u32,
152        n: u32,
153        k: u32,
154        layout: MatrixLayout,
155        value: C,
156    ) -> Self {
157        Matrix { _c: PhantomData }
158    }
159
160    /// Create a new matrix that is going to be used in the
161    /// [matrix-multiply and accumulate](execute()) function and is loaded from `value` with `stride`.
162    ///
163    /// You have to declare the shape used for the execution.
164    /// The shape of the current matrix is determined using the [MatrixIdent].
165    ///
166    /// * [MatrixIdent::A] Shape => (M, K)
167    /// * [MatrixIdent::B] Shape => (K, N)
168    /// * [MatrixIdent::Accumulator] Shape => (M, N)
169    ///
170    /// Not all shapes are supported, and the permitted shapes depend on the element type.
171    ///
172    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
173    #[allow(unused_variables)]
174    pub fn from_slice(
175        ident: MatrixIdent,
176        m: u32,
177        n: u32,
178        k: u32,
179        layout: MatrixLayout,
180        value: &Slice<C>,
181        stride: u32,
182    ) -> Self {
183        Matrix { _c: PhantomData }
184    }
185
186    pub fn __expand_uninitialized(
187        scope: &mut Scope,
188        ident: MatrixIdent,
189        m: ExpandElementTyped<u32>,
190        n: ExpandElementTyped<u32>,
191        k: ExpandElementTyped<u32>,
192        layout: MatrixLayout,
193    ) -> MatrixExpand<C> {
194        let elem = C::as_elem(scope);
195        let elem = scope.create_matrix(ir::Matrix {
196            ident,
197            m: m.constant().unwrap().as_u32() as u8,
198            n: n.constant().unwrap().as_u32() as u8,
199            k: k.constant().unwrap().as_u32() as u8,
200            elem,
201            layout,
202        });
203        MatrixExpand {
204            elem,
205            ident,
206            _c: PhantomData,
207        }
208    }
209
210    pub fn __expand_from_value(
211        scope: &mut Scope,
212        ident: MatrixIdent,
213        m: ExpandElementTyped<u32>,
214        n: ExpandElementTyped<u32>,
215        k: ExpandElementTyped<u32>,
216        layout: MatrixLayout,
217        value: ExpandElementTyped<C>,
218    ) -> MatrixExpand<C> {
219        let mat = Self::__expand_uninitialized(scope, ident, m, n, k, layout);
220        fill::expand(scope, mat.clone(), value);
221        mat
222    }
223
224    #[allow(clippy::too_many_arguments)]
225    pub fn __expand_from_slice(
226        scope: &mut Scope,
227        ident: MatrixIdent,
228        m: ExpandElementTyped<u32>,
229        n: ExpandElementTyped<u32>,
230        k: ExpandElementTyped<u32>,
231        layout: MatrixLayout,
232        value: SliceExpand<C, ReadOnly>,
233        stride: ExpandElementTyped<u32>,
234    ) -> MatrixExpand<C> {
235        let mat = Self::__expand_uninitialized(scope, ident, m, n, k, layout);
236        load::expand(scope, mat.clone(), value, stride);
237        mat
238    }
239}
240
241/// Fill the matrix with the provided value.
242#[allow(unused_variables)]
243pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
244    unexpanded!()
245}
246
247/// Module containing the expand function for [fill()].
248pub mod fill {
249    use super::*;
250
251    /// Expand method of [fill()].
252    pub fn expand<C: CubeType>(
253        scope: &mut Scope,
254        mat: MatrixExpand<C>,
255        value: ExpandElementTyped<C>,
256    ) {
257        let value: ExpandElement = value.into();
258        scope.register(Instruction::new(
259            ir::CoopMma::Fill { value: *value },
260            *mat.elem,
261        ));
262    }
263}
264
265/// Load the matrix with the provided array using the stride.
266#[allow(unused_variables)]
267pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
268    unexpanded!()
269}
270
271/// Module containing the expand function for [load()].
272pub mod load {
273    use super::*;
274
275    /// Expand method of [load()].
276    #[allow(unused_variables)]
277    pub fn expand<C: CubePrimitive, V: CubePrimitive>(
278        scope: &mut Scope,
279        mat: MatrixExpand<C>,
280        value: SliceExpand<V, ReadOnly>,
281        stride: ExpandElementTyped<u32>,
282    ) {
283        let stride: ExpandElement = stride.into();
284        assert_ne!(
285            mat.ident,
286            MatrixIdent::Accumulator,
287            "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
288        );
289
290        let (value, offset) = value.__to_raw_parts();
291
292        scope.register(Instruction::new(
293            ir::CoopMma::Load {
294                value,
295                stride: *stride,
296                offset,
297                layout: None,
298            },
299            *mat.elem,
300        ));
301    }
302}
303
304/// Load the matrix with the provided array using the stride with an explicit layout.
305/// Explicit layouts are required when loading accumulators.
306#[allow(unused_variables)]
307pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
308    mat: &Matrix<C>,
309    value: &Slice<V>,
310    stride: u32,
311    layout: MatrixLayout,
312) {
313    unexpanded!()
314}
315
316/// Module containing the expand function for [load_with_layout()].
317pub mod load_with_layout {
318    use super::*;
319
320    /// Expand method of [load_with_layout()].
321    #[allow(unused_variables)]
322    pub fn expand<C: CubeType, V: CubePrimitive>(
323        scope: &mut Scope,
324        mat: MatrixExpand<C>,
325        value: SliceExpand<V, ReadOnly>,
326        stride: ExpandElementTyped<u32>,
327        layout: MatrixLayout,
328    ) {
329        let stride: ExpandElement = stride.into();
330        let (value, offset) = value.__to_raw_parts();
331
332        scope.register(Instruction::new(
333            ir::CoopMma::Load {
334                value,
335                stride: *stride,
336                offset,
337                layout: Some(layout),
338            },
339            *mat.elem,
340        ));
341    }
342}
343
344/// Store the matrix in the given array following the given stride and layout.
345#[allow(unused_variables)]
346pub fn store<C: CubePrimitive, O: CubePrimitive>(
347    output: &mut SliceMut<O>,
348    mat: &Matrix<C>,
349    stride: u32,
350    layout: MatrixLayout,
351) {
352    unexpanded!()
353}
354
355/// Module containing the expand function for [store()].
356pub mod store {
357    use crate::prelude::ReadWrite;
358
359    use super::*;
360
361    /// Expand method of [store()].
362    #[allow(unused_variables)]
363    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
364        scope: &mut Scope,
365        output: SliceExpand<O, ReadWrite>,
366        mat: MatrixExpand<C>,
367        stride: ExpandElementTyped<u32>,
368        layout: MatrixLayout,
369    ) {
370        let stride: ExpandElement = stride.into();
371
372        let (output, offset) = output.__to_raw_parts();
373
374        scope.register(Instruction::new(
375            ir::CoopMma::Store {
376                mat: *mat.elem,
377                offset,
378                stride: *stride,
379                layout,
380            },
381            output,
382        ));
383    }
384}
385
386/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
387#[allow(unused_variables)]
388pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
389    mat_a: &Matrix<A>,
390    mat_b: &Matrix<B>,
391    mat_c: &Matrix<C>,
392    mat_d: &Matrix<D>,
393) {
394    unexpanded!()
395}
396
397/// Module containing the expand function for [execute()].
398pub mod execute {
399    use super::*;
400
401    /// Expand method of [execute()].
402    pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
403        scope: &mut Scope,
404        mat_a: MatrixExpand<A>,
405        mat_b: MatrixExpand<B>,
406        mat_c: MatrixExpand<C>,
407        mat_d: MatrixExpand<D>,
408    ) {
409        scope.register(Instruction::new(
410            ir::CoopMma::Execute {
411                mat_a: *mat_a.elem,
412                mat_b: *mat_b.elem,
413                mat_c: *mat_c.elem,
414            },
415            *mat_d.elem,
416        ));
417    }
418}
419
420/// Store the matrix in the given array following the given stride and layout.
421#[allow(unused_variables)]
422pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
423    unexpanded!()
424}
425
426/// Module containing the expand function for [store()].
427pub mod cast {
428    use super::*;
429
430    /// Expand method of [store()].
431    #[allow(unused_variables)]
432    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
433        scope: &mut Scope,
434        input: MatrixExpand<C>,
435    ) -> MatrixExpand<O> {
436        let ident = input.ident;
437
438        if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
439            return MatrixExpand {
440                elem: input.elem,
441                ident,
442                _c: PhantomData,
443            };
444        }
445        let input = *input.elem;
446        let input_mat = match input.kind {
447            ir::VariableKind::Matrix { mat, .. } => mat,
448            _ => unreachable!(),
449        };
450
451        let elem = O::as_elem(scope);
452        let elem = scope.create_matrix(ir::Matrix {
453            ident,
454            m: input_mat.m,
455            n: input_mat.n,
456            k: input_mat.k,
457            elem,
458            layout: MatrixLayout::Undefined,
459        });
460
461        let output = MatrixExpand {
462            ident,
463            elem,
464            _c: PhantomData,
465        };
466        scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
467
468        output
469    }
470}
471
472impl CubeType for MatrixLayout {
473    type ExpandType = Self;
474}
475
476impl IntoMut for MatrixLayout {
477    fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
478        self
479    }
480}
481
482impl CubeDebug for MatrixLayout {}