cubecl_core/frontend/
cmma.rs

1//! This module exposes cooperative matrix-multiply and accumulate operations.
2//!
3//! Most of the functions are actually unsafe, since they mutate their input, even if they are
4//! passed as reference.
5//!
6//! # Example
7//!
8//! This is a basic 16x16x16 matrix multiplication example.
9//!
10//! ```rust, ignore
11//! #[cube(launch)]
12//! pub fn example(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>) {
13//!     let a = cmma::Matrix::<F16>::new(
14//!         cmma::MatrixIdent::A,
15//!         16,
16//!         16,
17//!         16,
18//!         cmma::MatrixLayout::RowMajor,
19//!     );
20//!     let b = cmma::Matrix::<F16>::new(
21//!         cmma::MatrixIdent::B,
22//!         16,
23//!         16,
24//!         16,
25//!         cmma::MatrixLayout::ColMajor,
26//!     );
27//!     let c = cmma::Matrix::<F32>::new(
28//!         cmma::MatrixIdent::Accumulator,
29//!         16,
30//!         16,
31//!         16,
32//!         cmma::MatrixLayout::Undefined,
33//!     );
34//!     cmma::fill::<F32>(&c, F32::new(0.0));
35//!     cmma::load::<F16>(&a, lhs.as_slice(), u32::new(16));
36//!     cmma::load::<F16>(&b, rhs.as_slice(), u32::new(16));
37//!
38//!     cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
39//!
40//!     cmma::store::<F32>(
41//!         out.as_slice_mut(),
42//!         &c,
43//!         u32::new(16),
44//!         cmma::MatrixLayout::RowMajor,
45//!     );
46//! }
47//! ```
48
49use super::{
50    CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
51    SliceMut,
52};
53use crate::{
54    self as cubecl,
55    prelude::{Array, Line},
56};
57use crate::{
58    ir::{self, Instruction},
59    unexpanded,
60};
61use cubecl_macros::{comptime_type, cube, intrinsic};
62use std::marker::PhantomData;
63
64use cubecl_ir::{CoopMma, ExpandElement, Scope, StorageType, Type};
65pub use ir::{MatrixIdent, MatrixLayout};
66
67/// A matrix represent a 2D grid of numbers.
68///
69/// They can either be in a [row major](MatrixLayout::RowMajor) or a
70/// [column major](MatrixLayout::ColMajor) format.
71#[derive(Copy, Clone)]
72pub struct Matrix<C: CubeType> {
73    _c: PhantomData<C>,
74}
75
76/// Defines a matrix multiplication operation, including the input and output type, and the shape.
77#[derive(Copy, Clone)]
78pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
79    _a: PhantomData<A>,
80    _b: PhantomData<B>,
81    _cd: PhantomData<CD>,
82}
83
84/// Expand type of [Matrix].
85pub struct MatrixExpand<C: CubeType> {
86    elem: ExpandElement,
87    ident: MatrixIdent,
88    _c: PhantomData<C>,
89}
90
91/// Expand type of [MmaDefinition].
92#[derive(Debug)]
93pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
94    pub m: u32,
95    pub n: u32,
96    pub k: u32,
97    pub a_type: StorageType,
98    pub b_type: StorageType,
99    pub cd_type: StorageType,
100    pub scales_factor: Option<u32>,
101    pub scales_type: Option<StorageType>,
102    _a: PhantomData<A>,
103    _b: PhantomData<B>,
104    _cd: PhantomData<CD>,
105}
106
107impl<C: CubeType> Clone for MatrixExpand<C> {
108    fn clone(&self) -> Self {
109        Self {
110            elem: self.elem.clone(),
111            ident: self.ident,
112            _c: self._c,
113        }
114    }
115}
116
117impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
118    fn clone(&self) -> Self {
119        Self {
120            m: self.m,
121            n: self.n,
122            k: self.k,
123            a_type: self.a_type,
124            b_type: self.b_type,
125            cd_type: self.cd_type,
126            scales_factor: self.scales_factor,
127            scales_type: self.scales_type,
128            _a: PhantomData,
129            _b: PhantomData,
130            _cd: PhantomData,
131        }
132    }
133}
134
135impl<C: CubeType> CubeType for Matrix<C> {
136    type ExpandType = MatrixExpand<C>;
137}
138
139impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
140    type ExpandType = MmaDefinitionExpand<A, B, CD>;
141}
142
143impl<C: CubeType> IntoMut for MatrixExpand<C> {
144    fn into_mut(self, _scope: &mut Scope) -> Self {
145        self
146    }
147}
148
149impl<C: CubeType> CubeDebug for MatrixExpand<C> {
150    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
151        scope.update_variable_name(*self.elem, name);
152    }
153}
154
155impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
156    fn into_mut(self, _scope: &mut Scope) -> Self {
157        self
158    }
159}
160
161impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
162
163#[cube]
164impl<C: CubePrimitive> Matrix<C> {
165    /// Create a new uninitialized matrix that is going to be used in the
166    /// [matrix-multiply and accumulate](execute()) function.
167    ///
168    /// # Safety
169    /// Must be initialized with `load` or `fill` before use. Using it without initialization is
170    /// undefined behaviour on CUDA, and completely invalid on Vulkan.
171    ///
172    /// You have to declare the shape used for the execution.
173    /// The shape of the current matrix is determined using the [MatrixIdent].
174    ///
175    /// * [MatrixIdent::A] Shape => (M, K)
176    /// * [MatrixIdent::B] Shape => (K, N)
177    /// * [MatrixIdent::Accumulator] Shape => (M, N)
178    ///
179    /// Not all shapes are supported, and the permitted shapes depend on the element type.
180    ///
181    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
182    #[allow(unused_variables)]
183    pub unsafe fn uninitialized(
184        #[comptime] ident: MatrixIdent,
185        m: u32,
186        n: u32,
187        k: u32,
188        layout: MatrixLayout,
189    ) -> Self {
190        intrinsic!(|scope| {
191            let elem = C::as_type(scope);
192            let elem = scope.create_matrix(ir::Matrix::new(
193                ident,
194                m.constant().unwrap().as_u32(),
195                n.constant().unwrap().as_u32(),
196                k.constant().unwrap().as_u32(),
197                elem,
198                layout,
199            ));
200            MatrixExpand {
201                elem,
202                ident,
203                _c: PhantomData,
204            }
205        })
206    }
207
208    /// Create a new matrix that is going to be used in the
209    /// [matrix-multiply and accumulate](execute()) function and is filled with `value`.
210    ///
211    /// You have to declare the shape used for the execution.
212    /// The shape of the current matrix is determined using the [MatrixIdent].
213    ///
214    /// * [MatrixIdent::A] Shape => (M, K)
215    /// * [MatrixIdent::B] Shape => (K, N)
216    /// * [MatrixIdent::Accumulator] Shape => (M, N)
217    ///
218    /// Not all shapes are supported, and the permitted shapes depend on the element type.
219    ///
220    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
221    #[allow(unused_variables)]
222    pub fn from_value(
223        #[comptime] ident: MatrixIdent,
224        m: u32,
225        n: u32,
226        k: u32,
227        layout: MatrixLayout,
228        value: C,
229    ) -> Self {
230        let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
231
232        intrinsic!(|scope| {
233            fill::expand(scope, mat.clone(), value);
234            mat
235        })
236    }
237
238    /// Create a new matrix that is going to be used in the
239    /// [matrix-multiply and accumulate](execute()) function and is loaded from `value` with `stride`.
240    ///
241    /// You have to declare the shape used for the execution.
242    /// The shape of the current matrix is determined using the [MatrixIdent].
243    ///
244    /// * [MatrixIdent::A] Shape => (M, K)
245    /// * [MatrixIdent::B] Shape => (K, N)
246    /// * [MatrixIdent::Accumulator] Shape => (M, N)
247    ///
248    /// Not all shapes are supported, and the permitted shapes depend on the element type.
249    ///
250    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
251    #[allow(unused_variables)]
252    pub fn from_slice(
253        #[comptime] ident: MatrixIdent,
254        m: u32,
255        n: u32,
256        k: u32,
257        layout: MatrixLayout,
258        value: &Slice<C>,
259        stride: u32,
260    ) -> Self {
261        let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
262
263        intrinsic!(|scope| {
264            load::expand(scope, mat.clone(), value, stride);
265            mat
266        })
267    }
268}
269
270#[cube]
271impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B, CD> {
272    /// Create a new matrix definition that is going to be used in the manual
273    /// [matrix-multiply and accumulate](execute_manual()) function.
274    ///
275    /// You have to declare the shape used for the execution.
276    /// The shape of the current matrix is determined using the [MatrixIdent].
277    ///
278    /// * [MatrixIdent::A] Shape => (M, K)
279    /// * [MatrixIdent::B] Shape => (K, N)
280    /// * [MatrixIdent::Accumulator] Shape => (M, N)
281    ///
282    /// Not all shapes are supported, and the permitted shapes depend on the element type.
283    /// Layout for manual MMA is determined by the runtime and must be handled manually.
284    /// Use [`line_layout`] to check the correct data layout for each element.
285    ///
286    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
287    #[allow(unused_variables)]
288    pub fn new(#[comptime] m: u32, #[comptime] n: u32, #[comptime] k: u32) -> Self {
289        intrinsic!(|scope| {
290            let a_type = A::as_type(scope);
291            let b_type = B::as_type(scope);
292            let cd_type = CD::as_type(scope);
293
294            MmaDefinitionExpand {
295                m,
296                n,
297                k,
298                a_type,
299                b_type,
300                cd_type,
301                scales_factor: None,
302                scales_type: None,
303                _a: PhantomData,
304                _b: PhantomData,
305                _cd: PhantomData,
306            }
307        })
308    }
309
310    /// Create a new matrix definition that is going to be used in the manual
311    /// [matrix-multiply and accumulate](execute_manual()) function.
312    ///
313    /// You have to declare the shape used for the execution.
314    /// The shape of the current matrix is determined using the [MatrixIdent].
315    ///
316    /// * [MatrixIdent::A] Shape => (M, K)
317    /// * [MatrixIdent::B] Shape => (K, N)
318    /// * [MatrixIdent::Accumulator] Shape => (M, N)
319    ///
320    /// Not all shapes are supported, and the permitted shapes depend on the element type.
321    /// Layout for manual MMA is determined by the runtime and must be handled manually.
322    /// Use [`line_layout`] to check the correct data layout for each element.
323    ///
324    /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
325    #[allow(unused_variables)]
326    pub fn new_scaled<S: CubePrimitive>(
327        #[comptime] m: u32,
328        #[comptime] n: u32,
329        #[comptime] k: u32,
330        #[comptime] scale_factor: u32,
331    ) -> Self {
332        intrinsic!(|scope| {
333            let a_type = A::as_type(scope);
334            let b_type = B::as_type(scope);
335            let cd_type = CD::as_type(scope);
336
337            MmaDefinitionExpand {
338                m,
339                n,
340                k,
341                a_type,
342                b_type,
343                cd_type,
344                scales_factor: Some(scale_factor),
345                scales_type: Some(S::as_type(scope)),
346                _a: PhantomData,
347                _b: PhantomData,
348                _cd: PhantomData,
349            }
350        })
351    }
352
353    /// Number of elements in the matrix
354    #[allow(unused)]
355    pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
356        intrinsic!(|scope| {
357            match ident {
358                MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor() as u32,
359                MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor() as u32,
360                MatrixIdent::Accumulator => {
361                    (self.m * self.n) / self.cd_type.packing_factor() as u32
362                }
363            }
364        })
365    }
366
367    /// Returns the number of elements handled by each lane. Should be packed into `Line`s of size
368    /// `line_size` with [`line_layout`].
369    ///
370    /// # Note
371    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
372    /// to a cube.
373    #[allow(unused)]
374    pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
375        intrinsic!(|scope| {
376            let elems = self.__expand_num_elems_method(scope, ident);
377            let plane_dim = scope.runtime_properties.mma.const_plane_size;
378            let duplication = match ident {
379                MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
380                MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
381                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
382            };
383            (elems / plane_dim) * duplication
384        })
385    }
386
387    /// Returns the number of lines of size `line_size` with layout `line_layout` per lane.
388    ///
389    /// # Note
390    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
391    /// to a cube.
392    #[allow(unused)]
393    pub fn lines_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
394        intrinsic!(|scope| {
395            let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
396            let line_size = self.__expand_line_size_method(scope, ident);
397            elems / line_size
398        })
399    }
400
401    /// The layout of each line in this matrix (row major or column major)
402    #[allow(unused)]
403    pub fn line_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
404        intrinsic!(|scope| {
405            match ident {
406                MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
407                MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
408                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
409            }
410        })
411    }
412
413    /// Number of elements in each line passed to the execute function. Represents the maximum
414    /// number of contiguous elements held by the thread.
415    #[allow(unused_variables)]
416    pub fn line_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
417        intrinsic!(|scope| {
418            let storage = match ident {
419                MatrixIdent::A => self.a_type,
420                MatrixIdent::B => self.b_type,
421                MatrixIdent::Accumulator => self.cd_type,
422            };
423            let matrix = cubecl_ir::Matrix {
424                ident,
425                m: self.m,
426                n: self.n,
427                k: self.k,
428                storage: storage,
429                layout: MatrixLayout::ColMajor,
430            };
431            scope
432                .runtime_properties
433                .mma
434                .contiguous_elements
435                .apply(ident, matrix)
436        })
437    }
438
439    /// Returns the coordinates of the `nth` element handled by the `lane_id`
440    /// Each lane contains [`elems_per_lane`] elements in [`line_size`] chunks.
441    /// Returns (`row_idx`, `col_idx`)
442    ///
443    /// # Note
444    /// "Lane" here refers to the unit relative to a plane, to distinguish it from a unit relative
445    /// to a cube.
446    #[allow(unused_variables)]
447    pub fn position_of_nth(
448        &self,
449        lane_id: u32,
450        elem_idx: u32,
451        #[comptime] ident: MatrixIdent,
452    ) -> (u32, u32) {
453        intrinsic!(|scope| {
454            let lane_id: ExpandElement = lane_id.into();
455            let elem_idx: ExpandElement = elem_idx.into();
456
457            let ty = match ident {
458                MatrixIdent::A => self.a_type,
459                MatrixIdent::B => self.b_type,
460                MatrixIdent::Accumulator => self.cd_type,
461            };
462            let layout = match ident {
463                MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
464                MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
465                MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
466            };
467            let matrix = cubecl_ir::Matrix {
468                ident,
469                m: self.m,
470                n: self.n,
471                k: self.k,
472                storage: ty,
473                layout,
474            };
475
476            let row = scope.create_local(Type::new(u32::as_type(scope)));
477            let col = scope.create_local(Type::new(u32::as_type(scope)));
478            scope.register(Instruction::new(
479                CoopMma::RowIndex {
480                    lane_id: *lane_id,
481                    i: *elem_idx,
482                    matrix,
483                },
484                *row,
485            ));
486            scope.register(Instruction::new(
487                CoopMma::ColIndex {
488                    lane_id: *lane_id,
489                    i: *elem_idx,
490                    matrix,
491                },
492                *col,
493            ));
494            (row.into(), col.into())
495        })
496    }
497
498    /// Index of the scales for this thread, along the non-major dimension of the matrix.
499    /// Each thread loads all scales in the major direction into a single `Line`.
500    pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
501        // Just do CUDA for now, call an actual intrinsic when HIP gets support
502        let quad_id = lane_id / 4;
503        let t_id = lane_id % 4;
504        match ident {
505            MatrixIdent::A => quad_id + (t_id % 2) * 8,
506            MatrixIdent::B => quad_id,
507            MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
508        }
509    }
510
511    /// Number of scales in each line (not the line size!). Line size may include padding bytes.
512    pub fn scales_count(&self) -> comptime_type!(u32) {
513        // We only have the CUDA version for now, so just use `scales_factor`. The function can
514        // be modified for HIP in the future without having to redo all uses.
515        intrinsic!(|_| {
516            self.scales_factor
517                .expect("Can't retrieve scales count for matrix with no scales")
518        })
519    }
520
521    /// Line size for the scale factors. May be larger than the total number of scales.
522    pub fn scales_line_size(&self) -> comptime_type!(u32) {
523        intrinsic!(|scope| {
524            let elem = self
525                .scales_type
526                .expect("Can't retrieve scales line size for matrix with no scales");
527            scope.runtime_properties.mma.register_size_bits / elem.size_bits() as u32
528        })
529    }
530
531    /// Load one or more matrix register using intrinsic instructions. CUDA only.
532    /// The number of matrices must be 1, 2, or 4. The rows for the nth matrix are passed by the 8
533    /// lanes starting at `n * 8`. All slice starts must be valid, even for non-participating lanes.
534    /// The slice determines the starting address for a 16-byte row loaded by this unit, with
535    /// the row index being `UNIT_POS_PLANE % 8`.
536    /// The number of elements is determined by element size.
537    ///
538    /// # Constraints:
539    /// Address must be aligned to 16 bytes
540    /// Address must be in shared memory
541    #[allow(unused_variables)]
542    pub fn load_matrix<E: CubePrimitive>(
543        &self,
544        row: &Slice<Line<E>>,
545        #[comptime] ident: MatrixIdent,
546        #[comptime] num_matrices: u32,
547        #[comptime] transpose: bool,
548    ) -> Array<Line<E>> {
549        intrinsic!(|scope| {
550            let line_size = self.__expand_line_size_method(scope, ident);
551            let slice_line_size = row.line_size;
552            let (buffer, offset) = row.__to_raw_parts();
553            let out = Array::__expand_vectorized(scope, num_matrices, line_size);
554            scope.register(Instruction::new(
555                CoopMma::LoadMatrix {
556                    buffer,
557                    offset,
558                    line_size: slice_line_size,
559                    factor: num_matrices,
560                    transpose,
561                },
562                *out.expand,
563            ));
564            out
565        })
566    }
567
568    /// Execute a low level `mma` operation with manually managed registers. Register layout
569    /// and index mapping can be retrieved from the [`MatrixDefinition`]
570    #[allow(unused)]
571    pub fn execute(
572        &self,
573        registers_a: &Array<Line<A>>,
574        registers_b: &Array<Line<B>>,
575        registers_c: &Array<Line<CD>>,
576    ) -> Array<Line<CD>> {
577        intrinsic!(|scope| {
578            let acc_elems = self
579                .clone()
580                .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
581            let acc_line_size = self
582                .clone()
583                .__expand_line_size_method(scope, MatrixIdent::Accumulator);
584            let num_registers = acc_elems / acc_line_size;
585
586            let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
587
588            let registers_a = *registers_a.expand;
589            let registers_b = *registers_b.expand;
590            let registers_c = *registers_c.expand;
591
592            // Only shape is actually used
593            let matrix = cubecl_ir::Matrix {
594                ident: MatrixIdent::A,
595                m: self.m,
596                n: self.n,
597                k: self.k,
598                storage: self.a_type,
599                layout: MatrixLayout::ColMajor,
600            };
601
602            scope.register(Instruction::new(
603                CoopMma::ExecuteManual {
604                    matrix,
605                    registers_a,
606                    registers_b,
607                    registers_c,
608                },
609                *registers_d.expand,
610            ));
611
612            registers_d
613        })
614    }
615
616    /// Execute a low level block scaled `mma` operation with manually managed registers. Register
617    /// layout and index mapping can be retrieved from the [`MatrixDefinition`]
618    #[allow(unused)]
619    pub fn execute_scaled<S: CubePrimitive>(
620        &self,
621        registers_a: &Array<Line<A>>,
622        registers_b: &Array<Line<B>>,
623        registers_c: &Array<Line<CD>>,
624        scales_a: Line<S>,
625        scales_b: Line<S>,
626    ) -> Array<Line<CD>> {
627        intrinsic!(|scope| {
628            let acc_elems = self
629                .clone()
630                .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
631            let acc_line_size = self
632                .clone()
633                .__expand_line_size_method(scope, MatrixIdent::Accumulator);
634            let num_registers = acc_elems / acc_line_size;
635
636            let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
637
638            let registers_a = *registers_a.expand;
639            let registers_b = *registers_b.expand;
640            let registers_c = *registers_c.expand;
641
642            // Only shape is actually used
643            let matrix = cubecl_ir::Matrix {
644                ident: MatrixIdent::A,
645                m: self.m,
646                n: self.n,
647                k: self.k,
648                storage: self.a_type,
649                layout: MatrixLayout::ColMajor,
650            };
651
652            scope.register(Instruction::new(
653                CoopMma::ExecuteScaled {
654                    matrix,
655                    registers_a,
656                    registers_b,
657                    registers_c,
658                    scales_a: *scales_a.expand,
659                    scales_b: *scales_b.expand,
660                    scales_factor: self
661                        .scales_factor
662                        .expect("Can't execute scaled on matrix with no scales"),
663                },
664                *registers_d.expand,
665            ));
666
667            registers_d
668        })
669    }
670}
671
672/// Fill the matrix with the provided value.
673#[allow(unused_variables)]
674pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
675    unexpanded!()
676}
677
678/// Module containing the expand function for [fill()].
679pub mod fill {
680    use super::*;
681
682    /// Expand method of [fill()].
683    pub fn expand<C: CubeType>(
684        scope: &mut Scope,
685        mat: MatrixExpand<C>,
686        value: ExpandElementTyped<C>,
687    ) {
688        let value: ExpandElement = value.into();
689        scope.register(Instruction::new(
690            ir::CoopMma::Fill { value: *value },
691            *mat.elem,
692        ));
693    }
694}
695
696/// Load the matrix with the provided array using the stride.
697#[allow(unused_variables)]
698pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
699    unexpanded!()
700}
701
702/// Module containing the expand function for [load()].
703pub mod load {
704    use super::*;
705
706    /// Expand method of [load()].
707    #[allow(unused_variables)]
708    pub fn expand<C: CubePrimitive, V: CubePrimitive>(
709        scope: &mut Scope,
710        mat: MatrixExpand<C>,
711        value: SliceExpand<V, ReadOnly>,
712        stride: ExpandElementTyped<u32>,
713    ) {
714        let stride: ExpandElement = stride.into();
715        assert_ne!(
716            mat.ident,
717            MatrixIdent::Accumulator,
718            "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
719        );
720
721        let (value, offset) = value.__to_raw_parts();
722
723        scope.register(Instruction::new(
724            ir::CoopMma::Load {
725                value,
726                stride: *stride,
727                offset,
728                layout: None,
729            },
730            *mat.elem,
731        ));
732    }
733}
734
735/// Load the matrix with the provided array using the stride with an explicit layout.
736/// Explicit layouts are required when loading accumulators.
737#[allow(unused_variables)]
738pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
739    mat: &Matrix<C>,
740    value: &Slice<V>,
741    stride: u32,
742    layout: MatrixLayout,
743) {
744    unexpanded!()
745}
746
747/// Module containing the expand function for [load_with_layout()].
748pub mod load_with_layout {
749    use super::*;
750
751    /// Expand method of [load_with_layout()].
752    #[allow(unused_variables)]
753    pub fn expand<C: CubeType, V: CubePrimitive>(
754        scope: &mut Scope,
755        mat: MatrixExpand<C>,
756        value: SliceExpand<V, ReadOnly>,
757        stride: ExpandElementTyped<u32>,
758        layout: MatrixLayout,
759    ) {
760        let stride: ExpandElement = stride.into();
761        let (value, offset) = value.__to_raw_parts();
762
763        scope.register(Instruction::new(
764            ir::CoopMma::Load {
765                value,
766                stride: *stride,
767                offset,
768                layout: Some(layout),
769            },
770            *mat.elem,
771        ));
772    }
773}
774
775/// Store the matrix in the given array following the given stride and layout.
776#[allow(unused_variables)]
777pub fn store<C: CubePrimitive, O: CubePrimitive>(
778    output: &mut SliceMut<O>,
779    mat: &Matrix<C>,
780    stride: u32,
781    layout: MatrixLayout,
782) {
783    unexpanded!()
784}
785
786/// Module containing the expand function for [store()].
787pub mod store {
788    use crate::prelude::ReadWrite;
789
790    use super::*;
791
792    /// Expand method of [store()].
793    #[allow(unused_variables)]
794    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
795        scope: &mut Scope,
796        output: SliceExpand<O, ReadWrite>,
797        mat: MatrixExpand<C>,
798        stride: ExpandElementTyped<u32>,
799        layout: MatrixLayout,
800    ) {
801        let stride: ExpandElement = stride.into();
802
803        let (output, offset) = output.__to_raw_parts();
804
805        scope.register(Instruction::new(
806            ir::CoopMma::Store {
807                mat: *mat.elem,
808                offset,
809                stride: *stride,
810                layout,
811            },
812            output,
813        ));
814    }
815}
816
817/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
818#[allow(unused_variables)]
819pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
820    mat_a: &Matrix<A>,
821    mat_b: &Matrix<B>,
822    mat_c: &Matrix<C>,
823    mat_d: &Matrix<D>,
824) {
825    unexpanded!()
826}
827
828/// Module containing the expand function for [execute()].
829pub mod execute {
830    use super::*;
831
832    /// Expand method of [execute()].
833    pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
834        scope: &mut Scope,
835        mat_a: MatrixExpand<A>,
836        mat_b: MatrixExpand<B>,
837        mat_c: MatrixExpand<C>,
838        mat_d: MatrixExpand<D>,
839    ) {
840        scope.register(Instruction::new(
841            ir::CoopMma::Execute {
842                mat_a: *mat_a.elem,
843                mat_b: *mat_b.elem,
844                mat_c: *mat_c.elem,
845            },
846            *mat_d.elem,
847        ));
848    }
849}
850
851/// Store the matrix in the given array following the given stride and layout.
852#[allow(unused_variables)]
853pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
854    unexpanded!()
855}
856
857/// Module containing the expand function for [store()].
858pub mod cast {
859    use super::*;
860
861    /// Expand method of [store()].
862    #[allow(unused_variables)]
863    pub fn expand<C: CubePrimitive, O: CubePrimitive>(
864        scope: &mut Scope,
865        input: MatrixExpand<C>,
866    ) -> MatrixExpand<O> {
867        let ident = input.ident;
868
869        if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
870            return MatrixExpand {
871                elem: input.elem,
872                ident,
873                _c: PhantomData,
874            };
875        }
876        let input = *input.elem;
877        let input_mat = match input.kind {
878            ir::VariableKind::Matrix { mat, .. } => mat,
879            _ => unreachable!(),
880        };
881
882        let elem = O::as_type(scope);
883        let elem = scope.create_matrix(ir::Matrix::new(
884            ident,
885            input_mat.m,
886            input_mat.n,
887            input_mat.k,
888            elem,
889            MatrixLayout::Undefined,
890        ));
891
892        let output = MatrixExpand {
893            ident,
894            elem,
895            _c: PhantomData,
896        };
897        scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
898
899        output
900    }
901}
902
903impl CubeType for MatrixLayout {
904    type ExpandType = Self;
905}
906
907impl IntoMut for MatrixLayout {
908    fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
909        self
910    }
911}
912
913impl CubeDebug for MatrixLayout {}