1use super::{
50 CubeDebug, CubePrimitive, CubeType, IntoMut, NativeExpand, ReadOnly, Slice, SliceExpand,
51 SliceMut,
52};
53use crate::{self as cubecl, prelude::*};
54use crate::{
55 ir::{self, Instruction},
56 unexpanded,
57};
58use core::marker::PhantomData;
59use cubecl_macros::{comptime_type, cube, intrinsic};
60
61use cubecl_ir::{CoopMma, ManagedVariable, Scope, StorageType, VectorSize};
62pub use ir::{MatrixIdent, MatrixLayout};
63
64#[derive(Copy, Clone)]
69pub struct Matrix<C: CubeType> {
70 _c: PhantomData<C>,
71}
72
73#[derive(Copy, Clone)]
75pub struct MmaDefinition<A: CubeType, B: CubeType, CD: CubeType> {
76 _a: PhantomData<A>,
77 _b: PhantomData<B>,
78 _cd: PhantomData<CD>,
79}
80
81impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for &MmaDefinitionExpand<A, B, CD> {
82 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
83 MmaDefinitionExpand::set_debug_name(self, scope, name);
84 }
85}
86
87pub struct MatrixExpand<C: CubeType> {
89 elem: ManagedVariable,
90 ident: MatrixIdent,
91 _c: PhantomData<C>,
92}
93
94#[derive(Debug)]
96pub struct MmaDefinitionExpand<A: CubeType, B: CubeType, CD: CubeType> {
97 pub m: usize,
98 pub n: usize,
99 pub k: usize,
100 pub a_type: StorageType,
101 pub b_type: StorageType,
102 pub cd_type: StorageType,
103 pub scales_factor: Option<usize>,
104 pub scales_type: Option<StorageType>,
105 _a: PhantomData<A>,
106 _b: PhantomData<B>,
107 _cd: PhantomData<CD>,
108}
109
110impl<C: CubeType> Clone for MatrixExpand<C> {
111 fn clone(&self) -> Self {
112 Self {
113 elem: self.elem.clone(),
114 ident: self.ident,
115 _c: self._c,
116 }
117 }
118}
119
120impl<A: CubeType, B: CubeType, CD: CubeType> Clone for MmaDefinitionExpand<A, B, CD> {
121 fn clone(&self) -> Self {
122 Self {
123 m: self.m,
124 n: self.n,
125 k: self.k,
126 a_type: self.a_type,
127 b_type: self.b_type,
128 cd_type: self.cd_type,
129 scales_factor: self.scales_factor,
130 scales_type: self.scales_type,
131 _a: PhantomData,
132 _b: PhantomData,
133 _cd: PhantomData,
134 }
135 }
136}
137
138impl<C: CubeType> CubeType for Matrix<C> {
139 type ExpandType = MatrixExpand<C>;
140}
141
142impl<A: CubeType, B: CubeType, CD: CubeType> CubeType for MmaDefinition<A, B, CD> {
143 type ExpandType = MmaDefinitionExpand<A, B, CD>;
144}
145
146impl<C: CubeType> IntoMut for MatrixExpand<C> {
147 fn into_mut(self, _scope: &mut Scope) -> Self {
148 self
149 }
150}
151
152impl<C: CubeType> CubeDebug for MatrixExpand<C> {
153 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
154 scope.update_variable_name(*self.elem, name);
155 }
156}
157
158impl<A: CubeType, B: CubeType, CD: CubeType> IntoMut for MmaDefinitionExpand<A, B, CD> {
159 fn into_mut(self, _scope: &mut Scope) -> Self {
160 self
161 }
162}
163
164impl<A: CubeType, B: CubeType, CD: CubeType> CubeDebug for MmaDefinitionExpand<A, B, CD> {}
165
166#[cube]
167impl<C: CubePrimitive> Matrix<C> {
168 #[allow(unused_variables)]
186 pub unsafe fn uninitialized(
187 #[comptime] ident: MatrixIdent,
188 #[comptime] m: usize,
189 #[comptime] n: usize,
190 #[comptime] k: usize,
191 layout: MatrixLayout,
192 ) -> Self {
193 intrinsic!(|scope| {
194 let elem = C::as_type(scope).storage_type();
195 let elem = scope.create_matrix(ir::Matrix::new(ident, m, n, k, elem, layout));
196 MatrixExpand {
197 elem,
198 ident,
199 _c: PhantomData,
200 }
201 })
202 }
203
204 #[allow(unused_variables)]
218 pub fn from_value(
219 #[comptime] ident: MatrixIdent,
220 #[comptime] m: usize,
221 #[comptime] n: usize,
222 #[comptime] k: usize,
223 layout: MatrixLayout,
224 value: C,
225 ) -> Self
226 where
227 C: Scalar,
228 {
229 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
230
231 intrinsic!(|scope| {
232 fill::expand(scope, mat.clone(), value);
233 mat
234 })
235 }
236
237 #[allow(unused_variables)]
251 pub fn from_slice(
252 #[comptime] ident: MatrixIdent,
253 #[comptime] m: usize,
254 #[comptime] n: usize,
255 #[comptime] k: usize,
256 layout: MatrixLayout,
257 value: &Slice<C>,
258 stride: u32,
259 ) -> Self {
260 let mat = unsafe { Self::uninitialized(ident, m, n, k, layout) };
261
262 intrinsic!(|scope| {
263 load::expand(scope, mat.clone(), value, stride);
264 mat
265 })
266 }
267}
268
269#[cube(self_type = "ref")]
270impl<A: Scalar, B: Scalar, CD: Scalar> MmaDefinition<A, B, CD> {
271 #[allow(unused_variables)]
287 pub fn new(#[comptime] m: usize, #[comptime] n: usize, #[comptime] k: usize) -> Self {
288 intrinsic!(|scope| {
289 let a_type = A::as_type(scope).storage_type();
290 let b_type = B::as_type(scope).storage_type();
291 let cd_type = CD::as_type(scope).storage_type();
292
293 MmaDefinitionExpand {
294 m,
295 n,
296 k,
297 a_type,
298 b_type,
299 cd_type,
300 scales_factor: None,
301 scales_type: None,
302 _a: PhantomData,
303 _b: PhantomData,
304 _cd: PhantomData,
305 }
306 })
307 }
308
309 #[allow(unused_variables)]
325 pub fn new_scaled<S: CubePrimitive>(
326 #[comptime] m: usize,
327 #[comptime] n: usize,
328 #[comptime] k: usize,
329 #[comptime] scale_factor: usize,
330 ) -> Self {
331 intrinsic!(|scope| {
332 let a_type = A::as_type(scope).storage_type();
333 let b_type = B::as_type(scope).storage_type();
334 let cd_type = CD::as_type(scope).storage_type();
335
336 MmaDefinitionExpand {
337 m,
338 n,
339 k,
340 a_type,
341 b_type,
342 cd_type,
343 scales_factor: Some(scale_factor),
344 scales_type: Some(S::as_type(scope).storage_type()),
345 _a: PhantomData,
346 _b: PhantomData,
347 _cd: PhantomData,
348 }
349 })
350 }
351
352 #[allow(unused)]
354 pub fn num_elems(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
355 intrinsic!(|scope| {
356 match ident {
357 MatrixIdent::A => (self.m * self.k) / self.a_type.packing_factor(),
358 MatrixIdent::B => (self.k * self.n) / self.b_type.packing_factor(),
359 MatrixIdent::Accumulator => (self.m * self.n) / self.cd_type.packing_factor(),
360 }
361 })
362 }
363
364 #[allow(unused)]
371 pub fn elems_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
372 intrinsic!(|scope| {
373 let elems = self.__expand_num_elems_method(scope, ident);
374 let plane_dim = scope.runtime_properties.mma.const_plane_size as usize;
375 let duplication = match ident {
376 MatrixIdent::A => scope.runtime_properties.mma.register_duplication_a,
377 MatrixIdent::B => scope.runtime_properties.mma.register_duplication_b,
378 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_duplication_acc,
379 };
380 (elems * duplication) / plane_dim
381 })
382 }
383
384 #[allow(unused)]
390 pub fn vectors_per_lane(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(usize) {
391 intrinsic!(|scope| {
392 let elems = self.clone().__expand_elems_per_lane_method(scope, ident);
393 let vector_size = self.__expand_vector_size_method(scope, ident);
394 elems / vector_size
395 })
396 }
397
398 #[allow(unused)]
400 pub fn vector_layout(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(MatrixLayout) {
401 intrinsic!(|scope| {
402 match ident {
403 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
404 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
405 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
406 }
407 })
408 }
409
410 #[allow(unused_variables)]
413 pub fn vector_size(&self, #[comptime] ident: MatrixIdent) -> comptime_type!(VectorSize) {
414 intrinsic!(|scope| {
415 let storage = match ident {
416 MatrixIdent::A => self.a_type,
417 MatrixIdent::B => self.b_type,
418 MatrixIdent::Accumulator => self.cd_type,
419 };
420 let matrix = cubecl_ir::Matrix {
421 ident,
422 m: self.m,
423 n: self.n,
424 k: self.k,
425 storage: storage,
426 layout: MatrixLayout::ColMajor,
427 };
428 scope
429 .runtime_properties
430 .mma
431 .contiguous_elements
432 .apply(ident, matrix)
433 })
434 }
435
436 #[allow(unused_variables)]
444 pub fn position_of_nth(
445 &self,
446 lane_id: u32,
447 elem_idx: u32,
448 #[comptime] ident: MatrixIdent,
449 ) -> (u32, u32) {
450 intrinsic!(|scope| {
451 let lane_id: ManagedVariable = lane_id.into();
452 let elem_idx: ManagedVariable = elem_idx.into();
453
454 let ty = match ident {
455 MatrixIdent::A => self.a_type,
456 MatrixIdent::B => self.b_type,
457 MatrixIdent::Accumulator => self.cd_type,
458 };
459 let layout = match ident {
460 MatrixIdent::A => scope.runtime_properties.mma.register_layout_a,
461 MatrixIdent::B => scope.runtime_properties.mma.register_layout_b,
462 MatrixIdent::Accumulator => scope.runtime_properties.mma.register_layout_acc,
463 };
464 let matrix = cubecl_ir::Matrix {
465 ident,
466 m: self.m,
467 n: self.n,
468 k: self.k,
469 storage: ty,
470 layout,
471 };
472
473 let row = scope.create_local(u32::as_type(scope));
474 let col = scope.create_local(u32::as_type(scope));
475 scope.register(Instruction::new(
476 CoopMma::RowIndex {
477 lane_id: *lane_id,
478 i: *elem_idx,
479 matrix,
480 },
481 *row,
482 ));
483 scope.register(Instruction::new(
484 CoopMma::ColIndex {
485 lane_id: *lane_id,
486 i: *elem_idx,
487 matrix,
488 },
489 *col,
490 ));
491 (row.into(), col.into())
492 })
493 }
494
495 pub fn scales_index(&self, lane_id: u32, #[comptime] ident: MatrixIdent) -> u32 {
498 let quad_id = lane_id / 4;
500 let t_id = lane_id % 4;
501 match ident {
502 MatrixIdent::A => quad_id + (t_id % 2) * 8,
503 MatrixIdent::B => quad_id,
504 MatrixIdent::Accumulator => panic!("Accumulator doesn't have scales"),
505 }
506 }
507
508 pub fn scales_count(&self) -> comptime_type!(usize) {
510 intrinsic!(|_| {
513 self.scales_factor
514 .expect("Can't retrieve scales count for matrix with no scales")
515 })
516 }
517
518 pub fn scales_vector_size(&self) -> comptime_type!(VectorSize) {
520 intrinsic!(|scope| {
521 let elem = self
522 .scales_type
523 .expect("Can't retrieve scales vector size for matrix with no scales");
524 scope.runtime_properties.mma.register_size_bits / elem.size_bits()
525 })
526 }
527
528 #[allow(unused_variables)]
539 pub fn load_matrix<E: CubePrimitive, NO: Size>(
540 &self,
541 row: &Slice<E>,
542 #[comptime] ident: MatrixIdent,
543 #[comptime] num_matrices: usize,
544 #[comptime] transpose: bool,
545 ) -> Array<Vector<E::Scalar, NO>> {
546 intrinsic!(|scope| {
547 let slice_vector_size = row.vector_size;
548 let (buffer, offset) = row.__to_raw_parts();
549 let out = Array::__expand_new(scope, num_matrices);
550 scope.register(Instruction::new(
551 CoopMma::LoadMatrix {
552 buffer,
553 offset,
554 vector_size: slice_vector_size,
555 factor: num_matrices,
556 transpose,
557 },
558 *out.expand,
559 ));
560 out
561 })
562 }
563
564 #[allow(unused_variables)]
565 pub fn load_matrix_inplace<E: Scalar, N: Size>(
566 &self,
567 row: &Slice<E>,
568 fragment: &mut Array<Vector<E, N>>,
569 #[comptime] ident: MatrixIdent,
570 #[comptime] num_matrices: usize,
571 #[comptime] transpose: bool,
572 ) {
573 intrinsic!(|scope| {
574 let vector_size = self.__expand_vector_size_method(scope, ident);
575 let slice_vector_size = row.vector_size;
576 let (buffer, offset) = row.__to_raw_parts();
577 scope.register(Instruction::new(
578 CoopMma::LoadMatrix {
579 buffer,
580 offset,
581 vector_size: slice_vector_size,
582 factor: num_matrices,
583 transpose,
584 },
585 *fragment.expand,
586 ));
587 })
588 }
589
590 #[allow(unused_variables)]
601 pub fn store_matrix<E: CubePrimitive, N: Size>(
602 &self,
603 row: &mut Slice<E, ReadWrite>,
604 registers: &Array<Vector<E::Scalar, N>>,
605 #[comptime] ident: MatrixIdent,
606 #[comptime] num_matrices: usize,
607 #[comptime] transpose: bool,
608 ) {
609 intrinsic!(|scope| {
610 let vector_size = self.__expand_vector_size_method(scope, ident);
611 let slice_vector_size = row.vector_size;
612 let (buffer, offset) = row.__to_raw_parts();
613 scope.register(Instruction::new(
614 CoopMma::StoreMatrix {
615 offset,
616 vector_size: slice_vector_size,
617 registers: *registers.expand,
618 factor: num_matrices,
619 transpose,
620 },
621 buffer,
622 ));
623 })
624 }
625
626 #[allow(unused)]
629 pub fn execute<NA: Size, NB: Size, NC: Size>(
630 &self,
631 registers_a: &Array<Vector<A, NA>>,
632 registers_b: &Array<Vector<B, NB>>,
633 registers_c: &Array<Vector<CD, NC>>,
634 ) -> Array<Vector<CD, NC>> {
635 intrinsic!(|scope| {
636 let acc_elems = self
637 .clone()
638 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
639 let acc_vector_size = self
640 .clone()
641 .__expand_vector_size_method(scope, MatrixIdent::Accumulator);
642 let num_registers = acc_elems / acc_vector_size;
643
644 let registers_d = Array::__expand_new(scope, num_registers);
645
646 let registers_a = *registers_a.expand;
647 let registers_b = *registers_b.expand;
648 let registers_c = *registers_c.expand;
649
650 let matrix = cubecl_ir::Matrix {
652 ident: MatrixIdent::A,
653 m: self.m,
654 n: self.n,
655 k: self.k,
656 storage: self.a_type,
657 layout: MatrixLayout::ColMajor,
658 };
659
660 scope.register(Instruction::new(
661 CoopMma::ExecuteManual {
662 matrix,
663 registers_a,
664 registers_b,
665 registers_c,
666 },
667 *registers_d.expand,
668 ));
669
670 registers_d
671 })
672 }
673
674 #[allow(unused)]
675 pub fn execute_inplace<NA: Size, NB: Size, NC: Size>(
676 &self,
677 registers_a: &Array<Vector<A, NA>>,
678 registers_b: &Array<Vector<B, NB>>,
679 registers_c: &mut Array<Vector<CD, NC>>,
680 ) {
681 intrinsic!(|scope| {
682 let acc_elems = self
683 .clone()
684 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
685 let acc_vector_size = self
686 .clone()
687 .__expand_vector_size_method(scope, MatrixIdent::Accumulator);
688 let num_registers = acc_elems / acc_vector_size;
689
690 let registers_a = *registers_a.expand;
691 let registers_b = *registers_b.expand;
692 let registers_c = *registers_c.expand;
693
694 let matrix = cubecl_ir::Matrix {
696 ident: MatrixIdent::A,
697 m: self.m,
698 n: self.n,
699 k: self.k,
700 storage: self.a_type,
701 layout: MatrixLayout::ColMajor,
702 };
703
704 scope.register(Instruction::new(
705 CoopMma::ExecuteManual {
706 matrix,
707 registers_a,
708 registers_b,
709 registers_c,
710 },
711 registers_c,
712 ));
713 })
714 }
715
716 #[allow(unused)]
719 pub fn execute_scaled<S: Scalar, NA: Size, NB: Size, NC: Size, NS: Size>(
720 &self,
721 registers_a: &Array<Vector<A, NA>>,
722 registers_b: &Array<Vector<B, NB>>,
723 registers_c: &Array<Vector<CD, NC>>,
724 scales_a: Vector<S, NS>,
725 scales_b: Vector<S, NS>,
726 ) -> Array<Vector<CD, NC>> {
727 intrinsic!(|scope| {
728 let acc_elems = self
729 .clone()
730 .__expand_elems_per_lane_method(scope, MatrixIdent::Accumulator);
731 let acc_vector_size = self
732 .clone()
733 .__expand_vector_size_method(scope, MatrixIdent::Accumulator);
734 let num_registers = acc_elems / acc_vector_size;
735
736 let registers_d = Array::__expand_new(scope, num_registers);
737
738 let registers_a = *registers_a.expand;
739 let registers_b = *registers_b.expand;
740 let registers_c = *registers_c.expand;
741
742 let matrix = cubecl_ir::Matrix {
744 ident: MatrixIdent::A,
745 m: self.m,
746 n: self.n,
747 k: self.k,
748 storage: self.a_type,
749 layout: MatrixLayout::ColMajor,
750 };
751
752 scope.register(Instruction::new(
753 CoopMma::ExecuteScaled {
754 matrix,
755 registers_a,
756 registers_b,
757 registers_c,
758 scales_a: *scales_a.expand,
759 scales_b: *scales_b.expand,
760 scales_factor: self
761 .scales_factor
762 .expect("Can't execute scaled on matrix with no scales"),
763 },
764 *registers_d.expand,
765 ));
766
767 registers_d
768 })
769 }
770}
771
772#[allow(unused_variables)]
774pub fn fill<C: Scalar>(mat: &Matrix<C>, value: C) {
775 unexpanded!()
776}
777
778pub mod fill {
780 use super::*;
781
782 pub fn expand<C: Scalar>(scope: &mut Scope, mat: MatrixExpand<C>, value: NativeExpand<C>) {
784 let value: ManagedVariable = value.into();
785 scope.register(Instruction::new(
786 ir::CoopMma::Fill { value: *value },
787 *mat.elem,
788 ));
789 }
790}
791
792#[allow(unused_variables)]
794pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
795 unexpanded!()
796}
797
798pub mod load {
800 use super::*;
801
802 #[allow(unused_variables)]
804 pub fn expand<C: CubePrimitive, V: CubePrimitive>(
805 scope: &mut Scope,
806 mat: MatrixExpand<C>,
807 value: SliceExpand<V, ReadOnly>,
808 stride: NativeExpand<u32>,
809 ) {
810 let stride: ManagedVariable = stride.into();
811 assert_ne!(
812 mat.ident,
813 MatrixIdent::Accumulator,
814 "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
815 );
816
817 let (value, offset) = value.__to_raw_parts();
818
819 scope.register(Instruction::new(
820 ir::CoopMma::Load {
821 value,
822 stride: *stride,
823 offset,
824 layout: None,
825 },
826 *mat.elem,
827 ));
828 }
829}
830
831#[allow(unused_variables)]
834pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
835 mat: &Matrix<C>,
836 value: &Slice<V>,
837 stride: u32,
838 layout: MatrixLayout,
839) {
840 unexpanded!()
841}
842
843pub mod load_with_layout {
845 use super::*;
846
847 #[allow(unused_variables)]
849 pub fn expand<C: CubeType, V: CubePrimitive>(
850 scope: &mut Scope,
851 mat: MatrixExpand<C>,
852 value: SliceExpand<V, ReadOnly>,
853 stride: NativeExpand<u32>,
854 layout: MatrixLayout,
855 ) {
856 let stride: ManagedVariable = stride.into();
857 let (value, offset) = value.__to_raw_parts();
858
859 scope.register(Instruction::new(
860 ir::CoopMma::Load {
861 value,
862 stride: *stride,
863 offset,
864 layout: Some(layout),
865 },
866 *mat.elem,
867 ));
868 }
869}
870
871#[allow(unused_variables)]
873pub fn store<C: CubePrimitive, O: CubePrimitive>(
874 output: &mut SliceMut<O>,
875 mat: &Matrix<C>,
876 stride: u32,
877 layout: MatrixLayout,
878) {
879 unexpanded!()
880}
881
882pub mod store {
884 use crate::prelude::ReadWrite;
885
886 use super::*;
887
888 #[allow(unused_variables)]
890 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
891 scope: &mut Scope,
892 output: SliceExpand<O, ReadWrite>,
893 mat: MatrixExpand<C>,
894 stride: NativeExpand<u32>,
895 layout: MatrixLayout,
896 ) {
897 let stride: ManagedVariable = stride.into();
898
899 let (output, offset) = output.__to_raw_parts();
900
901 scope.register(Instruction::new(
902 ir::CoopMma::Store {
903 mat: *mat.elem,
904 offset,
905 stride: *stride,
906 layout,
907 },
908 output,
909 ));
910 }
911}
912
913#[allow(unused_variables)]
915pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
916 mat_a: &Matrix<A>,
917 mat_b: &Matrix<B>,
918 mat_c: &Matrix<C>,
919 mat_d: &Matrix<D>,
920) {
921 unexpanded!()
922}
923
924pub mod execute {
926 use super::*;
927
928 pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
930 scope: &mut Scope,
931 mat_a: MatrixExpand<A>,
932 mat_b: MatrixExpand<B>,
933 mat_c: MatrixExpand<C>,
934 mat_d: MatrixExpand<D>,
935 ) {
936 scope.register(Instruction::new(
937 ir::CoopMma::Execute {
938 mat_a: *mat_a.elem,
939 mat_b: *mat_b.elem,
940 mat_c: *mat_c.elem,
941 },
942 *mat_d.elem,
943 ));
944 }
945}
946
947#[allow(unused_variables)]
949pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
950 unexpanded!()
951}
952
953pub mod cast {
955 use super::*;
956
957 #[allow(unused_variables)]
959 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
960 scope: &mut Scope,
961 input: MatrixExpand<C>,
962 ) -> MatrixExpand<O> {
963 let ident = input.ident;
964
965 if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
966 return MatrixExpand {
967 elem: input.elem,
968 ident,
969 _c: PhantomData,
970 };
971 }
972 let input = *input.elem;
973 let input_mat = match input.kind {
974 ir::VariableKind::Matrix { mat, .. } => mat,
975 _ => unreachable!(),
976 };
977
978 let elem = O::as_type(scope).storage_type();
979 let elem = scope.create_matrix(ir::Matrix::new(
980 ident,
981 input_mat.m,
982 input_mat.n,
983 input_mat.k,
984 elem,
985 MatrixLayout::Undefined,
986 ));
987
988 let output = MatrixExpand {
989 ident,
990 elem,
991 _c: PhantomData,
992 };
993 scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
994
995 output
996 }
997}
998
999impl CubeType for MatrixLayout {
1000 type ExpandType = Self;
1001}
1002
1003impl IntoMut for MatrixLayout {
1004 fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
1005 self
1006 }
1007}
1008
1009impl CubeDebug for MatrixLayout {}