cubecl_std/tensor/
matrix_batch_layout.rs1use cubecl_core::quant::scheme::QuantScheme;
2use serde::{Deserialize, Serialize};
3
4#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
5pub enum MatrixBatchLayout {
8 Contiguous,
10 MildlyPermuted {
12 transposed: bool,
14 batch_swap: bool,
16 },
17 HighlyPermuted,
19}
20
21pub fn matrix_batch_layout(strides: &[usize], scheme: Option<&QuantScheme>) -> MatrixBatchLayout {
23 let packing_dim = scheme.and_then(|s| s.packing_dim());
24 let rank = strides.len();
25 if rank <= 1 {
26 return MatrixBatchLayout::Contiguous;
27 }
28
29 let mut transposed = false;
30 let mut batch_swap = false;
31 let row_stride = strides[rank - 2];
32 let col_stride = strides[rank - 1];
33 if row_stride == 0 || col_stride == 0 {
34 return MatrixBatchLayout::HighlyPermuted;
36 }
37 if let Some(packing_dim) = packing_dim {
38 match packing_dim {
39 0 => {}
40 1 => {
41 transposed = true;
42 }
43 _ => {
44 return MatrixBatchLayout::HighlyPermuted;
45 }
46 }
47 } else if row_stride < col_stride {
48 transposed = true;
49 }
50 let mut previous_stride = row_stride;
51
52 for d in 0..rank - 2 {
53 let current_stride = strides[rank - 3 - d];
54 if current_stride < row_stride || current_stride < col_stride {
55 if current_stride == 0 {
56 batch_swap = true;
58 } else {
59 return MatrixBatchLayout::HighlyPermuted;
60 }
61 }
62 if current_stride < previous_stride {
63 batch_swap = true;
64 }
65
66 previous_stride = current_stride;
67 }
68
69 if transposed || batch_swap {
70 MatrixBatchLayout::MildlyPermuted {
71 transposed,
72 batch_swap,
73 }
74 } else {
75 MatrixBatchLayout::Contiguous
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 #[test]
84 fn layout_is_contiguous() {
85 let strides = &[8, 4, 2, 1];
86 assert_eq!(
87 matrix_batch_layout(strides, None),
88 MatrixBatchLayout::Contiguous
89 );
90 }
91
92 #[test]
93 fn vector_is_contiguous() {
94 let strides = &[1];
95 assert_eq!(
96 matrix_batch_layout(strides, None),
97 MatrixBatchLayout::Contiguous
98 )
99 }
100
101 #[test]
102 fn layout_is_transposed_only() {
103 let strides = &[8, 4, 1, 2];
104 if let MatrixBatchLayout::MildlyPermuted {
105 transposed,
106 batch_swap,
107 } = matrix_batch_layout(strides, None)
108 {
109 assert!(transposed && !batch_swap);
110 } else {
111 unreachable!()
112 }
113 }
114
115 #[test]
116 fn layout_has_swapped_batches_only() {
117 let strides = &[4, 8, 2, 1];
118 if let MatrixBatchLayout::MildlyPermuted {
119 transposed,
120 batch_swap,
121 } = matrix_batch_layout(strides, None)
122 {
123 assert!(!transposed && batch_swap);
124 } else {
125 unreachable!()
126 }
127 }
128
129 #[test]
130 fn layout_has_swapped_batches_and_is_transposed() {
131 let strides = &[4, 8, 1, 2];
132 if let MatrixBatchLayout::MildlyPermuted {
133 transposed,
134 batch_swap,
135 } = matrix_batch_layout(strides, None)
136 {
137 assert!(transposed && batch_swap);
138 } else {
139 unreachable!()
140 }
141 }
142
143 #[test]
144 fn layout_has_batch_swapped_with_row() {
145 let strides = &[8, 2, 4, 1];
146 assert_eq!(
147 matrix_batch_layout(strides, None),
148 MatrixBatchLayout::HighlyPermuted
149 );
150 }
151
152 #[test]
153 fn layout_has_batch_swapped_with_col() {
154 let strides = &[1, 4, 2, 8];
155 assert_eq!(
156 matrix_batch_layout(strides, None),
157 MatrixBatchLayout::HighlyPermuted
158 );
159 }
160
161 #[test]
162 fn layout_has_multiple_broadcasted_dims() {
163 let strides = &[0, 0, 1];
165 assert_eq!(
166 matrix_batch_layout(strides, None),
167 MatrixBatchLayout::HighlyPermuted
168 );
169 }
170
171 #[test]
172 fn layout_has_row_broadcasted() {
173 let strides = &[0, 1];
175 assert_eq!(
176 matrix_batch_layout(strides, None),
177 MatrixBatchLayout::HighlyPermuted
178 );
179 }
180
181 #[test]
182 fn layout_has_col_broadcasted() {
183 let strides = &[1, 0];
185 assert_eq!(
186 matrix_batch_layout(strides, None),
187 MatrixBatchLayout::HighlyPermuted
188 );
189 }
190
191 #[test]
192 fn layout_has_batch_broadcasted() {
193 let strides = &[0, 4, 1];
195 if let MatrixBatchLayout::MildlyPermuted {
196 transposed,
197 batch_swap,
198 } = matrix_batch_layout(strides, None)
199 {
200 assert!(!transposed && batch_swap);
201 } else {
202 unreachable!()
203 }
204 }
205
206 #[test]
207 fn layout_has_multiple_batch_broadcasted() {
208 let strides = &[0, 0, 4, 1];
210 if let MatrixBatchLayout::MildlyPermuted {
211 transposed,
212 batch_swap,
213 } = matrix_batch_layout(strides, None)
214 {
215 assert!(!transposed && batch_swap);
216 } else {
217 unreachable!()
218 }
219 }
220}