cubecl_linalg/tensor/
layout.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
4pub enum MatrixBatchLayout {
7 Contiguous,
9 MildlyPermuted {
11 transposed: bool,
13 batch_swap: bool,
15 },
16 HighlyPermuted,
18}
19
20pub fn matrix_batch_layout(strides: &[usize]) -> MatrixBatchLayout {
22 let rank = strides.len();
23 if rank <= 1 {
24 return MatrixBatchLayout::Contiguous;
25 }
26
27 let mut transposed = false;
28 let mut batch_swap = false;
29 let row_stride = strides[rank - 2];
30 let col_stride = strides[rank - 1];
31 if row_stride == 0 || col_stride == 0 {
32 return MatrixBatchLayout::HighlyPermuted;
34 }
35 if row_stride < col_stride {
36 transposed = true;
37 }
38 let mut previous_stride = row_stride;
39
40 for d in 0..rank - 2 {
41 let current_stride = strides[rank - 3 - d];
42 if current_stride < row_stride || current_stride < col_stride {
43 if current_stride == 0 {
44 batch_swap = true;
46 } else {
47 return MatrixBatchLayout::HighlyPermuted;
48 }
49 }
50 if current_stride < previous_stride {
51 batch_swap = true;
52 }
53
54 previous_stride = current_stride;
55 }
56
57 if transposed || batch_swap {
58 MatrixBatchLayout::MildlyPermuted {
59 transposed,
60 batch_swap,
61 }
62 } else {
63 MatrixBatchLayout::Contiguous
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70
71 #[test]
72 fn layout_is_contiguous() {
73 let strides = &[8, 4, 2, 1];
74 assert_eq!(matrix_batch_layout(strides), MatrixBatchLayout::Contiguous);
75 }
76
77 #[test]
78 fn vector_is_contiguous() {
79 let strides = &[1];
80 assert_eq!(matrix_batch_layout(strides), MatrixBatchLayout::Contiguous)
81 }
82
83 #[test]
84 fn layout_is_transposed_only() {
85 let strides = &[8, 4, 1, 2];
86 if let MatrixBatchLayout::MildlyPermuted {
87 transposed,
88 batch_swap,
89 } = matrix_batch_layout(strides)
90 {
91 assert!(transposed && !batch_swap);
92 } else {
93 unreachable!()
94 }
95 }
96
97 #[test]
98 fn layout_has_swapped_batches_only() {
99 let strides = &[4, 8, 2, 1];
100 if let MatrixBatchLayout::MildlyPermuted {
101 transposed,
102 batch_swap,
103 } = matrix_batch_layout(strides)
104 {
105 assert!(!transposed && batch_swap);
106 } else {
107 unreachable!()
108 }
109 }
110
111 #[test]
112 fn layout_has_swapped_batches_and_is_transposed() {
113 let strides = &[4, 8, 1, 2];
114 if let MatrixBatchLayout::MildlyPermuted {
115 transposed,
116 batch_swap,
117 } = matrix_batch_layout(strides)
118 {
119 assert!(transposed && batch_swap);
120 } else {
121 unreachable!()
122 }
123 }
124
125 #[test]
126 fn layout_has_batch_swapped_with_row() {
127 let strides = &[8, 2, 4, 1];
128 assert_eq!(
129 matrix_batch_layout(strides),
130 MatrixBatchLayout::HighlyPermuted
131 );
132 }
133
134 #[test]
135 fn layout_has_batch_swapped_with_col() {
136 let strides = &[1, 4, 2, 8];
137 assert_eq!(
138 matrix_batch_layout(strides),
139 MatrixBatchLayout::HighlyPermuted
140 );
141 }
142
143 #[test]
144 fn layout_has_multiple_broadcasted_dims() {
145 let strides = &[0, 0, 1];
147 assert_eq!(
148 matrix_batch_layout(strides),
149 MatrixBatchLayout::HighlyPermuted
150 );
151 }
152
153 #[test]
154 fn layout_has_row_broadcasted() {
155 let strides = &[0, 1];
157 assert_eq!(
158 matrix_batch_layout(strides),
159 MatrixBatchLayout::HighlyPermuted
160 );
161 }
162
163 #[test]
164 fn layout_has_col_broadcasted() {
165 let strides = &[1, 0];
167 assert_eq!(
168 matrix_batch_layout(strides),
169 MatrixBatchLayout::HighlyPermuted
170 );
171 }
172
173 #[test]
174 fn layout_has_batch_broadcasted() {
175 let strides = &[0, 4, 1];
177 if let MatrixBatchLayout::MildlyPermuted {
178 transposed,
179 batch_swap,
180 } = matrix_batch_layout(strides)
181 {
182 assert!(!transposed && batch_swap);
183 } else {
184 unreachable!()
185 }
186 }
187
188 #[test]
189 fn layout_has_multiple_batch_broadcasted() {
190 let strides = &[0, 0, 4, 1];
192 if let MatrixBatchLayout::MildlyPermuted {
193 transposed,
194 batch_swap,
195 } = matrix_batch_layout(strides)
196 {
197 assert!(!transposed && batch_swap);
198 } else {
199 unreachable!()
200 }
201 }
202}