ndarray_layout/transform/
index.rs1use crate::ArrayLayout;
2use std::iter::zip;
3
4#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct IndexArg {
7 pub axis: usize,
9 pub index: usize,
11}
12
13impl<const N: usize> ArrayLayout<N> {
14 pub fn index(&self, axis: usize, index: usize) -> Self {
25 self.index_many(&[IndexArg { axis, index }])
26 }
27
28 pub fn index_many(&self, mut args: &[IndexArg]) -> Self {
30 let content = self.content();
31 let mut offset = content.offset();
32 let shape = content.shape();
33 let iter = zip(shape, content.strides()).enumerate();
34
35 let check = |&IndexArg { axis, index }| shape.get(axis).filter(|&&d| index < d).is_some();
36
37 if let [first, ..] = args {
38 assert!(check(first), "Invalid index arg: {first:?}");
39 } else {
40 return self.clone();
41 }
42
43 let mut ans = Self::with_ndim(self.ndim - args.len());
44 let mut content = ans.content_mut();
45 let mut j = 0;
46 for (i, (&d, &s)) in iter {
47 match *args {
48 [IndexArg { axis, index }, ref tail @ ..] if axis == i => {
49 offset += index as isize * s;
50 if let [first, ..] = tail {
51 assert!(check(first), "Invalid index arg: {first:?}");
52 assert!(first.axis > axis, "Index args must be in ascending order");
53 }
54 args = tail;
55 }
56 [..] => {
57 content.set_shape(j, d);
58 content.set_stride(j, s);
59 j += 1;
60 }
61 }
62 }
63 content.set_offset(offset as _);
64 ans
65 }
66}
67
68#[test]
69fn test() {
70 let layout = ArrayLayout::<1>::new(&[2, 3, 4], &[12, 4, 1], 0);
71 let layout = layout.index(1, 2);
72 assert_eq!(layout.shape(), &[2, 4]);
73 assert_eq!(layout.strides(), &[12, 1]);
74 assert_eq!(layout.offset(), 8);
75
76 let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
77 let layout = layout.index(1, 2);
78 assert_eq!(layout.shape(), &[2, 4]);
79 assert_eq!(layout.strides(), &[12, 1]);
80 assert_eq!(layout.offset(), 12);
81
82 let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
83 let layout = layout.index_many(&[]);
84 assert_eq!(layout.shape(), &[2, 3, 4]);
85 assert_eq!(layout.strides(), &[12, -4, 1]);
86 assert_eq!(layout.offset(), 20);
87
88 let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20);
89 let layout = layout.index_many(&[
90 IndexArg { axis: 0, index: 1 },
91 IndexArg { axis: 1, index: 2 },
92 ]);
93 assert_eq!(layout.shape(), &[4]);
94 assert_eq!(layout.strides(), &[1]);
95 assert_eq!(layout.offset(), 24);
96}