ndarray_layout/transform/
merge.rs1use crate::{ArrayLayout, Endian};
2use std::iter::zip;
3
4#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct MergeArg {
7 pub start: usize,
9 pub len: usize,
11 pub endian: Option<Endian>,
13}
14
15impl<const N: usize> ArrayLayout<N> {
16 #[inline]
27 pub fn merge_be(&self, start: usize, len: usize) -> Option<Self> {
28 self.merge_many(&[MergeArg {
29 start,
30 len,
31 endian: Some(Endian::BigEndian),
32 }])
33 }
34
35 #[inline]
46 pub fn merge_le(&self, start: usize, len: usize) -> Option<Self> {
47 self.merge_many(&[MergeArg {
48 start,
49 len,
50 endian: Some(Endian::LittleEndian),
51 }])
52 }
53
54 #[inline]
65 pub fn merge_free(&self, start: usize, len: usize) -> Option<Self> {
66 self.merge_many(&[MergeArg {
67 start,
68 len,
69 endian: None,
70 }])
71 }
72
73 pub fn merge_many(&self, args: &[MergeArg]) -> Option<Self> {
75 let content = self.content();
76 let shape = content.shape();
77 let strides = content.strides();
78
79 let merged = args.iter().map(|arg| arg.len).sum::<usize>();
80 let mut ans = Self::with_ndim(self.ndim + args.len() - merged);
81
82 let mut content = ans.content_mut();
83 content.set_offset(self.offset());
84 let mut i = 0;
85 let mut push = |d, s| {
86 content.set_shape(i, d);
87 content.set_stride(i, s);
88 i += 1;
89 };
90
91 let mut last_end = 0;
92 for arg in args {
93 let &MergeArg { start, len, endian } = arg;
94 let end = start + len;
95
96 if len == 0 {
97 continue;
98 }
99
100 for j in last_end..arg.start {
101 push(shape[j], strides[j]);
102 }
103
104 let mut pairs = zip(&shape[start..end], &strides[start..end]).collect::<Vec<_>>();
105 match endian {
106 Some(Endian::BigEndian) => pairs.reverse(),
107 Some(Endian::LittleEndian) => {}
108 None => pairs.sort_unstable_by_key(|(_, &s)| s.unsigned_abs()),
109 }
110
111 let ((&d, &s), pairs) = pairs.split_first().unwrap();
112 let mut d = d;
113
114 for (&d_, &s_) in pairs {
115 if d <= 1 || d_ <= 1 || s_ == s * d as isize {
117 d *= d_
118 } else {
119 return None;
120 }
121 }
122
123 push(d, s);
124 last_end = end;
125 }
126 for j in last_end..shape.len() {
127 push(shape[j], strides[j]);
128 }
129
130 Some(ans)
131 }
132}