cubecl_core/frontend/
cmma.rs

1//! This module exposes cooperative matrix-multiply and accumulate operations.
2//!
3//! Most of the functions are actually unsafe, since they mutate their input, even if they are
4//! passed as reference.
5//!
6//! # Example
7//!
8//! This is a basic 16x16x16 matrix multiplication example.
9//!
10//! ```rust, ignore
11//! #[cube(launch)]
12//! pub fn example(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>) {
13//!     let a = cmma::Matrix::<F16>::new(
14//!         cmma::MatrixIdent::A,
15//!         16,
16//!         16,
17//!         16,
18//!         cmma::MatrixLayout::RowMajor,
19//!     );
20//!     let b = cmma::Matrix::<F16>::new(
21//!         cmma::MatrixIdent::B,
22//!         16,
23//!         16,
24//!         16,
25//!         cmma::MatrixLayout::ColMajor,
26//!     );
27//!     let c = cmma::Matrix::<F32>::new(
28//!         cmma::MatrixIdent::Accumulator,
29//!         16,
30//!         16,
31//!         16,
32//!         cmma::MatrixLayout::Undefined,
33//!     );
34//!     cmma::fill::<F32>(&c, F32::new(0.0));
35//!     cmma::load::<F16>(&a, lhs.as_slice(), u32::new(16));
36//!     cmma::load::<F16>(&b, rhs.as_slice(), u32::new(16));
37//!
38//!     cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
39//!
40//!     cmma::store::<F32>(
41//!         out.as_slice_mut(),
42//!         &c,
43//!         u32::new(16),
44//!         cmma::MatrixLayout::RowMajor,
45//!     );
46//! }
47//! ```
48
49use super::{
50    CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
51    SliceMut,
52};
53use crate::{
54    self as cubecl,
55    prelude::{Array, Line},
56};
57use crate::{
58    ir::{self, Instruction},
59    unexpanded,
60};
61use cubecl_macros::{comptime_type, cube, intrinsic};
62use std::marker::PhantomData;
63
64use cubecl_ir::{CoopMma, ExpandElement, Scope, StorageType, Type};
65pub use ir::{MatrixIdent, MatrixLayout};
66
67/// A matrix represent a 2D grid of numbers.
68///
69/// They can either be in a [row major](MatrixLayout::RowMajor) or a
70/// [column major](MatrixLayout::ColMajor) format.
71#[derive(Copy, Clone)]
72pub struct Matrix<C: CubeType> {
73    _c: PhantomData<C>,
74}
75
76/// Defines a matrix multiplication operation, including the input and output type, and the shape.
77#[derive(Copy, Clone)]
78pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
79    _a: PhantomData<A>,
80    _b: PhantomData<B>,
81    _cd: PhantomData<CD>,
82}
83
84/// Expand type of [Matrix].
85pub struct MatrixExpand<C: CubeType> {
86    elem: ExpandElement,
87    ident: MatrixIdent,
88    _c: PhantomData<C>,
89}
90
91/// Expand type of [MmaDefinition].
92#[derive(Debug)]
93pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
94    pub m: u32,
95    pub n: u32,
96    pub k: u32,
97    pub a_type: StorageType,
98    pub b_type: StorageType,
99    pub cd_type: StorageType,
100    pub scales_factor: Option<u32>,
101    pub scales_type: Option<StorageType>,
102    _a: PhantomData<A>,
103    _b: PhantomData<B>,
104    _cd: PhantomData<CD>,
105}
106
107impl<C: CubeType> Clone for MatrixExpand<C> {
108    fn clone(&self) -> Self {
109        Self {
110            elem: self.elem.clone(),
111            ident: self.ident,
112            _c: self._c,
113        }
114    }
115}
116
117impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
118    fn clone(&self) -> Self {
119        Self {
120            m: self.m,
121            n: self.n,
122            k: self.k,
123            a_type: self.a_type,
124            b_type: self.b_type,
125            cd_type: self.cd_type,
126            scales_factor: self.scales_factor,
127            scales_type: self.scales_type,
128            _a: PhantomData,
129            _b: PhantomData,
130            _cd: PhantomData,
131        }
132    }
133}
134
135impl<C: CubeType> CubeType for Matrix<C> {
136    type ExpandType = MatrixExpand<C>;
137}
138
139impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
140    type ExpandType = MmaDefinitionExpand<A, B, CD>;
141}
142
143impl<C: CubeType> IntoMut for MatrixExpand<C> {
144    fn into_mut(self, _scope: &mut Scope) -> Self {
145        self
146    }
147}
148
149impl<C: CubeType> CubeDebug for MatrixExpand<C> {
150    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
151        scope.update_variable_name(*self.elem, name);
152    }
153}
154
155impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
156    fn into_mut(self, _scope: &mut Scope) -> Self {
157        self
158    }
159}
160
161impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
162
163#[cube]
164impl<C: CubePrimitive> Matrix<C> {
165    /// Create a new uninitialized matrix that is going to be used in the
166    /// [matrix-multiply and accumulate](execute()) function.
167    ///
168    /// # Safety
169    /// Must be initialized with `load` or `fill` before use. Using it without initialization is
170    /// undefined behaviour on CUDA, and completely invalid on Vulkan.
171    ///
172    /// You have to declare the shape used for the execution.
173    /// The shape of the current matrix is determined using the [MatrixIdent].
174    ///
175    /// * [MatrixIdent::A] Shape => (M, K)
176    /// * [MatrixIdent::B] Shape => (K, N)
177    /// * [MatrixIdent::Accumulator] Shape => (M, N)
178    ///
179    /// Not all shapes are supported, and the permitted shapes depend on the element type.
180    ///
181    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
182    #[allow(unused_variables)]
183    pub unsafe fn uninitialized(
184        #[comptime] ident: MatrixIdent,
185        m: u32,
186        n: u32,
187        k: u32,
188        layout: MatrixLayout,
189    ) -> Self {
190        intrinsic!(|scope| {
191            let elem = C::as_type(scope);
192            let elem = scope.create_matrix(ir::Matrix::new(
193                ident,
194                m.constant().unwrap().as_u32(),
195                n.constant().unwrap().as_u32(),
196                k.constant().unwrap().as_u32(),
197                elem,
198                layout,
199            ));
200            MatrixExpand {
201                elem,
202                ident,
203                _c: PhantomData,
204            }
205        })
206    }
207
208    /// Create a new matrix that is going to be used in the
209    /// [matrix-multiply and accumulate](execute()) function and is filled with `value`.
210    ///
211    /// You have to declare the shape used for the execution.
212    /// The shape of the current matrix is determined using the [MatrixIdent].
213    ///
214    /// * [MatrixIdent::A] Shape => (M, K)
215    /// * [MatrixIdent::B] Shape => (K, N)
216    /// * [MatrixIdent::Accumulator] Shape => (M, N)
217    ///
218    /// Not all shapes are supported, and the permitted shapes depend on the element type.
219    ///
220    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
221    #[allow(unused_variables)]
222    pub fn from_value(
223        #[comptime] ident: MatrixIdent,
224        m: u32,
225        n: u32,
226        k: u32,
227        layout: MatrixLayout,
228        value: C,
229    ) -> Self {
230        let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
231
232        intrinsic!(|scope| {
233            fill::expand(scope, mat.clone(), value);
234            mat
235        })
236    }
237
238    /// Create a new matrix that is going to be used in the
239    /// [matrix-multiply and accumulate](execute()) function and is loaded from `value` with `stride`.
240    ///
241    /// You have to declare the shape used for the execution.
242    /// The shape of the current matrix is determined using the [MatrixIdent].
243    ///
244    /// * [MatrixIdent::A] Shape => (M, K)
245    /// * [MatrixIdent::B] Shape => (K, N)
246    /// * [MatrixIdent::Accumulator] Shape => (M, N)
247    ///
248    /// Not all shapes are supported, and the permitted shapes depend on the element type.
249    ///
250    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
251    #[allow(unused_variables)]
252    pub fn from_slice(
253        #[comptime] ident: MatrixIdent,
254        m: u32,
255        n: u32,
256        k: u32,
257        layout: MatrixLayout,
258        value: &Slice<C>,
259        stride: u32,
260    ) -> Self {
261        let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
262
263        intrinsic!(|scope| {
264            load::expand(scope, mat.clone(), value, stride);
265            mat
266        })
267    }
268}
269
270#[cube]
271impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B, CD> {
272    /// Create a new matrix definition that is going to be used in the manual
273    /// [matrix-multiply and accumulate](execute_manual()) function.
274    ///
275    /// You have to declare the shape used for the execution.
276    /// The shape of the current matrix is determined using the [MatrixIdent].
277    ///
278    /// * [MatrixIdent::A] Shape => (M, K)
279    /// * [MatrixIdent::B] Shape => (K, N)
280    /// * [MatrixIdent::Accumulator] Shape => (M, N)
281    ///
282    /// Not all shapes are supported, and the permitted shapes depend on the element type.
283    /// Layout for manual MMA is determined by the runtime and must be handled manually.
284    /// Use [`line_layout`] to check the correct data layout for each element.
285    ///
286    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
287    #[allow(unused_variables)]
288    pub fn new(#[comptime] m: u32, #[comptime] n: u32, #[comptime] k: u32) -> Self {
289        intrinsic!(|scope| {
290            let a_type = A::as_type(scope);
291            let b_type = B::as_type(scope);
292            let cd_type = CD::as_type(scope);
293
294            MmaDefinitionExpand {
295                m,
296                n,
297                k,
298                a_type,
299                b_type,
300                cd_type,
301                scales_factor: None,
302                scales_type: None,
303                _a: PhantomData,
304                _b: PhantomData,
305                _cd: PhantomData,
306            }
307        })
308    }
309
310    /// Create a new matrix definition that is going to be used in the manual
311    /// [matrix-multiply and accumulate](execute_manual()) function.
312    ///
313    /// You have to declare the shape used for the execution.
314    /// The shape of the current matrix is determined using the [MatrixIdent].
315    ///
316    /// * [MatrixIdent::A] Shape => (M, K)
317    /// * [MatrixIdent::B] Shape => (K, N)
318    /// * [MatrixIdent::Accumulator] Shape => (M, N)
319    ///
320    /// Not all shapes are supported, and the permitted shapes depend on the element type.
321    /// Layout for manual MMA is determined by the runtime and must be handled manually.
322    /// Use [`line_layout`] to check the correct data layout for each element.
323    ///
324    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
325    #[allow(unused_variables)]
326    pub fn new_scaled<S: CubePrimitive>(
327        #[comptime] m: u32,
328        #[comptime] n: u32,
329        #[comptime] k: u32,
330        #[comptime] scale_factor: u32,
331    ) -> Self {
332        intrinsic!(|scope| {
333            let a_type = A::as_type(scope);
334            let b_type = B::as_type(scope);
335            let cd_type = CD::as_type(scope);
336
337            MmaDefinitionExpand {
338                m,
339                n,
340                k,
341                a_type,
342                b_type,
343                cd_type,
344                scales_factor: Some(scale_factor),
345                scales_type: Some(S::as_type(scope)),
346                _a: PhantomData,
347                _b: PhantomData,
348                _cd: PhantomData,
349            }
350        })
351    }
352
353    /// Number of elements in the matrix
354    #[allow(unused)]
355    pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
356        intrinsic!(|scope| {
357            match ident {
358                MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor() as u32,
359                MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor() as u32,
360                MatrixIdent::Accumulator => {
361                    (self.m * self.n) / self.cd_type.packing_factor() as u32
362                }
363            }
364        })
365    }
366
367    /// Returns the number of elements handled by each lane. Should be packed into `Line`s of size
368    /// `line_size` with [`line_layout`].
369    ///
370    /// # Note
371    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
372    /// to a cube.
373    #[allow(unused)]
374    pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
375        intrinsic!(|scope| {
376            let elems = self.__expand_num_elems_method(scope, ident);
377            let plane_dim = scope.runtime_properties.mma.const_plane_size;
378            let duplication = match ident {
379                MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
380                MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
381                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
382            };
383            (elems / plane_dim) * duplication
384        })
385    }
386
387    /// Returns the number of lines of size `line_size` with layout `line_layout` per lane.
388    ///
389    /// # Note
390    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
391    /// to a cube.
392    #[allow(unused)]
393    pub fn lines_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
394        intrinsic!(|scope| {
395            let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
396            let line_size = self.__expand_line_size_method(scope, ident);
397            elems / line_size
398        })
399    }
400
401    /// The layout of each line in this matrix (row major or column major)
402    #[allow(unused)]
403    pub fn line_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
404        intrinsic!(|scope| {
405            match ident {
406                MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
407                MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
408                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
409            }
410        })
411    }
412
413    /// Number of elements in each line passed to the execute function
414    #[allow(unused_variables)]
415    pub fn line_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
416        intrinsic!(|scope| {
417            let bits = match ident {
418                MatrixIdent::A => StorageType::size_bits(&self.a_type) as u32,
419                MatrixIdent::B => StorageType::size_bits(&self.b_type) as u32,
420                MatrixIdent::Accumulator => StorageType::size_bits(&self.cd_type) as u32,
421            };
422            let register_size = scope.runtime_properties.mma.register_size_bits;
423            // div_ceil for potential compatibility with f64
424            register_size.div_ceil(bits)
425        })
426    }
427
428    /// Returns the coordinates of the `nth` element handled by the `lane_id`
429    /// Each lane contains [`elems_per_lane`] elements in [`line_size`] chunks.
430    /// Returns (`row_idx`, `col_idx`)
431    ///
432    /// # Note
433    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
434    /// to a cube.
435    #[allow(unused_variables)]
436    pub fn position_of_nth(
437        &self,
438        lane_id: u32,
439        elem_idx: u32,
440        #[comptime] ident: MatrixIdent,
441    ) -> (u32, u32) {
442        intrinsic!(|scope| {
443            let lane_id: ExpandElement = lane_id.into();
444            let elem_idx: ExpandElement = elem_idx.into();
445
446            let ty = match ident {
447                MatrixIdent::A => self.a_type,
448                MatrixIdent::B => self.b_type,
449                MatrixIdent::Accumulator => self.cd_type,
450            };
451            let layout = match ident {
452                MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
453                MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
454                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
455            };
456            let matrix = cubecl_ir::Matrix {
457                ident,
458                m: self.m,
459                n: self.n,
460                k: self.k,
461                storage: ty,
462                layout,
463            };
464
465            let row = scope.create_local(Type::new(u32::as_type(scope)));
466            let col = scope.create_local(Type::new(u32::as_type(scope)));
467            scope.register(Instruction::new(
468                CoopMma::RowIndex {
469                    lane_id: *lane_id,
470                    i: *elem_idx,
471                    matrix,
472                },
473                *row,
474            ));
475            scope.register(Instruction::new(
476                CoopMma::ColIndex {
477                    lane_id: *lane_id,
478                    i: *elem_idx,
479                    matrix,
480                },
481                *col,
482            ));
483            (row.into(), col.into())
484        })
485    }
486
487    /// Index of the scales for this thread, along the non-major dimension of the matrix.
488    /// Each thread loads all scales in the major direction into a single `Line`.
489    pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
490        // Just do CUDA for now, call an actual intrinsic when HIP gets support
491        let quad_id = lane_id / 4;
492        let t_id = lane_id % 4;
493        match ident {
494            MatrixIdent::A => quad_id + (t_id % 2) * 8,
495            MatrixIdent::B => quad_id,
496            MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
497        }
498    }
499
500    /// Number of scales in each line (not the line size!). Line size may include padding bytes.
501    pub fn scales_count(&self) -> comptime_type!(u32) {
502        // We only have the CUDA version for now, so just use `scales_factor`. The function can
503        // be modified for HIP in the future without having to redo all uses.
504        intrinsic!(|_| {
505            self.scales_factor
506                .expect("Can't retrieve scales count for matrix with no scales")
507        })
508    }
509
510    /// Line size for the scale factors. May be larger than the total number of scales.
511    pub fn scales_line_size(&self) -> comptime_type!(u32) {
512        intrinsic!(|scope| {
513            let elem = self
514                .scales_type
515                .expect("Can't retrieve scales line size for matrix with no scales");
516            scope.runtime_properties.mma.register_size_bits / elem.size_bits() as u32
517        })
518    }
519
520    /// Load one or more matrix register using intrinsic instructions. CUDA only.
521    /// The number of matrices must be 1, 2, or 4. The rows for the nth matrix are passed by the 8
522    /// lanes starting at `n * 8`. All slice starts must be valid, even for non-participating lanes.
523    /// The slice determines the starting address for a 16-byte row loaded by this unit, with
524    /// the row index being `UNIT_POS_PLANE % 8`.
525    /// The number of elements is determined by element size.
526    ///
527    /// # Constraints:
528    /// Address must be aligned to 16 bytes
529    /// Address must be in shared memory
530    #[allow(unused_variables)]
531    pub fn load_matrix<E: CubePrimitive>(
532        &self,
533        row: &Slice<Line<E>>,
534        #[comptime] ident: MatrixIdent,
535        #[comptime] num_matrices: u32,
536        #[comptime] transpose: bool,
537    ) -> Array<Line<E>> {
538        intrinsic!(|scope| {
539            let line_size = self.__expand_line_size_method(scope, ident);
540            let slice_line_size = row.line_size;
541            let (buffer, offset) = row.__to_raw_parts();
542            let out = Array::__expand_vectorized(scope, num_matrices, line_size);
543            scope.register(Instruction::new(
544                CoopMma::LoadMatrix {
545                    buffer,
546                    offset,
547                    line_size: slice_line_size,
548                    factor: num_matrices,
549                    transpose,
550                },
551                *out.expand,
552            ));
553            out
554        })
555    }
556
557    /// Execute a low level `mma` operation with manually managed registers. Register layout
558    /// and index mapping can be retrieved from the [`MatrixDefinition`]
559    #[allow(unused)]
560    pub fn execute(
561        &self,
562        registers_a: &Array<Line<A>>,
563        registers_b: &Array<Line<B>>,
564        registers_c: &Array<Line<CD>>,
565    ) -> Array<Line<CD>> {
566        intrinsic!(|scope| {
567            let acc_elems = self
568                .clone()
569                .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
570            let acc_line_size = self
571                .clone()
572                .__expand_line_size_method(scope, MatrixIdent::Accumulator);
573            let num_registers = acc_elems / acc_line_size;
574
575            let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
576
577            let registers_a = *registers_a.expand;
578            let registers_b = *registers_b.expand;
579            let registers_c = *registers_c.expand;
580
581            // Only shape is actually used
582            let matrix = cubecl_ir::Matrix {
583                ident: MatrixIdent::A,
584                m: self.m,
585                n: self.n,
586                k: self.k,
587                storage: self.a_type,
588                layout: MatrixLayout::ColMajor,
589            };
590
591            scope.register(Instruction::new(
592                CoopMma::ExecuteManual {
593                    matrix,
594                    registers_a,
595                    registers_b,
596                    registers_c,
597                },
598                *registers_d.expand,
599            ));
600
601            registers_d
602        })
603    }
604
605    /// Execute a low level block scaled `mma` operation with manually managed registers. Register
606    /// layout and index mapping can be retrieved from the [`MatrixDefinition`]
607    #[allow(unused)]
608    pub fn execute_scaled<S: CubePrimitive>(
609        &self,
610        registers_a: &Array<Line<A>>,
611        registers_b: &Array<Line<B>>,
612        registers_c: &Array<Line<CD>>,
613        scales_a: Line<S>,
614        scales_b: Line<S>,
615    ) -> Array<Line<CD>> {
616        intrinsic!(|scope| {
617            let acc_elems = self
618                .clone()
619                .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
620            let acc_line_size = self
621                .clone()
622                .__expand_line_size_method(scope, MatrixIdent::Accumulator);
623            let num_registers = acc_elems / acc_line_size;
624
625            let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
626
627            let registers_a = *registers_a.expand;
628            let registers_b = *registers_b.expand;
629            let registers_c = *registers_c.expand;
630
631            // Only shape is actually used
632            let matrix = cubecl_ir::Matrix {
633                ident: MatrixIdent::A,
634                m: self.m,
635                n: self.n,
636                k: self.k,
637                storage: self.a_type,
638                layout: MatrixLayout::ColMajor,
639            };
640
641            scope.register(Instruction::new(
642                CoopMma::ExecuteScaled {
643                    matrix,
644                    registers_a,
645                    registers_b,
646                    registers_c,
647                    scales_a: *scales_a.expand,
648                    scales_b: *scales_b.expand,
649                    scales_factor: self
650                        .scales_factor
651                        .expect("Can't execute scaled on matrix with no scales"),
652                },
653                *registers_d.expand,
654            ));
655
656            registers_d
657        })
658    }
659}
660
661/// Fill the matrix with the provided value.
662#[allow(unused_variables)]
663pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
664    unexpanded!()
665}
666
667/// Module containing the expand function for [fill()].
668pub mod fill {
669    use super::*;
670
671    /// Expand method of [fill()].
672    pub fn expand<C: CubeType>(
673        scope: &mut Scope,
674        mat: MatrixExpand<C>,
675        value: ExpandElementTyped<C>,
676    ) {
677        let value: ExpandElement = value.into();
678        scope.register(Instruction::new(
679            ir::CoopMma::Fill { value: *value },
680            *mat.elem,
681        ));
682    }
683}
684
685/// Load the matrix with the provided array using the stride.
686#[allow(unused_variables)]
687pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
688    unexpanded!()
689}
690
691/// Module containing the expand function for [load()].
692pub mod load {
693    use super::*;
694
695    /// Expand method of [load()].
696    #[allow(unused_variables)]
697    pub fn expand<C: CubePrimitive, V: CubePrimitive>(
698        scope: &mut Scope,
699        mat: MatrixExpand<C>,
700        value: SliceExpand<V, ReadOnly>,
701        stride: ExpandElementTyped<u32>,
702    ) {
703        let stride: ExpandElement = stride.into();
704        assert_ne!(
705            mat.ident,
706            MatrixIdent::Accumulator,
707            "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
708        );
709
710        let (value, offset) = value.__to_raw_parts();
711
712        scope.register(Instruction::new(
713            ir::CoopMma::Load {
714                value,
715                stride: *stride,
716                offset,
717                layout: None,
718            },
719            *mat.elem,
720        ));
721    }
722}
723
724/// Load the matrix with the provided array using the stride with an explicit layout.
725/// Explicit layouts are required when loading accumulators.
726#[allow(unused_variables)]
727pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
728    mat: &Matrix<C>,
729    value: &Slice<V>,
730    stride: u32,
731    layout: MatrixLayout,
732) {
733    unexpanded!()
734}
735
736/// Module containing the expand function for [load_with_layout()].
737pub mod load_with_layout {
738    use super::*;
739
740    /// Expand method of [load_with_layout()].
741    #[allow(unused_variables)]
742    pub fn expand<C: CubeType, V: CubePrimitive>(
743        scope: &mut Scope,
744        mat: MatrixExpand<C>,
745        value: SliceExpand<V, ReadOnly>,
746        stride: ExpandElementTyped<u32>,
747        layout: MatrixLayout,
748    ) {
749        let stride: ExpandElement = stride.into();
750        let (value, offset) = value.__to_raw_parts();
751
752        scope.register(Instruction::new(
753            ir::CoopMma::Load {
754                value,
755                stride: *stride,
756                offset,
757                layout: Some(layout),
758            },
759            *mat.elem,
760        ));
761    }
762}
763
764/// Store the matrix in the given array following the given stride and layout.
765#[allow(unused_variables)]
766pub fn store<C: CubePrimitive, O: CubePrimitive>(
767    output: &mut SliceMut<O>,
768    mat: &Matrix<C>,
769    stride: u32,
770    layout: MatrixLayout,
771) {
772    unexpanded!()
773}
774
775/// Module containing the expand function for [store()].
776pub mod store {
777    use crate::prelude::ReadWrite;
778
779    use super::*;
780
781    /// Expand method of [store()].
782    #[allow(unused_variables)]
783    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
784        scope: &mut Scope,
785        output: SliceExpand<O, ReadWrite>,
786        mat: MatrixExpand<C>,
787        stride: ExpandElementTyped<u32>,
788        layout: MatrixLayout,
789    ) {
790        let stride: ExpandElement = stride.into();
791
792        let (output, offset) = output.__to_raw_parts();
793
794        scope.register(Instruction::new(
795            ir::CoopMma::Store {
796                mat: *mat.elem,
797                offset,
798                stride: *stride,
799                layout,
800            },
801            output,
802        ));
803    }
804}
805
806/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
807#[allow(unused_variables)]
808pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
809    mat_a: &Matrix<A>,
810    mat_b: &Matrix<B>,
811    mat_c: &Matrix<C>,
812    mat_d: &Matrix<D>,
813) {
814    unexpanded!()
815}
816
817/// Module containing the expand function for [execute()].
818pub mod execute {
819    use super::*;
820
821    /// Expand method of [execute()].
822    pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
823        scope: &mut Scope,
824        mat_a: MatrixExpand<A>,
825        mat_b: MatrixExpand<B>,
826        mat_c: MatrixExpand<C>,
827        mat_d: MatrixExpand<D>,
828    ) {
829        scope.register(Instruction::new(
830            ir::CoopMma::Execute {
831                mat_a: *mat_a.elem,
832                mat_b: *mat_b.elem,
833                mat_c: *mat_c.elem,
834            },
835            *mat_d.elem,
836        ));
837    }
838}
839
840/// Store the matrix in the given array following the given stride and layout.
841#[allow(unused_variables)]
842pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
843    unexpanded!()
844}
845
846/// Module containing the expand function for [store()].
847pub mod cast {
848    use super::*;
849
850    /// Expand method of [store()].
851    #[allow(unused_variables)]
852    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
853        scope: &mut Scope,
854        input: MatrixExpand<C>,
855    ) -> MatrixExpand<O> {
856        let ident = input.ident;
857
858        if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
859            return MatrixExpand {
860                elem: input.elem,
861                ident,
862                _c: PhantomData,
863            };
864        }
865        let input = *input.elem;
866        let input_mat = match input.kind {
867            ir::VariableKind::Matrix { mat, .. } => mat,
868            _ => unreachable!(),
869        };
870
871        let elem = O::as_type(scope);
872        let elem = scope.create_matrix(ir::Matrix::new(
873            ident,
874            input_mat.m,
875            input_mat.n,
876            input_mat.k,
877            elem,
878            MatrixLayout::Undefined,
879        ));
880
881        let output = MatrixExpand {
882            ident,
883            elem,
884            _c: PhantomData,
885        };
886        scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
887
888        output
889    }
890}
891
892impl CubeType for MatrixLayout {
893    type ExpandType = Self;
894}
895
896impl IntoMut for MatrixLayout {
897    fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
898        self
899    }
900}
901
902impl CubeDebug for MatrixLayout {}