ndarray_layout/transform/
merge.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
use crate::ArrayLayout;
use std::{iter::zip, ops::Range};

impl<const N: usize> ArrayLayout<N> {
    /// 合并变换是将多个连续维度划分合并的变换。
    ///
    /// ```rust
    /// # use ndarray_layout::ArrayLayout;
    /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).merge(0..3).unwrap();
    /// assert_eq!(layout.shape(), &[24]);
    /// assert_eq!(layout.strides(), &[1]);
    /// assert_eq!(layout.offset(), 0);
    /// ```
    #[inline]
    pub fn merge(&self, range: Range<usize>) -> Option<Self> {
        self.merge_many(&[range])
    }

    /// 一次对多个阶进行合并变换。
    pub fn merge_many(&self, args: &[Range<usize>]) -> Option<Self> {
        let content = self.content();
        let shape = content.shape();
        let strides = content.strides();

        let merged = args.iter().map(|range| range.len()).sum::<usize>();
        let mut ans = Self::with_ndim(self.ndim + args.len() - merged);

        let mut content = ans.content_mut();
        content.set_offset(self.offset());
        let mut i = 0;
        let mut push = |d, s| {
            content.set_shape(i, d);
            content.set_stride(i, s);
            i += 1;
        };

        let mut last_end = 0;
        for range in args {
            if range.is_empty() {
                continue;
            }

            assert!(range.start >= last_end);
            for j in last_end..range.start {
                push(shape[j], strides[j]);
            }

            let mut pairs = zip(&shape[range.clone()], &strides[range.clone()]).collect::<Vec<_>>();
            pairs.sort_unstable_by_key(|(_, &s)| s.unsigned_abs());

            let ((&d, &s), pairs) = pairs.split_first().unwrap();
            let mut d = d;

            for (&d_, &s_) in pairs {
                // 合并的维度长度若有 0 或 1 则不需要判断步长
                if d <= 1 || d_ <= 1 || s_ == s * d as isize {
                    d *= d_
                } else {
                    return None;
                }
            }

            push(d, s);
            last_end = range.end;
        }
        for j in last_end..shape.len() {
            push(shape[j], strides[j]);
        }

        Some(ans)
    }
}