1use super::{
50 CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
51 SliceMut,
52};
53use crate::{
54 self as cubecl,
55 prelude::{Array, Line},
56};
57use crate::{
58 ir::{self, Instruction},
59 unexpanded,
60};
61use cubecl_macros::{comptime_type, cube, intrinsic};
62use std::marker::PhantomData;
63
64use cubecl_ir::{CoopMma, ExpandElement, Scope, StorageType, Type};
65pub use ir::{MatrixIdent, MatrixLayout};
66
67#[derive(Copy, Clone)]
72pub struct Matrix<C: CubeType> {
73 _c: PhantomData<C>,
74}
75
76#[derive(Copy, Clone)]
78pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
79 _a: PhantomData<A>,
80 _b: PhantomData<B>,
81 _cd: PhantomData<CD>,
82}
83
84pub struct MatrixExpand<C: CubeType> {
86 elem: ExpandElement,
87 ident: MatrixIdent,
88 _c: PhantomData<C>,
89}
90
91#[derive(Debug)]
93pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
94 pub m: u32,
95 pub n: u32,
96 pub k: u32,
97 pub a_type: StorageType,
98 pub b_type: StorageType,
99 pub cd_type: StorageType,
100 pub scales_factor: Option<u32>,
101 pub scales_type: Option<StorageType>,
102 _a: PhantomData<A>,
103 _b: PhantomData<B>,
104 _cd: PhantomData<CD>,
105}
106
107impl<C: CubeType> Clone for MatrixExpand<C> {
108 fn clone(&self) -> Self {
109 Self {
110 elem: self.elem.clone(),
111 ident: self.ident,
112 _c: self._c,
113 }
114 }
115}
116
117impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
118 fn clone(&self) -> Self {
119 Self {
120 m: self.m,
121 n: self.n,
122 k: self.k,
123 a_type: self.a_type,
124 b_type: self.b_type,
125 cd_type: self.cd_type,
126 scales_factor: self.scales_factor,
127 scales_type: self.scales_type,
128 _a: PhantomData,
129 _b: PhantomData,
130 _cd: PhantomData,
131 }
132 }
133}
134
135impl<C: CubeType> CubeType for Matrix<C> {
136 type ExpandType = MatrixExpand<C>;
137}
138
139impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
140 type ExpandType = MmaDefinitionExpand<A, B, CD>;
141}
142
143impl<C: CubeType> IntoMut for MatrixExpand<C> {
144 fn into_mut(self, _scope: &mut Scope) -> Self {
145 self
146 }
147}
148
149impl<C: CubeType> CubeDebug for MatrixExpand<C> {
150 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
151 scope.update_variable_name(*self.elem, name);
152 }
153}
154
155impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
156 fn into_mut(self, _scope: &mut Scope) -> Self {
157 self
158 }
159}
160
161impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
162
163#[cube]
164impl<C: CubePrimitive> Matrix<C> {
165 #[allow(unused_variables)]
183 pub unsafe fn uninitialized(
184 #[comptime] ident: MatrixIdent,
185 m: u32,
186 n: u32,
187 k: u32,
188 layout: MatrixLayout,
189 ) -> Self {
190 intrinsic!(|scope| {
191 let elem = C::as_type(scope);
192 let elem = scope.create_matrix(ir::Matrix::new(
193 ident,
194 m.constant().unwrap().as_u32(),
195 n.constant().unwrap().as_u32(),
196 k.constant().unwrap().as_u32(),
197 elem,
198 layout,
199 ));
200 MatrixExpand {
201 elem,
202 ident,
203 _c: PhantomData,
204 }
205 })
206 }
207
208 #[allow(unused_variables)]
222 pub fn from_value(
223 #[comptime] ident: MatrixIdent,
224 m: u32,
225 n: u32,
226 k: u32,
227 layout: MatrixLayout,
228 value: C,
229 ) -> Self {
230 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
231
232 intrinsic!(|scope| {
233 fill::expand(scope, mat.clone(), value);
234 mat
235 })
236 }
237
238 #[allow(unused_variables)]
252 pub fn from_slice(
253 #[comptime] ident: MatrixIdent,
254 m: u32,
255 n: u32,
256 k: u32,
257 layout: MatrixLayout,
258 value: &Slice<C>,
259 stride: u32,
260 ) -> Self {
261 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
262
263 intrinsic!(|scope| {
264 load::expand(scope, mat.clone(), value, stride);
265 mat
266 })
267 }
268}
269
270#[cube]
271impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B, CD> {
272 #[allow(unused_variables)]
288 pub fn new(#[comptime] m: u32, #[comptime] n: u32, #[comptime] k: u32) -> Self {
289 intrinsic!(|scope| {
290 let a_type = A::as_type(scope);
291 let b_type = B::as_type(scope);
292 let cd_type = CD::as_type(scope);
293
294 MmaDefinitionExpand {
295 m,
296 n,
297 k,
298 a_type,
299 b_type,
300 cd_type,
301 scales_factor: None,
302 scales_type: None,
303 _a: PhantomData,
304 _b: PhantomData,
305 _cd: PhantomData,
306 }
307 })
308 }
309
310 #[allow(unused_variables)]
326 pub fn new_scaled<S: CubePrimitive>(
327 #[comptime] m: u32,
328 #[comptime] n: u32,
329 #[comptime] k: u32,
330 #[comptime] scale_factor: u32,
331 ) -> Self {
332 intrinsic!(|scope| {
333 let a_type = A::as_type(scope);
334 let b_type = B::as_type(scope);
335 let cd_type = CD::as_type(scope);
336
337 MmaDefinitionExpand {
338 m,
339 n,
340 k,
341 a_type,
342 b_type,
343 cd_type,
344 scales_factor: Some(scale_factor),
345 scales_type: Some(S::as_type(scope)),
346 _a: PhantomData,
347 _b: PhantomData,
348 _cd: PhantomData,
349 }
350 })
351 }
352
353 #[allow(unused)]
355 pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
356 intrinsic!(|scope| {
357 match ident {
358 MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor() as u32,
359 MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor() as u32,
360 MatrixIdent::Accumulator => {
361 (self.m * self.n) / self.cd_type.packing_factor() as u32
362 }
363 }
364 })
365 }
366
367 #[allow(unused)]
374 pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
375 intrinsic!(|scope| {
376 let elems = self.__expand_num_elems_method(scope, ident);
377 let plane_dim = scope.runtime_properties.mma.const_plane_size;
378 let duplication = match ident {
379 MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
380 MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
381 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
382 };
383 (elems / plane_dim) * duplication
384 })
385 }
386
387 #[allow(unused)]
393 pub fn lines_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
394 intrinsic!(|scope| {
395 let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
396 let line_size = self.__expand_line_size_method(scope, ident);
397 elems / line_size
398 })
399 }
400
401 #[allow(unused)]
403 pub fn line_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
404 intrinsic!(|scope| {
405 match ident {
406 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
407 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
408 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
409 }
410 })
411 }
412
413 #[allow(unused_variables)]
416 pub fn line_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
417 intrinsic!(|scope| {
418 let storage = match ident {
419 MatrixIdent::A => self.a_type,
420 MatrixIdent::B => self.b_type,
421 MatrixIdent::Accumulator => self.cd_type,
422 };
423 let matrix = cubecl_ir::Matrix {
424 ident,
425 m: self.m,
426 n: self.n,
427 k: self.k,
428 storage: storage,
429 layout: MatrixLayout::ColMajor,
430 };
431 scope
432 .runtime_properties
433 .mma
434 .contiguous_elements
435 .apply(ident, matrix)
436 })
437 }
438
439 #[allow(unused_variables)]
447 pub fn position_of_nth(
448 &self,
449 lane_id: u32,
450 elem_idx: u32,
451 #[comptime] ident: MatrixIdent,
452 ) -> (u32, u32) {
453 intrinsic!(|scope| {
454 let lane_id: ExpandElement = lane_id.into();
455 let elem_idx: ExpandElement = elem_idx.into();
456
457 let ty = match ident {
458 MatrixIdent::A => self.a_type,
459 MatrixIdent::B => self.b_type,
460 MatrixIdent::Accumulator => self.cd_type,
461 };
462 let layout = match ident {
463 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
464 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
465 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
466 };
467 let matrix = cubecl_ir::Matrix {
468 ident,
469 m: self.m,
470 n: self.n,
471 k: self.k,
472 storage: ty,
473 layout,
474 };
475
476 let row = scope.create_local(Type::new(u32::as_type(scope)));
477 let col = scope.create_local(Type::new(u32::as_type(scope)));
478 scope.register(Instruction::new(
479 CoopMma::RowIndex {
480 lane_id: *lane_id,
481 i: *elem_idx,
482 matrix,
483 },
484 *row,
485 ));
486 scope.register(Instruction::new(
487 CoopMma::ColIndex {
488 lane_id: *lane_id,
489 i: *elem_idx,
490 matrix,
491 },
492 *col,
493 ));
494 (row.into(), col.into())
495 })
496 }
497
498 pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
501 let quad_id = lane_id / 4;
503 let t_id = lane_id % 4;
504 match ident {
505 MatrixIdent::A => quad_id + (t_id % 2) * 8,
506 MatrixIdent::B => quad_id,
507 MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
508 }
509 }
510
511 pub fn scales_count(&self) -> comptime_type!(u32) {
513 intrinsic!(|_| {
516 self.scales_factor
517 .expect("Can't retrieve scales count for matrix with no scales")
518 })
519 }
520
521 pub fn scales_line_size(&self) -> comptime_type!(u32) {
523 intrinsic!(|scope| {
524 let elem = self
525 .scales_type
526 .expect("Can't retrieve scales line size for matrix with no scales");
527 scope.runtime_properties.mma.register_size_bits / elem.size_bits() as u32
528 })
529 }
530
531 #[allow(unused_variables)]
542 pub fn load_matrix<E: CubePrimitive>(
543 &self,
544 row: &Slice<Line<E>>,
545 #[comptime] ident: MatrixIdent,
546 #[comptime] num_matrices: u32,
547 #[comptime] transpose: bool,
548 ) -> Array<Line<E>> {
549 intrinsic!(|scope| {
550 let line_size = self.__expand_line_size_method(scope, ident);
551 let slice_line_size = row.line_size;
552 let (buffer, offset) = row.__to_raw_parts();
553 let out = Array::__expand_vectorized(scope, num_matrices, line_size);
554 scope.register(Instruction::new(
555 CoopMma::LoadMatrix {
556 buffer,
557 offset,
558 line_size: slice_line_size,
559 factor: num_matrices,
560 transpose,
561 },
562 *out.expand,
563 ));
564 out
565 })
566 }
567
568 #[allow(unused)]
571 pub fn execute(
572 &self,
573 registers_a: &Array<Line<A>>,
574 registers_b: &Array<Line<B>>,
575 registers_c: &Array<Line<CD>>,
576 ) -> Array<Line<CD>> {
577 intrinsic!(|scope| {
578 let acc_elems = self
579 .clone()
580 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
581 let acc_line_size = self
582 .clone()
583 .__expand_line_size_method(scope, MatrixIdent::Accumulator);
584 let num_registers = acc_elems / acc_line_size;
585
586 let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
587
588 let registers_a = *registers_a.expand;
589 let registers_b = *registers_b.expand;
590 let registers_c = *registers_c.expand;
591
592 let matrix = cubecl_ir::Matrix {
594 ident: MatrixIdent::A,
595 m: self.m,
596 n: self.n,
597 k: self.k,
598 storage: self.a_type,
599 layout: MatrixLayout::ColMajor,
600 };
601
602 scope.register(Instruction::new(
603 CoopMma::ExecuteManual {
604 matrix,
605 registers_a,
606 registers_b,
607 registers_c,
608 },
609 *registers_d.expand,
610 ));
611
612 registers_d
613 })
614 }
615
616 #[allow(unused)]
619 pub fn execute_scaled<S: CubePrimitive>(
620 &self,
621 registers_a: &Array<Line<A>>,
622 registers_b: &Array<Line<B>>,
623 registers_c: &Array<Line<CD>>,
624 scales_a: Line<S>,
625 scales_b: Line<S>,
626 ) -> Array<Line<CD>> {
627 intrinsic!(|scope| {
628 let acc_elems = self
629 .clone()
630 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
631 let acc_line_size = self
632 .clone()
633 .__expand_line_size_method(scope, MatrixIdent::Accumulator);
634 let num_registers = acc_elems / acc_line_size;
635
636 let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
637
638 let registers_a = *registers_a.expand;
639 let registers_b = *registers_b.expand;
640 let registers_c = *registers_c.expand;
641
642 let matrix = cubecl_ir::Matrix {
644 ident: MatrixIdent::A,
645 m: self.m,
646 n: self.n,
647 k: self.k,
648 storage: self.a_type,
649 layout: MatrixLayout::ColMajor,
650 };
651
652 scope.register(Instruction::new(
653 CoopMma::ExecuteScaled {
654 matrix,
655 registers_a,
656 registers_b,
657 registers_c,
658 scales_a: *scales_a.expand,
659 scales_b: *scales_b.expand,
660 scales_factor: self
661 .scales_factor
662 .expect("Can't execute scaled on matrix with no scales"),
663 },
664 *registers_d.expand,
665 ));
666
667 registers_d
668 })
669 }
670}
671
672#[allow(unused_variables)]
674pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
675 unexpanded!()
676}
677
678pub mod fill {
680 use super::*;
681
682 pub fn expand<C: CubeType>(
684 scope: &mut Scope,
685 mat: MatrixExpand<C>,
686 value: ExpandElementTyped<C>,
687 ) {
688 let value: ExpandElement = value.into();
689 scope.register(Instruction::new(
690 ir::CoopMma::Fill { value: *value },
691 *mat.elem,
692 ));
693 }
694}
695
696#[allow(unused_variables)]
698pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
699 unexpanded!()
700}
701
702pub mod load {
704 use super::*;
705
706 #[allow(unused_variables)]
708 pub fn expand<C: CubePrimitive, V: CubePrimitive>(
709 scope: &mut Scope,
710 mat: MatrixExpand<C>,
711 value: SliceExpand<V, ReadOnly>,
712 stride: ExpandElementTyped<u32>,
713 ) {
714 let stride: ExpandElement = stride.into();
715 assert_ne!(
716 mat.ident,
717 MatrixIdent::Accumulator,
718 "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
719 );
720
721 let (value, offset) = value.__to_raw_parts();
722
723 scope.register(Instruction::new(
724 ir::CoopMma::Load {
725 value,
726 stride: *stride,
727 offset,
728 layout: None,
729 },
730 *mat.elem,
731 ));
732 }
733}
734
735#[allow(unused_variables)]
738pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
739 mat: &Matrix<C>,
740 value: &Slice<V>,
741 stride: u32,
742 layout: MatrixLayout,
743) {
744 unexpanded!()
745}
746
747pub mod load_with_layout {
749 use super::*;
750
751 #[allow(unused_variables)]
753 pub fn expand<C: CubeType, V: CubePrimitive>(
754 scope: &mut Scope,
755 mat: MatrixExpand<C>,
756 value: SliceExpand<V, ReadOnly>,
757 stride: ExpandElementTyped<u32>,
758 layout: MatrixLayout,
759 ) {
760 let stride: ExpandElement = stride.into();
761 let (value, offset) = value.__to_raw_parts();
762
763 scope.register(Instruction::new(
764 ir::CoopMma::Load {
765 value,
766 stride: *stride,
767 offset,
768 layout: Some(layout),
769 },
770 *mat.elem,
771 ));
772 }
773}
774
775#[allow(unused_variables)]
777pub fn store<C: CubePrimitive, O: CubePrimitive>(
778 output: &mut SliceMut<O>,
779 mat: &Matrix<C>,
780 stride: u32,
781 layout: MatrixLayout,
782) {
783 unexpanded!()
784}
785
786pub mod store {
788 use crate::prelude::ReadWrite;
789
790 use super::*;
791
792 #[allow(unused_variables)]
794 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
795 scope: &mut Scope,
796 output: SliceExpand<O, ReadWrite>,
797 mat: MatrixExpand<C>,
798 stride: ExpandElementTyped<u32>,
799 layout: MatrixLayout,
800 ) {
801 let stride: ExpandElement = stride.into();
802
803 let (output, offset) = output.__to_raw_parts();
804
805 scope.register(Instruction::new(
806 ir::CoopMma::Store {
807 mat: *mat.elem,
808 offset,
809 stride: *stride,
810 layout,
811 },
812 output,
813 ));
814 }
815}
816
817#[allow(unused_variables)]
819pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
820 mat_a: &Matrix<A>,
821 mat_b: &Matrix<B>,
822 mat_c: &Matrix<C>,
823 mat_d: &Matrix<D>,
824) {
825 unexpanded!()
826}
827
828pub mod execute {
830 use super::*;
831
832 pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
834 scope: &mut Scope,
835 mat_a: MatrixExpand<A>,
836 mat_b: MatrixExpand<B>,
837 mat_c: MatrixExpand<C>,
838 mat_d: MatrixExpand<D>,
839 ) {
840 scope.register(Instruction::new(
841 ir::CoopMma::Execute {
842 mat_a: *mat_a.elem,
843 mat_b: *mat_b.elem,
844 mat_c: *mat_c.elem,
845 },
846 *mat_d.elem,
847 ));
848 }
849}
850
851#[allow(unused_variables)]
853pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
854 unexpanded!()
855}
856
857pub mod cast {
859 use super::*;
860
861 #[allow(unused_variables)]
863 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
864 scope: &mut Scope,
865 input: MatrixExpand<C>,
866 ) -> MatrixExpand<O> {
867 let ident = input.ident;
868
869 if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
870 return MatrixExpand {
871 elem: input.elem,
872 ident,
873 _c: PhantomData,
874 };
875 }
876 let input = *input.elem;
877 let input_mat = match input.kind {
878 ir::VariableKind::Matrix { mat, .. } => mat,
879 _ => unreachable!(),
880 };
881
882 let elem = O::as_type(scope);
883 let elem = scope.create_matrix(ir::Matrix::new(
884 ident,
885 input_mat.m,
886 input_mat.n,
887 input_mat.k,
888 elem,
889 MatrixLayout::Undefined,
890 ));
891
892 let output = MatrixExpand {
893 ident,
894 elem,
895 _c: PhantomData,
896 };
897 scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
898
899 output
900 }
901}
902
903impl CubeType for MatrixLayout {
904 type ExpandType = Self;
905}
906
907impl IntoMut for MatrixLayout {
908 fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
909 self
910 }
911}
912
913impl CubeDebug for MatrixLayout {}