cubecl_std/tensor/
matrix_batch_layout.rs

1use cubecl_core::quant::scheme::QuantScheme;
2use serde::{Deserialize, Serialize};
3
4#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize)]
5/// Layout for matrix batch tensors, i.e. tensors whose interpretation
6/// is a bunch of batched matrices of 2 dimensions
7pub enum MatrixBatchLayout {
8    /// Memory is wholly contiguous, with row major layout
9    Contiguous,
10    /// Permutations happened, but may not impact some kernels
11    MildlyPermuted {
12        /// Last two dims are inverted
13        transposed: bool,
14        /// Some permutations exist in batch dimensions
15        batch_swap: bool,
16    },
17    /// Permutations happened between batch dimensions and last two dims
18    HighlyPermuted,
19}
20
21/// Return the layout of a matrix batch given the strides.
22pub fn matrix_batch_layout(strides: &[usize], scheme: Option<&QuantScheme>) -> MatrixBatchLayout {
23    let packing_dim = scheme.and_then(|s| s.packing_dim());
24    let rank = strides.len();
25    if rank <= 1 {
26        return MatrixBatchLayout::Contiguous;
27    }
28
29    let mut transposed = false;
30    let mut batch_swap = false;
31    let row_stride = strides[rank - 2];
32    let col_stride = strides[rank - 1];
33    if row_stride == 0 || col_stride == 0 {
34        // Broadcasted last two dims
35        return MatrixBatchLayout::HighlyPermuted;
36    }
37    if let Some(packing_dim) = packing_dim {
38        match packing_dim {
39            0 => {}
40            1 => {
41                transposed = true;
42            }
43            _ => {
44                return MatrixBatchLayout::HighlyPermuted;
45            }
46        }
47    } else if row_stride < col_stride {
48        transposed = true;
49    }
50    let mut previous_stride = row_stride;
51
52    for d in 0..rank - 2 {
53        let current_stride = strides[rank - 3 - d];
54        if current_stride < row_stride || current_stride < col_stride {
55            if current_stride == 0 {
56                // Broadcasted batch dim
57                batch_swap = true;
58            } else {
59                return MatrixBatchLayout::HighlyPermuted;
60            }
61        }
62        if current_stride < previous_stride {
63            batch_swap = true;
64        }
65
66        previous_stride = current_stride;
67    }
68
69    if transposed || batch_swap {
70        MatrixBatchLayout::MildlyPermuted {
71            transposed,
72            batch_swap,
73        }
74    } else {
75        MatrixBatchLayout::Contiguous
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn layout_is_contiguous() {
85        let strides = &[8, 4, 2, 1];
86        assert_eq!(
87            matrix_batch_layout(strides, None),
88            MatrixBatchLayout::Contiguous
89        );
90    }
91
92    #[test]
93    fn vector_is_contiguous() {
94        let strides = &[1];
95        assert_eq!(
96            matrix_batch_layout(strides, None),
97            MatrixBatchLayout::Contiguous
98        )
99    }
100
101    #[test]
102    fn layout_is_transposed_only() {
103        let strides = &[8, 4, 1, 2];
104        if let MatrixBatchLayout::MildlyPermuted {
105            transposed,
106            batch_swap,
107        } = matrix_batch_layout(strides, None)
108        {
109            assert!(transposed && !batch_swap);
110        } else {
111            unreachable!()
112        }
113    }
114
115    #[test]
116    fn layout_has_swapped_batches_only() {
117        let strides = &[4, 8, 2, 1];
118        if let MatrixBatchLayout::MildlyPermuted {
119            transposed,
120            batch_swap,
121        } = matrix_batch_layout(strides, None)
122        {
123            assert!(!transposed && batch_swap);
124        } else {
125            unreachable!()
126        }
127    }
128
129    #[test]
130    fn layout_has_swapped_batches_and_is_transposed() {
131        let strides = &[4, 8, 1, 2];
132        if let MatrixBatchLayout::MildlyPermuted {
133            transposed,
134            batch_swap,
135        } = matrix_batch_layout(strides, None)
136        {
137            assert!(transposed && batch_swap);
138        } else {
139            unreachable!()
140        }
141    }
142
143    #[test]
144    fn layout_has_batch_swapped_with_row() {
145        let strides = &[8, 2, 4, 1];
146        assert_eq!(
147            matrix_batch_layout(strides, None),
148            MatrixBatchLayout::HighlyPermuted
149        );
150    }
151
152    #[test]
153    fn layout_has_batch_swapped_with_col() {
154        let strides = &[1, 4, 2, 8];
155        assert_eq!(
156            matrix_batch_layout(strides, None),
157            MatrixBatchLayout::HighlyPermuted
158        );
159    }
160
161    #[test]
162    fn layout_has_multiple_broadcasted_dims() {
163        // E.g., tensor w/ shape [1, 4] expanded to [2, 3, 4]
164        let strides = &[0, 0, 1];
165        assert_eq!(
166            matrix_batch_layout(strides, None),
167            MatrixBatchLayout::HighlyPermuted
168        );
169    }
170
171    #[test]
172    fn layout_has_row_broadcasted() {
173        // E.g., tensor w/ shape [1, 4] expanded to [3, 4]
174        let strides = &[0, 1];
175        assert_eq!(
176            matrix_batch_layout(strides, None),
177            MatrixBatchLayout::HighlyPermuted
178        );
179    }
180
181    #[test]
182    fn layout_has_col_broadcasted() {
183        // E.g., tensor w/ shape [2, 1] expanded to [2, 3]
184        let strides = &[1, 0];
185        assert_eq!(
186            matrix_batch_layout(strides, None),
187            MatrixBatchLayout::HighlyPermuted
188        );
189    }
190
191    #[test]
192    fn layout_has_batch_broadcasted() {
193        // E.g., tensor w/ shape [2, 4] expanded to [2, 2, 4]
194        let strides = &[0, 4, 1];
195        if let MatrixBatchLayout::MildlyPermuted {
196            transposed,
197            batch_swap,
198        } = matrix_batch_layout(strides, None)
199        {
200            assert!(!transposed && batch_swap);
201        } else {
202            unreachable!()
203        }
204    }
205
206    #[test]
207    fn layout_has_multiple_batch_broadcasted() {
208        // E.g., tensor w/ shape [2, 4] expanded to [2, 2, 2, 4]
209        let strides = &[0, 0, 4, 1];
210        if let MatrixBatchLayout::MildlyPermuted {
211            transposed,
212            batch_swap,
213        } = matrix_batch_layout(strides, None)
214        {
215            assert!(!transposed && batch_swap);
216        } else {
217            unreachable!()
218        }
219    }
220}