1use alloc::vec::Vec;
4use core::error::Error;
5use core::fmt::{Display, Formatter};
6
7#[derive(Debug, Clone, PartialEq)]
12pub struct StrideRecord {
13 pub shape: Vec<usize>,
14 pub strides: Vec<isize>,
15}
16
17impl StrideRecord {
18 pub fn from_usize_strides(shape: &[usize], strides: &[usize]) -> StrideRecord {
20 StrideRecord {
21 shape: shape.to_vec(),
22 strides: strides.iter().map(|s| *s as isize).collect(),
23 }
24 }
25
26 pub fn from_isize_strides(shape: &[usize], strides: &[isize]) -> StrideRecord {
28 StrideRecord {
29 shape: shape.to_vec(),
30 strides: strides.to_vec(),
31 }
32 }
33}
34
35#[derive(Debug, Clone, PartialEq)]
37pub enum StrideError {
38 MalformedRanks { record: StrideRecord },
40
41 UnsupportedRank { rank: usize, record: StrideRecord },
43
44 Invalid {
46 message: String,
47 record: StrideRecord,
48 },
49}
50
51impl Display for StrideError {
52 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
53 match self {
54 StrideError::MalformedRanks { record } => write!(f, "Malformed strides: {:?}", record),
55 StrideError::UnsupportedRank { rank, record } => {
56 write!(f, "Unsupported rank {}: {:?}", rank, record)
57 }
58 StrideError::Invalid { message, record } => {
59 write!(f, "Invalid strides: {}: {:?}", message, record)
60 }
61 }
62 }
63}
64
65impl Error for StrideError {}
66
67pub fn try_check_matching_ranks<A, B>(shape: A, strides: B) -> Result<usize, StrideError>
76where
77 A: AsRef<[usize]>,
78 B: AsRef<[usize]>,
79{
80 let shape = shape.as_ref();
81 let strides = strides.as_ref();
82
83 let rank = shape.len();
84 if strides.len() != rank {
85 Err(StrideError::MalformedRanks {
86 record: StrideRecord::from_usize_strides(shape, strides),
87 })
88 } else {
89 Ok(rank)
90 }
91}
92
93pub fn try_check_pitched_row_major_strides<A, B>(shape: A, strides: B) -> Result<(), StrideError>
104where
105 A: AsRef<[usize]>,
106 B: AsRef<[usize]>,
107{
108 let shape = shape.as_ref();
109 let strides = strides.as_ref();
110
111 let rank = try_check_matching_ranks(shape, strides)?;
112
113 if rank == 0 {
114 return Err(StrideError::UnsupportedRank {
115 rank,
116 record: StrideRecord::from_usize_strides(shape, strides),
117 });
118 }
119
120 let mut valid_layout = strides[rank - 1] == 1 && strides.iter().all(|s| *s != 0);
121 if valid_layout && rank > 1 {
122 if strides[rank - 2] < shape[rank - 1] {
123 valid_layout = false;
124 }
125 for i in 0..rank - 2 {
126 if strides[i] != shape[i + 1] * strides[i + 1] {
127 valid_layout = false;
128 break;
129 }
130 }
131 }
132
133 if valid_layout {
134 Ok(())
135 } else {
136 Err(StrideError::Invalid {
137 message: "strides are not valid pitched row major order".to_string(),
138 record: StrideRecord::from_usize_strides(shape, strides),
139 })
140 }
141}
142
143pub fn has_pitched_row_major_strides<A, B>(shape: A, strides: B) -> bool
153where
154 A: AsRef<[usize]>,
155 B: AsRef<[usize]>,
156{
157 match try_check_pitched_row_major_strides(shape, strides) {
162 Ok(()) => true,
163 Err(err) => match err {
164 StrideError::UnsupportedRank { .. } | StrideError::MalformedRanks { .. } => {
165 panic!("{err}")
166 }
167 StrideError::Invalid { .. } => false,
168 },
169 }
170}
171
172pub fn try_check_contiguous_row_major_strides<A, B>(shape: A, strides: B) -> Result<(), StrideError>
183where
184 A: AsRef<[usize]>,
185 B: AsRef<[usize]>,
186{
187 let shape = shape.as_ref();
188 let strides = strides.as_ref();
189
190 let rank = try_check_matching_ranks(shape, strides)?;
191
192 if rank == 0 {
193 return Err(StrideError::UnsupportedRank {
194 rank,
195 record: StrideRecord::from_usize_strides(shape, strides),
196 });
197 }
198
199 let mut valid_layout = strides[rank - 1] == 1;
200 if valid_layout && rank > 1 {
201 for i in 0..rank - 1 {
202 if strides[i] != shape[i + 1] * strides[i + 1] {
203 valid_layout = false;
204 break;
205 }
206 }
207 }
208 if valid_layout {
209 Ok(())
210 } else {
211 Err(StrideError::Invalid {
212 message: "strides are not contiguous in row major order".to_string(),
213 record: StrideRecord::from_usize_strides(shape, strides),
214 })
215 }
216}
217
218pub fn has_contiguous_row_major_strides<A, B>(shape: A, strides: B) -> bool
228where
229 A: AsRef<[usize]>,
230 B: AsRef<[usize]>,
231{
232 match try_check_contiguous_row_major_strides(shape, strides) {
237 Ok(()) => true,
238 Err(err) => match err {
239 StrideError::UnsupportedRank { .. } | StrideError::MalformedRanks { .. } => {
240 panic!("{err}")
241 }
242 StrideError::Invalid { .. } => false,
243 },
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_try_check_matching_ranks() {
253 assert_eq!(try_check_matching_ranks([1, 2, 3], [1, 2, 3]).unwrap(), 3);
254
255 assert_eq!(
256 &try_check_matching_ranks([1, 2], [1, 2, 3]),
257 &Err(StrideError::MalformedRanks {
258 record: StrideRecord {
259 shape: vec![1, 2],
260 strides: vec![1, 2, 3]
261 }
262 })
263 );
264 }
265
266 #[test]
267 fn test_try_check_contiguous_row_major_strides() {
268 try_check_contiguous_row_major_strides([0], [1]).unwrap();
269 try_check_contiguous_row_major_strides([2], [1]).unwrap();
270 try_check_contiguous_row_major_strides([3, 2], [2, 1]).unwrap();
271 try_check_contiguous_row_major_strides([4, 3, 2], [6, 2, 1]).unwrap();
272
273 assert_eq!(
275 try_check_contiguous_row_major_strides([], []),
276 Err(StrideError::UnsupportedRank {
277 rank: 0,
278 record: StrideRecord {
279 shape: vec![],
280 strides: vec![]
281 }
282 })
283 );
284
285 assert_eq!(
287 try_check_contiguous_row_major_strides([2, 2], [3, 1]),
288 Err(StrideError::Invalid {
289 message: "strides are not contiguous in row major order".to_string(),
290 record: StrideRecord {
291 shape: vec![2, 2],
292 strides: vec![3, 1]
293 }
294 })
295 );
296
297 assert_eq!(
299 try_check_contiguous_row_major_strides([1, 2], [1, 2]),
300 Err(StrideError::Invalid {
301 message: "strides are not contiguous in row major order".to_string(),
302 record: StrideRecord {
303 shape: vec![1, 2],
304 strides: vec![1, 2]
305 }
306 })
307 );
308 }
309
310 #[test]
311 #[should_panic]
312 fn test_has_contiguous_row_major_strides_malformed_ranks() {
313 has_contiguous_row_major_strides([1, 2], [1, 2, 3]);
314 }
315
316 #[test]
317 #[should_panic]
318 fn test_has_contiguous_row_major_strides_unsupported_rank() {
319 has_contiguous_row_major_strides([], []);
320 }
321
322 #[test]
323 fn test_has_contiguous_row_major_strides() {
324 assert!(has_contiguous_row_major_strides([0], [1]));
325 assert!(has_contiguous_row_major_strides([2], [1]));
326 assert!(has_contiguous_row_major_strides([3, 2], [2, 1]));
327 assert!(has_contiguous_row_major_strides([4, 3, 2], [6, 2, 1]));
328
329 assert!(!has_contiguous_row_major_strides([1], [2]));
331
332 assert!(!has_contiguous_row_major_strides([1, 2], [1, 2]));
334 }
335}