ndarray_layout/transform/split.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
use crate::ArrayLayout;
/// 切分变换参数。
pub struct Split<'a, const N: usize> {
src: &'a ArrayLayout<N>,
axis: usize,
start: usize,
parts: &'a [usize],
}
impl<const N: usize> ArrayLayout<N> {
/// 切分变换讲单个张量沿某个维度切分成多个张量,因此可以支持不均匀的切分。
///
/// ```rust
/// # use ndarray_layout::ArrayLayout;
/// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0);
/// let mut splits = layout.split(2, &[1, 3]);
///
/// let layout = splits.next().unwrap();
/// assert_eq!(layout.shape(), &[2, 3, 1]);
/// assert_eq!(layout.strides(), &[12, 4, 1]);
/// assert_eq!(layout.offset(), 0);
///
/// let layout = splits.next().unwrap();
/// assert_eq!(layout.shape(), &[2, 3, 3]);
/// assert_eq!(layout.strides(), &[12, 4, 1]);
/// assert_eq!(layout.offset(), 1);
/// ```
#[inline]
pub fn split<'a>(&'a self, axis: usize, parts: &'a [usize]) -> Split<'a, N> {
assert_eq!(self.shape()[axis], parts.iter().sum());
Split {
src: self,
axis,
start: 0,
parts,
}
}
}
impl<const N: usize> Iterator for Split<'_, N> {
type Item = ArrayLayout<N>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.parts.split_first().map(|(&head, tail)| {
let start = self.start;
self.start += head;
self.parts = tail;
self.src.slice(self.axis, start, 1, head)
})
}
}