ndarray_layout/transform/
merge.rs

1use crate::{ArrayLayout, Endian};
2use std::iter::zip;
3
4/// 合并变换参数。
5#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct MergeArg {
7    /// 合并的起点。
8    pub start: usize,
9    /// 合并的宽度
10    pub len: usize,
11    /// 分块的顺序。
12    pub endian: Option<Endian>,
13}
14
15impl<const N: usize> ArrayLayout<N> {
16    /// 合并变换是将多个连续维度划分合并的变换。
17    /// 大端合并对维度从后到前依次合并。
18    ///
19    /// ```rust
20    /// # use ndarray_layout::ArrayLayout;
21    /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).merge_be(0, 3).unwrap();
22    /// assert_eq!(layout.shape(), &[24]);
23    /// assert_eq!(layout.strides(), &[1]);
24    /// assert_eq!(layout.offset(), 0);
25    /// ```
26    #[inline]
27    pub fn merge_be(&self, start: usize, len: usize) -> Option<Self> {
28        self.merge_many(&[MergeArg {
29            start,
30            len,
31            endian: Some(Endian::BigEndian),
32        }])
33    }
34
35    /// 合并变换是将多个连续维度划分合并的变换。
36    /// 小端合并对维度从前到后依次合并。
37    ///
38    /// ```rust
39    /// # use ndarray_layout::ArrayLayout;
40    /// let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0).merge_le(0, 3).unwrap();
41    /// assert_eq!(layout.shape(), &[24]);
42    /// assert_eq!(layout.strides(), &[1]);
43    /// assert_eq!(layout.offset(), 0);
44    /// ```
45    #[inline]
46    pub fn merge_le(&self, start: usize, len: usize) -> Option<Self> {
47        self.merge_many(&[MergeArg {
48            start,
49            len,
50            endian: Some(Endian::LittleEndian),
51        }])
52    }
53
54    /// 合并变换是将多个连续维度划分合并的变换。
55    /// 任意合并只考虑维度的存储连续性。
56    ///
57    /// ```rust
58    /// # use ndarray_layout::ArrayLayout;
59    /// let layout = ArrayLayout::<3>::new(&[3, 2, 4], &[4, 12, 1], 0).merge_free(0, 3).unwrap();
60    /// assert_eq!(layout.shape(), &[24]);
61    /// assert_eq!(layout.strides(), &[1]);
62    /// assert_eq!(layout.offset(), 0);
63    /// ```
64    #[inline]
65    pub fn merge_free(&self, start: usize, len: usize) -> Option<Self> {
66        self.merge_many(&[MergeArg {
67            start,
68            len,
69            endian: None,
70        }])
71    }
72
73    /// 一次对多个阶进行合并变换。
74    pub fn merge_many(&self, args: &[MergeArg]) -> Option<Self> {
75        let content = self.content();
76        let shape = content.shape();
77        let strides = content.strides();
78
79        let merged = args.iter().map(|arg| arg.len.max(1)).sum::<usize>();
80        if merged == args.len() {
81            return Some(self.clone());
82        }
83
84        let mut ans = Self::with_ndim(self.ndim + args.len() - merged);
85
86        let mut content = ans.content_mut();
87        content.set_offset(self.offset());
88        let mut i = 0;
89        let mut push = |d, s| {
90            content.set_shape(i, d);
91            content.set_stride(i, s);
92            i += 1;
93        };
94
95        let mut last_end = 0;
96        for arg in args {
97            let &MergeArg { start, len, endian } = arg;
98            let end = start + len;
99
100            if len < 2 {
101                continue;
102            }
103
104            for j in last_end..arg.start {
105                push(shape[j], strides[j]);
106            }
107
108            let mut pairs = Vec::with_capacity(len);
109            for (&d, &s) in zip(&shape[start..end], &strides[start..end]) {
110                match d {
111                    0 => todo!(),
112                    1 => {}
113                    _ => pairs.push((d, s)),
114                }
115            }
116
117            last_end = end;
118
119            if pairs.is_empty() {
120                push(1, 0);
121                continue;
122            }
123            match endian {
124                Some(Endian::BigEndian) => pairs.reverse(),
125                Some(Endian::LittleEndian) => {}
126                None => pairs.sort_unstable_by_key(|(_, s)| s.unsigned_abs()),
127            }
128
129            let ((d, s), pairs) = pairs.split_first().unwrap();
130            let mut d = *d;
131
132            for &(d_, s_) in pairs {
133                if s_ == s * d as isize {
134                    d *= d_
135                } else {
136                    return None;
137                }
138            }
139
140            push(d, *s);
141        }
142        for j in last_end..shape.len() {
143            push(shape[j], strides[j]);
144        }
145
146        Some(ans)
147    }
148}
149
150#[test]
151fn test_merge_return_none() {
152    let layout = ArrayLayout::<3>::new(&[16, 4, 2], &[8, 4, 1], 0).merge_be(0, 3);
153    assert!(layout.is_none());
154}
155
156#[test]
157fn test_merge_pairs_empyt() {
158    let layout = ArrayLayout::<3>::new(&[1, 1, 1], &[1, 1, 1], 0)
159        .merge_be(0, 2)
160        .unwrap();
161    assert_eq!(layout.shape(), &[1, 1]);
162    assert_eq!(layout.strides(), &[0, 1]);
163    assert_eq!(layout.offset(), 0);
164}
165
166#[test]
167fn test_merge_be_example() {
168    let layout = ArrayLayout::<3>::new(&[16, 1, 4], &[16, 768, 4], 0)
169        .merge_be(0, 2)
170        .unwrap();
171    assert_eq!(layout.shape(), &[16, 4]);
172    assert_eq!(layout.strides(), &[16, 4]);
173    assert_eq!(layout.offset(), 0);
174}
175
176#[test]
177fn test_merge_le_example() {
178    let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0);
179    let merged_layout = layout.merge_le(0, 3).unwrap();
180
181    assert_eq!(merged_layout.shape(), &[24]);
182    assert_eq!(merged_layout.strides(), &[1]);
183    assert_eq!(merged_layout.offset(), 0);
184}
185
186#[test]
187fn test_merge_len_zero() {
188    let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0);
189    let merged_layout = layout.merge_le(0, 0).unwrap();
190
191    assert_eq!(merged_layout.shape(), &[4, 3, 2]);
192    assert_eq!(merged_layout.strides(), &[1, 4, 12]);
193    assert_eq!(merged_layout.offset(), 0);
194}
195
196#[test]
197fn test_partial_merge() {
198    let layout = ArrayLayout::<4>::new(&[2, 3, 4, 5], &[60, 20, 5, 1], 0);
199    let merged_layout = layout.merge_be(1, 2).unwrap();
200
201    assert_eq!(merged_layout.shape(), &[2, 12, 5]);
202    assert_eq!(merged_layout.strides(), &[60, 5, 1]);
203    assert_eq!(merged_layout.offset(), 0);
204}
205
206#[test]
207fn test_merge_free_example() {
208    let layout = ArrayLayout::<3>::new(&[3, 2, 4], &[4, 12, 1], 0);
209    let merged_layout = layout.merge_free(0, 3).unwrap();
210
211    assert_eq!(merged_layout.shape(), &[24]);
212    assert_eq!(merged_layout.strides(), &[1]);
213    assert_eq!(merged_layout.offset(), 0);
214}