use crate::{Strides, strides};
pub fn row_major_contiguous_strides<S>(shape: S) -> Strides
where
S: AsRef<[usize]>,
{
let shape = shape.as_ref();
let rank = shape.len();
let mut strides = strides![1; rank];
if rank > 1 {
for i in (0..rank - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
strides
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_row_major_contiguous_strides() {
assert_eq!(row_major_contiguous_strides([]), strides![]);
assert_eq!(row_major_contiguous_strides([1, 2, 3]), strides![6, 3, 1]);
}
}