pub(in crate::ndarray) fn collapse_contiguous(shape: &[usize], stride: &[usize]) -> (Vec<usize>, Vec<usize>) {
if stride.last() != Some(&1) {
return (shape.to_vec(), stride.to_vec());
}
let mut stride_if_contiguous = 1;
let mut ndims = shape.len();
for (&axis_length, &actual_stride) in shape.iter().zip(stride.iter()).rev() {
if stride_if_contiguous != actual_stride {
break;
}
ndims -= 1;
stride_if_contiguous *= axis_length;
}
if stride_if_contiguous == 1 { return (shape.to_vec(), stride.to_vec());
}
let mut collapsed_shape = shape[..ndims].to_vec();
let mut collapsed_stride = stride[..ndims].to_vec();
collapsed_shape.push(stride_if_contiguous);
collapsed_stride.push(1);
(collapsed_shape, collapsed_stride)
}
pub(crate) fn collapse_to_uniform_stride(shape: &[usize], stride: &[usize]) -> (Vec<usize>, Vec<usize>) {
let ndims = shape.len();
if ndims == 0 {
return (vec![], vec![]);
}
let mut new_shape = Vec::with_capacity(ndims);
let mut new_stride = Vec::with_capacity(ndims);
new_shape.push(shape[0]);
new_stride.push(stride[0]);
let mut last_idx = 0;
for i in 1..ndims {
let can_collapse =
new_stride[last_idx] == shape[i] * stride[i] || (new_stride[last_idx] == 0 && new_shape[last_idx] == 1);
if can_collapse {
new_shape[last_idx] *= shape[i]; new_stride[last_idx] = stride[i];
} else {
new_shape.push(shape[i]); new_stride.push(stride[i]);
last_idx += 1;
}
}
(new_shape, new_stride)
}
pub(crate) fn has_uniform_stride(shape: &[usize], stride: &[usize]) -> Option<usize> {
let ndims = shape.len();
if ndims == 0 {
return Some(0);
}
for i in 1..ndims {
if stride[i - 1] != 0 && stride[i - 1] != shape[i] * stride[i] {
return None;
}
}
Some(stride[ndims - 1])
}
#[cfg(test)]
mod tests {
use super::collapse_contiguous;
use crate::iterator::collapse_contiguous::{collapse_to_uniform_stride, has_uniform_stride};
use crate::{s, NdArray};
use crate::common::constructors::Constructors;
#[test]
fn test_collapse_contiguous() {
let a = NdArray::new([
[[0, 1, 2], [3, 4, 5]],
[[6, 7, 8], [9, 10, 11]],
[[12, 13, 14], [15, 16, 17]],
]);
let (shape, stride) = collapse_contiguous(&a.shape, &a.stride);
assert_eq!(shape, [18]);
assert_eq!(stride, [1]);
let b = a.slice(s![.., 0]);
let (shape, stride) = collapse_contiguous(&b.shape, &b.stride);
assert_eq!(shape, [3, 3]);
assert_eq!(stride, [6, 1]);
let b = a.slice(s![1]);
let (shape, stride) = collapse_contiguous(&b.shape, &b.stride);
assert_eq!(shape, [6]);
assert_eq!(stride, [1]);
let b = a.slice(s![..2, 1, 1..]);
let (shape, stride) = collapse_contiguous(&b.shape, &b.stride);
assert_eq!(shape, [2, 2]);
assert_eq!(stride, [6, 1]);
}
#[test]
fn test_collapse_to_uniform_stride() {
let shape = [2, 3];
let stride = [3, 1];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [6]);
assert_eq!(b, [1]);
assert_eq!(has_uniform_stride(&shape, &stride), Some(1));
let shape = [2, 3];
let stride = [6, 2];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [6]);
assert_eq!(b, [2]);
assert_eq!(has_uniform_stride(&shape, &stride), Some(2));
let shape = [2, 3];
let stride = [5, 2];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [2, 3]);
assert_eq!(b, [5, 2]);
assert_eq!(has_uniform_stride(&shape, &stride), None);
let shape = [2, 2, 2];
let stride = [6, 3, 2];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [4, 2]);
assert_eq!(b, [3, 2]);
assert_eq!(has_uniform_stride(&shape, &stride), None);
let shape = [3, 4, 5];
let stride = [20, 5, 1];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [60]);
assert_eq!(b, [1]);
assert_eq!(has_uniform_stride(&shape, &stride), Some(1));
let shape = [4, 5, 6];
let stride = [30, 6, 1];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [120]);
assert_eq!(b, [1]);
assert_eq!(has_uniform_stride(&shape, &stride), Some(1));
let shape = [3, 3, 3];
let stride = [9, 3, 1];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [27]);
assert_eq!(b, [1]);
assert_eq!(has_uniform_stride(&shape, &stride), Some(1));
}
#[test]
fn test_collapse_to_uniform_stride_edge() {
let shape = [1, 2];
let stride = [0, 2];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [2]);
assert_eq!(b, [2]);
assert_eq!(has_uniform_stride(&shape, &stride), Some(2));
let shape = [1, 1, 1, 2];
let stride = [0, 0, 0, 2];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [2]);
assert_eq!(b, [2]);
assert_eq!(has_uniform_stride(&shape, &stride), Some(2));
let shape = [];
let stride = [];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a.len(), 0);
assert_eq!(b.len(), 0); assert_eq!(has_uniform_stride(&shape, &stride), Some(0));
let shape = [10];
let stride = [1];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [10]);
assert_eq!(b, [1]); assert_eq!(has_uniform_stride(&shape, &stride), Some(1));
let shape = [2, 3];
let stride = [4, 2];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [2, 3]);
assert_eq!(b, [4, 2]);
assert_eq!(has_uniform_stride(&shape, &stride), None);
let shape = [1, 2];
let stride = [5, 1];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [1, 2]);
assert_eq!(b, [5, 1]);
assert_eq!(has_uniform_stride(&shape, &stride), None);
let shape = [3, 3, 3];
let stride = [0, 1, 0];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [3, 3, 3]);
assert_eq!(b, [0, 1, 0]);
assert_eq!(has_uniform_stride(&shape, &stride), None);
let shape = [5, 2, 3, 3, 4, 3];
let stride = [6, 3, 0, 4, 1, 0];
let (a, b) = collapse_to_uniform_stride(&shape, &stride);
assert_eq!(a, [10, 3, 12, 3]);
assert_eq!(b, [3, 0, 1, 0]);
assert_eq!(has_uniform_stride(&shape, &stride), None);
}
}