cubecl_linalg/tensor/
layout.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
4/// Layout for matrix batch tensors, i.e. tensors whose interpretation
5/// is a bunch of batched matrices of 2 dimensions
6pub enum MatrixBatchLayout {
7    /// Memory is wholly contiguous, with row major layout
8    Contiguous,
9    /// Permutations happened, but may not impact some kernels
10    MildlyPermuted {
11        /// Last two dims are inverted
12        transposed: bool,
13        /// Some permutations exist in batch dimensions
14        batch_swap: bool,
15    },
16    /// Permutations happened between batch dimensions and last two dims
17    HighlyPermuted,
18}
19
20/// Return the layout of a matrix batch given the strides.
21pub fn matrix_batch_layout(strides: &[usize]) -> MatrixBatchLayout {
22    let rank = strides.len();
23    if rank <= 1 {
24        return MatrixBatchLayout::Contiguous;
25    }
26
27    let mut transposed = false;
28    let mut batch_swap = false;
29    let row_stride = strides[rank - 2];
30    let col_stride = strides[rank - 1];
31    if row_stride == 0 || col_stride == 0 {
32        // Broadcasted last two dims
33        return MatrixBatchLayout::HighlyPermuted;
34    }
35    if row_stride < col_stride {
36        transposed = true;
37    }
38    let mut previous_stride = row_stride;
39
40    for d in 0..rank - 2 {
41        let current_stride = strides[rank - 3 - d];
42        if current_stride < row_stride || current_stride < col_stride {
43            if current_stride == 0 {
44                // Broadcasted batch dim
45                batch_swap = true;
46            } else {
47                return MatrixBatchLayout::HighlyPermuted;
48            }
49        }
50        if current_stride < previous_stride {
51            batch_swap = true;
52        }
53
54        previous_stride = current_stride;
55    }
56
57    if transposed || batch_swap {
58        MatrixBatchLayout::MildlyPermuted {
59            transposed,
60            batch_swap,
61        }
62    } else {
63        MatrixBatchLayout::Contiguous
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    #[test]
72    fn layout_is_contiguous() {
73        let strides = &[8, 4, 2, 1];
74        assert_eq!(matrix_batch_layout(strides), MatrixBatchLayout::Contiguous);
75    }
76
77    #[test]
78    fn vector_is_contiguous() {
79        let strides = &[1];
80        assert_eq!(matrix_batch_layout(strides), MatrixBatchLayout::Contiguous)
81    }
82
83    #[test]
84    fn layout_is_transposed_only() {
85        let strides = &[8, 4, 1, 2];
86        if let MatrixBatchLayout::MildlyPermuted {
87            transposed,
88            batch_swap,
89        } = matrix_batch_layout(strides)
90        {
91            assert!(transposed && !batch_swap);
92        } else {
93            unreachable!()
94        }
95    }
96
97    #[test]
98    fn layout_has_swapped_batches_only() {
99        let strides = &[4, 8, 2, 1];
100        if let MatrixBatchLayout::MildlyPermuted {
101            transposed,
102            batch_swap,
103        } = matrix_batch_layout(strides)
104        {
105            assert!(!transposed && batch_swap);
106        } else {
107            unreachable!()
108        }
109    }
110
111    #[test]
112    fn layout_has_swapped_batches_and_is_transposed() {
113        let strides = &[4, 8, 1, 2];
114        if let MatrixBatchLayout::MildlyPermuted {
115            transposed,
116            batch_swap,
117        } = matrix_batch_layout(strides)
118        {
119            assert!(transposed && batch_swap);
120        } else {
121            unreachable!()
122        }
123    }
124
125    #[test]
126    fn layout_has_batch_swapped_with_row() {
127        let strides = &[8, 2, 4, 1];
128        assert_eq!(
129            matrix_batch_layout(strides),
130            MatrixBatchLayout::HighlyPermuted
131        );
132    }
133
134    #[test]
135    fn layout_has_batch_swapped_with_col() {
136        let strides = &[1, 4, 2, 8];
137        assert_eq!(
138            matrix_batch_layout(strides),
139            MatrixBatchLayout::HighlyPermuted
140        );
141    }
142
143    #[test]
144    fn layout_has_multiple_broadcasted_dims() {
145        // E.g., tensor w/ shape [1, 4] expanded to [2, 3, 4]
146        let strides = &[0, 0, 1];
147        assert_eq!(
148            matrix_batch_layout(strides),
149            MatrixBatchLayout::HighlyPermuted
150        );
151    }
152
153    #[test]
154    fn layout_has_row_broadcasted() {
155        // E.g., tensor w/ shape [1, 4] expanded to [3, 4]
156        let strides = &[0, 1];
157        assert_eq!(
158            matrix_batch_layout(strides),
159            MatrixBatchLayout::HighlyPermuted
160        );
161    }
162
163    #[test]
164    fn layout_has_col_broadcasted() {
165        // E.g., tensor w/ shape [2, 1] expanded to [2, 3]
166        let strides = &[1, 0];
167        assert_eq!(
168            matrix_batch_layout(strides),
169            MatrixBatchLayout::HighlyPermuted
170        );
171    }
172
173    #[test]
174    fn layout_has_batch_broadcasted() {
175        // E.g., tensor w/ shape [2, 4] expanded to [2, 2, 4]
176        let strides = &[0, 4, 1];
177        if let MatrixBatchLayout::MildlyPermuted {
178            transposed,
179            batch_swap,
180        } = matrix_batch_layout(strides)
181        {
182            assert!(!transposed && batch_swap);
183        } else {
184            unreachable!()
185        }
186    }
187
188    #[test]
189    fn layout_has_multiple_batch_broadcasted() {
190        // E.g., tensor w/ shape [2, 4] expanded to [2, 2, 2, 4]
191        let strides = &[0, 0, 4, 1];
192        if let MatrixBatchLayout::MildlyPermuted {
193            transposed,
194            batch_swap,
195        } = matrix_batch_layout(strides)
196        {
197            assert!(!transposed && batch_swap);
198        } else {
199            unreachable!()
200        }
201    }
202}