Skip to main content

cubecl_std/tensor/
matrix_batch_layout.rs

1use cubecl_core::{quant::scheme::QuantScheme, zspace::Strides};
2use serde::{Deserialize, Serialize};
3
4#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
5/// Layout for matrix batch tensors, i.e. tensors whose interpretation
6/// is a bunch of batched matrices of 2 dimensions
7pub enum MatrixBatchLayout {
8    /// Memory is wholly contiguous, with row major layout
9    Contiguous,
10    /// Permutations happened, but may not impact some kernels
11    MildlyPermuted {
12        /// Last two dims are inverted
13        transposed: bool,
14        /// Some permutations exist in batch dimensions
15        batch_swap: bool,
16    },
17    /// Permutations happened between batch dimensions and last two dims
18    HighlyPermuted,
19}
20
21/// Return the layout of a matrix batch given the strides.
22pub fn matrix_batch_layout(strides: &Strides, scheme: Option<&QuantScheme>) -> MatrixBatchLayout {
23    let packing_dim = scheme.and_then(|s| s.packing_dim());
24    let rank = strides.len();
25    if rank <= 1 {
26        return MatrixBatchLayout::Contiguous;
27    }
28
29    let mut transposed = false;
30    let mut batch_swap = false;
31    let row_stride = strides[rank - 2];
32    let col_stride = strides[rank - 1];
33    if row_stride == 0 || col_stride == 0 {
34        // Broadcasted last two dims
35        return MatrixBatchLayout::HighlyPermuted;
36    }
37    if let Some(packing_dim) = packing_dim {
38        match packing_dim {
39            0 => {}
40            1 => {
41                transposed = true;
42            }
43            _ => {
44                return MatrixBatchLayout::HighlyPermuted;
45            }
46        }
47    } else if row_stride < col_stride {
48        transposed = true;
49    }
50    let mut previous_stride = row_stride;
51
52    for d in 0..rank - 2 {
53        let current_stride = strides[rank - 3 - d];
54        if current_stride < row_stride || current_stride < col_stride {
55            if current_stride == 0 {
56                // Broadcasted batch dim
57                batch_swap = true;
58            } else {
59                return MatrixBatchLayout::HighlyPermuted;
60            }
61        }
62        if current_stride < previous_stride {
63            batch_swap = true;
64        }
65
66        previous_stride = current_stride;
67    }
68
69    if transposed || batch_swap {
70        MatrixBatchLayout::MildlyPermuted {
71            transposed,
72            batch_swap,
73        }
74    } else {
75        MatrixBatchLayout::Contiguous
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use cubecl_core::zspace::strides;
82
83    use super::*;
84
85    #[test]
86    fn layout_is_contiguous() {
87        let strides = strides![8, 4, 2, 1];
88        assert_eq!(
89            matrix_batch_layout(&strides, None),
90            MatrixBatchLayout::Contiguous
91        );
92    }
93
94    #[test]
95    fn vector_is_contiguous() {
96        let strides = strides![1];
97        assert_eq!(
98            matrix_batch_layout(&strides, None),
99            MatrixBatchLayout::Contiguous
100        )
101    }
102
103    #[test]
104    fn layout_is_transposed_only() {
105        let strides = strides![8, 4, 1, 2];
106        if let MatrixBatchLayout::MildlyPermuted {
107            transposed,
108            batch_swap,
109        } = matrix_batch_layout(&strides, None)
110        {
111            assert!(transposed && !batch_swap);
112        } else {
113            unreachable!()
114        }
115    }
116
117    #[test]
118    fn layout_has_swapped_batches_only() {
119        let strides = strides![4, 8, 2, 1];
120        if let MatrixBatchLayout::MildlyPermuted {
121            transposed,
122            batch_swap,
123        } = matrix_batch_layout(&strides, None)
124        {
125            assert!(!transposed && batch_swap);
126        } else {
127            unreachable!()
128        }
129    }
130
131    #[test]
132    fn layout_has_swapped_batches_and_is_transposed() {
133        let strides = strides![4, 8, 1, 2];
134        if let MatrixBatchLayout::MildlyPermuted {
135            transposed,
136            batch_swap,
137        } = matrix_batch_layout(&strides, None)
138        {
139            assert!(transposed && batch_swap);
140        } else {
141            unreachable!()
142        }
143    }
144
145    #[test]
146    fn layout_has_batch_swapped_with_row() {
147        let strides = strides![8, 2, 4, 1];
148        assert_eq!(
149            matrix_batch_layout(&strides, None),
150            MatrixBatchLayout::HighlyPermuted
151        );
152    }
153
154    #[test]
155    fn layout_has_batch_swapped_with_col() {
156        let strides = strides![1, 4, 2, 8];
157        assert_eq!(
158            matrix_batch_layout(&strides, None),
159            MatrixBatchLayout::HighlyPermuted
160        );
161    }
162
163    #[test]
164    fn layout_has_multiple_broadcasted_dims() {
165        // E.g., tensor w/ shape [1, 4] expanded to [2, 3, 4]
166        let strides = strides![0, 0, 1];
167        assert_eq!(
168            matrix_batch_layout(&strides, None),
169            MatrixBatchLayout::HighlyPermuted
170        );
171    }
172
173    #[test]
174    fn layout_has_row_broadcasted() {
175        // E.g., tensor w/ shape [1, 4] expanded to [3, 4]
176        let strides = strides![0, 1];
177        assert_eq!(
178            matrix_batch_layout(&strides, None),
179            MatrixBatchLayout::HighlyPermuted
180        );
181    }
182
183    #[test]
184    fn layout_has_col_broadcasted() {
185        // E.g., tensor w/ shape [2, 1] expanded to [2, 3]
186        let strides = strides![1, 0];
187        assert_eq!(
188            matrix_batch_layout(&strides, None),
189            MatrixBatchLayout::HighlyPermuted
190        );
191    }
192
193    #[test]
194    fn layout_has_batch_broadcasted() {
195        // E.g., tensor w/ shape [2, 4] expanded to [2, 2, 4]
196        let strides = strides![0, 4, 1];
197        if let MatrixBatchLayout::MildlyPermuted {
198            transposed,
199            batch_swap,
200        } = matrix_batch_layout(&strides, None)
201        {
202            assert!(!transposed && batch_swap);
203        } else {
204            unreachable!()
205        }
206    }
207
208    #[test]
209    fn layout_has_multiple_batch_broadcasted() {
210        // E.g., tensor w/ shape [2, 4] expanded to [2, 2, 2, 4]
211        let strides = strides![0, 0, 4, 1];
212        if let MatrixBatchLayout::MildlyPermuted {
213            transposed,
214            batch_swap,
215        } = matrix_batch_layout(&strides, None)
216        {
217            assert!(!transposed && batch_swap);
218        } else {
219            unreachable!()
220        }
221    }
222}