use crate::{
Shape, Strides,
errors::{StrideError, StrideRecord},
};
use alloc::string::ToString;
use core::panic;
pub fn try_check_matching_ranks(shape: &Shape, strides: &Strides) -> Result<usize, StrideError> {
let rank = shape.len();
if strides.len() != rank {
Err(StrideError::MalformedRanks {
record: StrideRecord::from_usize_strides(shape, strides),
})
} else {
Ok(rank)
}
}
pub fn try_check_pitched_row_major_strides(
shape: &Shape,
strides: &Strides,
) -> Result<(), StrideError> {
let rank = try_check_matching_ranks(shape, strides)?;
if rank == 0 {
return Err(StrideError::UnsupportedRank {
rank,
record: StrideRecord::from_usize_strides(shape, strides),
});
}
let mut valid_layout = strides[rank - 1] == 1 && strides.iter().all(|s| *s != 0);
if valid_layout && rank > 1 {
if strides[rank - 2] < shape[rank - 1] {
valid_layout = false;
}
for i in 0..rank - 2 {
if strides[i] != shape[i + 1] * strides[i + 1] {
valid_layout = false;
break;
}
}
}
if valid_layout {
Ok(())
} else {
Err(StrideError::Invalid {
message: "strides are not valid pitched row major order".to_string(),
record: StrideRecord::from_usize_strides(shape, strides),
})
}
}
pub fn has_pitched_row_major_strides(shape: &Shape, strides: &Strides) -> bool {
match try_check_pitched_row_major_strides(shape, strides) {
Ok(()) => true,
Err(err) => match err {
StrideError::UnsupportedRank { .. } | StrideError::MalformedRanks { .. } => {
panic!("{err}")
}
StrideError::Invalid { .. } => false,
},
}
}
pub fn try_check_contiguous_row_major_strides(
shape: &Shape,
strides: &Strides,
) -> Result<(), StrideError> {
let rank = try_check_matching_ranks(shape, strides)?;
if rank == 0 {
return Err(StrideError::UnsupportedRank {
rank,
record: StrideRecord::from_usize_strides(shape, strides),
});
}
let mut valid_layout = strides[rank - 1] == 1;
if valid_layout && rank > 1 {
for i in 0..rank - 1 {
if strides[i] != shape[i + 1] * strides[i + 1] {
valid_layout = false;
break;
}
}
}
if valid_layout {
Ok(())
} else {
Err(StrideError::Invalid {
message: "strides are not contiguous in row major order".to_string(),
record: StrideRecord::from_usize_strides(shape, strides),
})
}
}
pub fn has_contiguous_row_major_strides(shape: &Shape, strides: &Strides) -> bool {
match try_check_contiguous_row_major_strides(shape, strides) {
Ok(()) => true,
Err(err) => match err {
StrideError::UnsupportedRank { .. } | StrideError::MalformedRanks { .. } => {
panic!("{err}")
}
StrideError::Invalid { .. } => false,
},
}
}
#[cfg(test)]
mod tests {
use crate::{shape, strides};
use super::*;
#[test]
fn test_try_check_matching_ranks() {
assert_eq!(
try_check_matching_ranks(&shape![1, 2, 3], &strides![1, 2, 3]).unwrap(),
3
);
assert_eq!(
&try_check_matching_ranks(&shape![1, 2], &strides![1, 2, 3]),
&Err(StrideError::MalformedRanks {
record: StrideRecord {
shape: shape![1, 2],
strides: strides![1, 2, 3]
}
})
);
}
#[test]
fn test_try_check_contiguous_row_major_strides() {
try_check_contiguous_row_major_strides(&shape![0], &strides![1]).unwrap();
try_check_contiguous_row_major_strides(&shape![2], &strides![1]).unwrap();
try_check_contiguous_row_major_strides(&shape![3, 2], &strides![2, 1]).unwrap();
try_check_contiguous_row_major_strides(&shape![4, 3, 2], &strides![6, 2, 1]).unwrap();
assert_eq!(
try_check_contiguous_row_major_strides(&shape![], &strides![]),
Err(StrideError::UnsupportedRank {
rank: 0,
record: StrideRecord {
shape: shape![],
strides: strides![]
}
})
);
assert_eq!(
try_check_contiguous_row_major_strides(&shape![2, 2], &strides![3, 1]),
Err(StrideError::Invalid {
message: "strides are not contiguous in row major order".to_string(),
record: StrideRecord {
shape: shape![2, 2],
strides: strides![3, 1]
}
})
);
assert_eq!(
try_check_contiguous_row_major_strides(&shape![1, 2], &strides![1, 2]),
Err(StrideError::Invalid {
message: "strides are not contiguous in row major order".to_string(),
record: StrideRecord {
shape: shape![1, 2],
strides: strides![1, 2]
}
})
);
}
#[test]
#[should_panic]
fn test_has_contiguous_row_major_strides_malformed_ranks() {
has_contiguous_row_major_strides(&shape![1, 2], &strides![1, 2, 3]);
}
#[test]
#[should_panic]
fn test_has_contiguous_row_major_strides_unsupported_rank() {
has_contiguous_row_major_strides(&shape![], &strides![]);
}
#[test]
fn test_has_contiguous_row_major_strides() {
assert!(has_contiguous_row_major_strides(&shape![0], &strides![1]));
assert!(has_contiguous_row_major_strides(&shape![2], &strides![1]));
assert!(has_contiguous_row_major_strides(
&shape![3, 2],
&strides![2, 1]
));
assert!(has_contiguous_row_major_strides(
&shape![4, 3, 2],
&strides![6, 2, 1]
));
assert!(!has_contiguous_row_major_strides(&shape![1], &strides![2]));
assert!(!has_contiguous_row_major_strides(
&shape![1, 2],
&strides![1, 2]
));
}
}