cubecl_core/frontend/
cmma.rs

1//! This module exposes cooperative matrix-multiply and accumulate operations.
2//!
3//! Most of the functions are actually unsafe, since they mutate their input, even if they are
4//! passed as reference.
5//!
6//! # Example
7//!
8//! This is a basic 16x16x16 matrix multiplication example.
9//!
10//! ```rust, ignore
11//! #[cube(launch)]
12//! pub fn example(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>) {
13//!     let a = cmma::Matrix::<F16>::new(
14//!         cmma::MatrixIdent::A,
15//!         16,
16//!         16,
17//!         16,
18//!         cmma::MatrixLayout::RowMajor,
19//!     );
20//!     let b = cmma::Matrix::<F16>::new(
21//!         cmma::MatrixIdent::B,
22//!         16,
23//!         16,
24//!         16,
25//!         cmma::MatrixLayout::ColMajor,
26//!     );
27//!     let c = cmma::Matrix::<F32>::new(
28//!         cmma::MatrixIdent::Accumulator,
29//!         16,
30//!         16,
31//!         16,
32//!         cmma::MatrixLayout::Undefined,
33//!     );
34//!     cmma::fill::<F32>(&c, F32::new(0.0));
35//!     cmma::load::<F16>(&a, lhs.as_slice(), u32::new(16));
36//!     cmma::load::<F16>(&b, rhs.as_slice(), u32::new(16));
37//!
38//!     cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
39//!
40//!     cmma::store::<F32>(
41//!         out.as_slice_mut(),
42//!         &c,
43//!         u32::new(16),
44//!         cmma::MatrixLayout::RowMajor,
45//!     );
46//! }
47//! ```
48
49use std::marker::PhantomData;
50
51use crate::{
52    ir::{self, Instruction, Operation},
53    unexpanded,
54};
55
56use super::{
57    CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, IntoRuntime,
58    Slice, SliceMut,
59};
60
61pub use ir::{MatrixIdent, MatrixLayout};
62
63/// A matrix represent a 2D grid of numbers.
64///
65/// They can either be in a [row major](MatrixLayout::RowMajor) or a
66/// [column major](MatrixLayout::ColMajor) format.
67#[derive(Copy, Clone)]
68pub struct Matrix<C: CubeType> {
69    _c: PhantomData<C>,
70}
71
72/// Expand type of [Matrix].
73pub struct MatrixExpand<C: CubeType> {
74    elem: ExpandElement,
75    ident: MatrixIdent,
76    _c: PhantomData<C>,
77}
78
79impl<C: CubeType> Clone for MatrixExpand<C> {
80    fn clone(&self) -> Self {
81        Self {
82            elem: self.elem.clone(),
83            ident: self.ident,
84            _c: self._c,
85        }
86    }
87}
88
89impl<C: CubeType> CubeType for Matrix<C> {
90    type ExpandType = MatrixExpand<C>;
91}
92
93impl<C: CubeType> IntoRuntime for Matrix<C> {
94    fn __expand_runtime_method(self, _context: &mut CubeContext) -> MatrixExpand<C> {
95        unimplemented!("Matrices can't exist at compile time")
96    }
97}
98
99impl<C: CubeType> Init for MatrixExpand<C> {
100    fn init(self, _context: &mut CubeContext) -> Self {
101        self
102    }
103}
104
105impl<C: CubePrimitive> Matrix<C> {
106    /// Create a new uninitialized matrix that is going to be used in the
107    /// [matrix-multiply and accumulate](execute()) function.
108    ///
109    /// # Safety
110    /// Must be initialized with `load` or `fill` before use. Using it without initialization is
111    /// undefined behaviour on CUDA, and completely invalid on Vulkan.
112    ///
113    /// You have to declare the shape used for the execution.
114    /// The shape of the current matrix is determined using the [MatrixIdent].
115    ///
116    /// * [MatrixIdent::A] Shape => (M, K)
117    /// * [MatrixIdent::B] Shape => (K, N)
118    /// * [MatrixIdent::Accumulator] Shape => (M, N)
119    ///
120    /// Not all shapes are supported, and the permitted shapes depend on the element type.
121    ///
122    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
123    #[allow(unused_variables)]
124    pub unsafe fn uninitialized(
125        ident: MatrixIdent,
126        m: u32,
127        n: u32,
128        k: u32,
129        layout: MatrixLayout,
130    ) -> Self {
131        Matrix { _c: PhantomData }
132    }
133
134    /// Create a new matrix that is going to be used in the
135    /// [matrix-multiply and accumulate](execute()) function and is filled with `value`.
136    ///
137    /// You have to declare the shape used for the execution.
138    /// The shape of the current matrix is determined using the [MatrixIdent].
139    ///
140    /// * [MatrixIdent::A] Shape => (M, K)
141    /// * [MatrixIdent::B] Shape => (K, N)
142    /// * [MatrixIdent::Accumulator] Shape => (M, N)
143    ///
144    /// Not all shapes are supported, and the permitted shapes depend on the element type.
145    ///
146    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
147    #[allow(unused_variables)]
148    pub fn from_value(
149        ident: MatrixIdent,
150        m: u32,
151        n: u32,
152        k: u32,
153        layout: MatrixLayout,
154        value: C,
155    ) -> Self {
156        Matrix { _c: PhantomData }
157    }
158
159    /// Create a new matrix that is going to be used in the
160    /// [matrix-multiply and accumulate](execute()) function and is loaded from `value` with `stride`.
161    ///
162    /// You have to declare the shape used for the execution.
163    /// The shape of the current matrix is determined using the [MatrixIdent].
164    ///
165    /// * [MatrixIdent::A] Shape => (M, K)
166    /// * [MatrixIdent::B] Shape => (K, N)
167    /// * [MatrixIdent::Accumulator] Shape => (M, N)
168    ///
169    /// Not all shapes are supported, and the permitted shapes depend on the element type.
170    ///
171    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
172    #[allow(unused_variables)]
173    pub fn from_slice(
174        ident: MatrixIdent,
175        m: u32,
176        n: u32,
177        k: u32,
178        layout: MatrixLayout,
179        value: &Slice<C>,
180        stride: u32,
181    ) -> Self {
182        Matrix { _c: PhantomData }
183    }
184
185    pub fn __expand_uninitialized(
186        context: &mut CubeContext,
187        ident: MatrixIdent,
188        m: ExpandElementTyped<u32>,
189        n: ExpandElementTyped<u32>,
190        k: ExpandElementTyped<u32>,
191        layout: MatrixLayout,
192    ) -> MatrixExpand<C> {
193        let elem = C::as_elem(context);
194        let elem = context.create_matrix(ir::Matrix {
195            ident,
196            m: m.constant().unwrap().as_u32() as u8,
197            n: n.constant().unwrap().as_u32() as u8,
198            k: k.constant().unwrap().as_u32() as u8,
199            elem,
200            layout,
201        });
202        MatrixExpand {
203            elem,
204            ident,
205            _c: PhantomData,
206        }
207    }
208
209    pub fn __expand_from_value(
210        context: &mut CubeContext,
211        ident: MatrixIdent,
212        m: ExpandElementTyped<u32>,
213        n: ExpandElementTyped<u32>,
214        k: ExpandElementTyped<u32>,
215        layout: MatrixLayout,
216        value: ExpandElementTyped<C>,
217    ) -> MatrixExpand<C> {
218        let mat = Self::__expand_uninitialized(context, ident, m, n, k, layout);
219        fill::expand(context, mat.clone(), value);
220        mat
221    }
222
223    #[allow(clippy::too_many_arguments)]
224    pub fn __expand_from_slice(
225        context: &mut CubeContext,
226        ident: MatrixIdent,
227        m: ExpandElementTyped<u32>,
228        n: ExpandElementTyped<u32>,
229        k: ExpandElementTyped<u32>,
230        layout: MatrixLayout,
231        value: ExpandElementTyped<Slice<C>>,
232        stride: ExpandElementTyped<u32>,
233    ) -> MatrixExpand<C> {
234        let mat = Self::__expand_uninitialized(context, ident, m, n, k, layout);
235        load::expand(context, mat.clone(), value, stride);
236        mat
237    }
238}
239
240/// Fill the matrix with the provided value.
241#[allow(unused_variables)]
242pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
243    unexpanded!()
244}
245
246/// Module containing the expand function for [fill()].
247pub mod fill {
248    use super::*;
249
250    /// Expand method of [fill()].
251    pub fn expand<C: CubeType>(
252        context: &mut CubeContext,
253        mat: MatrixExpand<C>,
254        value: ExpandElementTyped<C>,
255    ) {
256        let value: ExpandElement = value.into();
257        context.register(Instruction::new(
258            ir::CoopMma::Fill { value: *value },
259            *mat.elem,
260        ));
261    }
262}
263
264/// Load the matrix with the provided array using the stride.
265#[allow(unused_variables)]
266pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
267    unexpanded!()
268}
269
270/// Module containing the expand function for [load()].
271pub mod load {
272    use super::*;
273
274    /// Expand method of [load()].
275    #[allow(unused_variables)]
276    pub fn expand<C: CubePrimitive, V: CubePrimitive>(
277        context: &mut CubeContext,
278        mat: MatrixExpand<C>,
279        value: ExpandElementTyped<Slice<V>>,
280        stride: ExpandElementTyped<u32>,
281    ) {
282        let stride: ExpandElement = stride.into();
283        assert_ne!(
284            mat.ident,
285            MatrixIdent::Accumulator,
286            "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
287        );
288
289        context.register(Instruction::new(
290            ir::CoopMma::Load {
291                value: *value.expand,
292                stride: *stride,
293                layout: None,
294            },
295            *mat.elem,
296        ));
297    }
298}
299
300/// Load the matrix with the provided array using the stride with an explicit layout.
301/// Explicit layouts are required when loading accumulators.
302#[allow(unused_variables)]
303pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
304    mat: &Matrix<C>,
305    value: &Slice<V>,
306    stride: u32,
307    layout: MatrixLayout,
308) {
309    unexpanded!()
310}
311
312/// Module containing the expand function for [load_with_layout()].
313pub mod load_with_layout {
314    use super::*;
315
316    /// Expand method of [load_with_layout()].
317    #[allow(unused_variables)]
318    pub fn expand<C: CubeType, V: CubePrimitive>(
319        context: &mut CubeContext,
320        mat: MatrixExpand<C>,
321        value: ExpandElementTyped<Slice<V>>,
322        stride: ExpandElementTyped<u32>,
323        layout: MatrixLayout,
324    ) {
325        let stride: ExpandElement = stride.into();
326
327        context.register(Instruction::new(
328            ir::CoopMma::Load {
329                value: *value.expand,
330                stride: *stride,
331                layout: Some(layout),
332            },
333            *mat.elem,
334        ));
335    }
336}
337
338/// Store the matrix in the given array following the given stride and layout.
339#[allow(unused_variables)]
340pub fn store<C: CubePrimitive, O: CubePrimitive>(
341    output: &mut SliceMut<O>,
342    mat: &Matrix<C>,
343    stride: u32,
344    layout: MatrixLayout,
345) {
346    unexpanded!()
347}
348
349/// Module containing the expand function for [store()].
350pub mod store {
351    use super::*;
352
353    /// Expand method of [store()].
354    #[allow(unused_variables)]
355    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
356        context: &mut CubeContext,
357        output: ExpandElementTyped<SliceMut<O>>,
358        mat: MatrixExpand<C>,
359        stride: ExpandElementTyped<u32>,
360        layout: MatrixLayout,
361    ) {
362        let stride: ExpandElement = stride.into();
363
364        context.register(Instruction::new(
365            ir::CoopMma::Store {
366                mat: *mat.elem,
367                stride: *stride,
368                layout,
369            },
370            *output.expand,
371        ));
372    }
373}
374
375/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
376#[allow(unused_variables)]
377pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
378    mat_a: &Matrix<A>,
379    mat_b: &Matrix<B>,
380    mat_c: &Matrix<C>,
381    mat_d: &Matrix<D>,
382) {
383    unexpanded!()
384}
385
386/// Module containing the expand function for [execute()].
387pub mod execute {
388    use super::*;
389
390    /// Expand method of [execute()].
391    pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
392        context: &mut CubeContext,
393        mat_a: MatrixExpand<A>,
394        mat_b: MatrixExpand<B>,
395        mat_c: MatrixExpand<C>,
396        mat_d: MatrixExpand<D>,
397    ) {
398        context.register(Instruction::new(
399            ir::CoopMma::Execute {
400                mat_a: *mat_a.elem,
401                mat_b: *mat_b.elem,
402                mat_c: *mat_c.elem,
403            },
404            *mat_d.elem,
405        ));
406    }
407}
408
409/// Store the matrix in the given array following the given stride and layout.
410#[allow(unused_variables)]
411pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
412    unexpanded!()
413}
414
415/// Module containing the expand function for [store()].
416pub mod cast {
417    use super::*;
418
419    /// Expand method of [store()].
420    #[allow(unused_variables)]
421    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
422        context: &mut CubeContext,
423        input: MatrixExpand<C>,
424    ) -> MatrixExpand<O> {
425        let ident = input.ident;
426
427        if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
428            return MatrixExpand {
429                elem: input.elem,
430                ident,
431                _c: PhantomData,
432            };
433        }
434        let input = *input.elem;
435        let input_mat = match input.kind {
436            ir::VariableKind::Matrix { mat, .. } => mat,
437            _ => unreachable!(),
438        };
439
440        let elem = O::as_elem(context);
441        let elem = context.create_matrix(ir::Matrix {
442            ident,
443            m: input_mat.m,
444            n: input_mat.n,
445            k: input_mat.k,
446            elem,
447            layout: MatrixLayout::Undefined,
448        });
449
450        let output = MatrixExpand {
451            ident,
452            elem,
453            _c: PhantomData,
454        };
455        context.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
456
457        output
458    }
459}
460
461impl From<ir::CoopMma> for Operation {
462    fn from(value: ir::CoopMma) -> Self {
463        Operation::CoopMma(value)
464    }
465}