ndarray_layout/transform/
merge.rs

1use crate::{ArrayLayout, Endian};
2use std::iter::zip;
3
4/// 合并变换参数。
5#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct MergeArg {
7    /// 合并的起点。
8    pub start: usize,
9    /// 合并的宽度
10    pub len: usize,
11    /// 分块的顺序。
12    pub endian: Option<Endian>,
13}
14
15impl<const N: usize> ArrayLayout<N> {
16    /// 合并变换是将多个连续维度划分合并的变换。
17    /// 大端合并对维度从后到前依次合并。
18    ///
19    /// ```rust
20    /// # use ndarray_layout::ArrayLayout;
21    /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).merge_be(0, 3).unwrap();
22    /// assert_eq!(layout.shape(), &[24]);
23    /// assert_eq!(layout.strides(), &[1]);
24    /// assert_eq!(layout.offset(), 0);
25    /// ```
26    #[inline]
27    pub fn merge_be(&self, start: usize, len: usize) -> Option<Self> {
28        self.merge_many(&[MergeArg {
29            start,
30            len,
31            endian: Some(Endian::BigEndian),
32        }])
33    }
34
35    /// 合并变换是将多个连续维度划分合并的变换。
36    /// 小端合并对维度从前到后依次合并。
37    ///
38    /// ```rust
39    /// # use ndarray_layout::ArrayLayout;
40    /// let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0).merge_le(0, 3).unwrap();
41    /// assert_eq!(layout.shape(), &[24]);
42    /// assert_eq!(layout.strides(), &[1]);
43    /// assert_eq!(layout.offset(), 0);
44    /// ```
45    #[inline]
46    pub fn merge_le(&self, start: usize, len: usize) -> Option<Self> {
47        self.merge_many(&[MergeArg {
48            start,
49            len,
50            endian: Some(Endian::LittleEndian),
51        }])
52    }
53
54    /// 合并变换是将多个连续维度划分合并的变换。
55    /// 任意合并只考虑维度的存储连续性。
56    ///
57    /// ```rust
58    /// # use ndarray_layout::ArrayLayout;
59    /// let layout = ArrayLayout::<3>::new(&[3, 2, 4], &[4, 12, 1], 0).merge_free(0, 3).unwrap();
60    /// assert_eq!(layout.shape(), &[24]);
61    /// assert_eq!(layout.strides(), &[1]);
62    /// assert_eq!(layout.offset(), 0);
63    /// ```
64    #[inline]
65    pub fn merge_free(&self, start: usize, len: usize) -> Option<Self> {
66        self.merge_many(&[MergeArg {
67            start,
68            len,
69            endian: None,
70        }])
71    }
72
73    /// 一次对多个阶进行合并变换。
74    pub fn merge_many(&self, args: &[MergeArg]) -> Option<Self> {
75        let content = self.content();
76        let shape = content.shape();
77        let strides = content.strides();
78
79        let merged = args.iter().map(|arg| arg.len).sum::<usize>();
80        let mut ans = Self::with_ndim(self.ndim + args.len() - merged);
81
82        let mut content = ans.content_mut();
83        content.set_offset(self.offset());
84        let mut i = 0;
85        let mut push = |d, s| {
86            content.set_shape(i, d);
87            content.set_stride(i, s);
88            i += 1;
89        };
90
91        let mut last_end = 0;
92        for arg in args {
93            let &MergeArg { start, len, endian } = arg;
94            let end = start + len;
95
96            if len == 0 {
97                continue;
98            }
99
100            for j in last_end..arg.start {
101                push(shape[j], strides[j]);
102            }
103
104            let mut pairs = zip(&shape[start..end], &strides[start..end]).collect::<Vec<_>>();
105            match endian {
106                Some(Endian::BigEndian) => pairs.reverse(),
107                Some(Endian::LittleEndian) => {}
108                None => pairs.sort_unstable_by_key(|(_, &s)| s.unsigned_abs()),
109            }
110
111            let ((&d, &s), pairs) = pairs.split_first().unwrap();
112            let mut d = d;
113
114            for (&d_, &s_) in pairs {
115                // 合并的维度长度若有 0 或 1 则不需要判断步长
116                if d <= 1 || d_ <= 1 || s_ == s * d as isize {
117                    d *= d_
118                } else {
119                    return None;
120                }
121            }
122
123            push(d, s);
124            last_end = end;
125        }
126        for j in last_end..shape.len() {
127            push(shape[j], strides[j]);
128        }
129
130        Some(ans)
131    }
132}