use smallvec::SmallVec;
pub fn is_contiguous<S: AsRef<[usize]>>(shape: S, strides: S) -> bool {
let mut product = 1;
for (&size, &stride) in shape.as_ref().iter().zip(strides.as_ref().iter()).rev() {
if size == 1 {
continue;
}
if stride != product {
return false;
}
product *= size;
}
true
}
pub fn may_have_internal_overlap(shape: &[usize], strides: &[usize]) -> bool {
if shape.contains(&0) {
return false;
}
if is_contiguous(shape, strides) {
return false;
}
let mut stride_shape: SmallVec<[(usize, usize); 8]> =
strides.iter().copied().zip(shape.iter().copied()).collect();
stride_shape.sort_unstable();
let mut max_offset = 0;
for (stride, shape) in stride_shape {
if stride <= max_offset {
return true;
}
max_offset += (shape - 1) * stride;
}
false
}
#[cfg(test)]
mod tests {
use rten_testing::TestCases;
use super::is_contiguous;
#[test]
fn test_is_contiguous() {
#[derive(Debug)]
struct Case<'a> {
shape: &'a [usize],
strides: &'a [usize],
contiguous: bool,
}
let cases = [
Case {
shape: &[5],
strides: &[1],
contiguous: true,
},
Case {
shape: &[5],
strides: &[2],
contiguous: false,
},
Case {
shape: &[1],
strides: &[2],
contiguous: true,
},
Case {
shape: &[5, 5],
strides: &[5, 1],
contiguous: true,
},
Case {
shape: &[5, 1],
strides: &[1, 2],
contiguous: true,
},
Case {
shape: &[5, 5],
strides: &[1, 5],
contiguous: false,
},
Case {
shape: &[5, 5],
strides: &[1, 0],
contiguous: false,
},
Case {
shape: &[5, 1],
strides: &[1, 0],
contiguous: true,
},
Case {
shape: &[1, 4, 5, 5],
strides: &[100, 25, 5, 1],
contiguous: true,
},
Case {
shape: &[1, 2, 5, 5],
strides: &[100, 25, 5, 1],
contiguous: true,
},
Case {
shape: &[1, 4, 5, 5],
strides: &[100, 25, 1, 5],
contiguous: false,
},
];
cases.test_each(|case| {
assert_eq!(is_contiguous(case.shape, case.strides), case.contiguous);
})
}
}