ndarray_layout/transform/
merge.rs1use crate::{ArrayLayout, Endian};
2use std::iter::zip;
3
4#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct MergeArg {
7 pub start: usize,
9 pub len: usize,
11 pub endian: Option<Endian>,
13}
14
15impl<const N: usize> ArrayLayout<N> {
16 #[inline]
27 pub fn merge_be(&self, start: usize, len: usize) -> Option<Self> {
28 self.merge_many(&[MergeArg {
29 start,
30 len,
31 endian: Some(Endian::BigEndian),
32 }])
33 }
34
35 #[inline]
46 pub fn merge_le(&self, start: usize, len: usize) -> Option<Self> {
47 self.merge_many(&[MergeArg {
48 start,
49 len,
50 endian: Some(Endian::LittleEndian),
51 }])
52 }
53
54 #[inline]
65 pub fn merge_free(&self, start: usize, len: usize) -> Option<Self> {
66 self.merge_many(&[MergeArg {
67 start,
68 len,
69 endian: None,
70 }])
71 }
72
73 pub fn merge_many(&self, args: &[MergeArg]) -> Option<Self> {
75 let content = self.content();
76 let shape = content.shape();
77 let strides = content.strides();
78
79 let merged = args.iter().map(|arg| arg.len).sum::<usize>();
80 let mut ans = Self::with_ndim(self.ndim + args.len() - merged);
81
82 let mut content = ans.content_mut();
83 content.set_offset(self.offset());
84 let mut i = 0;
85 let mut push = |d, s| {
86 content.set_shape(i, d);
87 content.set_stride(i, s);
88 i += 1;
89 };
90
91 let mut last_end = 0;
92 for arg in args {
93 let &MergeArg { start, len, endian } = arg;
94 let end = start + len;
95
96 if len == 0 {
97 continue;
98 }
99
100 for j in last_end..arg.start {
101 push(shape[j], strides[j]);
102 }
103
104 let mut pairs = Vec::with_capacity(len);
105 for (&d, &s) in zip(&shape[start..end], &strides[start..end]) {
106 match d {
107 0 => todo!(),
108 1 => {}
109 _ => pairs.push((d, s)),
110 }
111 }
112 if pairs.is_empty() {
113 push(1, 0);
114 continue;
115 }
116 match endian {
117 Some(Endian::BigEndian) => pairs.reverse(),
118 Some(Endian::LittleEndian) => {}
119 None => pairs.sort_unstable_by_key(|(_, s)| s.unsigned_abs()),
120 }
121
122 let ((d, s), pairs) = pairs.split_first().unwrap();
123 let mut d = *d;
124
125 for &(d_, s_) in pairs {
126 if s_ == s * d as isize {
127 d *= d_
128 } else {
129 return None;
130 }
131 }
132
133 push(d, *s);
134 last_end = end;
135 }
136 for j in last_end..shape.len() {
137 push(shape[j], strides[j]);
138 }
139
140 Some(ans)
141 }
142}
143
144#[test]
145fn test_merge() {
146 let layout = ArrayLayout::<3>::new(&[16, 1, 4], &[16, 768, 4], 0)
147 .merge_be(0, 2)
148 .unwrap();
149 assert_eq!(layout.shape(), &[16, 4]);
150 assert_eq!(layout.strides(), &[16, 4]);
151 assert_eq!(layout.offset(), 0);
152}