Skip to main content

cubek_std/
matrix_layout.rs

1use cubecl::{
2    prelude::*,
3    quant::scheme::QuantScheme,
4    zspace::{Strides, strides},
5};
6
7use crate::InvalidConfigError;
8
9#[derive(CubeType, Copy, Clone, PartialEq, Eq, Hash, Debug, Default)]
10/// Layout of a 2D structure such as a tensor, shared memory or slice,
11/// used within any matmul kernel level
12pub enum MatrixLayout {
13    #[default]
14    RowMajor,
15    ColMajor,
16}
17
18impl MatrixLayout {
19    pub fn from_shape_and_strides(
20        shape: &[usize],
21        strides: &[usize],
22        scheme: Option<&QuantScheme>,
23    ) -> Result<Self, InvalidConfigError> {
24        assert!(
25            shape.len() >= 2 && shape.len() == strides.len(),
26            "Shape/stride mismatch or not a matrix"
27        );
28
29        if let Some(packing_dim) = scheme.and_then(|s| s.packing_dim()) {
30            if packing_dim == 0 {
31                return Ok(MatrixLayout::RowMajor);
32            }
33            if packing_dim == 1 {
34                return Ok(MatrixLayout::ColMajor);
35            }
36
37            return Err(Box::new(format!(
38                "Invalid or non-contiguous matrix layout: packing_dim={packing_dim:?}"
39            )));
40        }
41
42        let n = shape.len();
43
44        let outer = shape[n - 2];
45        let inner = shape[n - 1];
46
47        let stride_outer = strides[n - 2];
48        let stride_inner = strides[n - 1];
49
50        // These checks are actually broken for quantized inputs (and are not trivially fixable).
51        // For quantized tensors the quantized axis will probably need to be stored, since it can be
52        // hard to tell on which axis it is packed.
53        // The packed axis is always the contiguous one. One test case has a logical shape of [4, 4]
54        // for example, with strides of [1, 1]. It is not possible to determine the packing dimension
55        // accurately for this problem.
56
57        // Row-major: inner dimension is contiguous
58        if (stride_inner == 1) && stride_outer >= inner {
59            return Ok(MatrixLayout::RowMajor);
60        }
61
62        // Col-major: outer dimension is contiguous
63        if (stride_outer == 1) && stride_inner >= outer {
64            return Ok(MatrixLayout::ColMajor);
65        }
66
67        Err(Box::new(format!(
68            "Invalid or non-contiguous matrix layout: shape={shape:?}, strides={strides:?}",
69        )))
70    }
71
72    pub fn to_strides(&self, shape: &[usize]) -> Strides {
73        assert!(shape.len() >= 2, "Shape must have at least 2 dimensions");
74
75        let n = shape.len();
76        let mut strides = strides![0; n];
77
78        // Start with contiguous layout for last two dims
79        match self {
80            MatrixLayout::RowMajor => {
81                strides[n - 1] = 1; // inner dim contiguous
82                strides[n - 2] = shape[n - 1]; // outer stride = inner size
83            }
84            MatrixLayout::ColMajor => {
85                strides[n - 2] = 1; // outer dim contiguous
86                strides[n - 1] = shape[n - 2]; // inner stride = outer size
87            }
88        }
89
90        // Batch dims: contiguous
91        for i in (0..n - 2).rev() {
92            strides[i] = strides[i + 1] * shape[i + 1];
93        }
94
95        strides
96    }
97}
98
99#[cube]
100/// Maps the matmul MatrixLayout to cmma's MatrixLayout, for use in Cmma API.
101pub fn as_cmma_layout(#[comptime] layout: MatrixLayout) -> cmma::MatrixLayout {
102    match layout {
103        MatrixLayout::RowMajor => cmma::MatrixLayout::RowMajor,
104        MatrixLayout::ColMajor => cmma::MatrixLayout::ColMajor,
105    }
106}
107
108#[cube]
109/// Maps the cmma's MatrixLayout to matmul MatrixLayout.
110pub fn from_cmma_layout(#[comptime] layout: cmma::MatrixLayout) -> comptime_type!(MatrixLayout) {
111    match layout {
112        cmma::MatrixLayout::RowMajor => MatrixLayout::RowMajor,
113        cmma::MatrixLayout::ColMajor => MatrixLayout::ColMajor,
114        cmma::MatrixLayout::Undefined => MatrixLayout::RowMajor,
115    }
116}