1use super::{
50 CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
51 SliceMut,
52};
53use crate::{
54 self as cubecl,
55 prelude::{Array, Line, Sequence},
56};
57use crate::{
58 ir::{self, Instruction},
59 unexpanded,
60};
61use cubecl_macros::{comptime_type, cube, intrinsic};
62use std::marker::PhantomData;
63
64use cubecl_ir::{CoopMma, ExpandElement, Scope, StorageType, Type};
65pub use ir::{MatrixIdent, MatrixLayout};
66
67#[derive(Copy, Clone)]
72pub struct Matrix<C: CubeType> {
73 _c: PhantomData<C>,
74}
75
76#[derive(Copy, Clone)]
78pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
79 _a: PhantomData<A>,
80 _b: PhantomData<B>,
81 _cd: PhantomData<CD>,
82}
83
84pub struct MatrixExpand<C: CubeType> {
86 elem: ExpandElement,
87 ident: MatrixIdent,
88 _c: PhantomData<C>,
89}
90
91#[derive(Debug)]
93pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
94 pub m: u32,
95 pub n: u32,
96 pub k: u32,
97 pub a_type: StorageType,
98 pub b_type: StorageType,
99 pub cd_type: StorageType,
100 pub scales_factor: Option<u32>,
101 pub scales_type: Option<StorageType>,
102 _a: PhantomData<A>,
103 _b: PhantomData<B>,
104 _cd: PhantomData<CD>,
105}
106
107impl<C: CubeType> Clone for MatrixExpand<C> {
108 fn clone(&self) -> Self {
109 Self {
110 elem: self.elem.clone(),
111 ident: self.ident,
112 _c: self._c,
113 }
114 }
115}
116
117impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
118 fn clone(&self) -> Self {
119 Self {
120 m: self.m,
121 n: self.n,
122 k: self.k,
123 a_type: self.a_type,
124 b_type: self.b_type,
125 cd_type: self.cd_type,
126 scales_factor: self.scales_factor,
127 scales_type: self.scales_type,
128 _a: PhantomData,
129 _b: PhantomData,
130 _cd: PhantomData,
131 }
132 }
133}
134
135impl<C: CubeType> CubeType for Matrix<C> {
136 type ExpandType = MatrixExpand<C>;
137}
138
139impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
140 type ExpandType = MmaDefinitionExpand<A, B, CD>;
141}
142
143impl<C: CubeType> IntoMut for MatrixExpand<C> {
144 fn into_mut(self, _scope: &mut Scope) -> Self {
145 self
146 }
147}
148
149impl<C: CubeType> CubeDebug for MatrixExpand<C> {
150 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
151 scope.update_variable_name(*self.elem, name);
152 }
153}
154
155impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
156 fn into_mut(self, _scope: &mut Scope) -> Self {
157 self
158 }
159}
160
161impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
162
163#[cube]
164impl<C: CubePrimitive> Matrix<C> {
165 #[allow(unused_variables)]
183 pub unsafe fn uninitialized(
184 #[comptime] ident: MatrixIdent,
185 m: u32,
186 n: u32,
187 k: u32,
188 layout: MatrixLayout,
189 ) -> Self {
190 intrinsic!(|scope| {
191 let elem = C::as_type(scope);
192 let elem = scope.create_matrix(ir::Matrix::new(
193 ident,
194 m.constant().unwrap().as_u32(),
195 n.constant().unwrap().as_u32(),
196 k.constant().unwrap().as_u32(),
197 elem,
198 layout,
199 ));
200 MatrixExpand {
201 elem,
202 ident,
203 _c: PhantomData,
204 }
205 })
206 }
207
208 #[allow(unused_variables)]
222 pub fn from_value(
223 #[comptime] ident: MatrixIdent,
224 m: u32,
225 n: u32,
226 k: u32,
227 layout: MatrixLayout,
228 value: C,
229 ) -> Self {
230 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
231
232 intrinsic!(|scope| {
233 fill::expand(scope, mat.clone(), value);
234 mat
235 })
236 }
237
238 #[allow(unused_variables)]
252 pub fn from_slice(
253 #[comptime] ident: MatrixIdent,
254 m: u32,
255 n: u32,
256 k: u32,
257 layout: MatrixLayout,
258 value: &Slice<C>,
259 stride: u32,
260 ) -> Self {
261 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
262
263 intrinsic!(|scope| {
264 load::expand(scope, mat.clone(), value, stride);
265 mat
266 })
267 }
268}
269
270#[cube]
271impl<A: CubePrimitive, B: CubePrimitive, CD: CubePrimitive> MmaDefinition<A, B, CD> {
272 #[allow(unused_variables)]
288 pub fn new(#[comptime] m: u32, #[comptime] n: u32, #[comptime] k: u32) -> Self {
289 intrinsic!(|scope| {
290 let a_type = A::as_type(scope);
291 let b_type = B::as_type(scope);
292 let cd_type = CD::as_type(scope);
293
294 MmaDefinitionExpand {
295 m,
296 n,
297 k,
298 a_type,
299 b_type,
300 cd_type,
301 scales_factor: None,
302 scales_type: None,
303 _a: PhantomData,
304 _b: PhantomData,
305 _cd: PhantomData,
306 }
307 })
308 }
309
310 #[allow(unused_variables)]
326 pub fn new_scaled<S: CubePrimitive>(
327 #[comptime] m: u32,
328 #[comptime] n: u32,
329 #[comptime] k: u32,
330 #[comptime] scale_factor: u32,
331 ) -> Self {
332 intrinsic!(|scope| {
333 let a_type = A::as_type(scope);
334 let b_type = B::as_type(scope);
335 let cd_type = CD::as_type(scope);
336
337 MmaDefinitionExpand {
338 m,
339 n,
340 k,
341 a_type,
342 b_type,
343 cd_type,
344 scales_factor: Some(scale_factor),
345 scales_type: Some(S::as_type(scope)),
346 _a: PhantomData,
347 _b: PhantomData,
348 _cd: PhantomData,
349 }
350 })
351 }
352
353 #[allow(unused)]
355 pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
356 intrinsic!(|scope| {
357 match ident {
358 MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor() as u32,
359 MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor() as u32,
360 MatrixIdent::Accumulator => {
361 (self.m * self.n) / self.cd_type.packing_factor() as u32
362 }
363 }
364 })
365 }
366
367 #[allow(unused)]
374 pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
375 intrinsic!(|scope| {
376 let elems = self.__expand_num_elems_method(scope, ident);
377 let plane_dim = scope.runtime_properties.mma.const_plane_size;
378 let duplication = match ident {
379 MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
380 MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
381 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
382 };
383 (elems / plane_dim) * duplication
384 })
385 }
386
387 #[allow(unused)]
393 pub fn lines_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
394 intrinsic!(|scope| {
395 let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
396 let line_size = self.__expand_line_size_method(scope, ident);
397 elems / line_size
398 })
399 }
400
401 #[allow(unused)]
403 pub fn line_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
404 intrinsic!(|scope| {
405 match ident {
406 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
407 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
408 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
409 }
410 })
411 }
412
413 #[allow(unused_variables)]
415 pub fn line_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(u32) {
416 intrinsic!(|scope| {
417 let bits = match ident {
418 MatrixIdent::A => StorageType::size_bits(&self.a_type) as u32,
419 MatrixIdent::B => StorageType::size_bits(&self.b_type) as u32,
420 MatrixIdent::Accumulator => StorageType::size_bits(&self.cd_type) as u32,
421 };
422 let register_size = scope.runtime_properties.mma.register_size_bits;
423 register_size.div_ceil(bits)
425 })
426 }
427
428 #[allow(unused_variables)]
436 pub fn position_of_nth(
437 &self,
438 lane_id: u32,
439 elem_idx: u32,
440 #[comptime] ident: MatrixIdent,
441 ) -> (u32, u32) {
442 intrinsic!(|scope| {
443 let lane_id: ExpandElement = lane_id.into();
444 let elem_idx: ExpandElement = elem_idx.into();
445
446 let ty = match ident {
447 MatrixIdent::A => self.a_type,
448 MatrixIdent::B => self.b_type,
449 MatrixIdent::Accumulator => self.cd_type,
450 };
451 let layout = match ident {
452 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
453 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
454 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
455 };
456 let matrix = cubecl_ir::Matrix {
457 ident,
458 m: self.m,
459 n: self.n,
460 k: self.k,
461 storage: ty,
462 layout,
463 };
464
465 let row = scope.create_local(Type::new(u32::as_type(scope)));
466 let col = scope.create_local(Type::new(u32::as_type(scope)));
467 scope.register(Instruction::new(
468 CoopMma::RowIndex {
469 lane_id: *lane_id,
470 i: *elem_idx,
471 matrix,
472 },
473 *row,
474 ));
475 scope.register(Instruction::new(
476 CoopMma::ColIndex {
477 lane_id: *lane_id,
478 i: *elem_idx,
479 matrix,
480 },
481 *col,
482 ));
483 (row.into(), col.into())
484 })
485 }
486
487 pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
490 let quad_id = lane_id / 4;
492 let t_id = lane_id % 4;
493 match ident {
494 MatrixIdent::A => quad_id + (t_id % 2) * 8,
495 MatrixIdent::B => quad_id,
496 MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
497 }
498 }
499
500 pub fn scales_count(&self) -> comptime_type!(u32) {
502 intrinsic!(|_| {
505 self.scales_factor
506 .expect("Can't retrieve scales count for matrix with no scales")
507 })
508 }
509
510 pub fn scales_line_size(&self) -> comptime_type!(u32) {
512 intrinsic!(|scope| {
513 let elem = self
514 .scales_type
515 .expect("Can't retrieve scales line size for matrix with no scales");
516 scope.runtime_properties.mma.register_size_bits / elem.size_bits() as u32
517 })
518 }
519
520 #[allow(unused)]
523 pub fn execute(
524 &self,
525 registers_a: &Sequence<Line<A>>,
526 registers_b: &Sequence<Line<B>>,
527 registers_c: &Sequence<Line<CD>>,
528 ) -> Array<Line<CD>> {
529 intrinsic!(|scope| {
530 let acc_elems = self
531 .clone()
532 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
533 let acc_line_size = self
534 .clone()
535 .__expand_line_size_method(scope, MatrixIdent::Accumulator);
536 let num_registers = acc_elems / acc_line_size;
537
538 let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
539
540 let registers_a = registers_a
541 .iter_cloned()
542 .map(|it| *it.expand)
543 .collect::<Vec<_>>();
544 let registers_b = registers_b
545 .iter_cloned()
546 .map(|it| *it.expand)
547 .collect::<Vec<_>>();
548 let registers_c = registers_c
549 .iter_cloned()
550 .map(|it| *it.expand)
551 .collect::<Vec<_>>();
552
553 let matrix = cubecl_ir::Matrix {
555 ident: MatrixIdent::A,
556 m: self.m,
557 n: self.n,
558 k: self.k,
559 storage: self.a_type,
560 layout: MatrixLayout::ColMajor,
561 };
562
563 scope.register(Instruction::new(
564 CoopMma::ExecuteManual {
565 matrix,
566 registers_a,
567 registers_b,
568 registers_c,
569 },
570 *registers_d.expand,
571 ));
572
573 registers_d
574 })
575 }
576
577 #[allow(unused)]
580 pub fn execute_scaled<S: CubePrimitive>(
581 &self,
582 registers_a: &Sequence<Line<A>>,
583 registers_b: &Sequence<Line<B>>,
584 registers_c: &Sequence<Line<CD>>,
585 scales_a: Line<S>,
586 scales_b: Line<S>,
587 ) -> Array<Line<CD>> {
588 intrinsic!(|scope| {
589 let acc_elems = self
590 .clone()
591 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
592 let acc_line_size = self
593 .clone()
594 .__expand_line_size_method(scope, MatrixIdent::Accumulator);
595 let num_registers = acc_elems / acc_line_size;
596
597 let registers_d = Array::__expand_vectorized(scope, num_registers, acc_line_size);
598
599 let registers_a = registers_a
600 .iter_cloned()
601 .map(|it| *it.expand)
602 .collect::<Vec<_>>();
603 let registers_b = registers_b
604 .iter_cloned()
605 .map(|it| *it.expand)
606 .collect::<Vec<_>>();
607 let registers_c = registers_c
608 .iter_cloned()
609 .map(|it| *it.expand)
610 .collect::<Vec<_>>();
611
612 let matrix = cubecl_ir::Matrix {
614 ident: MatrixIdent::A,
615 m: self.m,
616 n: self.n,
617 k: self.k,
618 storage: self.a_type,
619 layout: MatrixLayout::ColMajor,
620 };
621
622 scope.register(Instruction::new(
623 CoopMma::ExecuteScaled {
624 matrix,
625 registers_a,
626 registers_b,
627 registers_c,
628 scales_a: *scales_a.expand,
629 scales_b: *scales_b.expand,
630 scales_factor: self
631 .scales_factor
632 .expect("Can't execute scaled on matrix with no scales"),
633 },
634 *registers_d.expand,
635 ));
636
637 registers_d
638 })
639 }
640}
641
642#[allow(unused_variables)]
644pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
645 unexpanded!()
646}
647
648pub mod fill {
650 use super::*;
651
652 pub fn expand<C: CubeType>(
654 scope: &mut Scope,
655 mat: MatrixExpand<C>,
656 value: ExpandElementTyped<C>,
657 ) {
658 let value: ExpandElement = value.into();
659 scope.register(Instruction::new(
660 ir::CoopMma::Fill { value: *value },
661 *mat.elem,
662 ));
663 }
664}
665
666#[allow(unused_variables)]
668pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
669 unexpanded!()
670}
671
672pub mod load {
674 use super::*;
675
676 #[allow(unused_variables)]
678 pub fn expand<C: CubePrimitive, V: CubePrimitive>(
679 scope: &mut Scope,
680 mat: MatrixExpand<C>,
681 value: SliceExpand<V, ReadOnly>,
682 stride: ExpandElementTyped<u32>,
683 ) {
684 let stride: ExpandElement = stride.into();
685 assert_ne!(
686 mat.ident,
687 MatrixIdent::Accumulator,
688 "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
689 );
690
691 let (value, offset) = value.__to_raw_parts();
692
693 scope.register(Instruction::new(
694 ir::CoopMma::Load {
695 value,
696 stride: *stride,
697 offset,
698 layout: None,
699 },
700 *mat.elem,
701 ));
702 }
703}
704
705#[allow(unused_variables)]
708pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
709 mat: &Matrix<C>,
710 value: &Slice<V>,
711 stride: u32,
712 layout: MatrixLayout,
713) {
714 unexpanded!()
715}
716
717pub mod load_with_layout {
719 use super::*;
720
721 #[allow(unused_variables)]
723 pub fn expand<C: CubeType, V: CubePrimitive>(
724 scope: &mut Scope,
725 mat: MatrixExpand<C>,
726 value: SliceExpand<V, ReadOnly>,
727 stride: ExpandElementTyped<u32>,
728 layout: MatrixLayout,
729 ) {
730 let stride: ExpandElement = stride.into();
731 let (value, offset) = value.__to_raw_parts();
732
733 scope.register(Instruction::new(
734 ir::CoopMma::Load {
735 value,
736 stride: *stride,
737 offset,
738 layout: Some(layout),
739 },
740 *mat.elem,
741 ));
742 }
743}
744
745#[allow(unused_variables)]
747pub fn store<C: CubePrimitive, O: CubePrimitive>(
748 output: &mut SliceMut<O>,
749 mat: &Matrix<C>,
750 stride: u32,
751 layout: MatrixLayout,
752) {
753 unexpanded!()
754}
755
756pub mod store {
758 use crate::prelude::ReadWrite;
759
760 use super::*;
761
762 #[allow(unused_variables)]
764 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
765 scope: &mut Scope,
766 output: SliceExpand<O, ReadWrite>,
767 mat: MatrixExpand<C>,
768 stride: ExpandElementTyped<u32>,
769 layout: MatrixLayout,
770 ) {
771 let stride: ExpandElement = stride.into();
772
773 let (output, offset) = output.__to_raw_parts();
774
775 scope.register(Instruction::new(
776 ir::CoopMma::Store {
777 mat: *mat.elem,
778 offset,
779 stride: *stride,
780 layout,
781 },
782 output,
783 ));
784 }
785}
786
787#[allow(unused_variables)]
789pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
790 mat_a: &Matrix<A>,
791 mat_b: &Matrix<B>,
792 mat_c: &Matrix<C>,
793 mat_d: &Matrix<D>,
794) {
795 unexpanded!()
796}
797
798pub mod execute {
800 use super::*;
801
802 pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
804 scope: &mut Scope,
805 mat_a: MatrixExpand<A>,
806 mat_b: MatrixExpand<B>,
807 mat_c: MatrixExpand<C>,
808 mat_d: MatrixExpand<D>,
809 ) {
810 scope.register(Instruction::new(
811 ir::CoopMma::Execute {
812 mat_a: *mat_a.elem,
813 mat_b: *mat_b.elem,
814 mat_c: *mat_c.elem,
815 },
816 *mat_d.elem,
817 ));
818 }
819}
820
821#[allow(unused_variables)]
823pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
824 unexpanded!()
825}
826
827pub mod cast {
829 use super::*;
830
831 #[allow(unused_variables)]
833 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
834 scope: &mut Scope,
835 input: MatrixExpand<C>,
836 ) -> MatrixExpand<O> {
837 let ident = input.ident;
838
839 if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
840 return MatrixExpand {
841 elem: input.elem,
842 ident,
843 _c: PhantomData,
844 };
845 }
846 let input = *input.elem;
847 let input_mat = match input.kind {
848 ir::VariableKind::Matrix { mat, .. } => mat,
849 _ => unreachable!(),
850 };
851
852 let elem = O::as_type(scope);
853 let elem = scope.create_matrix(ir::Matrix::new(
854 ident,
855 input_mat.m,
856 input_mat.n,
857 input_mat.k,
858 elem,
859 MatrixLayout::Undefined,
860 ));
861
862 let output = MatrixExpand {
863 ident,
864 elem,
865 _c: PhantomData,
866 };
867 scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
868
869 output
870 }
871}
872
873impl CubeType for MatrixLayout {
874 type ExpandType = Self;
875}
876
877impl IntoMut for MatrixLayout {
878 fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
879 self
880 }
881}
882
883impl CubeDebug for MatrixLayout {}