ndarray_layout/transform/split.rs
1use crate::ArrayLayout;
2
3/// 切分变换参数。
4pub struct Split<'a, const N: usize> {
5 src: &'a ArrayLayout<N>,
6 axis: usize,
7 start: usize,
8 parts: &'a [usize],
9}
10
11impl<const N: usize> ArrayLayout<N> {
12 /// 切分变换讲单个张量沿某个维度切分成多个张量,因此可以支持不均匀的切分。
13 ///
14 /// ```rust
15 /// # use ndarray_layout::ArrayLayout;
16 /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0);
17 /// let mut splits = layout.split(2, &[1, 3]);
18 ///
19 /// let layout = splits.next().unwrap();
20 /// assert_eq!(layout.shape(), &[2, 3, 1]);
21 /// assert_eq!(layout.strides(), &[12, 4, 1]);
22 /// assert_eq!(layout.offset(), 0);
23 ///
24 /// let layout = splits.next().unwrap();
25 /// assert_eq!(layout.shape(), &[2, 3, 3]);
26 /// assert_eq!(layout.strides(), &[12, 4, 1]);
27 /// assert_eq!(layout.offset(), 1);
28 /// ```
29 #[inline]
30 pub fn split<'a>(&'a self, axis: usize, parts: &'a [usize]) -> Split<'a, N> {
31 assert_eq!(self.shape()[axis], parts.iter().sum());
32 Split {
33 src: self,
34 axis,
35 start: 0,
36 parts,
37 }
38 }
39}
40
41impl<const N: usize> Iterator for Split<'_, N> {
42 type Item = ArrayLayout<N>;
43
44 #[inline]
45 fn next(&mut self) -> Option<Self::Item> {
46 self.parts.split_first().map(|(&head, tail)| {
47 let start = self.start;
48 self.start += head;
49 self.parts = tail;
50 self.src.slice(self.axis, start, 1, head)
51 })
52 }
53}