1use super::{
50 CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
51 SliceMut,
52};
53use crate::{
54 self as cubecl,
55 prelude::{Array, Line, ReadWrite},
56};
57use crate::{
58 ir::{self, Instruction},
59 unexpanded,
60};
61use cubecl_macros::{comptime_type, cube, intrinsic};
62use std::marker::PhantomData;
63
64use cubecl_ir::{CoopMma, ExpandElement, LineSize, Scope, StorageType, Type};
65pub use ir::{MatrixIdent, MatrixLayout};
66
67#[derive(Copy, Clone)]
72pub struct Matrix<C: CubeType> {
73 _c: PhantomData<C>,
74}
75
76#[derive(Copy, Clone)]
78pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
79 _a: PhantomData<A>,
80 _b: PhantomData<B>,
81 _cd: PhantomData<CD>,
82}
83
84impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for &MmaDefinitionExpand<A, B, CD> {
85 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
86 MmaDefinitionExpand::set_debug_name(self, scope, name);
87 }
88}
89
90pub struct MatrixExpand<C: CubeType> {
92 elem: ExpandElement,
93 ident: MatrixIdent,
94 _c: PhantomData<C>,
95}
96
97#[derive(Debug)]
99pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
100 pub m: usize,
101 pub n: usize,
102 pub k: usize,
103 pub a_type: StorageType,
104 pub b_type: StorageType,
105 pub cd_type: StorageType,
106 pub scales_factor: Option<usize>,
107 pub scales_type: Option<StorageType>,
108 _a: PhantomData<A>,
109 _b: PhantomData<B>,
110 _cd: PhantomData<CD>,
111}
112
113impl<C: CubeType> Clone for MatrixExpand<C> {
114 fn clone(&self) -> Self {
115 Self {
116 elem: self.elem.clone(),
117 ident: self.ident,
118 _c: self._c,
119 }
120 }
121}
122
123impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
124 fn clone(&self) -> Self {
125 Self {
126 m: self.m,
127 n: self.n,
128 k: self.k,
129 a_type: self.a_type,
130 b_type: self.b_type,
131 cd_type: self.cd_type,
132 scales_factor: self.scales_factor,
133 scales_type: self.scales_type,
134 _a: PhantomData,
135 _b: PhantomData,
136 _cd: PhantomData,
137 }
138 }
139}
140
141impl<C: CubeType> CubeType for Matrix<C> {
142 type ExpandType = MatrixExpand<C>;
143}
144
145impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
146 type ExpandType = MmaDefinitionExpand<A, B, CD>;
147}
148
149impl<C: CubeType> IntoMut for MatrixExpand<C> {
150 fn into_mut(self, _scope: &mut Scope) -> Self {
151 self
152 }
153}
154
155impl<C: CubeType> CubeDebug for MatrixExpand<C> {
156 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
157 scope.update_variable_name(*self.elem, name);
158 }
159}
160
161impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
162 fn into_mut(self, _scope: &mut Scope) -> Self {
163 self
164 }
165}
166
167impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
168
169#[cube]
170impl<C: CubePrimitive> Matrix<C> {
171 #[allow(unused_variables)]
189 pub unsafe fn uninitialized(
190 #[comptime] ident: MatrixIdent,
191 #[comptime] m: usize,
192 #[comptime] n: usize,
193 #[comptime] k: usize,
194 layout: MatrixLayout,
195 ) -> Self {
196 intrinsic!(|scope| {
197 let elem = C::as_type(scope);
198 let elem = scope.create_matrix(ir::Matrix::new(ident, m, n, k, elem, layout));
199 MatrixExpand {
200 elem,
201 ident,
202 _c: PhantomData,
203 }
204 })
205 }
206
207 #[allow(unused_variables)]
221 pub fn from_value(
222 #[comptime] ident: MatrixIdent,
223 #[comptime] m: usize,
224 #[comptime] n: usize,
225 #[comptime] k: usize,
226 layout: MatrixLayout,
227 value: C,
228 ) -> Self {
229 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
230
231 intrinsic!(|scope| {
232 fill::expand(scope, mat.clone(), value);
233 mat
234 })
235 }
236
237 #[allow(unused_variables)]
251 pub fn from_slice(
252 #[comptime] ident: MatrixIdent,
253 #[comptime] m: usize,
254 #[comptime] n: usize,
255 #[comptime] k: usize,
256 layout: MatrixLayout,
257 value: &Slice<C>,
258 stride: u32,
259 ) -> Self {
260 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
261
262 intrinsic!(|scope| {
263 load::expand(scope, mat.clone(), value, stride);
264 mat
265 })
266 }
267}
268
269#[cube(self_type = "ref")]
270impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B, CD> {
271 #[allow(unused_variables)]
287 pub fn new(#[comptime] m: usize, #[comptime] n: usize, #[comptime] k: usize) -> Self {
288 intrinsic!(|scope| {
289 let a_type = A::as_type(scope);
290 let b_type = B::as_type(scope);
291 let cd_type = CD::as_type(scope);
292
293 MmaDefinitionExpand {
294 m,
295 n,
296 k,
297 a_type,
298 b_type,
299 cd_type,
300 scales_factor: None,
301 scales_type: None,
302 _a: PhantomData,
303 _b: PhantomData,
304 _cd: PhantomData,
305 }
306 })
307 }
308
309 #[allow(unused_variables)]
325 pub fn new_scaled<S: CubePrimitive>(
326 #[comptime] m: usize,
327 #[comptime] n: usize,
328 #[comptime] k: usize,
329 #[comptime] scale_factor: usize,
330 ) -> Self {
331 intrinsic!(|scope| {
332 let a_type = A::as_type(scope);
333 let b_type = B::as_type(scope);
334 let cd_type = CD::as_type(scope);
335
336 MmaDefinitionExpand {
337 m,
338 n,
339 k,
340 a_type,
341 b_type,
342 cd_type,
343 scales_factor: Some(scale_factor),
344 scales_type: Some(S::as_type(scope)),
345 _a: PhantomData,
346 _b: PhantomData,
347 _cd: PhantomData,
348 }
349 })
350 }
351
352 #[allow(unused)]
354 pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
355 intrinsic!(|scope| {
356 match ident {
357 MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor(),
358 MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor(),
359 MatrixIdent::Accumulator => (self.m * self.n) / self.cd_type.packing_factor(),
360 }
361 })
362 }
363
364 #[allow(unused)]
371 pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
372 intrinsic!(|scope| {
373 let elems = self.__expand_num_elems_method(scope, ident);
374 let plane_dim = scope.runtime_properties.mma.const_plane_size as usize;
375 let duplication = match ident {
376 MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
377 MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
378 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
379 };
380 (elems / plane_dim) * duplication
381 })
382 }
383
384 #[allow(unused)]
390 pub fn lines_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
391 intrinsic!(|scope| {
392 let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
393 let line_size = self.__expand_line_size_method(scope, ident);
394 elems / line_size
395 })
396 }
397
398 #[allow(unused)]
400 pub fn line_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
401 intrinsic!(|scope| {
402 match ident {
403 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
404 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
405 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
406 }
407 })
408 }
409
410 #[allow(unused_variables)]
413 pub fn line_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(LineSize) {
414 intrinsic!(|scope| {
415 let storage = match ident {
416 MatrixIdent::A => self.a_type,
417 MatrixIdent::B => self.b_type,
418 MatrixIdent::Accumulator => self.cd_type,
419 };
420 let matrix = cubecl_ir::Matrix {
421 ident,
422 m: self.m,
423 n: self.n,
424 k: self.k,
425 storage: storage,
426 layout: MatrixLayout::ColMajor,
427 };
428 scope
429 .runtime_properties
430 .mma
431 .contiguous_elements
432 .apply(ident, matrix)
433 })
434 }
435
436 #[allow(unused_variables)]
444 pub fn position_of_nth(
445 &self,
446 lane_id: u32,
447 elem_idx: u32,
448 #[comptime] ident: MatrixIdent,
449 ) -> (u32, u32) {
450 intrinsic!(|scope| {
451 let lane_id: ExpandElement = lane_id.into();
452 let elem_idx: ExpandElement = elem_idx.into();
453
454 let ty = match ident {
455 MatrixIdent::A => self.a_type,
456 MatrixIdent::B => self.b_type,
457 MatrixIdent::Accumulator => self.cd_type,
458 };
459 let layout = match ident {
460 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
461 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
462 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
463 };
464 let matrix = cubecl_ir::Matrix {
465 ident,
466 m: self.m,
467 n: self.n,
468 k: self.k,
469 storage: ty,
470 layout,
471 };
472
473 let row = scope.create_local(Type::new(u32::as_type(scope)));
474 let col = scope.create_local(Type::new(u32::as_type(scope)));
475 scope.register(Instruction::new(
476 CoopMma::RowIndex {
477 lane_id: *lane_id,
478 i: *elem_idx,
479 matrix,
480 },
481 *row,
482 ));
483 scope.register(Instruction::new(
484 CoopMma::ColIndex {
485 lane_id: *lane_id,
486 i: *elem_idx,
487 matrix,
488 },
489 *col,
490 ));
491 (row.into(), col.into())
492 })
493 }
494
495 pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
498 let quad_id = lane_id / 4;
500 let t_id = lane_id % 4;
501 match ident {
502 MatrixIdent::A => quad_id + (t_id % 2) * 8,
503 MatrixIdent::B => quad_id,
504 MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
505 }
506 }
507
508 pub fn scales_count(&self) -> comptime_type!(usize) {
510 intrinsic!(|_| {
513 self.scales_factor
514 .expect("Can't retrieve scales count for matrix with no scales")
515 })
516 }
517
518 pub fn scales_line_size(&self) -> comptime_type!(LineSize) {
520 intrinsic!(|scope| {
521 let elem = self
522 .scales_type
523 .expect("Can't retrieve scales line size for matrix with no scales");
524 scope.runtime_properties.mma.register_size_bits / elem.size_bits()
525 })
526 }
527
528 #[allow(unused_variables)]
539 pub fn load_matrix<E: CubePrimitive>(
540 &self,
541 row: &Slice<Line<E>>,
542 #[comptime] ident: MatrixIdent,
543 #[comptime] num_matrices: usize,
544 #[comptime] transpose: bool,
545 ) -> Array<Line<E>> {
546 intrinsic!(|scope| {
547 let line_size = self.__expand_line_size_method(scope, ident);
548 let slice_line_size = row.line_size;
549 let (buffer, offset) = row.__to_raw_parts();
550 let out = Array::__expand_lined(scope, num_matrices, line_size);
551 scope.register(Instruction::new(
552 CoopMma::LoadMatrix {
553 buffer,
554 offset,
555 line_size: slice_line_size,
556 factor: num_matrices,
557 transpose,
558 },
559 *out.expand,
560 ));
561 out
562 })
563 }
564
565 #[allow(unused_variables)]
576 pub fn store_matrix<E: CubePrimitive>(
577 &self,
578 row: &mut Slice<Line<E>, ReadWrite>,
579 registers: &Array<Line<E>>,
580 #[comptime] ident: MatrixIdent,
581 #[comptime] num_matrices: usize,
582 #[comptime] transpose: bool,
583 ) {
584 intrinsic!(|scope| {
585 let line_size = self.__expand_line_size_method(scope, ident);
586 let slice_line_size = row.line_size;
587 let (buffer, offset) = row.__to_raw_parts();
588 scope.register(Instruction::new(
589 CoopMma::StoreMatrix {
590 offset,
591 line_size: slice_line_size,
592 registers: *registers.expand,
593 factor: num_matrices,
594 transpose,
595 },
596 buffer,
597 ));
598 })
599 }
600
601 #[allow(unused)]
604 pub fn execute(
605 &self,
606 registers_a: &Array<Line<A>>,
607 registers_b: &Array<Line<B>>,
608 registers_c: &Array<Line<CD>>,
609 ) -> Array<Line<CD>> {
610 intrinsic!(|scope| {
611 let acc_elems = self
612 .clone()
613 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
614 let acc_line_size = self
615 .clone()
616 .__expand_line_size_method(scope, MatrixIdent::Accumulator);
617 let num_registers = acc_elems / acc_line_size;
618
619 let registers_d = Array::__expand_lined(scope, num_registers, acc_line_size);
620
621 let registers_a = *registers_a.expand;
622 let registers_b = *registers_b.expand;
623 let registers_c = *registers_c.expand;
624
625 let matrix = cubecl_ir::Matrix {
627 ident: MatrixIdent::A,
628 m: self.m,
629 n: self.n,
630 k: self.k,
631 storage: self.a_type,
632 layout: MatrixLayout::ColMajor,
633 };
634
635 scope.register(Instruction::new(
636 CoopMma::ExecuteManual {
637 matrix,
638 registers_a,
639 registers_b,
640 registers_c,
641 },
642 *registers_d.expand,
643 ));
644
645 registers_d
646 })
647 }
648
649 #[allow(unused)]
652 pub fn execute_scaled<S: CubePrimitive>(
653 &self,
654 registers_a: &Array<Line<A>>,
655 registers_b: &Array<Line<B>>,
656 registers_c: &Array<Line<CD>>,
657 scales_a: Line<S>,
658 scales_b: Line<S>,
659 ) -> Array<Line<CD>> {
660 intrinsic!(|scope| {
661 let acc_elems = self
662 .clone()
663 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
664 let acc_line_size = self
665 .clone()
666 .__expand_line_size_method(scope, MatrixIdent::Accumulator);
667 let num_registers = acc_elems / acc_line_size;
668
669 let registers_d = Array::__expand_lined(scope, num_registers, acc_line_size);
670
671 let registers_a = *registers_a.expand;
672 let registers_b = *registers_b.expand;
673 let registers_c = *registers_c.expand;
674
675 let matrix = cubecl_ir::Matrix {
677 ident: MatrixIdent::A,
678 m: self.m,
679 n: self.n,
680 k: self.k,
681 storage: self.a_type,
682 layout: MatrixLayout::ColMajor,
683 };
684
685 scope.register(Instruction::new(
686 CoopMma::ExecuteScaled {
687 matrix,
688 registers_a,
689 registers_b,
690 registers_c,
691 scales_a: *scales_a.expand,
692 scales_b: *scales_b.expand,
693 scales_factor: self
694 .scales_factor
695 .expect("Can't execute scaled on matrix with no scales"),
696 },
697 *registers_d.expand,
698 ));
699
700 registers_d
701 })
702 }
703}
704
705#[allow(unused_variables)]
707pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
708 unexpanded!()
709}
710
711pub mod fill {
713 use super::*;
714
715 pub fn expand<C: CubeType>(
717 scope: &mut Scope,
718 mat: MatrixExpand<C>,
719 value: ExpandElementTyped<C>,
720 ) {
721 let value: ExpandElement = value.into();
722 scope.register(Instruction::new(
723 ir::CoopMma::Fill { value: *value },
724 *mat.elem,
725 ));
726 }
727}
728
729#[allow(unused_variables)]
731pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
732 unexpanded!()
733}
734
735pub mod load {
737 use super::*;
738
739 #[allow(unused_variables)]
741 pub fn expand<C: CubePrimitive, V: CubePrimitive>(
742 scope: &mut Scope,
743 mat: MatrixExpand<C>,
744 value: SliceExpand<V, ReadOnly>,
745 stride: ExpandElementTyped<u32>,
746 ) {
747 let stride: ExpandElement = stride.into();
748 assert_ne!(
749 mat.ident,
750 MatrixIdent::Accumulator,
751 "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
752 );
753
754 let (value, offset) = value.__to_raw_parts();
755
756 scope.register(Instruction::new(
757 ir::CoopMma::Load {
758 value,
759 stride: *stride,
760 offset,
761 layout: None,
762 },
763 *mat.elem,
764 ));
765 }
766}
767
768#[allow(unused_variables)]
771pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
772 mat: &Matrix<C>,
773 value: &Slice<V>,
774 stride: u32,
775 layout: MatrixLayout,
776) {
777 unexpanded!()
778}
779
780pub mod load_with_layout {
782 use super::*;
783
784 #[allow(unused_variables)]
786 pub fn expand<C: CubeType, V: CubePrimitive>(
787 scope: &mut Scope,
788 mat: MatrixExpand<C>,
789 value: SliceExpand<V, ReadOnly>,
790 stride: ExpandElementTyped<u32>,
791 layout: MatrixLayout,
792 ) {
793 let stride: ExpandElement = stride.into();
794 let (value, offset) = value.__to_raw_parts();
795
796 scope.register(Instruction::new(
797 ir::CoopMma::Load {
798 value,
799 stride: *stride,
800 offset,
801 layout: Some(layout),
802 },
803 *mat.elem,
804 ));
805 }
806}
807
808#[allow(unused_variables)]
810pub fn store<C: CubePrimitive, O: CubePrimitive>(
811 output: &mut SliceMut<O>,
812 mat: &Matrix<C>,
813 stride: u32,
814 layout: MatrixLayout,
815) {
816 unexpanded!()
817}
818
819pub mod store {
821 use crate::prelude::ReadWrite;
822
823 use super::*;
824
825 #[allow(unused_variables)]
827 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
828 scope: &mut Scope,
829 output: SliceExpand<O, ReadWrite>,
830 mat: MatrixExpand<C>,
831 stride: ExpandElementTyped<u32>,
832 layout: MatrixLayout,
833 ) {
834 let stride: ExpandElement = stride.into();
835
836 let (output, offset) = output.__to_raw_parts();
837
838 scope.register(Instruction::new(
839 ir::CoopMma::Store {
840 mat: *mat.elem,
841 offset,
842 stride: *stride,
843 layout,
844 },
845 output,
846 ));
847 }
848}
849
850#[allow(unused_variables)]
852pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
853 mat_a: &Matrix<A>,
854 mat_b: &Matrix<B>,
855 mat_c: &Matrix<C>,
856 mat_d: &Matrix<D>,
857) {
858 unexpanded!()
859}
860
861pub mod execute {
863 use super::*;
864
865 pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
867 scope: &mut Scope,
868 mat_a: MatrixExpand<A>,
869 mat_b: MatrixExpand<B>,
870 mat_c: MatrixExpand<C>,
871 mat_d: MatrixExpand<D>,
872 ) {
873 scope.register(Instruction::new(
874 ir::CoopMma::Execute {
875 mat_a: *mat_a.elem,
876 mat_b: *mat_b.elem,
877 mat_c: *mat_c.elem,
878 },
879 *mat_d.elem,
880 ));
881 }
882}
883
884#[allow(unused_variables)]
886pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
887 unexpanded!()
888}
889
890pub mod cast {
892 use super::*;
893
894 #[allow(unused_variables)]
896 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
897 scope: &mut Scope,
898 input: MatrixExpand<C>,
899 ) -> MatrixExpand<O> {
900 let ident = input.ident;
901
902 if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
903 return MatrixExpand {
904 elem: input.elem,
905 ident,
906 _c: PhantomData,
907 };
908 }
909 let input = *input.elem;
910 let input_mat = match input.kind {
911 ir::VariableKind::Matrix { mat, .. } => mat,
912 _ => unreachable!(),
913 };
914
915 let elem = O::as_type(scope);
916 let elem = scope.create_matrix(ir::Matrix::new(
917 ident,
918 input_mat.m,
919 input_mat.n,
920 input_mat.k,
921 elem,
922 MatrixLayout::Undefined,
923 ));
924
925 let output = MatrixExpand {
926 ident,
927 elem,
928 _c: PhantomData,
929 };
930 scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
931
932 output
933 }
934}
935
936impl CubeType for MatrixLayout {
937 type ExpandType = Self;
938}
939
940impl IntoMut for MatrixLayout {
941 fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
942 self
943 }
944}
945
946impl CubeDebug for MatrixLayout {}