ndarray_layout/transform/
merge.rs1use crate::{ArrayLayout, Endian};
2use std::iter::zip;
3
4#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct MergeArg {
7 pub start: usize,
9 pub len: usize,
11 pub endian: Option<Endian>,
13}
14
15impl<const N: usize> ArrayLayout<N> {
16 #[inline]
27 pub fn merge_be(&self, start: usize, len: usize) -> Option<Self> {
28 self.merge_many(&[MergeArg {
29 start,
30 len,
31 endian: Some(Endian::BigEndian),
32 }])
33 }
34
35 #[inline]
46 pub fn merge_le(&self, start: usize, len: usize) -> Option<Self> {
47 self.merge_many(&[MergeArg {
48 start,
49 len,
50 endian: Some(Endian::LittleEndian),
51 }])
52 }
53
54 #[inline]
65 pub fn merge_free(&self, start: usize, len: usize) -> Option<Self> {
66 self.merge_many(&[MergeArg {
67 start,
68 len,
69 endian: None,
70 }])
71 }
72
73 pub fn merge_many(&self, args: &[MergeArg]) -> Option<Self> {
75 let content = self.content();
76 let shape = content.shape();
77 let strides = content.strides();
78
79 let merged = args.iter().map(|arg| arg.len.max(1)).sum::<usize>();
80 if merged == args.len() {
81 return Some(self.clone());
82 }
83
84 let mut ans = Self::with_ndim(self.ndim + args.len() - merged);
85
86 let mut content = ans.content_mut();
87 content.set_offset(self.offset());
88 let mut i = 0;
89 let mut push = |d, s| {
90 content.set_shape(i, d);
91 content.set_stride(i, s);
92 i += 1;
93 };
94
95 let mut last_end = 0;
96 for arg in args {
97 let &MergeArg { start, len, endian } = arg;
98 let end = start + len;
99
100 if len < 2 {
101 continue;
102 }
103
104 for j in last_end..arg.start {
105 push(shape[j], strides[j]);
106 }
107
108 let mut pairs = Vec::with_capacity(len);
109 for (&d, &s) in zip(&shape[start..end], &strides[start..end]) {
110 match d {
111 0 => todo!(),
112 1 => {}
113 _ => pairs.push((d, s)),
114 }
115 }
116
117 last_end = end;
118
119 if pairs.is_empty() {
120 push(1, 0);
121 continue;
122 }
123 match endian {
124 Some(Endian::BigEndian) => pairs.reverse(),
125 Some(Endian::LittleEndian) => {}
126 None => pairs.sort_unstable_by_key(|(_, s)| s.unsigned_abs()),
127 }
128
129 let ((d, s), pairs) = pairs.split_first().unwrap();
130 let mut d = *d;
131
132 for &(d_, s_) in pairs {
133 if s_ == s * d as isize {
134 d *= d_
135 } else {
136 return None;
137 }
138 }
139
140 push(d, *s);
141 }
142 for j in last_end..shape.len() {
143 push(shape[j], strides[j]);
144 }
145
146 Some(ans)
147 }
148}
149
150#[test]
151fn test_merge_return_none() {
152 let layout = ArrayLayout::<3>::new(&[16, 4, 2], &[8, 4, 1], 0).merge_be(0, 3);
153 assert!(layout.is_none());
154}
155
156#[test]
157fn test_merge_pairs_empyt() {
158 let layout = ArrayLayout::<3>::new(&[1, 1, 1], &[1, 1, 1], 0)
159 .merge_be(0, 2)
160 .unwrap();
161 assert_eq!(layout.shape(), &[1, 1]);
162 assert_eq!(layout.strides(), &[0, 1]);
163 assert_eq!(layout.offset(), 0);
164}
165
166#[test]
167fn test_merge_be_example() {
168 let layout = ArrayLayout::<3>::new(&[16, 1, 4], &[16, 768, 4], 0)
169 .merge_be(0, 2)
170 .unwrap();
171 assert_eq!(layout.shape(), &[16, 4]);
172 assert_eq!(layout.strides(), &[16, 4]);
173 assert_eq!(layout.offset(), 0);
174}
175
176#[test]
177fn test_merge_le_example() {
178 let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0);
179 let merged_layout = layout.merge_le(0, 3).unwrap();
180
181 assert_eq!(merged_layout.shape(), &[24]);
182 assert_eq!(merged_layout.strides(), &[1]);
183 assert_eq!(merged_layout.offset(), 0);
184}
185
186#[test]
187fn test_merge_len_zero() {
188 let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0);
189 let merged_layout = layout.merge_le(0, 0).unwrap();
190
191 assert_eq!(merged_layout.shape(), &[4, 3, 2]);
192 assert_eq!(merged_layout.strides(), &[1, 4, 12]);
193 assert_eq!(merged_layout.offset(), 0);
194}
195
196#[test]
197fn test_partial_merge() {
198 let layout = ArrayLayout::<4>::new(&[2, 3, 4, 5], &[60, 20, 5, 1], 0);
199 let merged_layout = layout.merge_be(1, 2).unwrap();
200
201 assert_eq!(merged_layout.shape(), &[2, 12, 5]);
202 assert_eq!(merged_layout.strides(), &[60, 5, 1]);
203 assert_eq!(merged_layout.offset(), 0);
204}
205
206#[test]
207fn test_merge_free_example() {
208 let layout = ArrayLayout::<3>::new(&[3, 2, 4], &[4, 12, 1], 0);
209 let merged_layout = layout.merge_free(0, 3).unwrap();
210
211 assert_eq!(merged_layout.shape(), &[24]);
212 assert_eq!(merged_layout.strides(), &[1]);
213 assert_eq!(merged_layout.offset(), 0);
214}