1use super::{
50 CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
51 SliceMut,
52};
53use crate::{
54 self as cubecl,
55 prelude::{Array, Line},
56};
57use crate::{
58 ir::{self, Instruction},
59 unexpanded,
60};
61use cubecl_macros::{comptime_type, cube, intrinsic};
62use std::marker::PhantomData;
63
64use cubecl_ir::{CoopMma, ExpandElement, Scope, StorageType, Type};
65pub use ir::{MatrixIdent, MatrixLayout};
66
67#[derive(Copy, Clone)]
72pub struct Matrix<C: CubeType> {
73 _c: PhantomData<C>,
74}
75
76#[derive(Copy, Clone)]
78pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
79 _a: PhantomData<A>,
80 _b: PhantomData<B>,
81 _cd: PhantomData<CD>,
82}
83
84pub struct MatrixExpand<C: CubeType> {
86 elem: ExpandElement,
87 ident: MatrixIdent,
88 _c: PhantomData<C>,
89}
90
91#[derive(Debug)]
93pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
94 pub m: u32,
95 pub n: u32,
96 pub k: u32,
97 pub a_type: StorageType,
98 pub b_type: StorageType,
99 pub cd_type: StorageType,
100 pub scales_factor: Option<u32>,
101 pub scales_type: Option<StorageType>,
102 _a: PhantomData<A>,
103 _b: PhantomData<B>,
104 _cd: PhantomData<CD>,
105}
106
107impl<C: CubeType> Clone for MatrixExpand<C> {
108 fn clone(&self) -> Self {
109 Self {
110 elem: self.elem.clone(),
111 ident: self.ident,
112 _c: self._c,
113 }
114 }
115}
116
117impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
118 fn clone(&self) -> Self {
119 Self {
120 m: self.m,
121 n: self.n,
122 k: self.k,
123 a_type: self.a_type,
124 b_type: self.b_type,
125 cd_type: self.cd_type,
126 scales_factor: self.scales_factor,
127 scales_type: self.scales_type,
128 _a: PhantomData,
129 _b: PhantomData,
130 _cd: PhantomData,
131 }
132 }
133}
134
135impl<C: CubeType> CubeType for Matrix<C> {
136 type ExpandType = MatrixExpand<C>;
137}
138
139impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
140 type ExpandType = MmaDefinitionExpand<A, B, CD>;
141}
142
143impl<C: CubeType> IntoMut for MatrixExpand<C> {
144 fn into_mut(self, _scope: &mut Scope) -> Self {
145 self
146 }
147}
148
149impl<C: CubeType> CubeDebug for MatrixExpand<C> {
150 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
151 scope.update_variable_name(*self.elem, name);
152 }
153}
154
155impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
156 fn into_mut(self, _scope: &mut Scope) -> Self {
157 self
158 }
159}
160
161impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
162
163#[cube]
164impl<C: CubePrimitive> Matrix<C> {
165 #[allow(unused_variables)]
183 pub unsafe fn uninitialized(
184 #[comptime] ident: MatrixIdent,
185 m: u32,
186 n: u32,
187 k: u32,
188 layout: MatrixLayout,
189 ) -> Self {
190 intrinsic!(|scope| {
191 let elem = C::as_type(scope);
192 let elem = scope.create_matrix(ir::Matrix::new(
193 ident,
194 m.constant().unwrap().as_u32(),
195 n.constant().unwrap().as_u32(),
196 k.constant().unwrap().as_u32(),
197 elem,
198 layout,
199 ));
200 MatrixExpand {
201 elem,
202 ident,
203 _c: PhantomData,
204 }
205 })
206 }
207
208 #[allow(unused_variables)]
222 pub fn from_value(
223 #[comptime] ident: MatrixIdent,
224 m: u32,
225 n: u32,
226 k: u32,
227 layout: MatrixLayout,
228 value: C,
229 ) -> Self {
230 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
231
232 intrinsic!(|scope| {
233 fill::expand(scope, mat.clone(), value);
234 mat
235 })
236 }
237
238 #[allow(unused_variables)]
252 pub fn from_slice(
253 #[comptime] ident: MatrixIdent,
254 m: u32,
255 n: u32,
256 k: u32,
257 layout: MatrixLayout,
258 value: &Slice<C>,
259 stride: u32,
260 ) -> Self {
261 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
262
263 intrinsic!(|scope| {
264 load::expand(scope, mat.clone(), value, stride);
265 mat
266 })
267 }
268}
269
270#[cube]
271impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B, CD> {
272 #[allow(unused_variables)]
288 pub fn new(#[comptime] m: u32, #[comptime] n: u32, #[comptime] k: u32) -> Self {
289 intrinsic!(|scope| {
290 let a_type = A::as_type(scope);
291 let b_type = B::as_type(scope);
292 let cd_type = CD::as_type(scope);
293
294 MmaDefinitionExpand {
295 m,
296 n,
297 k,
298 a_type,
299 b_type,
300 cd_type,
301 scales_factor: None,
302 scales_type: None,
303 _a: PhantomData,
304 _b: PhantomData,
305 _cd: PhantomData,
306 }
307 })
308 }
309
310 #[allow(unused_variables)]
326 pub fn new_scaled<S: CubePrimitive>(
327 #[comptime] m: u32,
328 #[comptime] n: u32,
329 #[comptime] k: u32,
330 #[comptime] scale_factor: u32,
331 ) -> Self {
332 intrinsic!(|scope| {
333 let a_type = A::as_type(scope);
334 let b_type = B::as_type(scope);
335 let cd_type = CD::as_type(scope);
336
337 MmaDefinitionExpand {
338 m,
339 n,
340 k,
341 a_type,
342 b_type,
343 cd_type,
344 scales_factor: Some(scale_factor),
345 scales_type: Some(S::as_type(scope)),
346 _a: PhantomData,
347 _b: PhantomData,
348 _cd: PhantomData,
349 }
350 })
351 }
352
353 #[allow(unused)]
355 pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
356 intrinsic!(|scope| {
357 match ident {
358 MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor() as u32,
359 MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor() as u32,
360 MatrixIdent::Accumulator => {
361 (self.m * self.n) / self.cd_type.packing_factor() as u32
362 }
363 }
364 })
365 }
366
367 #[allow(unused)]
374 pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
375 intrinsic!(|scope| {
376 let elems = self.__expand_num_elems_method(scope, ident);
377 let plane_dim = scope.runtime_properties.mma.const_plane_size;
378 let duplication = match ident {
379 MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
380 MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
381 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
382 };
383 (elems / plane_dim) * duplication
384 })
385 }
386
387 #[allow(unused)]
393 pub fn lines_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
394 intrinsic!(|scope| {
395 let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
396 let line_size = self.__expand_line_size_method(scope, ident);
397 elems / line_size
398 })
399 }
400
401 #[allow(unused)]
403 pub fn line_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
404 intrinsic!(|scope| {
405 match ident {
406 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
407 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
408 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
409 }
410 })
411 }
412
413 #[allow(unused_variables)]
415 pub fn line_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
416 intrinsic!(|scope| {
417 let bits = match ident {
418 MatrixIdent::A => StorageType::size_bits(&self.a_type) as u32,
419 MatrixIdent::B => StorageType::size_bits(&self.b_type) as u32,
420 MatrixIdent::Accumulator => StorageType::size_bits(&self.cd_type) as u32,
421 };
422 let register_size = scope.runtime_properties.mma.register_size_bits;
423 register_size.div_ceil(bits)
425 })
426 }
427
428 #[allow(unused_variables)]
436 pub fn position_of_nth(
437 &self,
438 lane_id: u32,
439 elem_idx: u32,
440 #[comptime] ident: MatrixIdent,
441 ) -> (u32, u32) {
442 intrinsic!(|scope| {
443 let lane_id: ExpandElement = lane_id.into();
444 let elem_idx: ExpandElement = elem_idx.into();
445
446 let ty = match ident {
447 MatrixIdent::A => self.a_type,
448 MatrixIdent::B => self.b_type,
449 MatrixIdent::Accumulator => self.cd_type,
450 };
451 let layout = match ident {
452 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
453 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
454 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
455 };
456 let matrix = cubecl_ir::Matrix {
457 ident,
458 m: self.m,
459 n: self.n,
460 k: self.k,
461 storage: ty,
462 layout,
463 };
464
465 let row = scope.create_local(Type::new(u32::as_type(scope)));
466 let col = scope.create_local(Type::new(u32::as_type(scope)));
467 scope.register(Instruction::new(
468 CoopMma::RowIndex {
469 lane_id: *lane_id,
470 i: *elem_idx,
471 matrix,
472 },
473 *row,
474 ));
475 scope.register(Instruction::new(
476 CoopMma::ColIndex {
477 lane_id: *lane_id,
478 i: *elem_idx,
479 matrix,
480 },
481 *col,
482 ));
483 (row.into(), col.into())
484 })
485 }
486
487 pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
490 let quad_id = lane_id / 4;
492 let t_id = lane_id % 4;
493 match ident {
494 MatrixIdent::A => quad_id + (t_id % 2) * 8,
495 MatrixIdent::B => quad_id,
496 MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
497 }
498 }
499
500 pub fn scales_count(&self) -> comptime_type!(u32) {
502 intrinsic!(|_| {
505 self.scales_factor
506 .expect("Can't retrieve scales count for matrix with no scales")
507 })
508 }
509
510 pub fn scales_line_size(&self) -> comptime_type!(u32) {
512 intrinsic!(|scope| {
513 let elem = self
514 .scales_type
515 .expect("Can't retrieve scales line size for matrix with no scales");
516 scope.runtime_properties.mma.register_size_bits / elem.size_bits() as u32
517 })
518 }
519
520 #[allow(unused_variables)]
531 pub fn load_matrix<E: CubePrimitive>(
532 &self,
533 row: &Slice<Line<E>>,
534 #[comptime] ident: MatrixIdent,
535 #[comptime] num_matrices: u32,
536 #[comptime] transpose: bool,
537 ) -> Array<Line<E>> {
538 intrinsic!(|scope| {
539 let line_size = self.__expand_line_size_method(scope, ident);
540 let slice_line_size = row.line_size;
541 let (buffer, offset) = row.__to_raw_parts();
542 let out = Array::__expand_vectorized(scope, num_matrices, line_size);
543 scope.register(Instruction::new(
544 CoopMma::LoadMatrix {
545 buffer,
546 offset,
547 line_size: slice_line_size,
548 factor: num_matrices,
549 transpose,
550 },
551 *out.expand,
552 ));
553 out
554 })
555 }
556
557 #[allow(unused)]
560 pub fn execute(
561 &self,
562 registers_a: &Array<Line<A>>,
563 registers_b: &Array<Line<B>>,
564 registers_c: &Array<Line<CD>>,
565 ) -> Array<Line<CD>> {
566 intrinsic!(|scope| {
567 let acc_elems = self
568 .clone()
569 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
570 let acc_line_size = self
571 .clone()
572 .__expand_line_size_method(scope, MatrixIdent::Accumulator);
573 let num_registers = acc_elems / acc_line_size;
574
575 let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
576
577 let registers_a = *registers_a.expand;
578 let registers_b = *registers_b.expand;
579 let registers_c = *registers_c.expand;
580
581 let matrix = cubecl_ir::Matrix {
583 ident: MatrixIdent::A,
584 m: self.m,
585 n: self.n,
586 k: self.k,
587 storage: self.a_type,
588 layout: MatrixLayout::ColMajor,
589 };
590
591 scope.register(Instruction::new(
592 CoopMma::ExecuteManual {
593 matrix,
594 registers_a,
595 registers_b,
596 registers_c,
597 },
598 *registers_d.expand,
599 ));
600
601 registers_d
602 })
603 }
604
605 #[allow(unused)]
608 pub fn execute_scaled<S: CubePrimitive>(
609 &self,
610 registers_a: &Array<Line<A>>,
611 registers_b: &Array<Line<B>>,
612 registers_c: &Array<Line<CD>>,
613 scales_a: Line<S>,
614 scales_b: Line<S>,
615 ) -> Array<Line<CD>> {
616 intrinsic!(|scope| {
617 let acc_elems = self
618 .clone()
619 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
620 let acc_line_size = self
621 .clone()
622 .__expand_line_size_method(scope, MatrixIdent::Accumulator);
623 let num_registers = acc_elems / acc_line_size;
624
625 let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
626
627 let registers_a = *registers_a.expand;
628 let registers_b = *registers_b.expand;
629 let registers_c = *registers_c.expand;
630
631 let matrix = cubecl_ir::Matrix {
633 ident: MatrixIdent::A,
634 m: self.m,
635 n: self.n,
636 k: self.k,
637 storage: self.a_type,
638 layout: MatrixLayout::ColMajor,
639 };
640
641 scope.register(Instruction::new(
642 CoopMma::ExecuteScaled {
643 matrix,
644 registers_a,
645 registers_b,
646 registers_c,
647 scales_a: *scales_a.expand,
648 scales_b: *scales_b.expand,
649 scales_factor: self
650 .scales_factor
651 .expect("Can't execute scaled on matrix with no scales"),
652 },
653 *registers_d.expand,
654 ));
655
656 registers_d
657 })
658 }
659}
660
661#[allow(unused_variables)]
663pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
664 unexpanded!()
665}
666
667pub mod fill {
669 use super::*;
670
671 pub fn expand<C: CubeType>(
673 scope: &mut Scope,
674 mat: MatrixExpand<C>,
675 value: ExpandElementTyped<C>,
676 ) {
677 let value: ExpandElement = value.into();
678 scope.register(Instruction::new(
679 ir::CoopMma::Fill { value: *value },
680 *mat.elem,
681 ));
682 }
683}
684
685#[allow(unused_variables)]
687pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
688 unexpanded!()
689}
690
691pub mod load {
693 use super::*;
694
695 #[allow(unused_variables)]
697 pub fn expand<C: CubePrimitive, V: CubePrimitive>(
698 scope: &mut Scope,
699 mat: MatrixExpand<C>,
700 value: SliceExpand<V, ReadOnly>,
701 stride: ExpandElementTyped<u32>,
702 ) {
703 let stride: ExpandElement = stride.into();
704 assert_ne!(
705 mat.ident,
706 MatrixIdent::Accumulator,
707 "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
708 );
709
710 let (value, offset) = value.__to_raw_parts();
711
712 scope.register(Instruction::new(
713 ir::CoopMma::Load {
714 value,
715 stride: *stride,
716 offset,
717 layout: None,
718 },
719 *mat.elem,
720 ));
721 }
722}
723
724#[allow(unused_variables)]
727pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
728 mat: &Matrix<C>,
729 value: &Slice<V>,
730 stride: u32,
731 layout: MatrixLayout,
732) {
733 unexpanded!()
734}
735
736pub mod load_with_layout {
738 use super::*;
739
740 #[allow(unused_variables)]
742 pub fn expand<C: CubeType, V: CubePrimitive>(
743 scope: &mut Scope,
744 mat: MatrixExpand<C>,
745 value: SliceExpand<V, ReadOnly>,
746 stride: ExpandElementTyped<u32>,
747 layout: MatrixLayout,
748 ) {
749 let stride: ExpandElement = stride.into();
750 let (value, offset) = value.__to_raw_parts();
751
752 scope.register(Instruction::new(
753 ir::CoopMma::Load {
754 value,
755 stride: *stride,
756 offset,
757 layout: Some(layout),
758 },
759 *mat.elem,
760 ));
761 }
762}
763
764#[allow(unused_variables)]
766pub fn store<C: CubePrimitive, O: CubePrimitive>(
767 output: &mut SliceMut<O>,
768 mat: &Matrix<C>,
769 stride: u32,
770 layout: MatrixLayout,
771) {
772 unexpanded!()
773}
774
775pub mod store {
777 use crate::prelude::ReadWrite;
778
779 use super::*;
780
781 #[allow(unused_variables)]
783 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
784 scope: &mut Scope,
785 output: SliceExpand<O, ReadWrite>,
786 mat: MatrixExpand<C>,
787 stride: ExpandElementTyped<u32>,
788 layout: MatrixLayout,
789 ) {
790 let stride: ExpandElement = stride.into();
791
792 let (output, offset) = output.__to_raw_parts();
793
794 scope.register(Instruction::new(
795 ir::CoopMma::Store {
796 mat: *mat.elem,
797 offset,
798 stride: *stride,
799 layout,
800 },
801 output,
802 ));
803 }
804}
805
806#[allow(unused_variables)]
808pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
809 mat_a: &Matrix<A>,
810 mat_b: &Matrix<B>,
811 mat_c: &Matrix<C>,
812 mat_d: &Matrix<D>,
813) {
814 unexpanded!()
815}
816
817pub mod execute {
819 use super::*;
820
821 pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
823 scope: &mut Scope,
824 mat_a: MatrixExpand<A>,
825 mat_b: MatrixExpand<B>,
826 mat_c: MatrixExpand<C>,
827 mat_d: MatrixExpand<D>,
828 ) {
829 scope.register(Instruction::new(
830 ir::CoopMma::Execute {
831 mat_a: *mat_a.elem,
832 mat_b: *mat_b.elem,
833 mat_c: *mat_c.elem,
834 },
835 *mat_d.elem,
836 ));
837 }
838}
839
840#[allow(unused_variables)]
842pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
843 unexpanded!()
844}
845
846pub mod cast {
848 use super::*;
849
850 #[allow(unused_variables)]
852 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
853 scope: &mut Scope,
854 input: MatrixExpand<C>,
855 ) -> MatrixExpand<O> {
856 let ident = input.ident;
857
858 if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
859 return MatrixExpand {
860 elem: input.elem,
861 ident,
862 _c: PhantomData,
863 };
864 }
865 let input = *input.elem;
866 let input_mat = match input.kind {
867 ir::VariableKind::Matrix { mat, .. } => mat,
868 _ => unreachable!(),
869 };
870
871 let elem = O::as_type(scope);
872 let elem = scope.create_matrix(ir::Matrix::new(
873 ident,
874 input_mat.m,
875 input_mat.n,
876 input_mat.k,
877 elem,
878 MatrixLayout::Undefined,
879 ));
880
881 let output = MatrixExpand {
882 ident,
883 elem,
884 _c: PhantomData,
885 };
886 scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
887
888 output
889 }
890}
891
892impl CubeType for MatrixLayout {
893 type ExpandType = Self;
894}
895
896impl IntoMut for MatrixLayout {
897 fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
898 self
899 }
900}
901
902impl CubeDebug for MatrixLayout {}