1use super::{
50 CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
51 SliceMut,
52};
53use crate::{
54 self as cubecl,
55 prelude::{Array, Line, ReadWrite},
56};
57use crate::{
58 ir::{self, Instruction},
59 unexpanded,
60};
61use cubecl_macros::{comptime_type, cube, intrinsic};
62use std::marker::PhantomData;
63
64use cubecl_ir::{CoopMma, ExpandElement, Scope, StorageType, Type};
65pub use ir::{MatrixIdent, MatrixLayout};
66
67#[derive(Copy, Clone)]
72pub struct Matrix<C: CubeType> {
73 _c: PhantomData<C>,
74}
75
76#[derive(Copy, Clone)]
78pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
79 _a: PhantomData<A>,
80 _b: PhantomData<B>,
81 _cd: PhantomData<CD>,
82}
83
84pub struct MatrixExpand<C: CubeType> {
86 elem: ExpandElement,
87 ident: MatrixIdent,
88 _c: PhantomData<C>,
89}
90
91#[derive(Debug)]
93pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
94 pub m: u32,
95 pub n: u32,
96 pub k: u32,
97 pub a_type: StorageType,
98 pub b_type: StorageType,
99 pub cd_type: StorageType,
100 pub scales_factor: Option<u32>,
101 pub scales_type: Option<StorageType>,
102 _a: PhantomData<A>,
103 _b: PhantomData<B>,
104 _cd: PhantomData<CD>,
105}
106
107impl<C: CubeType> Clone for MatrixExpand<C> {
108 fn clone(&self) -> Self {
109 Self {
110 elem: self.elem.clone(),
111 ident: self.ident,
112 _c: self._c,
113 }
114 }
115}
116
117impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
118 fn clone(&self) -> Self {
119 Self {
120 m: self.m,
121 n: self.n,
122 k: self.k,
123 a_type: self.a_type,
124 b_type: self.b_type,
125 cd_type: self.cd_type,
126 scales_factor: self.scales_factor,
127 scales_type: self.scales_type,
128 _a: PhantomData,
129 _b: PhantomData,
130 _cd: PhantomData,
131 }
132 }
133}
134
135impl<C: CubeType> CubeType for Matrix<C> {
136 type ExpandType = MatrixExpand<C>;
137}
138
139impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
140 type ExpandType = MmaDefinitionExpand<A, B, CD>;
141}
142
143impl<C: CubeType> IntoMut for MatrixExpand<C> {
144 fn into_mut(self, _scope: &mut Scope) -> Self {
145 self
146 }
147}
148
149impl<C: CubeType> CubeDebug for MatrixExpand<C> {
150 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
151 scope.update_variable_name(*self.elem, name);
152 }
153}
154
155impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
156 fn into_mut(self, _scope: &mut Scope) -> Self {
157 self
158 }
159}
160
161impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
162
163#[cube]
164impl<C: CubePrimitive> Matrix<C> {
165 #[allow(unused_variables)]
183 pub unsafe fn uninitialized(
184 #[comptime] ident: MatrixIdent,
185 m: u32,
186 n: u32,
187 k: u32,
188 layout: MatrixLayout,
189 ) -> Self {
190 intrinsic!(|scope| {
191 let elem = C::as_type(scope);
192 let elem = scope.create_matrix(ir::Matrix::new(
193 ident,
194 m.constant().unwrap().as_u32(),
195 n.constant().unwrap().as_u32(),
196 k.constant().unwrap().as_u32(),
197 elem,
198 layout,
199 ));
200 MatrixExpand {
201 elem,
202 ident,
203 _c: PhantomData,
204 }
205 })
206 }
207
208 #[allow(unused_variables)]
222 pub fn from_value(
223 #[comptime] ident: MatrixIdent,
224 m: u32,
225 n: u32,
226 k: u32,
227 layout: MatrixLayout,
228 value: C,
229 ) -> Self {
230 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
231
232 intrinsic!(|scope| {
233 fill::expand(scope, mat.clone(), value);
234 mat
235 })
236 }
237
238 #[allow(unused_variables)]
252 pub fn from_slice(
253 #[comptime] ident: MatrixIdent,
254 m: u32,
255 n: u32,
256 k: u32,
257 layout: MatrixLayout,
258 value: &Slice<C>,
259 stride: u32,
260 ) -> Self {
261 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
262
263 intrinsic!(|scope| {
264 load::expand(scope, mat.clone(), value, stride);
265 mat
266 })
267 }
268}
269
270#[cube]
271impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B, CD> {
272 #[allow(unused_variables)]
288 pub fn new(#[comptime] m: u32, #[comptime] n: u32, #[comptime] k: u32) -> Self {
289 intrinsic!(|scope| {
290 let a_type = A::as_type(scope);
291 let b_type = B::as_type(scope);
292 let cd_type = CD::as_type(scope);
293
294 MmaDefinitionExpand {
295 m,
296 n,
297 k,
298 a_type,
299 b_type,
300 cd_type,
301 scales_factor: None,
302 scales_type: None,
303 _a: PhantomData,
304 _b: PhantomData,
305 _cd: PhantomData,
306 }
307 })
308 }
309
310 #[allow(unused_variables)]
326 pub fn new_scaled<S: CubePrimitive>(
327 #[comptime] m: u32,
328 #[comptime] n: u32,
329 #[comptime] k: u32,
330 #[comptime] scale_factor: u32,
331 ) -> Self {
332 intrinsic!(|scope| {
333 let a_type = A::as_type(scope);
334 let b_type = B::as_type(scope);
335 let cd_type = CD::as_type(scope);
336
337 MmaDefinitionExpand {
338 m,
339 n,
340 k,
341 a_type,
342 b_type,
343 cd_type,
344 scales_factor: Some(scale_factor),
345 scales_type: Some(S::as_type(scope)),
346 _a: PhantomData,
347 _b: PhantomData,
348 _cd: PhantomData,
349 }
350 })
351 }
352
353 #[allow(unused)]
355 pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
356 intrinsic!(|scope| {
357 match ident {
358 MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor() as u32,
359 MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor() as u32,
360 MatrixIdent::Accumulator => {
361 (self.m * self.n) / self.cd_type.packing_factor() as u32
362 }
363 }
364 })
365 }
366
367 #[allow(unused)]
374 pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
375 intrinsic!(|scope| {
376 let elems = self.__expand_num_elems_method(scope, ident);
377 let plane_dim = scope.runtime_properties.mma.const_plane_size;
378 let duplication = match ident {
379 MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
380 MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
381 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
382 };
383 (elems / plane_dim) * duplication
384 })
385 }
386
387 #[allow(unused)]
393 pub fn lines_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
394 intrinsic!(|scope| {
395 let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
396 let line_size = self.__expand_line_size_method(scope, ident);
397 elems / line_size
398 })
399 }
400
401 #[allow(unused)]
403 pub fn line_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
404 intrinsic!(|scope| {
405 match ident {
406 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
407 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
408 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
409 }
410 })
411 }
412
413 #[allow(unused_variables)]
416 pub fn line_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
417 intrinsic!(|scope| {
418 let storage = match ident {
419 MatrixIdent::A => self.a_type,
420 MatrixIdent::B => self.b_type,
421 MatrixIdent::Accumulator => self.cd_type,
422 };
423 let matrix = cubecl_ir::Matrix {
424 ident,
425 m: self.m,
426 n: self.n,
427 k: self.k,
428 storage: storage,
429 layout: MatrixLayout::ColMajor,
430 };
431 scope
432 .runtime_properties
433 .mma
434 .contiguous_elements
435 .apply(ident, matrix)
436 })
437 }
438
439 #[allow(unused_variables)]
447 pub fn position_of_nth(
448 &self,
449 lane_id: u32,
450 elem_idx: u32,
451 #[comptime] ident: MatrixIdent,
452 ) -> (u32, u32) {
453 intrinsic!(|scope| {
454 let lane_id: ExpandElement = lane_id.into();
455 let elem_idx: ExpandElement = elem_idx.into();
456
457 let ty = match ident {
458 MatrixIdent::A => self.a_type,
459 MatrixIdent::B => self.b_type,
460 MatrixIdent::Accumulator => self.cd_type,
461 };
462 let layout = match ident {
463 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
464 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
465 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
466 };
467 let matrix = cubecl_ir::Matrix {
468 ident,
469 m: self.m,
470 n: self.n,
471 k: self.k,
472 storage: ty,
473 layout,
474 };
475
476 let row = scope.create_local(Type::new(u32::as_type(scope)));
477 let col = scope.create_local(Type::new(u32::as_type(scope)));
478 scope.register(Instruction::new(
479 CoopMma::RowIndex {
480 lane_id: *lane_id,
481 i: *elem_idx,
482 matrix,
483 },
484 *row,
485 ));
486 scope.register(Instruction::new(
487 CoopMma::ColIndex {
488 lane_id: *lane_id,
489 i: *elem_idx,
490 matrix,
491 },
492 *col,
493 ));
494 (row.into(), col.into())
495 })
496 }
497
498 pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
501 let quad_id = lane_id / 4;
503 let t_id = lane_id % 4;
504 match ident {
505 MatrixIdent::A => quad_id + (t_id % 2) * 8,
506 MatrixIdent::B => quad_id,
507 MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
508 }
509 }
510
511 pub fn scales_count(&self) -> comptime_type!(u32) {
513 intrinsic!(|_| {
516 self.scales_factor
517 .expect("Can't retrieve scales count for matrix with no scales")
518 })
519 }
520
521 pub fn scales_line_size(&self) -> comptime_type!(u32) {
523 intrinsic!(|scope| {
524 let elem = self
525 .scales_type
526 .expect("Can't retrieve scales line size for matrix with no scales");
527 scope.runtime_properties.mma.register_size_bits / elem.size_bits() as u32
528 })
529 }
530
531 #[allow(unused_variables)]
542 pub fn load_matrix<E: CubePrimitive>(
543 &self,
544 row: &Slice<Line<E>>,
545 #[comptime] ident: MatrixIdent,
546 #[comptime] num_matrices: u32,
547 #[comptime] transpose: bool,
548 ) -> Array<Line<E>> {
549 intrinsic!(|scope| {
550 let line_size = self.__expand_line_size_method(scope, ident);
551 let slice_line_size = row.line_size;
552 let (buffer, offset) = row.__to_raw_parts();
553 let out = Array::__expand_vectorized(scope, num_matrices, line_size);
554 scope.register(Instruction::new(
555 CoopMma::LoadMatrix {
556 buffer,
557 offset,
558 line_size: slice_line_size,
559 factor: num_matrices,
560 transpose,
561 },
562 *out.expand,
563 ));
564 out
565 })
566 }
567
568 #[allow(unused_variables)]
579 pub fn store_matrix<E: CubePrimitive>(
580 &self,
581 row: &mut Slice<Line<E>, ReadWrite>,
582 registers: &Array<Line<E>>,
583 #[comptime] ident: MatrixIdent,
584 #[comptime] num_matrices: u32,
585 #[comptime] transpose: bool,
586 ) {
587 intrinsic!(|scope| {
588 let line_size = self.__expand_line_size_method(scope, ident);
589 let slice_line_size = row.line_size;
590 let (buffer, offset) = row.__to_raw_parts();
591 scope.register(Instruction::new(
592 CoopMma::StoreMatrix {
593 offset,
594 line_size: slice_line_size,
595 registers: *registers.expand,
596 factor: num_matrices,
597 transpose,
598 },
599 buffer,
600 ));
601 })
602 }
603
604 #[allow(unused)]
607 pub fn execute(
608 &self,
609 registers_a: &Array<Line<A>>,
610 registers_b: &Array<Line<B>>,
611 registers_c: &Array<Line<CD>>,
612 ) -> Array<Line<CD>> {
613 intrinsic!(|scope| {
614 let acc_elems = self
615 .clone()
616 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
617 let acc_line_size = self
618 .clone()
619 .__expand_line_size_method(scope, MatrixIdent::Accumulator);
620 let num_registers = acc_elems / acc_line_size;
621
622 let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
623
624 let registers_a = *registers_a.expand;
625 let registers_b = *registers_b.expand;
626 let registers_c = *registers_c.expand;
627
628 let matrix = cubecl_ir::Matrix {
630 ident: MatrixIdent::A,
631 m: self.m,
632 n: self.n,
633 k: self.k,
634 storage: self.a_type,
635 layout: MatrixLayout::ColMajor,
636 };
637
638 scope.register(Instruction::new(
639 CoopMma::ExecuteManual {
640 matrix,
641 registers_a,
642 registers_b,
643 registers_c,
644 },
645 *registers_d.expand,
646 ));
647
648 registers_d
649 })
650 }
651
652 #[allow(unused)]
655 pub fn execute_scaled<S: CubePrimitive>(
656 &self,
657 registers_a: &Array<Line<A>>,
658 registers_b: &Array<Line<B>>,
659 registers_c: &Array<Line<CD>>,
660 scales_a: Line<S>,
661 scales_b: Line<S>,
662 ) -> Array<Line<CD>> {
663 intrinsic!(|scope| {
664 let acc_elems = self
665 .clone()
666 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
667 let acc_line_size = self
668 .clone()
669 .__expand_line_size_method(scope, MatrixIdent::Accumulator);
670 let num_registers = acc_elems / acc_line_size;
671
672 let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
673
674 let registers_a = *registers_a.expand;
675 let registers_b = *registers_b.expand;
676 let registers_c = *registers_c.expand;
677
678 let matrix = cubecl_ir::Matrix {
680 ident: MatrixIdent::A,
681 m: self.m,
682 n: self.n,
683 k: self.k,
684 storage: self.a_type,
685 layout: MatrixLayout::ColMajor,
686 };
687
688 scope.register(Instruction::new(
689 CoopMma::ExecuteScaled {
690 matrix,
691 registers_a,
692 registers_b,
693 registers_c,
694 scales_a: *scales_a.expand,
695 scales_b: *scales_b.expand,
696 scales_factor: self
697 .scales_factor
698 .expect("Can't execute scaled on matrix with no scales"),
699 },
700 *registers_d.expand,
701 ));
702
703 registers_d
704 })
705 }
706}
707
708#[allow(unused_variables)]
710pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
711 unexpanded!()
712}
713
714pub mod fill {
716 use super::*;
717
718 pub fn expand<C: CubeType>(
720 scope: &mut Scope,
721 mat: MatrixExpand<C>,
722 value: ExpandElementTyped<C>,
723 ) {
724 let value: ExpandElement = value.into();
725 scope.register(Instruction::new(
726 ir::CoopMma::Fill { value: *value },
727 *mat.elem,
728 ));
729 }
730}
731
732#[allow(unused_variables)]
734pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
735 unexpanded!()
736}
737
738pub mod load {
740 use super::*;
741
742 #[allow(unused_variables)]
744 pub fn expand<C: CubePrimitive, V: CubePrimitive>(
745 scope: &mut Scope,
746 mat: MatrixExpand<C>,
747 value: SliceExpand<V, ReadOnly>,
748 stride: ExpandElementTyped<u32>,
749 ) {
750 let stride: ExpandElement = stride.into();
751 assert_ne!(
752 mat.ident,
753 MatrixIdent::Accumulator,
754 "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
755 );
756
757 let (value, offset) = value.__to_raw_parts();
758
759 scope.register(Instruction::new(
760 ir::CoopMma::Load {
761 value,
762 stride: *stride,
763 offset,
764 layout: None,
765 },
766 *mat.elem,
767 ));
768 }
769}
770
771#[allow(unused_variables)]
774pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
775 mat: &Matrix<C>,
776 value: &Slice<V>,
777 stride: u32,
778 layout: MatrixLayout,
779) {
780 unexpanded!()
781}
782
783pub mod load_with_layout {
785 use super::*;
786
787 #[allow(unused_variables)]
789 pub fn expand<C: CubeType, V: CubePrimitive>(
790 scope: &mut Scope,
791 mat: MatrixExpand<C>,
792 value: SliceExpand<V, ReadOnly>,
793 stride: ExpandElementTyped<u32>,
794 layout: MatrixLayout,
795 ) {
796 let stride: ExpandElement = stride.into();
797 let (value, offset) = value.__to_raw_parts();
798
799 scope.register(Instruction::new(
800 ir::CoopMma::Load {
801 value,
802 stride: *stride,
803 offset,
804 layout: Some(layout),
805 },
806 *mat.elem,
807 ));
808 }
809}
810
811#[allow(unused_variables)]
813pub fn store<C: CubePrimitive, O: CubePrimitive>(
814 output: &mut SliceMut<O>,
815 mat: &Matrix<C>,
816 stride: u32,
817 layout: MatrixLayout,
818) {
819 unexpanded!()
820}
821
822pub mod store {
824 use crate::prelude::ReadWrite;
825
826 use super::*;
827
828 #[allow(unused_variables)]
830 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
831 scope: &mut Scope,
832 output: SliceExpand<O, ReadWrite>,
833 mat: MatrixExpand<C>,
834 stride: ExpandElementTyped<u32>,
835 layout: MatrixLayout,
836 ) {
837 let stride: ExpandElement = stride.into();
838
839 let (output, offset) = output.__to_raw_parts();
840
841 scope.register(Instruction::new(
842 ir::CoopMma::Store {
843 mat: *mat.elem,
844 offset,
845 stride: *stride,
846 layout,
847 },
848 output,
849 ));
850 }
851}
852
853#[allow(unused_variables)]
855pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
856 mat_a: &Matrix<A>,
857 mat_b: &Matrix<B>,
858 mat_c: &Matrix<C>,
859 mat_d: &Matrix<D>,
860) {
861 unexpanded!()
862}
863
864pub mod execute {
866 use super::*;
867
868 pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
870 scope: &mut Scope,
871 mat_a: MatrixExpand<A>,
872 mat_b: MatrixExpand<B>,
873 mat_c: MatrixExpand<C>,
874 mat_d: MatrixExpand<D>,
875 ) {
876 scope.register(Instruction::new(
877 ir::CoopMma::Execute {
878 mat_a: *mat_a.elem,
879 mat_b: *mat_b.elem,
880 mat_c: *mat_c.elem,
881 },
882 *mat_d.elem,
883 ));
884 }
885}
886
887#[allow(unused_variables)]
889pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
890 unexpanded!()
891}
892
893pub mod cast {
895 use super::*;
896
897 #[allow(unused_variables)]
899 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
900 scope: &mut Scope,
901 input: MatrixExpand<C>,
902 ) -> MatrixExpand<O> {
903 let ident = input.ident;
904
905 if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
906 return MatrixExpand {
907 elem: input.elem,
908 ident,
909 _c: PhantomData,
910 };
911 }
912 let input = *input.elem;
913 let input_mat = match input.kind {
914 ir::VariableKind::Matrix { mat, .. } => mat,
915 _ => unreachable!(),
916 };
917
918 let elem = O::as_type(scope);
919 let elem = scope.create_matrix(ir::Matrix::new(
920 ident,
921 input_mat.m,
922 input_mat.n,
923 input_mat.k,
924 elem,
925 MatrixLayout::Undefined,
926 ));
927
928 let output = MatrixExpand {
929 ident,
930 elem,
931 _c: PhantomData,
932 };
933 scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
934
935 output
936 }
937}
938
939impl CubeType for MatrixLayout {
940 type ExpandType = Self;
941}
942
943impl IntoMut for MatrixLayout {
944 fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
945 self
946 }
947}
948
949impl CubeDebug for MatrixLayout {}