ndarray_layout/transform/
slice.rs1use crate::ArrayLayout;
2use std::iter::zip;
3
4#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct SliceArg {
7 pub axis: usize,
9 pub start: usize,
11 pub step: isize,
13 pub len: usize,
15}
16
17impl<const N: usize> ArrayLayout<N> {
18 pub fn slice(&self, axis: usize, start: usize, step: isize, len: usize) -> Self {
29 self.slice_many(&[SliceArg {
30 axis,
31 start,
32 step,
33 len,
34 }])
35 }
36
37 pub fn slice_many(&self, mut args: &[SliceArg]) -> Self {
39 let content = self.content();
40 let mut offset = content.offset();
41 let iter = zip(content.shape(), content.strides()).enumerate();
42
43 let mut ans = Self::with_ndim(self.ndim);
44 let mut content = ans.content_mut();
45 for (i, (&d, &s)) in iter {
46 match args {
47 [arg, tail @ ..] if arg.axis == i => {
48 let &SliceArg {
49 axis,
50 start,
51 step,
52 len,
53 } = arg;
54 use std::cmp::Ordering::*;
55 let len = match step.cmp(&0) {
56 Greater => {
57 assert!(start < d);
58 offset += start as isize * s;
59 (d - start).div_ceil(step as _).min(len)
60 }
61 Equal => {
62 assert!(start < d);
63 offset += start as isize * s;
64 len
65 }
66 Less => {
67 let start = start.min(d - 1);
68 offset += start as isize * s;
69 (start + 1).div_ceil((-step) as _).min(len)
70 }
71 };
72 content.set_shape(i, len);
73 content.set_stride(i, s * step);
74
75 if let [next, ..] = tail {
76 assert!(
77 axis < next.axis && next.axis < self.ndim,
78 "next.axis = {} !in ({}, {})",
79 next.axis,
80 axis,
81 self.ndim,
82 );
83 }
84 args = tail;
85 }
86 [..] => {
87 content.set_shape(i, d);
88 content.set_stride(i, s);
89 }
90 }
91 }
92 content.set_offset(offset as _);
93 ans
94 }
95}
96
97#[test]
98fn test_slice() {
99 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, -1, 2);
100 assert_eq!(layout.shape(), &[2, 2, 4]);
101 assert_eq!(layout.strides(), &[12, -4, 1]);
102 assert_eq!(layout.offset(), 8);
103
104 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, 0, 2);
105 assert_eq!(layout.shape(), &[2, 2, 4]);
106 assert_eq!(layout.strides(), &[12, 0, 1]);
107 assert_eq!(layout.offset(), 8);
108
109 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 0, 1, 2);
110 assert_eq!(layout.shape(), &[2, 2, 4]);
111 assert_eq!(layout.strides(), &[12, 4, 1]);
112 assert_eq!(layout.offset(), 0);
113
114 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice_many(&[
115 SliceArg {
116 axis: 1,
117 start: 0,
118 step: 1,
119 len: 2,
120 },
121 SliceArg {
122 axis: 2,
123 start: 0,
124 step: 1,
125 len: 4,
126 },
127 ]);
128 assert_eq!(layout.shape(), &[2, 2, 4]);
129 assert_eq!(layout.strides(), &[12, 4, 1]);
130 assert_eq!(layout.offset(), 0);
131}