cubecl_std/tensor/
matrix_batch_layout.rs1use cubecl_core::{quant::scheme::QuantScheme, zspace::Strides};
2use serde::{Deserialize, Serialize};
3
4#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
5pub enum MatrixBatchLayout {
8 Contiguous,
10 MildlyPermuted {
12 transposed: bool,
14 batch_swap: bool,
16 },
17 HighlyPermuted,
19}
20
21pub fn matrix_batch_layout(strides: &Strides, scheme: Option<&QuantScheme>) -> MatrixBatchLayout {
23 let packing_dim = scheme.and_then(|s| s.packing_dim());
24 let rank = strides.len();
25 if rank <= 1 {
26 return MatrixBatchLayout::Contiguous;
27 }
28
29 let mut transposed = false;
30 let mut batch_swap = false;
31 let row_stride = strides[rank - 2];
32 let col_stride = strides[rank - 1];
33 if row_stride == 0 || col_stride == 0 {
34 return MatrixBatchLayout::HighlyPermuted;
36 }
37 if let Some(packing_dim) = packing_dim {
38 match packing_dim {
39 0 => {}
40 1 => {
41 transposed = true;
42 }
43 _ => {
44 return MatrixBatchLayout::HighlyPermuted;
45 }
46 }
47 } else if row_stride < col_stride {
48 transposed = true;
49 }
50 let mut previous_stride = row_stride;
51
52 for d in 0..rank - 2 {
53 let current_stride = strides[rank - 3 - d];
54 if current_stride < row_stride || current_stride < col_stride {
55 if current_stride == 0 {
56 batch_swap = true;
58 } else {
59 return MatrixBatchLayout::HighlyPermuted;
60 }
61 }
62 if current_stride < previous_stride {
63 batch_swap = true;
64 }
65
66 previous_stride = current_stride;
67 }
68
69 if transposed || batch_swap {
70 MatrixBatchLayout::MildlyPermuted {
71 transposed,
72 batch_swap,
73 }
74 } else {
75 MatrixBatchLayout::Contiguous
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use cubecl_core::zspace::strides;
82
83 use super::*;
84
85 #[test]
86 fn layout_is_contiguous() {
87 let strides = strides![8, 4, 2, 1];
88 assert_eq!(
89 matrix_batch_layout(&strides, None),
90 MatrixBatchLayout::Contiguous
91 );
92 }
93
94 #[test]
95 fn vector_is_contiguous() {
96 let strides = strides![1];
97 assert_eq!(
98 matrix_batch_layout(&strides, None),
99 MatrixBatchLayout::Contiguous
100 )
101 }
102
103 #[test]
104 fn layout_is_transposed_only() {
105 let strides = strides![8, 4, 1, 2];
106 if let MatrixBatchLayout::MildlyPermuted {
107 transposed,
108 batch_swap,
109 } = matrix_batch_layout(&strides, None)
110 {
111 assert!(transposed && !batch_swap);
112 } else {
113 unreachable!()
114 }
115 }
116
117 #[test]
118 fn layout_has_swapped_batches_only() {
119 let strides = strides![4, 8, 2, 1];
120 if let MatrixBatchLayout::MildlyPermuted {
121 transposed,
122 batch_swap,
123 } = matrix_batch_layout(&strides, None)
124 {
125 assert!(!transposed && batch_swap);
126 } else {
127 unreachable!()
128 }
129 }
130
131 #[test]
132 fn layout_has_swapped_batches_and_is_transposed() {
133 let strides = strides![4, 8, 1, 2];
134 if let MatrixBatchLayout::MildlyPermuted {
135 transposed,
136 batch_swap,
137 } = matrix_batch_layout(&strides, None)
138 {
139 assert!(transposed && batch_swap);
140 } else {
141 unreachable!()
142 }
143 }
144
145 #[test]
146 fn layout_has_batch_swapped_with_row() {
147 let strides = strides![8, 2, 4, 1];
148 assert_eq!(
149 matrix_batch_layout(&strides, None),
150 MatrixBatchLayout::HighlyPermuted
151 );
152 }
153
154 #[test]
155 fn layout_has_batch_swapped_with_col() {
156 let strides = strides![1, 4, 2, 8];
157 assert_eq!(
158 matrix_batch_layout(&strides, None),
159 MatrixBatchLayout::HighlyPermuted
160 );
161 }
162
163 #[test]
164 fn layout_has_multiple_broadcasted_dims() {
165 let strides = strides![0, 0, 1];
167 assert_eq!(
168 matrix_batch_layout(&strides, None),
169 MatrixBatchLayout::HighlyPermuted
170 );
171 }
172
173 #[test]
174 fn layout_has_row_broadcasted() {
175 let strides = strides![0, 1];
177 assert_eq!(
178 matrix_batch_layout(&strides, None),
179 MatrixBatchLayout::HighlyPermuted
180 );
181 }
182
183 #[test]
184 fn layout_has_col_broadcasted() {
185 let strides = strides![1, 0];
187 assert_eq!(
188 matrix_batch_layout(&strides, None),
189 MatrixBatchLayout::HighlyPermuted
190 );
191 }
192
193 #[test]
194 fn layout_has_batch_broadcasted() {
195 let strides = strides![0, 4, 1];
197 if let MatrixBatchLayout::MildlyPermuted {
198 transposed,
199 batch_swap,
200 } = matrix_batch_layout(&strides, None)
201 {
202 assert!(!transposed && batch_swap);
203 } else {
204 unreachable!()
205 }
206 }
207
208 #[test]
209 fn layout_has_multiple_batch_broadcasted() {
210 let strides = strides![0, 0, 4, 1];
212 if let MatrixBatchLayout::MildlyPermuted {
213 transposed,
214 batch_swap,
215 } = matrix_batch_layout(&strides, None)
216 {
217 assert!(!transposed && batch_swap);
218 } else {
219 unreachable!()
220 }
221 }
222}