cubecl_zspace/striding/
layout_validation.rs

1//! # Stride Layout Utilities
2
3use alloc::vec::Vec;
4use core::error::Error;
5use core::fmt::{Display, Formatter};
6
7/// Collected shape/stride record.
8///
9/// As this is used for error messages, there is no expectation that this is valid,
10/// or that the ranks match.
11#[derive(Debug, Clone, PartialEq)]
12pub struct StrideRecord {
13    pub shape: Vec<usize>,
14    pub strides: Vec<isize>,
15}
16
17impl StrideRecord {
18    /// Create a new StrideRecord from a slice of usize strides.
19    pub fn from_usize_strides(shape: &[usize], strides: &[usize]) -> StrideRecord {
20        StrideRecord {
21            shape: shape.to_vec(),
22            strides: strides.iter().map(|s| *s as isize).collect(),
23        }
24    }
25
26    /// Create a new StrideRecord from a slice of isize strides.
27    pub fn from_isize_strides(shape: &[usize], strides: &[isize]) -> StrideRecord {
28        StrideRecord {
29            shape: shape.to_vec(),
30            strides: strides.to_vec(),
31        }
32    }
33}
34
35/// Error describing striding issues.
36#[derive(Debug, Clone, PartialEq)]
37pub enum StrideError {
38    /// The ranks of the shape and strides do not match.
39    MalformedRanks { record: StrideRecord },
40
41    /// This is an unsupported rank.
42    UnsupportedRank { rank: usize, record: StrideRecord },
43
44    /// The strides violate a constraint.
45    Invalid {
46        message: String,
47        record: StrideRecord,
48    },
49}
50
51impl Display for StrideError {
52    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
53        match self {
54            StrideError::MalformedRanks { record } => write!(f, "Malformed strides: {:?}", record),
55            StrideError::UnsupportedRank { rank, record } => {
56                write!(f, "Unsupported rank {}: {:?}", rank, record)
57            }
58            StrideError::Invalid { message, record } => {
59                write!(f, "Invalid strides: {}: {:?}", message, record)
60            }
61        }
62    }
63}
64
65impl Error for StrideError {}
66
67/// Validate that a `shape`/`stride` pair has matching ranks.
68///
69/// # Arguments
70/// * `shape` - the shape of a tensor.
71/// * `strides` - the skip-strides of a tensor.
72///
73/// # Returns
74/// `Ok(rank)` if the ranks match, otherwise `Err(StrideError::MalformedRanks)`
75pub fn try_check_matching_ranks<A, B>(shape: A, strides: B) -> Result<usize, StrideError>
76where
77    A: AsRef<[usize]>,
78    B: AsRef<[usize]>,
79{
80    let shape = shape.as_ref();
81    let strides = strides.as_ref();
82
83    let rank = shape.len();
84    if strides.len() != rank {
85        Err(StrideError::MalformedRanks {
86            record: StrideRecord::from_usize_strides(shape, strides),
87        })
88    } else {
89        Ok(rank)
90    }
91}
92
93/// Validate that a `shape`/`stride` pair is row-major and non-zero on all dimensions.
94///
95/// # Arguments
96/// * `shape` - the shape of a tensor.
97/// * `strides` - the skip-strides of a tensor.
98///
99/// # Returns
100/// * `Ok(())` - if the strides are non-zero and row-major,
101/// * `Err(StrideError::MalformedRanks)` - if the ranks do not match,
102/// * `Err(StrideError::UnsupportedRank)` - if the rank is 0,
103pub fn try_check_pitched_row_major_strides<A, B>(shape: A, strides: B) -> Result<(), StrideError>
104where
105    A: AsRef<[usize]>,
106    B: AsRef<[usize]>,
107{
108    let shape = shape.as_ref();
109    let strides = strides.as_ref();
110
111    let rank = try_check_matching_ranks(shape, strides)?;
112
113    if rank == 0 {
114        return Err(StrideError::UnsupportedRank {
115            rank,
116            record: StrideRecord::from_usize_strides(shape, strides),
117        });
118    }
119
120    let mut valid_layout = strides[rank - 1] == 1 && strides.iter().all(|s| *s != 0);
121    if valid_layout && rank > 1 {
122        if strides[rank - 2] < shape[rank - 1] {
123            valid_layout = false;
124        }
125        for i in 0..rank - 2 {
126            if strides[i] != shape[i + 1] * strides[i + 1] {
127                valid_layout = false;
128                break;
129            }
130        }
131    }
132
133    if valid_layout {
134        Ok(())
135    } else {
136        Err(StrideError::Invalid {
137            message: "strides are not valid pitched row major order".to_string(),
138            record: StrideRecord::from_usize_strides(shape, strides),
139        })
140    }
141}
142
143/// Check that the shape/stride layout is valid for cubecl layout.
144///
145/// # Returns
146///
147/// `true` if the shape and strides are valid for cubecl layout, `false` otherwise.
148///
149/// # Panics
150/// - if `shape.len() == 0`.
151/// - If `shape.len() != strides.len()`.
152pub fn has_pitched_row_major_strides<A, B>(shape: A, strides: B) -> bool
153where
154    A: AsRef<[usize]>,
155    B: AsRef<[usize]>,
156{
157    // TODO: migrate call sites to the `try_..()` form.
158    // This contract (bool for some things, panic for others)
159    // is a continuation of legacy code,
160
161    match try_check_pitched_row_major_strides(shape, strides) {
162        Ok(()) => true,
163        Err(err) => match err {
164            StrideError::UnsupportedRank { .. } | StrideError::MalformedRanks { .. } => {
165                panic!("{err}")
166            }
167            StrideError::Invalid { .. } => false,
168        },
169    }
170}
171
172/// Validate that a `shape`/`stride` pair is contiguous and row-major.
173///
174/// # Arguments
175/// * `shape` - the shape of a tensor.
176/// * `strides` - the skip-strides of a tensor.
177///
178/// # Returns
179/// * `Ok(())` - if the strides are contiguous and row-major,
180/// * `Err(StrideError::MalformedRanks)` - if the ranks do not match,
181/// * `Err(StrideError::UnsupportedRank)` - if the rank is 0,
182pub fn try_check_contiguous_row_major_strides<A, B>(shape: A, strides: B) -> Result<(), StrideError>
183where
184    A: AsRef<[usize]>,
185    B: AsRef<[usize]>,
186{
187    let shape = shape.as_ref();
188    let strides = strides.as_ref();
189
190    let rank = try_check_matching_ranks(shape, strides)?;
191
192    if rank == 0 {
193        return Err(StrideError::UnsupportedRank {
194            rank,
195            record: StrideRecord::from_usize_strides(shape, strides),
196        });
197    }
198
199    let mut valid_layout = strides[rank - 1] == 1;
200    if valid_layout && rank > 1 {
201        for i in 0..rank - 1 {
202            if strides[i] != shape[i + 1] * strides[i + 1] {
203                valid_layout = false;
204                break;
205            }
206        }
207    }
208    if valid_layout {
209        Ok(())
210    } else {
211        Err(StrideError::Invalid {
212            message: "strides are not contiguous in row major order".to_string(),
213            record: StrideRecord::from_usize_strides(shape, strides),
214        })
215    }
216}
217
218/// Check that the shape/stride layout is contiguous
219///
220/// # Returns
221///
222/// `true` if the shape and strides are contiguous, `false` otherwise.
223///
224/// # Panics
225/// - if `shape.len() == 0`.
226/// - If `shape.len() != strides.len()`.
227pub fn has_contiguous_row_major_strides<A, B>(shape: A, strides: B) -> bool
228where
229    A: AsRef<[usize]>,
230    B: AsRef<[usize]>,
231{
232    // TODO: migrate call sites to the `try_..()` form.
233    // This contract (bool for some things, panic for others)
234    // is a continuation of legacy code,
235
236    match try_check_contiguous_row_major_strides(shape, strides) {
237        Ok(()) => true,
238        Err(err) => match err {
239            StrideError::UnsupportedRank { .. } | StrideError::MalformedRanks { .. } => {
240                panic!("{err}")
241            }
242            StrideError::Invalid { .. } => false,
243        },
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_try_check_matching_ranks() {
253        assert_eq!(try_check_matching_ranks([1, 2, 3], [1, 2, 3]).unwrap(), 3);
254
255        assert_eq!(
256            &try_check_matching_ranks([1, 2], [1, 2, 3]),
257            &Err(StrideError::MalformedRanks {
258                record: StrideRecord {
259                    shape: vec![1, 2],
260                    strides: vec![1, 2, 3]
261                }
262            })
263        );
264    }
265
266    #[test]
267    fn test_try_check_contiguous_row_major_strides() {
268        try_check_contiguous_row_major_strides([0], [1]).unwrap();
269        try_check_contiguous_row_major_strides([2], [1]).unwrap();
270        try_check_contiguous_row_major_strides([3, 2], [2, 1]).unwrap();
271        try_check_contiguous_row_major_strides([4, 3, 2], [6, 2, 1]).unwrap();
272
273        // rank=0
274        assert_eq!(
275            try_check_contiguous_row_major_strides([], []),
276            Err(StrideError::UnsupportedRank {
277                rank: 0,
278                record: StrideRecord {
279                    shape: vec![],
280                    strides: vec![]
281                }
282            })
283        );
284
285        // non-contiguous
286        assert_eq!(
287            try_check_contiguous_row_major_strides([2, 2], [3, 1]),
288            Err(StrideError::Invalid {
289                message: "strides are not contiguous in row major order".to_string(),
290                record: StrideRecord {
291                    shape: vec![2, 2],
292                    strides: vec![3, 1]
293                }
294            })
295        );
296
297        // not row-major
298        assert_eq!(
299            try_check_contiguous_row_major_strides([1, 2], [1, 2]),
300            Err(StrideError::Invalid {
301                message: "strides are not contiguous in row major order".to_string(),
302                record: StrideRecord {
303                    shape: vec![1, 2],
304                    strides: vec![1, 2]
305                }
306            })
307        );
308    }
309
310    #[test]
311    #[should_panic]
312    fn test_has_contiguous_row_major_strides_malformed_ranks() {
313        has_contiguous_row_major_strides([1, 2], [1, 2, 3]);
314    }
315
316    #[test]
317    #[should_panic]
318    fn test_has_contiguous_row_major_strides_unsupported_rank() {
319        has_contiguous_row_major_strides([], []);
320    }
321
322    #[test]
323    fn test_has_contiguous_row_major_strides() {
324        assert!(has_contiguous_row_major_strides([0], [1]));
325        assert!(has_contiguous_row_major_strides([2], [1]));
326        assert!(has_contiguous_row_major_strides([3, 2], [2, 1]));
327        assert!(has_contiguous_row_major_strides([4, 3, 2], [6, 2, 1]));
328
329        // non-contiguous
330        assert!(!has_contiguous_row_major_strides([1], [2]));
331
332        // not row-major
333        assert!(!has_contiguous_row_major_strides([1, 2], [1, 2]));
334    }
335}