cubecl_core/frontend/
cmma.rs

1//! This module exposes cooperative matrix-multiply and accumulate operations.
2//!
3//! Most of the functions are actually unsafe, since they mutate their input, even if they are
4//! passed as reference.
5//!
6//! # Example
7//!
8//! This is a basic 16x16x16 matrix multiplication example.
9//!
10//! ```rust, ignore
11//! #[cube(launch)]
12//! pub fn example(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>) {
13//!     let a = cmma::Matrix::<F16>::new(
14//!         cmma::MatrixIdent::A,
15//!         16,
16//!         16,
17//!         16,
18//!         cmma::MatrixLayout::RowMajor,
19//!     );
20//!     let b = cmma::Matrix::<F16>::new(
21//!         cmma::MatrixIdent::B,
22//!         16,
23//!         16,
24//!         16,
25//!         cmma::MatrixLayout::ColMajor,
26//!     );
27//!     let c = cmma::Matrix::<F32>::new(
28//!         cmma::MatrixIdent::Accumulator,
29//!         16,
30//!         16,
31//!         16,
32//!         cmma::MatrixLayout::Undefined,
33//!     );
34//!     cmma::fill::<F32>(&c, F32::new(0.0));
35//!     cmma::load::<F16>(&a, lhs.as_slice(), u32::new(16));
36//!     cmma::load::<F16>(&b, rhs.as_slice(), u32::new(16));
37//!
38//!     cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
39//!
40//!     cmma::store::<F32>(
41//!         out.as_slice_mut(),
42//!         &c,
43//!         u32::new(16),
44//!         cmma::MatrixLayout::RowMajor,
45//!     );
46//! }
47//! ```
48
49use super::{
50    CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
51    SliceMut,
52};
53use crate::{
54    self as cubecl,
55    prelude::{Array, Line, ReadWrite},
56};
57use crate::{
58    ir::{self, Instruction},
59    unexpanded,
60};
61use cubecl_macros::{comptime_type, cube, intrinsic};
62use std::marker::PhantomData;
63
64use cubecl_ir::{CoopMma, ExpandElement, LineSize, Scope, StorageType, Type};
65pub use ir::{MatrixIdent, MatrixLayout};
66
67/// A matrix represent a 2D grid of numbers.
68///
69/// They can either be in a [row major](MatrixLayout::RowMajor) or a
70/// [column major](MatrixLayout::ColMajor) format.
71#[derive(Copy, Clone)]
72pub struct Matrix<C: CubeType> {
73    _c: PhantomData<C>,
74}
75
76/// Defines a matrix multiplication operation, including the input and output type, and the shape.
77#[derive(Copy, Clone)]
78pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
79    _a: PhantomData<A>,
80    _b: PhantomData<B>,
81    _cd: PhantomData<CD>,
82}
83
84impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for &MmaDefinitionExpand<A, B, CD> {
85    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
86        MmaDefinitionExpand::set_debug_name(self, scope, name);
87    }
88}
89
90/// Expand type of [Matrix].
91pub struct MatrixExpand<C: CubeType> {
92    elem: ExpandElement,
93    ident: MatrixIdent,
94    _c: PhantomData<C>,
95}
96
97/// Expand type of [MmaDefinition].
98#[derive(Debug)]
99pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
100    pub m: usize,
101    pub n: usize,
102    pub k: usize,
103    pub a_type: StorageType,
104    pub b_type: StorageType,
105    pub cd_type: StorageType,
106    pub scales_factor: Option<usize>,
107    pub scales_type: Option<StorageType>,
108    _a: PhantomData<A>,
109    _b: PhantomData<B>,
110    _cd: PhantomData<CD>,
111}
112
113impl<C: CubeType> Clone for MatrixExpand<C> {
114    fn clone(&self) -> Self {
115        Self {
116            elem: self.elem.clone(),
117            ident: self.ident,
118            _c: self._c,
119        }
120    }
121}
122
123impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
124    fn clone(&self) -> Self {
125        Self {
126            m: self.m,
127            n: self.n,
128            k: self.k,
129            a_type: self.a_type,
130            b_type: self.b_type,
131            cd_type: self.cd_type,
132            scales_factor: self.scales_factor,
133            scales_type: self.scales_type,
134            _a: PhantomData,
135            _b: PhantomData,
136            _cd: PhantomData,
137        }
138    }
139}
140
141impl<C: CubeType> CubeType for Matrix<C> {
142    type ExpandType = MatrixExpand<C>;
143}
144
145impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
146    type ExpandType = MmaDefinitionExpand<A, B, CD>;
147}
148
149impl<C: CubeType> IntoMut for MatrixExpand<C> {
150    fn into_mut(self, _scope: &mut Scope) -> Self {
151        self
152    }
153}
154
155impl<C: CubeType> CubeDebug for MatrixExpand<C> {
156    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
157        scope.update_variable_name(*self.elem, name);
158    }
159}
160
161impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
162    fn into_mut(self, _scope: &mut Scope) -> Self {
163        self
164    }
165}
166
167impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
168
169#[cube]
170impl<C: CubePrimitive> Matrix<C> {
171    /// Create a new uninitialized matrix that is going to be used in the
172    /// [matrix-multiply and accumulate](execute()) function.
173    ///
174    /// # Safety
175    /// Must be initialized with `load` or `fill` before use. Using it without initialization is
176    /// undefined behaviour on CUDA, and completely invalid on Vulkan.
177    ///
178    /// You have to declare the shape used for the execution.
179    /// The shape of the current matrix is determined using the [MatrixIdent].
180    ///
181    /// * [MatrixIdent::A] Shape => (M, K)
182    /// * [MatrixIdent::B] Shape => (K, N)
183    /// * [MatrixIdent::Accumulator] Shape => (M, N)
184    ///
185    /// Not all shapes are supported, and the permitted shapes depend on the element type.
186    ///
187    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
188    #[allow(unused_variables)]
189    pub unsafe fn uninitialized(
190        #[comptime] ident: MatrixIdent,
191        #[comptime] m: usize,
192        #[comptime] n: usize,
193        #[comptime] k: usize,
194        layout: MatrixLayout,
195    ) -> Self {
196        intrinsic!(|scope| {
197            let elem = C::as_type(scope);
198            let elem = scope.create_matrix(ir::Matrix::new(ident, m, n, k, elem, layout));
199            MatrixExpand {
200                elem,
201                ident,
202                _c: PhantomData,
203            }
204        })
205    }
206
207    /// Create a new matrix that is going to be used in the
208    /// [matrix-multiply and accumulate](execute()) function and is filled with `value`.
209    ///
210    /// You have to declare the shape used for the execution.
211    /// The shape of the current matrix is determined using the [MatrixIdent].
212    ///
213    /// * [MatrixIdent::A] Shape => (M, K)
214    /// * [MatrixIdent::B] Shape => (K, N)
215    /// * [MatrixIdent::Accumulator] Shape => (M, N)
216    ///
217    /// Not all shapes are supported, and the permitted shapes depend on the element type.
218    ///
219    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
220    #[allow(unused_variables)]
221    pub fn from_value(
222        #[comptime] ident: MatrixIdent,
223        #[comptime] m: usize,
224        #[comptime] n: usize,
225        #[comptime] k: usize,
226        layout: MatrixLayout,
227        value: C,
228    ) -> Self {
229        let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
230
231        intrinsic!(|scope| {
232            fill::expand(scope, mat.clone(), value);
233            mat
234        })
235    }
236
237    /// Create a new matrix that is going to be used in the
238    /// [matrix-multiply and accumulate](execute()) function and is loaded from `value` with `stride`.
239    ///
240    /// You have to declare the shape used for the execution.
241    /// The shape of the current matrix is determined using the [MatrixIdent].
242    ///
243    /// * [MatrixIdent::A] Shape => (M, K)
244    /// * [MatrixIdent::B] Shape => (K, N)
245    /// * [MatrixIdent::Accumulator] Shape => (M, N)
246    ///
247    /// Not all shapes are supported, and the permitted shapes depend on the element type.
248    ///
249    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
250    #[allow(unused_variables)]
251    pub fn from_slice(
252        #[comptime] ident: MatrixIdent,
253        #[comptime] m: usize,
254        #[comptime] n: usize,
255        #[comptime] k: usize,
256        layout: MatrixLayout,
257        value: &Slice<C>,
258        stride: u32,
259    ) -> Self {
260        let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
261
262        intrinsic!(|scope| {
263            load::expand(scope, mat.clone(), value, stride);
264            mat
265        })
266    }
267}
268
269#[cube(self_type = "ref")]
270impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B, CD> {
271    /// Create a new matrix definition that is going to be used in the manual
272    /// [matrix-multiply and accumulate](execute_manual()) function.
273    ///
274    /// You have to declare the shape used for the execution.
275    /// The shape of the current matrix is determined using the [MatrixIdent].
276    ///
277    /// * [MatrixIdent::A] Shape => (M, K)
278    /// * [MatrixIdent::B] Shape => (K, N)
279    /// * [MatrixIdent::Accumulator] Shape => (M, N)
280    ///
281    /// Not all shapes are supported, and the permitted shapes depend on the element type.
282    /// Layout for manual MMA is determined by the runtime and must be handled manually.
283    /// Use [`line_layout`] to check the correct data layout for each element.
284    ///
285    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
286    #[allow(unused_variables)]
287    pub fn new(#[comptime] m: usize, #[comptime] n: usize, #[comptime] k: usize) -> Self {
288        intrinsic!(|scope| {
289            let a_type = A::as_type(scope);
290            let b_type = B::as_type(scope);
291            let cd_type = CD::as_type(scope);
292
293            MmaDefinitionExpand {
294                m,
295                n,
296                k,
297                a_type,
298                b_type,
299                cd_type,
300                scales_factor: None,
301                scales_type: None,
302                _a: PhantomData,
303                _b: PhantomData,
304                _cd: PhantomData,
305            }
306        })
307    }
308
309    /// Create a new matrix definition that is going to be used in the manual
310    /// [matrix-multiply and accumulate](execute_manual()) function.
311    ///
312    /// You have to declare the shape used for the execution.
313    /// The shape of the current matrix is determined using the [MatrixIdent].
314    ///
315    /// * [MatrixIdent::A] Shape => (M, K)
316    /// * [MatrixIdent::B] Shape => (K, N)
317    /// * [MatrixIdent::Accumulator] Shape => (M, N)
318    ///
319    /// Not all shapes are supported, and the permitted shapes depend on the element type.
320    /// Layout for manual MMA is determined by the runtime and must be handled manually.
321    /// Use [`line_layout`] to check the correct data layout for each element.
322    ///
323    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
324    #[allow(unused_variables)]
325    pub fn new_scaled<S: CubePrimitive>(
326        #[comptime] m: usize,
327        #[comptime] n: usize,
328        #[comptime] k: usize,
329        #[comptime] scale_factor: usize,
330    ) -> Self {
331        intrinsic!(|scope| {
332            let a_type = A::as_type(scope);
333            let b_type = B::as_type(scope);
334            let cd_type = CD::as_type(scope);
335
336            MmaDefinitionExpand {
337                m,
338                n,
339                k,
340                a_type,
341                b_type,
342                cd_type,
343                scales_factor: Some(scale_factor),
344                scales_type: Some(S::as_type(scope)),
345                _a: PhantomData,
346                _b: PhantomData,
347                _cd: PhantomData,
348            }
349        })
350    }
351
352    /// Number of elements in the matrix
353    #[allow(unused)]
354    pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
355        intrinsic!(|scope| {
356            match ident {
357                MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor(),
358                MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor(),
359                MatrixIdent::Accumulator => (self.m * self.n) / self.cd_type.packing_factor(),
360            }
361        })
362    }
363
364    /// Returns the number of elements handled by each lane. Should be packed into `Line`s of size
365    /// `line_size` with [`line_layout`].
366    ///
367    /// # Note
368    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
369    /// to a cube.
370    #[allow(unused)]
371    pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
372        intrinsic!(|scope| {
373            let elems = self.__expand_num_elems_method(scope, ident);
374            let plane_dim = scope.runtime_properties.mma.const_plane_size as usize;
375            let duplication = match ident {
376                MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
377                MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
378                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
379            };
380            (elems / plane_dim) * duplication
381        })
382    }
383
384    /// Returns the number of lines of size `line_size` with layout `line_layout` per lane.
385    ///
386    /// # Note
387    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
388    /// to a cube.
389    #[allow(unused)]
390    pub fn lines_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
391        intrinsic!(|scope| {
392            let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
393            let line_size = self.__expand_line_size_method(scope, ident);
394            elems / line_size
395        })
396    }
397
398    /// The layout of each line in this matrix (row major or column major)
399    #[allow(unused)]
400    pub fn line_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
401        intrinsic!(|scope| {
402            match ident {
403                MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
404                MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
405                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
406            }
407        })
408    }
409
410    /// Number of elements in each line passed to the execute function. Represents the maximum
411    /// number of contiguous elements held by the thread.
412    #[allow(unused_variables)]
413    pub fn line_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(LineSize) {
414        intrinsic!(|scope| {
415            let storage = match ident {
416                MatrixIdent::A => self.a_type,
417                MatrixIdent::B => self.b_type,
418                MatrixIdent::Accumulator => self.cd_type,
419            };
420            let matrix = cubecl_ir::Matrix {
421                ident,
422                m: self.m,
423                n: self.n,
424                k: self.k,
425                storage: storage,
426                layout: MatrixLayout::ColMajor,
427            };
428            scope
429                .runtime_properties
430                .mma
431                .contiguous_elements
432                .apply(ident, matrix)
433        })
434    }
435
436    /// Returns the coordinates of the `nth` element handled by the `lane_id`
437    /// Each lane contains [`elems_per_lane`] elements in [`line_size`] chunks.
438    /// Returns (`row_idx`, `col_idx`)
439    ///
440    /// # Note
441    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
442    /// to a cube.
443    #[allow(unused_variables)]
444    pub fn position_of_nth(
445        &self,
446        lane_id: u32,
447        elem_idx: u32,
448        #[comptime] ident: MatrixIdent,
449    ) -> (u32, u32) {
450        intrinsic!(|scope| {
451            let lane_id: ExpandElement = lane_id.into();
452            let elem_idx: ExpandElement = elem_idx.into();
453
454            let ty = match ident {
455                MatrixIdent::A => self.a_type,
456                MatrixIdent::B => self.b_type,
457                MatrixIdent::Accumulator => self.cd_type,
458            };
459            let layout = match ident {
460                MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
461                MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
462                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
463            };
464            let matrix = cubecl_ir::Matrix {
465                ident,
466                m: self.m,
467                n: self.n,
468                k: self.k,
469                storage: ty,
470                layout,
471            };
472
473            let row = scope.create_local(Type::new(u32::as_type(scope)));
474            let col = scope.create_local(Type::new(u32::as_type(scope)));
475            scope.register(Instruction::new(
476                CoopMma::RowIndex {
477                    lane_id: *lane_id,
478                    i: *elem_idx,
479                    matrix,
480                },
481                *row,
482            ));
483            scope.register(Instruction::new(
484                CoopMma::ColIndex {
485                    lane_id: *lane_id,
486                    i: *elem_idx,
487                    matrix,
488                },
489                *col,
490            ));
491            (row.into(), col.into())
492        })
493    }
494
495    /// Index of the scales for this thread, along the non-major dimension of the matrix.
496    /// Each thread loads all scales in the major direction into a single `Line`.
497    pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
498        // Just do CUDA for now, call an actual intrinsic when HIP gets support
499        let quad_id = lane_id / 4;
500        let t_id = lane_id % 4;
501        match ident {
502            MatrixIdent::A => quad_id + (t_id % 2) * 8,
503            MatrixIdent::B => quad_id,
504            MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
505        }
506    }
507
508    /// Number of scales in each line (not the line size!). Line size may include padding bytes.
509    pub fn scales_count(&self) -> comptime_type!(usize) {
510        // We only have the CUDA version for now, so just use `scales_factor`. The function can
511        // be modified for HIP in the future without having to redo all uses.
512        intrinsic!(|_| {
513            self.scales_factor
514                .expect("Can't retrieve scales count for matrix with no scales")
515        })
516    }
517
518    /// Line size for the scale factors. May be larger than the total number of scales.
519    pub fn scales_line_size(&self) -> comptime_type!(LineSize) {
520        intrinsic!(|scope| {
521            let elem = self
522                .scales_type
523                .expect("Can't retrieve scales line size for matrix with no scales");
524            scope.runtime_properties.mma.register_size_bits / elem.size_bits()
525        })
526    }
527
528    /// Load one or more matrix register using intrinsic instructions. CUDA only.
529    /// The number of matrices must be 1, 2, or 4. The rows for the nth matrix are passed by the 8
530    /// lanes starting at `n * 8`. All slice starts must be valid, even for non-participating lanes.
531    /// The slice determines the starting address for a 16-byte row loaded by this unit, with
532    /// the row index being `UNIT_POS_PLANE % 8`.
533    /// The number of elements is determined by element size.
534    ///
535    /// # Constraints:
536    /// Address must be aligned to 16 bytes
537    /// Address must be in shared memory
538    #[allow(unused_variables)]
539    pub fn load_matrix<E: CubePrimitive>(
540        &self,
541        row: &Slice<Line<E>>,
542        #[comptime] ident: MatrixIdent,
543        #[comptime] num_matrices: usize,
544        #[comptime] transpose: bool,
545    ) -> Array<Line<E>> {
546        intrinsic!(|scope| {
547            let line_size = self.__expand_line_size_method(scope, ident);
548            let slice_line_size = row.line_size;
549            let (buffer, offset) = row.__to_raw_parts();
550            let out = Array::__expand_lined(scope, num_matrices, line_size);
551            scope.register(Instruction::new(
552                CoopMma::LoadMatrix {
553                    buffer,
554                    offset,
555                    line_size: slice_line_size,
556                    factor: num_matrices,
557                    transpose,
558                },
559                *out.expand,
560            ));
561            out
562        })
563    }
564
565    /// Store one or more matrix register using intrinsic instructions. CUDA only.
566    /// The number of matrices must be 1, 2, or 4. The rows for the nth matrix are passed by the 8
567    /// lanes starting at `n * 8`. All slice starts must be valid, even for non-participating lanes.
568    /// The slice determines the starting address for a 16-byte row loaded by this unit, with
569    /// the row index being `UNIT_POS_PLANE % 8`.
570    /// The number of elements is determined by element size.
571    ///
572    /// # Constraints:
573    /// Address must be aligned to 16 bytes
574    /// Address must be in shared memory
575    #[allow(unused_variables)]
576    pub fn store_matrix<E: CubePrimitive>(
577        &self,
578        row: &mut Slice<Line<E>, ReadWrite>,
579        registers: &Array<Line<E>>,
580        #[comptime] ident: MatrixIdent,
581        #[comptime] num_matrices: usize,
582        #[comptime] transpose: bool,
583    ) {
584        intrinsic!(|scope| {
585            let line_size = self.__expand_line_size_method(scope, ident);
586            let slice_line_size = row.line_size;
587            let (buffer, offset) = row.__to_raw_parts();
588            scope.register(Instruction::new(
589                CoopMma::StoreMatrix {
590                    offset,
591                    line_size: slice_line_size,
592                    registers: *registers.expand,
593                    factor: num_matrices,
594                    transpose,
595                },
596                buffer,
597            ));
598        })
599    }
600
601    /// Execute a low level `mma` operation with manually managed registers. Register layout
602    /// and index mapping can be retrieved from the [`MatrixDefinition`]
603    #[allow(unused)]
604    pub fn execute(
605        &self,
606        registers_a: &Array<Line<A>>,
607        registers_b: &Array<Line<B>>,
608        registers_c: &Array<Line<CD>>,
609    ) -> Array<Line<CD>> {
610        intrinsic!(|scope| {
611            let acc_elems = self
612                .clone()
613                .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
614            let acc_line_size = self
615                .clone()
616                .__expand_line_size_method(scope, MatrixIdent::Accumulator);
617            let num_registers = acc_elems / acc_line_size;
618
619            let registers_d = Array::__expand_lined(scope, num_registers, acc_line_size);
620
621            let registers_a = *registers_a.expand;
622            let registers_b = *registers_b.expand;
623            let registers_c = *registers_c.expand;
624
625            // Only shape is actually used
626            let matrix = cubecl_ir::Matrix {
627                ident: MatrixIdent::A,
628                m: self.m,
629                n: self.n,
630                k: self.k,
631                storage: self.a_type,
632                layout: MatrixLayout::ColMajor,
633            };
634
635            scope.register(Instruction::new(
636                CoopMma::ExecuteManual {
637                    matrix,
638                    registers_a,
639                    registers_b,
640                    registers_c,
641                },
642                *registers_d.expand,
643            ));
644
645            registers_d
646        })
647    }
648
649    /// Execute a low level block scaled `mma` operation with manually managed registers. Register
650    /// layout and index mapping can be retrieved from the [`MatrixDefinition`]
651    #[allow(unused)]
652    pub fn execute_scaled<S: CubePrimitive>(
653        &self,
654        registers_a: &Array<Line<A>>,
655        registers_b: &Array<Line<B>>,
656        registers_c: &Array<Line<CD>>,
657        scales_a: Line<S>,
658        scales_b: Line<S>,
659    ) -> Array<Line<CD>> {
660        intrinsic!(|scope| {
661            let acc_elems = self
662                .clone()
663                .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
664            let acc_line_size = self
665                .clone()
666                .__expand_line_size_method(scope, MatrixIdent::Accumulator);
667            let num_registers = acc_elems / acc_line_size;
668
669            let registers_d = Array::__expand_lined(scope, num_registers, acc_line_size);
670
671            let registers_a = *registers_a.expand;
672            let registers_b = *registers_b.expand;
673            let registers_c = *registers_c.expand;
674
675            // Only shape is actually used
676            let matrix = cubecl_ir::Matrix {
677                ident: MatrixIdent::A,
678                m: self.m,
679                n: self.n,
680                k: self.k,
681                storage: self.a_type,
682                layout: MatrixLayout::ColMajor,
683            };
684
685            scope.register(Instruction::new(
686                CoopMma::ExecuteScaled {
687                    matrix,
688                    registers_a,
689                    registers_b,
690                    registers_c,
691                    scales_a: *scales_a.expand,
692                    scales_b: *scales_b.expand,
693                    scales_factor: self
694                        .scales_factor
695                        .expect("Can't execute scaled on matrix with no scales"),
696                },
697                *registers_d.expand,
698            ));
699
700            registers_d
701        })
702    }
703}
704
705/// Fill the matrix with the provided value.
706#[allow(unused_variables)]
707pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
708    unexpanded!()
709}
710
711/// Module containing the expand function for [fill()].
712pub mod fill {
713    use super::*;
714
715    /// Expand method of [fill()].
716    pub fn expand<C: CubeType>(
717        scope: &mut Scope,
718        mat: MatrixExpand<C>,
719        value: ExpandElementTyped<C>,
720    ) {
721        let value: ExpandElement = value.into();
722        scope.register(Instruction::new(
723            ir::CoopMma::Fill { value: *value },
724            *mat.elem,
725        ));
726    }
727}
728
729/// Load the matrix with the provided array using the stride.
730#[allow(unused_variables)]
731pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
732    unexpanded!()
733}
734
735/// Module containing the expand function for [load()].
736pub mod load {
737    use super::*;
738
739    /// Expand method of [load()].
740    #[allow(unused_variables)]
741    pub fn expand<C: CubePrimitive, V: CubePrimitive>(
742        scope: &mut Scope,
743        mat: MatrixExpand<C>,
744        value: SliceExpand<V, ReadOnly>,
745        stride: ExpandElementTyped<u32>,
746    ) {
747        let stride: ExpandElement = stride.into();
748        assert_ne!(
749            mat.ident,
750            MatrixIdent::Accumulator,
751            "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
752        );
753
754        let (value, offset) = value.__to_raw_parts();
755
756        scope.register(Instruction::new(
757            ir::CoopMma::Load {
758                value,
759                stride: *stride,
760                offset,
761                layout: None,
762            },
763            *mat.elem,
764        ));
765    }
766}
767
768/// Load the matrix with the provided array using the stride with an explicit layout.
769/// Explicit layouts are required when loading accumulators.
770#[allow(unused_variables)]
771pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
772    mat: &Matrix<C>,
773    value: &Slice<V>,
774    stride: u32,
775    layout: MatrixLayout,
776) {
777    unexpanded!()
778}
779
780/// Module containing the expand function for [load_with_layout()].
781pub mod load_with_layout {
782    use super::*;
783
784    /// Expand method of [load_with_layout()].
785    #[allow(unused_variables)]
786    pub fn expand<C: CubeType, V: CubePrimitive>(
787        scope: &mut Scope,
788        mat: MatrixExpand<C>,
789        value: SliceExpand<V, ReadOnly>,
790        stride: ExpandElementTyped<u32>,
791        layout: MatrixLayout,
792    ) {
793        let stride: ExpandElement = stride.into();
794        let (value, offset) = value.__to_raw_parts();
795
796        scope.register(Instruction::new(
797            ir::CoopMma::Load {
798                value,
799                stride: *stride,
800                offset,
801                layout: Some(layout),
802            },
803            *mat.elem,
804        ));
805    }
806}
807
808/// Store the matrix in the given array following the given stride and layout.
809#[allow(unused_variables)]
810pub fn store<C: CubePrimitive, O: CubePrimitive>(
811    output: &mut SliceMut<O>,
812    mat: &Matrix<C>,
813    stride: u32,
814    layout: MatrixLayout,
815) {
816    unexpanded!()
817}
818
819/// Module containing the expand function for [store()].
820pub mod store {
821    use crate::prelude::ReadWrite;
822
823    use super::*;
824
825    /// Expand method of [store()].
826    #[allow(unused_variables)]
827    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
828        scope: &mut Scope,
829        output: SliceExpand<O, ReadWrite>,
830        mat: MatrixExpand<C>,
831        stride: ExpandElementTyped<u32>,
832        layout: MatrixLayout,
833    ) {
834        let stride: ExpandElement = stride.into();
835
836        let (output, offset) = output.__to_raw_parts();
837
838        scope.register(Instruction::new(
839            ir::CoopMma::Store {
840                mat: *mat.elem,
841                offset,
842                stride: *stride,
843                layout,
844            },
845            output,
846        ));
847    }
848}
849
850/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
851#[allow(unused_variables)]
852pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
853    mat_a: &Matrix<A>,
854    mat_b: &Matrix<B>,
855    mat_c: &Matrix<C>,
856    mat_d: &Matrix<D>,
857) {
858    unexpanded!()
859}
860
861/// Module containing the expand function for [execute()].
862pub mod execute {
863    use super::*;
864
865    /// Expand method of [execute()].
866    pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
867        scope: &mut Scope,
868        mat_a: MatrixExpand<A>,
869        mat_b: MatrixExpand<B>,
870        mat_c: MatrixExpand<C>,
871        mat_d: MatrixExpand<D>,
872    ) {
873        scope.register(Instruction::new(
874            ir::CoopMma::Execute {
875                mat_a: *mat_a.elem,
876                mat_b: *mat_b.elem,
877                mat_c: *mat_c.elem,
878            },
879            *mat_d.elem,
880        ));
881    }
882}
883
884/// Store the matrix in the given array following the given stride and layout.
885#[allow(unused_variables)]
886pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
887    unexpanded!()
888}
889
890/// Module containing the expand function for [store()].
891pub mod cast {
892    use super::*;
893
894    /// Expand method of [store()].
895    #[allow(unused_variables)]
896    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
897        scope: &mut Scope,
898        input: MatrixExpand<C>,
899    ) -> MatrixExpand<O> {
900        let ident = input.ident;
901
902        if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
903            return MatrixExpand {
904                elem: input.elem,
905                ident,
906                _c: PhantomData,
907            };
908        }
909        let input = *input.elem;
910        let input_mat = match input.kind {
911            ir::VariableKind::Matrix { mat, .. } => mat,
912            _ => unreachable!(),
913        };
914
915        let elem = O::as_type(scope);
916        let elem = scope.create_matrix(ir::Matrix::new(
917            ident,
918            input_mat.m,
919            input_mat.n,
920            input_mat.k,
921            elem,
922            MatrixLayout::Undefined,
923        ));
924
925        let output = MatrixExpand {
926            ident,
927            elem,
928            _c: PhantomData,
929        };
930        scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
931
932        output
933    }
934}
935
936impl CubeType for MatrixLayout {
937    type ExpandType = Self;
938}
939
940impl IntoMut for MatrixLayout {
941    fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
942        self
943    }
944}
945
946impl CubeDebug for MatrixLayout {}