Skip to main content

cubecl_core/frontend/
cmma.rs

1//! This module exposes cooperative matrix-multiply and accumulate operations.
2//!
3//! Most of the functions are actually unsafe, since they mutate their input, even if they are
4//! passed as reference.
5//!
6//! # Example
7//!
8//! This is a basic 16x16x16 matrix multiplication example.
9//!
10//! ```rust, ignore
11//! #[cube(launch)]
12//! pub fn example(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>) {
13//!     let a = cmma::Matrix::<F16>::new(
14//!         cmma::MatrixIdent::A,
15//!         16,
16//!         16,
17//!         16,
18//!         cmma::MatrixLayout::RowMajor,
19//!     );
20//!     let b = cmma::Matrix::<F16>::new(
21//!         cmma::MatrixIdent::B,
22//!         16,
23//!         16,
24//!         16,
25//!         cmma::MatrixLayout::ColMajor,
26//!     );
27//!     let c = cmma::Matrix::<F32>::new(
28//!         cmma::MatrixIdent::Accumulator,
29//!         16,
30//!         16,
31//!         16,
32//!         cmma::MatrixLayout::Undefined,
33//!     );
34//!     cmma::fill::<F32>(&c, F32::new(0.0));
35//!     cmma::load::<F16>(&a, lhs.as_slice(), u32::new(16));
36//!     cmma::load::<F16>(&b, rhs.as_slice(), u32::new(16));
37//!
38//!     cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
39//!
40//!     cmma::store::<F32>(
41//!         out.as_slice_mut(),
42//!         &c,
43//!         u32::new(16),
44//!         cmma::MatrixLayout::RowMajor,
45//!     );
46//! }
47//! ```
48
49use super::{
50    CubeDebug, CubePrimitive, CubeType, IntoMut, NativeExpand, ReadOnly, Slice, SliceExpand,
51    SliceMut,
52};
53use crate::{self as cubecl, prelude::*};
54use crate::{
55    ir::{self, Instruction},
56    unexpanded,
57};
58use core::marker::PhantomData;
59use cubecl_macros::{comptime_type, cube, intrinsic};
60
61use cubecl_ir::{CoopMma, ManagedVariable, Scope, StorageType, VectorSize};
62pub use ir::{MatrixIdent, MatrixLayout};
63
64/// A matrix represent a 2D grid of numbers.
65///
66/// They can either be in a [row major](MatrixLayout::RowMajor) or a
67/// [column major](MatrixLayout::ColMajor) format.
68#[derive(Copy, Clone)]
69pub struct Matrix<C: CubeType> {
70    _c: PhantomData<C>,
71}
72
73/// Defines a matrix multiplication operation, including the input and output type, and the shape.
74#[derive(Copy, Clone)]
75pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
76    _a: PhantomData<A>,
77    _b: PhantomData<B>,
78    _cd: PhantomData<CD>,
79}
80
81impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for &MmaDefinitionExpand<A, B, CD> {
82    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
83        MmaDefinitionExpand::set_debug_name(self, scope, name);
84    }
85}
86
87/// Expand type of [Matrix].
88pub struct MatrixExpand<C: CubeType> {
89    elem: ManagedVariable,
90    ident: MatrixIdent,
91    _c: PhantomData<C>,
92}
93
94/// Expand type of [`MmaDefinition`].
95#[derive(Debug)]
96pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
97    pub m: usize,
98    pub n: usize,
99    pub k: usize,
100    pub a_type: StorageType,
101    pub b_type: StorageType,
102    pub cd_type: StorageType,
103    pub scales_factor: Option<usize>,
104    pub scales_type: Option<StorageType>,
105    _a: PhantomData<A>,
106    _b: PhantomData<B>,
107    _cd: PhantomData<CD>,
108}
109
110impl<C: CubeType> Clone for MatrixExpand<C> {
111    fn clone(&self) -> Self {
112        Self {
113            elem: self.elem.clone(),
114            ident: self.ident,
115            _c: self._c,
116        }
117    }
118}
119
120impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
121    fn clone(&self) -> Self {
122        Self {
123            m: self.m,
124            n: self.n,
125            k: self.k,
126            a_type: self.a_type,
127            b_type: self.b_type,
128            cd_type: self.cd_type,
129            scales_factor: self.scales_factor,
130            scales_type: self.scales_type,
131            _a: PhantomData,
132            _b: PhantomData,
133            _cd: PhantomData,
134        }
135    }
136}
137
138impl<C: CubeType> CubeType for Matrix<C> {
139    type ExpandType = MatrixExpand<C>;
140}
141
142impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
143    type ExpandType = MmaDefinitionExpand<A, B, CD>;
144}
145
146impl<C: CubeType> IntoMut for MatrixExpand<C> {
147    fn into_mut(self, _scope: &mut Scope) -> Self {
148        self
149    }
150}
151
152impl<C: CubeType> CubeDebug for MatrixExpand<C> {
153    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
154        scope.update_variable_name(*self.elem, name);
155    }
156}
157
158impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
159    fn into_mut(self, _scope: &mut Scope) -> Self {
160        self
161    }
162}
163
164impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
165
166#[cube]
167impl<C: CubePrimitive> Matrix<C> {
168    /// Create a new uninitialized matrix that is going to be used in the
169    /// [matrix-multiply and accumulate](execute()) function.
170    ///
171    /// # Safety
172    /// Must be initialized with `load` or `fill` before use. Using it without initialization is
173    /// undefined behaviour on CUDA, and completely invalid on Vulkan.
174    ///
175    /// You have to declare the shape used for the execution.
176    /// The shape of the current matrix is determined using the [MatrixIdent].
177    ///
178    /// * [MatrixIdent::A] Shape => (M, K)
179    /// * [`MatrixIdent::B`] Shape => (K, N)
180    /// * [`MatrixIdent::Accumulator`] Shape => (M, N)
181    ///
182    /// Not all shapes are supported, and the permitted shapes depend on the element type.
183    ///
184    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
185    #[allow(unused_variables)]
186    pub unsafe fn uninitialized(
187        #[comptime] ident: MatrixIdent,
188        #[comptime] m: usize,
189        #[comptime] n: usize,
190        #[comptime] k: usize,
191        layout: MatrixLayout,
192    ) -> Self {
193        intrinsic!(|scope| {
194            let elem = C::as_type(scope).storage_type();
195            let elem = scope.create_matrix(ir::Matrix::new(ident, m, n, k, elem, layout));
196            MatrixExpand {
197                elem,
198                ident,
199                _c: PhantomData,
200            }
201        })
202    }
203
204    /// Create a new matrix that is going to be used in the
205    /// [matrix-multiply and accumulate](execute()) function and is filled with `value`.
206    ///
207    /// You have to declare the shape used for the execution.
208    /// The shape of the current matrix is determined using the [MatrixIdent].
209    ///
210    /// * [MatrixIdent::A] Shape => (M, K)
211    /// * [`MatrixIdent::B`] Shape => (K, N)
212    /// * [`MatrixIdent::Accumulator`] Shape => (M, N)
213    ///
214    /// Not all shapes are supported, and the permitted shapes depend on the element type.
215    ///
216    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
217    #[allow(unused_variables)]
218    pub fn from_value(
219        #[comptime] ident: MatrixIdent,
220        #[comptime] m: usize,
221        #[comptime] n: usize,
222        #[comptime] k: usize,
223        layout: MatrixLayout,
224        value: C,
225    ) -> Self
226    where
227        C: Scalar,
228    {
229        let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
230
231        intrinsic!(|scope| {
232            fill::expand(scope, mat.clone(), value);
233            mat
234        })
235    }
236
237    /// Create a new matrix that is going to be used in the
238    /// [matrix-multiply and accumulate](execute()) function and is loaded from `value` with `stride`.
239    ///
240    /// You have to declare the shape used for the execution.
241    /// The shape of the current matrix is determined using the [MatrixIdent].
242    ///
243    /// * [MatrixIdent::A] Shape => (M, K)
244    /// * [`MatrixIdent::B`] Shape => (K, N)
245    /// * [`MatrixIdent::Accumulator`] Shape => (M, N)
246    ///
247    /// Not all shapes are supported, and the permitted shapes depend on the element type.
248    ///
249    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
250    #[allow(unused_variables)]
251    pub fn from_slice(
252        #[comptime] ident: MatrixIdent,
253        #[comptime] m: usize,
254        #[comptime] n: usize,
255        #[comptime] k: usize,
256        layout: MatrixLayout,
257        value: &Slice<C>,
258        stride: u32,
259    ) -> Self {
260        let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
261
262        intrinsic!(|scope| {
263            load::expand(scope, mat.clone(), value, stride);
264            mat
265        })
266    }
267}
268
269#[cube(self_type = "ref")]
270impl<A: Scalar, B: Scalar, CD: Scalar> MmaDefinition<A, B, CD> {
271    /// Create a new matrix definition that is going to be used in the manual
272    /// matrix-multiply and accumulate ``execute_manual_mma()`` function.
273    ///
274    /// You have to declare the shape used for the execution.
275    /// The shape of the current matrix is determined using the [MatrixIdent].
276    ///
277    /// * [MatrixIdent::A] Shape => (M, K)
278    /// * [`MatrixIdent::B`] Shape => (K, N)
279    /// * [`MatrixIdent::Accumulator`] Shape => (M, N)
280    ///
281    /// Not all shapes are supported, and the permitted shapes depend on the element type.
282    /// Layout for manual MMA is determined by the runtime and must be handled manually.
283    /// Use [`Self::vector_layout`] to check the correct data layout for each element.
284    ///
285    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
286    #[allow(unused_variables)]
287    pub fn new(#[comptime] m: usize, #[comptime] n: usize, #[comptime] k: usize) -> Self {
288        intrinsic!(|scope| {
289            let a_type = A::as_type(scope).storage_type();
290            let b_type = B::as_type(scope).storage_type();
291            let cd_type = CD::as_type(scope).storage_type();
292
293            MmaDefinitionExpand {
294                m,
295                n,
296                k,
297                a_type,
298                b_type,
299                cd_type,
300                scales_factor: None,
301                scales_type: None,
302                _a: PhantomData,
303                _b: PhantomData,
304                _cd: PhantomData,
305            }
306        })
307    }
308
309    /// Create a new matrix definition that is going to be used in the manual
310    /// matrix-multiply and accumulate ``execute_manual_mma()`` function.
311    ///
312    /// You have to declare the shape used for the execution.
313    /// The shape of the current matrix is determined using the [MatrixIdent].
314    ///
315    /// * [MatrixIdent::A] Shape => (M, K)
316    /// * [`MatrixIdent::B`] Shape => (K, N)
317    /// * [`MatrixIdent::Accumulator`] Shape => (M, N)
318    ///
319    /// Not all shapes are supported, and the permitted shapes depend on the element type.
320    /// Layout for manual MMA is determined by the runtime and must be handled manually.
321    /// Use [`Self::vector_layout`] to check the correct data layout for each element.
322    ///
323    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
324    #[allow(unused_variables)]
325    pub fn new_scaled<S: CubePrimitive>(
326        #[comptime] m: usize,
327        #[comptime] n: usize,
328        #[comptime] k: usize,
329        #[comptime] scale_factor: usize,
330    ) -> Self {
331        intrinsic!(|scope| {
332            let a_type = A::as_type(scope).storage_type();
333            let b_type = B::as_type(scope).storage_type();
334            let cd_type = CD::as_type(scope).storage_type();
335
336            MmaDefinitionExpand {
337                m,
338                n,
339                k,
340                a_type,
341                b_type,
342                cd_type,
343                scales_factor: Some(scale_factor),
344                scales_type: Some(S::as_type(scope).storage_type()),
345                _a: PhantomData,
346                _b: PhantomData,
347                _cd: PhantomData,
348            }
349        })
350    }
351
352    /// Number of elements in the matrix
353    #[allow(unused)]
354    pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
355        intrinsic!(|scope| {
356            match ident {
357                MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor(),
358                MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor(),
359                MatrixIdent::Accumulator => (self.m * self.n) / self.cd_type.packing_factor(),
360            }
361        })
362    }
363
364    /// Returns the number of elements handled by each lane. Should be packed into `Vector`s of size
365    /// `vector_size` with [`Self::vector_layout`].
366    ///
367    /// # Note
368    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
369    /// to a cube.
370    #[allow(unused)]
371    pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
372        intrinsic!(|scope| {
373            let elems = self.__expand_num_elems_method(scope, ident);
374            let plane_dim = scope.runtime_properties.mma.const_plane_size as usize;
375            let duplication = match ident {
376                MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
377                MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
378                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
379            };
380            (elems * duplication) / plane_dim
381        })
382    }
383
384    /// Returns the number of vectors of size `vector_size` with layout `vector_layout` per lane.
385    ///
386    /// # Note
387    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
388    /// to a cube.
389    #[allow(unused)]
390    pub fn vectors_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
391        intrinsic!(|scope| {
392            let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
393            let vector_size = self.__expand_vector_size_method(scope, ident);
394            elems / vector_size
395        })
396    }
397
398    /// The layout of each vector in this matrix (row major or column major)
399    #[allow(unused)]
400    pub fn vector_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
401        intrinsic!(|scope| {
402            match ident {
403                MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
404                MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
405                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
406            }
407        })
408    }
409
410    /// Number of elements in each vector passed to the execute function. Represents the maximum
411    /// number of contiguous elements held by the thread.
412    #[allow(unused_variables)]
413    pub fn vector_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(VectorSize) {
414        intrinsic!(|scope| {
415            let storage = match ident {
416                MatrixIdent::A => self.a_type,
417                MatrixIdent::B => self.b_type,
418                MatrixIdent::Accumulator => self.cd_type,
419            };
420            let matrix = cubecl_ir::Matrix {
421                ident,
422                m: self.m,
423                n: self.n,
424                k: self.k,
425                storage: storage,
426                layout: MatrixLayout::ColMajor,
427            };
428            scope
429                .runtime_properties
430                .mma
431                .contiguous_elements
432                .apply(ident, matrix)
433        })
434    }
435
436    /// Returns the coordinates of the `nth` element handled by the `lane_id`
437    /// Each lane contains [`Self::elems_per_lane`] elements in [`Self::vector_size`] chunks.
438    /// Returns (`row_idx`, `col_idx`)
439    ///
440    /// # Note
441    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
442    /// to a cube.
443    #[allow(unused_variables)]
444    pub fn position_of_nth(
445        &self,
446        lane_id: u32,
447        elem_idx: u32,
448        #[comptime] ident: MatrixIdent,
449    ) -> (u32, u32) {
450        intrinsic!(|scope| {
451            let lane_id: ManagedVariable = lane_id.into();
452            let elem_idx: ManagedVariable = elem_idx.into();
453
454            let ty = match ident {
455                MatrixIdent::A => self.a_type,
456                MatrixIdent::B => self.b_type,
457                MatrixIdent::Accumulator => self.cd_type,
458            };
459            let layout = match ident {
460                MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
461                MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
462                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
463            };
464            let matrix = cubecl_ir::Matrix {
465                ident,
466                m: self.m,
467                n: self.n,
468                k: self.k,
469                storage: ty,
470                layout,
471            };
472
473            let row = scope.create_local(u32::as_type(scope));
474            let col = scope.create_local(u32::as_type(scope));
475            scope.register(Instruction::new(
476                CoopMma::RowIndex {
477                    lane_id: *lane_id,
478                    i: *elem_idx,
479                    matrix,
480                },
481                *row,
482            ));
483            scope.register(Instruction::new(
484                CoopMma::ColIndex {
485                    lane_id: *lane_id,
486                    i: *elem_idx,
487                    matrix,
488                },
489                *col,
490            ));
491            (row.into(), col.into())
492        })
493    }
494
495    /// Index of the scales for this thread, along the non-major dimension of the matrix.
496    /// Each thread loads all scales in the major direction into a single `Vector`.
497    pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
498        // Just do CUDA for now, call an actual intrinsic when HIP gets support
499        let quad_id = lane_id / 4;
500        let t_id = lane_id % 4;
501        match ident {
502            MatrixIdent::A => quad_id + (t_id % 2) * 8,
503            MatrixIdent::B => quad_id,
504            MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
505        }
506    }
507
508    /// Number of scales in each vector (not the vector size!). Vector size may include padding bytes.
509    pub fn scales_count(&self) -> comptime_type!(usize) {
510        // We only have the CUDA version for now, so just use `scales_factor`. The function can
511        // be modified for HIP in the future without having to redo all uses.
512        intrinsic!(|_| {
513            self.scales_factor
514                .expect("Can't retrieve scales count for matrix with no scales")
515        })
516    }
517
518    /// Vector size for the scale factors. May be larger than the total number of scales.
519    pub fn scales_vector_size(&self) -> comptime_type!(VectorSize) {
520        intrinsic!(|scope| {
521            let elem = self
522                .scales_type
523                .expect("Can't retrieve scales vector size for matrix with no scales");
524            scope.runtime_properties.mma.register_size_bits / elem.size_bits()
525        })
526    }
527
528    /// Load one or more matrix register using intrinsic instructions. CUDA only.
529    /// The number of matrices must be 1, 2, or 4. The rows for the nth matrix are passed by the 8
530    /// lanes starting at `n * 8`. All slice starts must be valid, even for non-participating lanes.
531    /// The slice determines the starting address for a 16-byte row loaded by this unit, with
532    /// the row index being `UNIT_POS_PLANE % 8`.
533    /// The number of elements is determined by element size.
534    ///
535    /// # Constraints:
536    /// Address must be aligned to 16 bytes
537    /// Address must be in shared memory
538    #[allow(unused_variables)]
539    pub fn load_matrix<E: CubePrimitive, NO: Size>(
540        &self,
541        row: &Slice<E>,
542        #[comptime] ident: MatrixIdent,
543        #[comptime] num_matrices: usize,
544        #[comptime] transpose: bool,
545    ) -> Array<Vector<E::Scalar, NO>> {
546        intrinsic!(|scope| {
547            let slice_vector_size = row.vector_size;
548            let (buffer, offset) = row.__to_raw_parts();
549            let out = Array::__expand_new(scope, num_matrices);
550            scope.register(Instruction::new(
551                CoopMma::LoadMatrix {
552                    buffer,
553                    offset,
554                    vector_size: slice_vector_size,
555                    factor: num_matrices,
556                    transpose,
557                },
558                *out.expand,
559            ));
560            out
561        })
562    }
563
564    #[allow(unused_variables)]
565    pub fn load_matrix_inplace<E: Scalar, N: Size>(
566        &self,
567        row: &Slice<E>,
568        fragment: &mut Array<Vector<E, N>>,
569        #[comptime] ident: MatrixIdent,
570        #[comptime] num_matrices: usize,
571        #[comptime] transpose: bool,
572    ) {
573        intrinsic!(|scope| {
574            let vector_size = self.__expand_vector_size_method(scope, ident);
575            let slice_vector_size = row.vector_size;
576            let (buffer, offset) = row.__to_raw_parts();
577            scope.register(Instruction::new(
578                CoopMma::LoadMatrix {
579                    buffer,
580                    offset,
581                    vector_size: slice_vector_size,
582                    factor: num_matrices,
583                    transpose,
584                },
585                *fragment.expand,
586            ));
587        })
588    }
589
590    /// Store one or more matrix register using intrinsic instructions. CUDA only.
591    /// The number of matrices must be 1, 2, or 4. The rows for the nth matrix are passed by the 8
592    /// lanes starting at `n * 8`. All slice starts must be valid, even for non-participating lanes.
593    /// The slice determines the starting address for a 16-byte row loaded by this unit, with
594    /// the row index being `UNIT_POS_PLANE % 8`.
595    /// The number of elements is determined by element size.
596    ///
597    /// # Constraints:
598    /// Address must be aligned to 16 bytes
599    /// Address must be in shared memory
600    #[allow(unused_variables)]
601    pub fn store_matrix<E: CubePrimitive, N: Size>(
602        &self,
603        row: &mut Slice<E, ReadWrite>,
604        registers: &Array<Vector<E::Scalar, N>>,
605        #[comptime] ident: MatrixIdent,
606        #[comptime] num_matrices: usize,
607        #[comptime] transpose: bool,
608    ) {
609        intrinsic!(|scope| {
610            let vector_size = self.__expand_vector_size_method(scope, ident);
611            let slice_vector_size = row.vector_size;
612            let (buffer, offset) = row.__to_raw_parts();
613            scope.register(Instruction::new(
614                CoopMma::StoreMatrix {
615                    offset,
616                    vector_size: slice_vector_size,
617                    registers: *registers.expand,
618                    factor: num_matrices,
619                    transpose,
620                },
621                buffer,
622            ));
623        })
624    }
625
626    /// Execute a low level `mma` operation with manually managed registers. Register layout
627    /// and index mapping can be retrieved from the [`MmaDefinition`]
628    #[allow(unused)]
629    pub fn execute<NA: Size, NB: Size, NC: Size>(
630        &self,
631        registers_a: &Array<Vector<A, NA>>,
632        registers_b: &Array<Vector<B, NB>>,
633        registers_c: &Array<Vector<CD, NC>>,
634    ) -> Array<Vector<CD, NC>> {
635        intrinsic!(|scope| {
636            let acc_elems = self
637                .clone()
638                .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
639            let acc_vector_size = self
640                .clone()
641                .__expand_vector_size_method(scope, MatrixIdent::Accumulator);
642            let num_registers = acc_elems / acc_vector_size;
643
644            let registers_d = Array::__expand_new(scope, num_registers);
645
646            let registers_a = *registers_a.expand;
647            let registers_b = *registers_b.expand;
648            let registers_c = *registers_c.expand;
649
650            // Only shape is actually used
651            let matrix = cubecl_ir::Matrix {
652                ident: MatrixIdent::A,
653                m: self.m,
654                n: self.n,
655                k: self.k,
656                storage: self.a_type,
657                layout: MatrixLayout::ColMajor,
658            };
659
660            scope.register(Instruction::new(
661                CoopMma::ExecuteManual {
662                    matrix,
663                    registers_a,
664                    registers_b,
665                    registers_c,
666                },
667                *registers_d.expand,
668            ));
669
670            registers_d
671        })
672    }
673
674    #[allow(unused)]
675    pub fn execute_inplace<NA: Size, NB: Size, NC: Size>(
676        &self,
677        registers_a: &Array<Vector<A, NA>>,
678        registers_b: &Array<Vector<B, NB>>,
679        registers_c: &mut Array<Vector<CD, NC>>,
680    ) {
681        intrinsic!(|scope| {
682            let acc_elems = self
683                .clone()
684                .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
685            let acc_vector_size = self
686                .clone()
687                .__expand_vector_size_method(scope, MatrixIdent::Accumulator);
688            let num_registers = acc_elems / acc_vector_size;
689
690            let registers_a = *registers_a.expand;
691            let registers_b = *registers_b.expand;
692            let registers_c = *registers_c.expand;
693
694            // Only shape is actually used
695            let matrix = cubecl_ir::Matrix {
696                ident: MatrixIdent::A,
697                m: self.m,
698                n: self.n,
699                k: self.k,
700                storage: self.a_type,
701                layout: MatrixLayout::ColMajor,
702            };
703
704            scope.register(Instruction::new(
705                CoopMma::ExecuteManual {
706                    matrix,
707                    registers_a,
708                    registers_b,
709                    registers_c,
710                },
711                registers_c,
712            ));
713        })
714    }
715
716    /// Execute a low level block scaled `mma` operation with manually managed registers. Register
717    /// layout and index mapping can be retrieved from the [`MmaDefinition`]
718    #[allow(unused)]
719    pub fn execute_scaled<S: Scalar, NA: Size, NB: Size, NC: Size, NS: Size>(
720        &self,
721        registers_a: &Array<Vector<A, NA>>,
722        registers_b: &Array<Vector<B, NB>>,
723        registers_c: &Array<Vector<CD, NC>>,
724        scales_a: Vector<S, NS>,
725        scales_b: Vector<S, NS>,
726    ) -> Array<Vector<CD, NC>> {
727        intrinsic!(|scope| {
728            let acc_elems = self
729                .clone()
730                .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
731            let acc_vector_size = self
732                .clone()
733                .__expand_vector_size_method(scope, MatrixIdent::Accumulator);
734            let num_registers = acc_elems / acc_vector_size;
735
736            let registers_d = Array::__expand_new(scope, num_registers);
737
738            let registers_a = *registers_a.expand;
739            let registers_b = *registers_b.expand;
740            let registers_c = *registers_c.expand;
741
742            // Only shape is actually used
743            let matrix = cubecl_ir::Matrix {
744                ident: MatrixIdent::A,
745                m: self.m,
746                n: self.n,
747                k: self.k,
748                storage: self.a_type,
749                layout: MatrixLayout::ColMajor,
750            };
751
752            scope.register(Instruction::new(
753                CoopMma::ExecuteScaled {
754                    matrix,
755                    registers_a,
756                    registers_b,
757                    registers_c,
758                    scales_a: *scales_a.expand,
759                    scales_b: *scales_b.expand,
760                    scales_factor: self
761                        .scales_factor
762                        .expect("Can't execute scaled on matrix with no scales"),
763                },
764                *registers_d.expand,
765            ));
766
767            registers_d
768        })
769    }
770}
771
772/// Fill the matrix with the provided value.
773#[allow(unused_variables)]
774pub fn fill<C: Scalar>(mat: &Matrix<C>, value: C) {
775    unexpanded!()
776}
777
778/// Module containing the expand function for [`fill()`].
779pub mod fill {
780    use super::*;
781
782    /// Expand method of [`fill()`].
783    pub fn expand<C: Scalar>(scope: &mut Scope, mat: MatrixExpand<C>, value: NativeExpand<C>) {
784        let value: ManagedVariable = value.into();
785        scope.register(Instruction::new(
786            ir::CoopMma::Fill { value: *value },
787            *mat.elem,
788        ));
789    }
790}
791
792/// Load the matrix with the provided array using the stride.
793#[allow(unused_variables)]
794pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
795    unexpanded!()
796}
797
798/// Module containing the expand function for [`load()`].
799pub mod load {
800    use super::*;
801
802    /// Expand method of [`load()`].
803    #[allow(unused_variables)]
804    pub fn expand<C: CubePrimitive, V: CubePrimitive>(
805        scope: &mut Scope,
806        mat: MatrixExpand<C>,
807        value: SliceExpand<V, ReadOnly>,
808        stride: NativeExpand<u32>,
809    ) {
810        let stride: ManagedVariable = stride.into();
811        assert_ne!(
812            mat.ident,
813            MatrixIdent::Accumulator,
814            "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
815        );
816
817        let (value, offset) = value.__to_raw_parts();
818
819        scope.register(Instruction::new(
820            ir::CoopMma::Load {
821                value,
822                stride: *stride,
823                offset,
824                layout: None,
825            },
826            *mat.elem,
827        ));
828    }
829}
830
831/// Load the matrix with the provided array using the stride with an explicit layout.
832/// Explicit layouts are required when loading accumulators.
833#[allow(unused_variables)]
834pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
835    mat: &Matrix<C>,
836    value: &Slice<V>,
837    stride: u32,
838    layout: MatrixLayout,
839) {
840    unexpanded!()
841}
842
843/// Module containing the expand function for [`load_with_layout()`].
844pub mod load_with_layout {
845    use super::*;
846
847    /// Expand method of [`load_with_layout()`].
848    #[allow(unused_variables)]
849    pub fn expand<C: CubeType, V: CubePrimitive>(
850        scope: &mut Scope,
851        mat: MatrixExpand<C>,
852        value: SliceExpand<V, ReadOnly>,
853        stride: NativeExpand<u32>,
854        layout: MatrixLayout,
855    ) {
856        let stride: ManagedVariable = stride.into();
857        let (value, offset) = value.__to_raw_parts();
858
859        scope.register(Instruction::new(
860            ir::CoopMma::Load {
861                value,
862                stride: *stride,
863                offset,
864                layout: Some(layout),
865            },
866            *mat.elem,
867        ));
868    }
869}
870
871/// Store the matrix in the given array following the given stride and layout.
872#[allow(unused_variables)]
873pub fn store<C: CubePrimitive, O: CubePrimitive>(
874    output: &mut SliceMut<O>,
875    mat: &Matrix<C>,
876    stride: u32,
877    layout: MatrixLayout,
878) {
879    unexpanded!()
880}
881
882/// Module containing the expand function for [`store()`].
883pub mod store {
884    use crate::prelude::ReadWrite;
885
886    use super::*;
887
888    /// Expand method of [`store()`].
889    #[allow(unused_variables)]
890    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
891        scope: &mut Scope,
892        output: SliceExpand<O, ReadWrite>,
893        mat: MatrixExpand<C>,
894        stride: NativeExpand<u32>,
895        layout: MatrixLayout,
896    ) {
897        let stride: ManagedVariable = stride.into();
898
899        let (output, offset) = output.__to_raw_parts();
900
901        scope.register(Instruction::new(
902            ir::CoopMma::Store {
903                mat: *mat.elem,
904                offset,
905                stride: *stride,
906                layout,
907            },
908            output,
909        ));
910    }
911}
912
913/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
914#[allow(unused_variables)]
915pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
916    mat_a: &Matrix<A>,
917    mat_b: &Matrix<B>,
918    mat_c: &Matrix<C>,
919    mat_d: &Matrix<D>,
920) {
921    unexpanded!()
922}
923
924/// Module containing the expand function for [`execute()`].
925pub mod execute {
926    use super::*;
927
928    /// Expand method of [`execute()`].
929    pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
930        scope: &mut Scope,
931        mat_a: MatrixExpand<A>,
932        mat_b: MatrixExpand<B>,
933        mat_c: MatrixExpand<C>,
934        mat_d: MatrixExpand<D>,
935    ) {
936        scope.register(Instruction::new(
937            ir::CoopMma::Execute {
938                mat_a: *mat_a.elem,
939                mat_b: *mat_b.elem,
940                mat_c: *mat_c.elem,
941            },
942            *mat_d.elem,
943        ));
944    }
945}
946
947/// Store the matrix in the given array following the given stride and layout.
948#[allow(unused_variables)]
949pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
950    unexpanded!()
951}
952
953/// Module containing the expand function for [`store()`].
954pub mod cast {
955    use super::*;
956
957    /// Expand method of [`store()`].
958    #[allow(unused_variables)]
959    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
960        scope: &mut Scope,
961        input: MatrixExpand<C>,
962    ) -> MatrixExpand<O> {
963        let ident = input.ident;
964
965        if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
966            return MatrixExpand {
967                elem: input.elem,
968                ident,
969                _c: PhantomData,
970            };
971        }
972        let input = *input.elem;
973        let input_mat = match input.kind {
974            ir::VariableKind::Matrix { mat, .. } => mat,
975            _ => unreachable!(),
976        };
977
978        let elem = O::as_type(scope).storage_type();
979        let elem = scope.create_matrix(ir::Matrix::new(
980            ident,
981            input_mat.m,
982            input_mat.n,
983            input_mat.k,
984            elem,
985            MatrixLayout::Undefined,
986        ));
987
988        let output = MatrixExpand {
989            ident,
990            elem,
991            _c: PhantomData,
992        };
993        scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
994
995        output
996    }
997}
998
999impl CubeType for MatrixLayout {
1000    type ExpandType = Self;
1001}
1002
1003impl IntoMut for MatrixLayout {
1004    fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
1005        self
1006    }
1007}
1008
1009impl CubeDebug for MatrixLayout {}