cubecl_core/frontend/
cmma.rs

1//! This module exposes cooperative matrix-multiply and accumulate operations.
2//!
3//! Most of the functions are actually unsafe, since they mutate their input, even if they are
4//! passed as reference.
5//!
6//! # Example
7//!
8//! This is a basic 16x16x16 matrix multiplication example.
9//!
10//! ```rust, ignore
11//! #[cube(launch)]
12//! pub fn example(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>) {
13//!     let a = cmma::Matrix::<F16>::new(
14//!         cmma::MatrixIdent::A,
15//!         16,
16//!         16,
17//!         16,
18//!         cmma::MatrixLayout::RowMajor,
19//!     );
20//!     let b = cmma::Matrix::<F16>::new(
21//!         cmma::MatrixIdent::B,
22//!         16,
23//!         16,
24//!         16,
25//!         cmma::MatrixLayout::ColMajor,
26//!     );
27//!     let c = cmma::Matrix::<F32>::new(
28//!         cmma::MatrixIdent::Accumulator,
29//!         16,
30//!         16,
31//!         16,
32//!         cmma::MatrixLayout::Undefined,
33//!     );
34//!     cmma::fill::<F32>(&c, F32::new(0.0));
35//!     cmma::load::<F16>(&a, lhs.as_slice(), u32::new(16));
36//!     cmma::load::<F16>(&b, rhs.as_slice(), u32::new(16));
37//!
38//!     cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
39//!
40//!     cmma::store::<F32>(
41//!         out.as_slice_mut(),
42//!         &c,
43//!         u32::new(16),
44//!         cmma::MatrixLayout::RowMajor,
45//!     );
46//! }
47//! ```
48
49use super::{
50    CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
51    SliceMut,
52};
53use crate::{
54    self as cubecl,
55    prelude::{Array, Line, Sequence},
56};
57use crate::{
58    ir::{self, Instruction},
59    unexpanded,
60};
61use cubecl_macros::{comptime_type, cube, intrinsic};
62use std::marker::PhantomData;
63
64use cubecl_ir::{CoopMma, ExpandElement, Scope, StorageType, Type};
65pub use ir::{MatrixIdent, MatrixLayout};
66
67/// A matrix represent a 2D grid of numbers.
68///
69/// They can either be in a [row major](MatrixLayout::RowMajor) or a
70/// [column major](MatrixLayout::ColMajor) format.
71#[derive(Copy, Clone)]
72pub struct Matrix<C: CubeType> {
73    _c: PhantomData<C>,
74}
75
76/// Defines a matrix multiplication operation, including the input and output type, and the shape.
77#[derive(Copy, Clone)]
78pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
79    _a: PhantomData<A>,
80    _b: PhantomData<B>,
81    _cd: PhantomData<CD>,
82}
83
84/// Expand type of [Matrix].
85pub struct MatrixExpand<C: CubeType> {
86    elem: ExpandElement,
87    ident: MatrixIdent,
88    _c: PhantomData<C>,
89}
90
91/// Expand type of [MmaDefinition].
92#[derive(Debug)]
93pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
94    pub m: u32,
95    pub n: u32,
96    pub k: u32,
97    pub a_type: StorageType,
98    pub b_type: StorageType,
99    pub cd_type: StorageType,
100    pub scales_factor: Option<u32>,
101    pub scales_type: Option<StorageType>,
102    _a: PhantomData<A>,
103    _b: PhantomData<B>,
104    _cd: PhantomData<CD>,
105}
106
107impl<C: CubeType> Clone for MatrixExpand<C> {
108    fn clone(&self) -> Self {
109        Self {
110            elem: self.elem.clone(),
111            ident: self.ident,
112            _c: self._c,
113        }
114    }
115}
116
117impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
118    fn clone(&self) -> Self {
119        Self {
120            m: self.m,
121            n: self.n,
122            k: self.k,
123            a_type: self.a_type,
124            b_type: self.b_type,
125            cd_type: self.cd_type,
126            scales_factor: self.scales_factor,
127            scales_type: self.scales_type,
128            _a: PhantomData,
129            _b: PhantomData,
130            _cd: PhantomData,
131        }
132    }
133}
134
135impl<C: CubeType> CubeType for Matrix<C> {
136    type ExpandType = MatrixExpand<C>;
137}
138
139impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
140    type ExpandType = MmaDefinitionExpand<A, B, CD>;
141}
142
143impl<C: CubeType> IntoMut for MatrixExpand<C> {
144    fn into_mut(self, _scope: &mut Scope) -> Self {
145        self
146    }
147}
148
149impl<C: CubeType> CubeDebug for MatrixExpand<C> {
150    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
151        scope.update_variable_name(*self.elem, name);
152    }
153}
154
155impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
156    fn into_mut(self, _scope: &mut Scope) -> Self {
157        self
158    }
159}
160
161impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
162
163#[cube]
164impl<C: CubePrimitive> Matrix<C> {
165    /// Create a new uninitialized matrix that is going to be used in the
166    /// [matrix-multiply and accumulate](execute()) function.
167    ///
168    /// # Safety
169    /// Must be initialized with `load` or `fill` before use. Using it without initialization is
170    /// undefined behaviour on CUDA, and completely invalid on Vulkan.
171    ///
172    /// You have to declare the shape used for the execution.
173    /// The shape of the current matrix is determined using the [MatrixIdent].
174    ///
175    /// * [MatrixIdent::A] Shape => (M, K)
176    /// * [MatrixIdent::B] Shape => (K, N)
177    /// * [MatrixIdent::Accumulator] Shape => (M, N)
178    ///
179    /// Not all shapes are supported, and the permitted shapes depend on the element type.
180    ///
181    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
182    #[allow(unused_variables)]
183    pub unsafe fn uninitialized(
184        #[comptime] ident: MatrixIdent,
185        m: u32,
186        n: u32,
187        k: u32,
188        layout: MatrixLayout,
189    ) -> Self {
190        intrinsic!(|scope| {
191            let elem = C::as_type(scope);
192            let elem = scope.create_matrix(ir::Matrix::new(
193                ident,
194                m.constant().unwrap().as_u32(),
195                n.constant().unwrap().as_u32(),
196                k.constant().unwrap().as_u32(),
197                elem,
198                layout,
199            ));
200            MatrixExpand {
201                elem,
202                ident,
203                _c: PhantomData,
204            }
205        })
206    }
207
208    /// Create a new matrix that is going to be used in the
209    /// [matrix-multiply and accumulate](execute()) function and is filled with `value`.
210    ///
211    /// You have to declare the shape used for the execution.
212    /// The shape of the current matrix is determined using the [MatrixIdent].
213    ///
214    /// * [MatrixIdent::A] Shape => (M, K)
215    /// * [MatrixIdent::B] Shape => (K, N)
216    /// * [MatrixIdent::Accumulator] Shape => (M, N)
217    ///
218    /// Not all shapes are supported, and the permitted shapes depend on the element type.
219    ///
220    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
221    #[allow(unused_variables)]
222    pub fn from_value(
223        #[comptime] ident: MatrixIdent,
224        m: u32,
225        n: u32,
226        k: u32,
227        layout: MatrixLayout,
228        value: C,
229    ) -> Self {
230        let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
231
232        intrinsic!(|scope| {
233            fill::expand(scope, mat.clone(), value);
234            mat
235        })
236    }
237
238    /// Create a new matrix that is going to be used in the
239    /// [matrix-multiply and accumulate](execute()) function and is loaded from `value` with `stride`.
240    ///
241    /// You have to declare the shape used for the execution.
242    /// The shape of the current matrix is determined using the [MatrixIdent].
243    ///
244    /// * [MatrixIdent::A] Shape => (M, K)
245    /// * [MatrixIdent::B] Shape => (K, N)
246    /// * [MatrixIdent::Accumulator] Shape => (M, N)
247    ///
248    /// Not all shapes are supported, and the permitted shapes depend on the element type.
249    ///
250    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
251    #[allow(unused_variables)]
252    pub fn from_slice(
253        #[comptime] ident: MatrixIdent,
254        m: u32,
255        n: u32,
256        k: u32,
257        layout: MatrixLayout,
258        value: &Slice<C>,
259        stride: u32,
260    ) -> Self {
261        let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
262
263        intrinsic!(|scope| {
264            load::expand(scope, mat.clone(), value, stride);
265            mat
266        })
267    }
268}
269
270#[cube]
271impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B, CD> {
272    /// Create a new matrix definition that is going to be used in the manual
273    /// [matrix-multiply and accumulate](execute_manual()) function.
274    ///
275    /// You have to declare the shape used for the execution.
276    /// The shape of the current matrix is determined using the [MatrixIdent].
277    ///
278    /// * [MatrixIdent::A] Shape => (M, K)
279    /// * [MatrixIdent::B] Shape => (K, N)
280    /// * [MatrixIdent::Accumulator] Shape => (M, N)
281    ///
282    /// Not all shapes are supported, and the permitted shapes depend on the element type.
283    /// Layout for manual MMA is determined by the runtime and must be handled manually.
284    /// Use [`line_layout`] to check the correct data layout for each element.
285    ///
286    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
287    #[allow(unused_variables)]
288    pub fn new(#[comptime] m: u32, #[comptime] n: u32, #[comptime] k: u32) -> Self {
289        intrinsic!(|scope| {
290            let a_type = A::as_type(scope);
291            let b_type = B::as_type(scope);
292            let cd_type = CD::as_type(scope);
293
294            MmaDefinitionExpand {
295                m,
296                n,
297                k,
298                a_type,
299                b_type,
300                cd_type,
301                scales_factor: None,
302                scales_type: None,
303                _a: PhantomData,
304                _b: PhantomData,
305                _cd: PhantomData,
306            }
307        })
308    }
309
310    /// Create a new matrix definition that is going to be used in the manual
311    /// [matrix-multiply and accumulate](execute_manual()) function.
312    ///
313    /// You have to declare the shape used for the execution.
314    /// The shape of the current matrix is determined using the [MatrixIdent].
315    ///
316    /// * [MatrixIdent::A] Shape => (M, K)
317    /// * [MatrixIdent::B] Shape => (K, N)
318    /// * [MatrixIdent::Accumulator] Shape => (M, N)
319    ///
320    /// Not all shapes are supported, and the permitted shapes depend on the element type.
321    /// Layout for manual MMA is determined by the runtime and must be handled manually.
322    /// Use [`line_layout`] to check the correct data layout for each element.
323    ///
324    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
325    #[allow(unused_variables)]
326    pub fn new_scaled<S: CubePrimitive>(
327        #[comptime] m: u32,
328        #[comptime] n: u32,
329        #[comptime] k: u32,
330        #[comptime] scale_factor: u32,
331    ) -> Self {
332        intrinsic!(|scope| {
333            let a_type = A::as_type(scope);
334            let b_type = B::as_type(scope);
335            let cd_type = CD::as_type(scope);
336
337            MmaDefinitionExpand {
338                m,
339                n,
340                k,
341                a_type,
342                b_type,
343                cd_type,
344                scales_factor: Some(scale_factor),
345                scales_type: Some(S::as_type(scope)),
346                _a: PhantomData,
347                _b: PhantomData,
348                _cd: PhantomData,
349            }
350        })
351    }
352
353    /// Number of elements in the matrix
354    #[allow(unused)]
355    pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
356        intrinsic!(|scope| {
357            match ident {
358                MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor() as u32,
359                MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor() as u32,
360                MatrixIdent::Accumulator => {
361                    (self.m * self.n) / self.cd_type.packing_factor() as u32
362                }
363            }
364        })
365    }
366
367    /// Returns the number of elements handled by each lane. Should be packed into `Line`s of size
368    /// `line_size` with [`line_layout`].
369    ///
370    /// # Note
371    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
372    /// to a cube.
373    #[allow(unused)]
374    pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
375        intrinsic!(|scope| {
376            let elems = self.__expand_num_elems_method(scope, ident);
377            let plane_dim = scope.runtime_properties.mma.const_plane_size;
378            let duplication = match ident {
379                MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
380                MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
381                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
382            };
383            (elems / plane_dim) * duplication
384        })
385    }
386
387    /// Returns the number of lines of size `line_size` with layout `line_layout` per lane.
388    ///
389    /// # Note
390    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
391    /// to a cube.
392    #[allow(unused)]
393    pub fn lines_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
394        intrinsic!(|scope| {
395            let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
396            let line_size = self.__expand_line_size_method(scope, ident);
397            elems / line_size
398        })
399    }
400
401    /// The layout of each line in this matrix (row major or column major)
402    #[allow(unused)]
403    pub fn line_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
404        intrinsic!(|scope| {
405            match ident {
406                MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
407                MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
408                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
409            }
410        })
411    }
412
413    /// Number of elements in each line passed to the execute function
414    #[allow(unused_variables)]
415    pub fn line_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
416        intrinsic!(|scope| {
417            let bits = match ident {
418                MatrixIdent::A => StorageType::size_bits(&self.a_type) as u32,
419                MatrixIdent::B => StorageType::size_bits(&self.b_type) as u32,
420                MatrixIdent::Accumulator => StorageType::size_bits(&self.cd_type) as u32,
421            };
422            let register_size = scope.runtime_properties.mma.register_size_bits;
423            // div_ceil for potential compatibility with f64
424            register_size.div_ceil(bits)
425        })
426    }
427
428    /// Returns the coordinates of the `nth` element handled by the `lane_id`
429    /// Each lane contains [`elems_per_lane`] elements in [`line_size`] chunks.
430    /// Returns (`row_idx`, `col_idx`)
431    ///
432    /// # Note
433    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
434    /// to a cube.
435    #[allow(unused_variables)]
436    pub fn position_of_nth(
437        &self,
438        lane_id: u32,
439        elem_idx: u32,
440        #[comptime] ident: MatrixIdent,
441    ) -> (u32, u32) {
442        intrinsic!(|scope| {
443            let lane_id: ExpandElement = lane_id.into();
444            let elem_idx: ExpandElement = elem_idx.into();
445
446            let ty = match ident {
447                MatrixIdent::A => self.a_type,
448                MatrixIdent::B => self.b_type,
449                MatrixIdent::Accumulator => self.cd_type,
450            };
451            let layout = match ident {
452                MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
453                MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
454                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
455            };
456            let matrix = cubecl_ir::Matrix {
457                ident,
458                m: self.m,
459                n: self.n,
460                k: self.k,
461                storage: ty,
462                layout,
463            };
464
465            let row = scope.create_local(Type::new(u32::as_type(scope)));
466            let col = scope.create_local(Type::new(u32::as_type(scope)));
467            scope.register(Instruction::new(
468                CoopMma::RowIndex {
469                    lane_id: *lane_id,
470                    i: *elem_idx,
471                    matrix,
472                },
473                *row,
474            ));
475            scope.register(Instruction::new(
476                CoopMma::ColIndex {
477                    lane_id: *lane_id,
478                    i: *elem_idx,
479                    matrix,
480                },
481                *col,
482            ));
483            (row.into(), col.into())
484        })
485    }
486
487    /// Index of the scales for this thread, along the non-major dimension of the matrix.
488    /// Each thread loads all scales in the major direction into a single `Line`.
489    pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
490        // Just do CUDA for now, call an actual intrinsic when HIP gets support
491        let quad_id = lane_id / 4;
492        let t_id = lane_id % 4;
493        match ident {
494            MatrixIdent::A => quad_id + (t_id % 2) * 8,
495            MatrixIdent::B => quad_id,
496            MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
497        }
498    }
499
500    /// Number of scales in each line (not the line size!). Line size may include padding bytes.
501    pub fn scales_count(&self) -> comptime_type!(u32) {
502        // We only have the CUDA version for now, so just use `scales_factor`. The function can
503        // be modified for HIP in the future without having to redo all uses.
504        intrinsic!(|_| {
505            self.scales_factor
506                .expect("Can't retrieve scales count for matrix with no scales")
507        })
508    }
509
510    /// Line size for the scale factors. May be larger than the total number of scales.
511    pub fn scales_line_size(&self) -> comptime_type!(u32) {
512        intrinsic!(|scope| {
513            let elem = self
514                .scales_type
515                .expect("Can't retrieve scales line size for matrix with no scales");
516            scope.runtime_properties.mma.register_size_bits / elem.size_bits() as u32
517        })
518    }
519
520    /// Execute a low level `mma` operation with manually managed registers. Register layout
521    /// and index mapping can be retrieved from the [`MatrixDefinition`]
522    #[allow(unused)]
523    pub fn execute(
524        &self,
525        registers_a: &Sequence<Line<A>>,
526        registers_b: &Sequence<Line<B>>,
527        registers_c: &Sequence<Line<CD>>,
528    ) -> Array<Line<CD>> {
529        intrinsic!(|scope| {
530            let acc_elems = self
531                .clone()
532                .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
533            let acc_line_size = self
534                .clone()
535                .__expand_line_size_method(scope, MatrixIdent::Accumulator);
536            let num_registers = acc_elems / acc_line_size;
537
538            let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
539
540            let registers_a = registers_a
541                .iter_cloned()
542                .map(|it| *it.expand)
543                .collect::<Vec<_>>();
544            let registers_b = registers_b
545                .iter_cloned()
546                .map(|it| *it.expand)
547                .collect::<Vec<_>>();
548            let registers_c = registers_c
549                .iter_cloned()
550                .map(|it| *it.expand)
551                .collect::<Vec<_>>();
552
553            // Only shape is actually used
554            let matrix = cubecl_ir::Matrix {
555                ident: MatrixIdent::A,
556                m: self.m,
557                n: self.n,
558                k: self.k,
559                storage: self.a_type,
560                layout: MatrixLayout::ColMajor,
561            };
562
563            scope.register(Instruction::new(
564                CoopMma::ExecuteManual {
565                    matrix,
566                    registers_a,
567                    registers_b,
568                    registers_c,
569                },
570                *registers_d.expand,
571            ));
572
573            registers_d
574        })
575    }
576
577    /// Execute a low level block scaled `mma` operation with manually managed registers. Register
578    /// layout and index mapping can be retrieved from the [`MatrixDefinition`]
579    #[allow(unused)]
580    pub fn execute_scaled<S: CubePrimitive>(
581        &self,
582        registers_a: &Sequence<Line<A>>,
583        registers_b: &Sequence<Line<B>>,
584        registers_c: &Sequence<Line<CD>>,
585        scales_a: Line<S>,
586        scales_b: Line<S>,
587    ) -> Array<Line<CD>> {
588        intrinsic!(|scope| {
589            let acc_elems = self
590                .clone()
591                .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
592            let acc_line_size = self
593                .clone()
594                .__expand_line_size_method(scope, MatrixIdent::Accumulator);
595            let num_registers = acc_elems / acc_line_size;
596
597            let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
598
599            let registers_a = registers_a
600                .iter_cloned()
601                .map(|it| *it.expand)
602                .collect::<Vec<_>>();
603            let registers_b = registers_b
604                .iter_cloned()
605                .map(|it| *it.expand)
606                .collect::<Vec<_>>();
607            let registers_c = registers_c
608                .iter_cloned()
609                .map(|it| *it.expand)
610                .collect::<Vec<_>>();
611
612            // Only shape is actually used
613            let matrix = cubecl_ir::Matrix {
614                ident: MatrixIdent::A,
615                m: self.m,
616                n: self.n,
617                k: self.k,
618                storage: self.a_type,
619                layout: MatrixLayout::ColMajor,
620            };
621
622            scope.register(Instruction::new(
623                CoopMma::ExecuteScaled {
624                    matrix,
625                    registers_a,
626                    registers_b,
627                    registers_c,
628                    scales_a: *scales_a.expand,
629                    scales_b: *scales_b.expand,
630                    scales_factor: self
631                        .scales_factor
632                        .expect("Can't execute scaled on matrix with no scales"),
633                },
634                *registers_d.expand,
635            ));
636
637            registers_d
638        })
639    }
640}
641
642/// Fill the matrix with the provided value.
643#[allow(unused_variables)]
644pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
645    unexpanded!()
646}
647
648/// Module containing the expand function for [fill()].
649pub mod fill {
650    use super::*;
651
652    /// Expand method of [fill()].
653    pub fn expand<C: CubeType>(
654        scope: &mut Scope,
655        mat: MatrixExpand<C>,
656        value: ExpandElementTyped<C>,
657    ) {
658        let value: ExpandElement = value.into();
659        scope.register(Instruction::new(
660            ir::CoopMma::Fill { value: *value },
661            *mat.elem,
662        ));
663    }
664}
665
666/// Load the matrix with the provided array using the stride.
667#[allow(unused_variables)]
668pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
669    unexpanded!()
670}
671
672/// Module containing the expand function for [load()].
673pub mod load {
674    use super::*;
675
676    /// Expand method of [load()].
677    #[allow(unused_variables)]
678    pub fn expand<C: CubePrimitive, V: CubePrimitive>(
679        scope: &mut Scope,
680        mat: MatrixExpand<C>,
681        value: SliceExpand<V, ReadOnly>,
682        stride: ExpandElementTyped<u32>,
683    ) {
684        let stride: ExpandElement = stride.into();
685        assert_ne!(
686            mat.ident,
687            MatrixIdent::Accumulator,
688            "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
689        );
690
691        let (value, offset) = value.__to_raw_parts();
692
693        scope.register(Instruction::new(
694            ir::CoopMma::Load {
695                value,
696                stride: *stride,
697                offset,
698                layout: None,
699            },
700            *mat.elem,
701        ));
702    }
703}
704
705/// Load the matrix with the provided array using the stride with an explicit layout.
706/// Explicit layouts are required when loading accumulators.
707#[allow(unused_variables)]
708pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
709    mat: &Matrix<C>,
710    value: &Slice<V>,
711    stride: u32,
712    layout: MatrixLayout,
713) {
714    unexpanded!()
715}
716
717/// Module containing the expand function for [load_with_layout()].
718pub mod load_with_layout {
719    use super::*;
720
721    /// Expand method of [load_with_layout()].
722    #[allow(unused_variables)]
723    pub fn expand<C: CubeType, V: CubePrimitive>(
724        scope: &mut Scope,
725        mat: MatrixExpand<C>,
726        value: SliceExpand<V, ReadOnly>,
727        stride: ExpandElementTyped<u32>,
728        layout: MatrixLayout,
729    ) {
730        let stride: ExpandElement = stride.into();
731        let (value, offset) = value.__to_raw_parts();
732
733        scope.register(Instruction::new(
734            ir::CoopMma::Load {
735                value,
736                stride: *stride,
737                offset,
738                layout: Some(layout),
739            },
740            *mat.elem,
741        ));
742    }
743}
744
745/// Store the matrix in the given array following the given stride and layout.
746#[allow(unused_variables)]
747pub fn store<C: CubePrimitive, O: CubePrimitive>(
748    output: &mut SliceMut<O>,
749    mat: &Matrix<C>,
750    stride: u32,
751    layout: MatrixLayout,
752) {
753    unexpanded!()
754}
755
756/// Module containing the expand function for [store()].
757pub mod store {
758    use crate::prelude::ReadWrite;
759
760    use super::*;
761
762    /// Expand method of [store()].
763    #[allow(unused_variables)]
764    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
765        scope: &mut Scope,
766        output: SliceExpand<O, ReadWrite>,
767        mat: MatrixExpand<C>,
768        stride: ExpandElementTyped<u32>,
769        layout: MatrixLayout,
770    ) {
771        let stride: ExpandElement = stride.into();
772
773        let (output, offset) = output.__to_raw_parts();
774
775        scope.register(Instruction::new(
776            ir::CoopMma::Store {
777                mat: *mat.elem,
778                offset,
779                stride: *stride,
780                layout,
781            },
782            output,
783        ));
784    }
785}
786
787/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
788#[allow(unused_variables)]
789pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
790    mat_a: &Matrix<A>,
791    mat_b: &Matrix<B>,
792    mat_c: &Matrix<C>,
793    mat_d: &Matrix<D>,
794) {
795    unexpanded!()
796}
797
798/// Module containing the expand function for [execute()].
799pub mod execute {
800    use super::*;
801
802    /// Expand method of [execute()].
803    pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
804        scope: &mut Scope,
805        mat_a: MatrixExpand<A>,
806        mat_b: MatrixExpand<B>,
807        mat_c: MatrixExpand<C>,
808        mat_d: MatrixExpand<D>,
809    ) {
810        scope.register(Instruction::new(
811            ir::CoopMma::Execute {
812                mat_a: *mat_a.elem,
813                mat_b: *mat_b.elem,
814                mat_c: *mat_c.elem,
815            },
816            *mat_d.elem,
817        ));
818    }
819}
820
821/// Store the matrix in the given array following the given stride and layout.
822#[allow(unused_variables)]
823pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
824    unexpanded!()
825}
826
827/// Module containing the expand function for [store()].
828pub mod cast {
829    use super::*;
830
831    /// Expand method of [store()].
832    #[allow(unused_variables)]
833    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
834        scope: &mut Scope,
835        input: MatrixExpand<C>,
836    ) -> MatrixExpand<O> {
837        let ident = input.ident;
838
839        if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
840            return MatrixExpand {
841                elem: input.elem,
842                ident,
843                _c: PhantomData,
844            };
845        }
846        let input = *input.elem;
847        let input_mat = match input.kind {
848            ir::VariableKind::Matrix { mat, .. } => mat,
849            _ => unreachable!(),
850        };
851
852        let elem = O::as_type(scope);
853        let elem = scope.create_matrix(ir::Matrix::new(
854            ident,
855            input_mat.m,
856            input_mat.n,
857            input_mat.k,
858            elem,
859            MatrixLayout::Undefined,
860        ));
861
862        let output = MatrixExpand {
863            ident,
864            elem,
865            _c: PhantomData,
866        };
867        scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
868
869        output
870    }
871}
872
873impl CubeType for MatrixLayout {
874    type ExpandType = Self;
875}
876
877impl IntoMut for MatrixLayout {
878    fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
879        self
880    }
881}
882
883impl CubeDebug for MatrixLayout {}