cubek_std/
matrix_layout.rs1use cubecl::{
2 prelude::*,
3 quant::scheme::QuantScheme,
4 zspace::{Strides, strides},
5};
6
7use crate::InvalidConfigError;
8
9#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug, Default)]
10pub enum MatrixLayout {
13 #[default]
14 RowMajor,
15 ColMajor,
16}
17
18impl MatrixLayout {
19 pub fn from_shape_and_strides(
20 shape: &[usize],
21 strides: &[usize],
22 scheme: Option<&QuantScheme>,
23 ) -> Result<Self, InvalidConfigError> {
24 assert!(
25 shape.len() >= 2 && shape.len() == strides.len(),
26 "Shape/stride mismatch or not a matrix"
27 );
28
29 if let Some(packing_dim) = scheme.and_then(|s| s.packing_dim()) {
30 if packing_dim == 0 {
31 return Ok(MatrixLayout::RowMajor);
32 }
33 if packing_dim == 1 {
34 return Ok(MatrixLayout::ColMajor);
35 }
36
37 return Err(Box::new(format!(
38 "Invalid or non-contiguous matrix layout: packing_dim={packing_dim:?}"
39 )));
40 }
41
42 let n = shape.len();
43
44 let outer = shape[n - 2];
45 let inner = shape[n - 1];
46
47 let stride_outer = strides[n - 2];
48 let stride_inner = strides[n - 1];
49
50 if (stride_inner == 1) && stride_outer >= inner {
59 return Ok(MatrixLayout::RowMajor);
60 }
61
62 if (stride_outer == 1) && stride_inner >= outer {
64 return Ok(MatrixLayout::ColMajor);
65 }
66
67 Err(Box::new(format!(
68 "Invalid or non-contiguous matrix layout: shape={shape:?}, strides={strides:?}",
69 )))
70 }
71
72 pub fn to_strides(&self, shape: &[usize]) -> Strides {
73 assert!(shape.len() >= 2, "Shape must have at least 2 dimensions");
74
75 let n = shape.len();
76 let mut strides = strides![0; n];
77
78 match self {
80 MatrixLayout::RowMajor => {
81 strides[n - 1] = 1; strides[n - 2] = shape[n - 1]; }
84 MatrixLayout::ColMajor => {
85 strides[n - 2] = 1; strides[n - 1] = shape[n - 2]; }
88 }
89
90 for i in (0..n - 2).rev() {
92 strides[i] = strides[i + 1] * shape[i + 1];
93 }
94
95 strides
96 }
97}
98
99#[cube]
100pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
102 match layout {
103 MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
104 MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
105 }
106}
107
108#[cube]
109pub fn from_cmma_layout(#[comptime] layout: cmma::MatrixLayout) -> comptime_type!(MatrixLayout) {
111 match layout {
112 cmma::MatrixLayout::RowMajor => MatrixLayout::RowMajor,
113 cmma::MatrixLayout::ColMajor => MatrixLayout::ColMajor,
114 cmma::MatrixLayout::Undefined => MatrixLayout::RowMajor,
115 }
116}