ndarray_layout/transform/
index.rsuse crate::ArrayLayout;
use std::iter::zip;
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct IndexArg {
pub axis: usize,
pub index: usize,
}
impl<const N: usize> ArrayLayout<N> {
pub fn index(&self, axis: usize, index: usize) -> Self {
self.index_many(&[IndexArg { axis, index }])
}
pub fn index_many(&self, mut args: &[IndexArg]) -> Self {
let content = self.content();
let mut offset = content.offset();
let shape = content.shape();
let iter = zip(shape, content.strides()).enumerate();
let check = |&IndexArg { axis, index }| shape.get(axis).filter(|&&d| index < d).is_some();
if let [first, ..] = args {
assert!(check(first), "Invalid index arg: {first:?}");
} else {
return self.clone();
}
let mut ans = Self::with_ndim(self.ndim - args.len());
let mut content = ans.content_mut();
let mut j = 0;
for (i, (&d, &s)) in iter {
match *args {
[IndexArg { axis, index }, ref tail @ ..] if axis == i => {
offset += index as isize * s;
if let [first, ..] = tail {
assert!(check(first), "Invalid index arg: {first:?}");
assert!(first.axis > axis, "Index args must be in ascending order");
}
args = tail;
}
[..] => {
content.set_shape(j, d);
content.set_stride(j, s);
j += 1;
}
}
}
content.set_offset(offset as _);
ans
}
}
#[test]
fn test() {
let layout = ArrayLayout::<1>::new(&[2, 3, 4], &[12, 4, 1], 0);
let layout = layout.index(1, 2);
assert_eq!(layout.shape(), &[2, 4]);
assert_eq!(layout.strides(), &[12, 1]);
assert_eq!(layout.offset(), 8);
let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
let layout = layout.index(1, 2);
assert_eq!(layout.shape(), &[2, 4]);
assert_eq!(layout.strides(), &[12, 1]);
assert_eq!(layout.offset(), 12);
}