1#![doc = include_str!("../README.md")]
2#![deny(warnings, missing_docs)]
3
4use itertools::izip;
5use ndarray_layout::ArrayLayout;
6use rayon::iter::{IntoParallelIterator, ParallelIterator};
7use std::{cmp::Ordering, ptr::copy_nonoverlapping};
8
9pub extern crate ndarray_layout;
10
11#[derive(Clone, Debug)]
14#[repr(transparent)]
15pub struct Rearranging(Box<[isize]>);
16
17#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
19#[repr(u8)]
20pub enum SchemeError {
21 ShapeMismatch,
23 DimReduce,
25}
26
27impl Rearranging {
28 pub fn new<const M: usize, const N: usize>(
32 dst: &ArrayLayout<M>,
33 src: &ArrayLayout<N>,
34 unit: usize,
35 ) -> Result<Self, SchemeError> {
36 let ndim = dst.ndim();
38 if src.ndim() != ndim {
39 return Err(SchemeError::ShapeMismatch);
40 }
41 #[derive(Clone, PartialEq, Eq, Debug)]
43 struct Dim {
44 len: usize,
45 dst: isize,
46 src: isize,
47 }
48 let mut dims = Vec::with_capacity(ndim);
49 for (&dd, &ds, &sd, &ss) in izip!(dst.shape(), src.shape(), dst.strides(), src.strides()) {
50 if dd != ds {
51 return Err(SchemeError::ShapeMismatch);
52 }
53 if dd != 1 {
55 if sd == 0 {
56 return Err(SchemeError::DimReduce);
57 }
58 dims.push(Dim {
59 len: dd,
60 dst: sd,
61 src: ss,
62 })
63 }
64 }
65 dims.sort_unstable_by(|a, b| {
67 use Ordering::Equal as Eq;
68 match a.dst.abs().cmp(&b.dst.abs()) {
69 Eq => match a.src.abs().cmp(&b.src.abs()) {
70 Eq => a.len.cmp(&b.len),
71 ord => ord.reverse(),
72 },
73 ord => ord.reverse(),
74 }
75 });
76 let mut unit = unit as isize;
78 let mut ndim = dims.len();
79 for dim in dims.iter_mut().rev() {
81 if dim.dst == unit && dim.src == unit {
82 unit *= dim.len as isize;
83 ndim -= 1
84 } else {
85 break;
86 }
87 }
88 dims.truncate(ndim);
89 for i in (1..dims.len()).rev() {
91 let (head, tail) = dims.split_at_mut(i);
92 let f = &mut head[i - 1]; let b = &mut tail[0]; let len = b.len as isize;
95 if b.dst * len == f.dst && b.src * len == f.src {
96 *f = Dim {
97 len: b.len * f.len,
98 dst: b.dst,
99 src: b.src,
100 };
101 *b = Dim {
102 len: 1,
103 dst: 0,
104 src: 0,
105 };
106 ndim -= 1
107 }
108 }
109 let mut ans = Self(vec![0isize; 4 + ndim * 3].into_boxed_slice());
111 ans.0[0] = unit as _;
112 ans.0[1] = dst.offset();
113 ans.0[2] = src.offset();
114 let layout = &mut ans.0[3..];
115 layout[ndim] = 1;
116 for (i, Dim { len, dst, src }) in dims.into_iter().filter(|d| d.len != 1).enumerate() {
117 layout[i] = len as _;
118 layout[i + 1 + ndim] = dst;
119 layout[i + 1 + ndim * 2] = src;
120 }
121 for i in (1..=ndim).rev() {
122 layout[i - 1] *= layout[i]
123 }
124 Ok(ans)
125 }
126
127 pub fn distribute_unit(&self, candidates: impl IntoIterator<Item = usize>) -> Option<Self> {
129 let unit = candidates.into_iter().find(|n| self.unit() % n == 0)?;
130 if unit == self.unit() {
131 return Some(self.clone());
132 }
133
134 let ndim = self.ndim();
135 let mut layout = vec![0isize; 4 + (ndim + 1) * 3].into_boxed_slice();
136 layout[0] = unit as _;
137 layout[1] = self.dst_offset();
138 layout[2] = self.src_offset();
139
140 let (_, tail) = layout.split_at_mut(3);
141 let (idx, tail) = tail.split_at_mut(ndim + 2);
142 let (dst, src) = tail.split_at_mut(ndim + 1);
143
144 let (_, tail) = self.0.split_at(3);
145 let (idx_, tail) = tail.split_at(ndim + 1);
146 let (dst_, src_) = tail.split_at(ndim);
147
148 idx[ndim + 1] = 1;
149 let extra = (self.unit() / unit) as isize;
150 for (new, old) in izip!(idx, idx_) {
151 *new = *old * extra;
152 }
153
154 fn copy_value(new: &mut [isize], old: &[isize], unit: usize) {
155 let [head @ .., tail] = new else {
156 unreachable!()
157 };
158 head.copy_from_slice(old);
159 *tail = unit as _;
160 }
161 copy_value(dst, dst_, unit);
162 copy_value(src, src_, unit);
163
164 Some(Self(layout))
165 }
166
167 #[inline]
169 pub fn ndim(&self) -> usize {
170 (self.0.len() - 4) / 3
171 }
172
173 #[inline]
175 pub fn unit(&self) -> usize {
176 self.0[0] as _
177 }
178
179 #[inline]
181 pub fn dst_offset(&self) -> isize {
182 self.0[1]
183 }
184
185 #[inline]
187 pub fn src_offset(&self) -> isize {
188 self.0[2]
189 }
190
191 #[inline]
193 pub fn count(&self) -> usize {
194 self.0[3] as _
195 }
196
197 #[inline]
199 pub fn idx_strides(&self) -> &[isize] {
200 let ndim = self.ndim();
201 &self.0[4..][..ndim]
202 }
203
204 #[inline]
206 pub fn dst_strides(&self) -> &[isize] {
207 let ndim = self.ndim();
208 &self.0[4 + ndim..][..ndim]
209 }
210
211 #[inline]
213 pub fn src_strides(&self) -> &[isize] {
214 let ndim = self.ndim();
215 &self.0[4 + ndim * 2..][..ndim]
216 }
217
218 pub fn shape(&self) -> impl Iterator<Item = usize> + '_ {
220 let ndim = self.ndim();
221 self.0[3..][..ndim + 1]
222 .windows(2)
223 .map(|pair| (pair[0] / pair[1]) as usize)
224 }
225
226 pub unsafe fn launch(&self, dst: *mut u8, src: *const u8) {
232 let dst = unsafe { dst.byte_offset(self.dst_offset()) };
233 let src = unsafe { src.byte_offset(self.src_offset()) };
234 match self.count() {
235 1 => unsafe { copy_nonoverlapping(src, dst, self.unit()) },
236 count => {
237 let dst = dst as isize;
238 let src = src as isize;
239 let idx_strides = self.idx_strides();
240 let dst_strides = self.dst_strides();
241 let src_strides = self.src_strides();
242 (0..count as isize).into_par_iter().for_each(|mut rem| {
243 let mut dst = dst;
244 let mut src = src;
245 for (i, &s) in idx_strides.iter().enumerate() {
246 let k = rem / s;
247 dst += k * dst_strides[i];
248 src += k * src_strides[i];
249 rem %= s
250 }
251 unsafe { copy_nonoverlapping::<u8>(src as _, dst as _, self.unit()) }
252 })
253 }
254 }
255 }
256}
257
258#[test]
259fn test_scheme() {
260 let shape = [4, 3, 2, 1, 2, 3, 4];
261 let dst = [288, 96, 48, 48, 24, 8, 2];
262 let src = [576, 192, 96, 48, 8, 16, 2];
263 let dst = ArrayLayout::<7>::new(&shape, &dst, 0);
264 let src = ArrayLayout::<7>::new(&shape, &src, 0);
265 let scheme = Rearranging::new(&dst, &src, 2).unwrap();
266 assert_eq!(scheme.ndim(), 3);
267 assert_eq!(scheme.dst_offset(), 0);
268 assert_eq!(scheme.src_offset(), 0);
269 assert_eq!(scheme.unit(), 8);
270 assert_eq!(scheme.count(), 24 * 2 * 3);
271 assert_eq!(scheme.idx_strides(), [6, 3, 1]);
272 assert_eq!(scheme.dst_strides(), [48, 24, 8]);
273 assert_eq!(scheme.src_strides(), [96, 8, 16]);
274 assert_eq!(scheme.shape().collect::<Vec<_>>(), [24, 2, 3]);
275}
276
277#[test]
278fn test_distribute_unit() {
279 let shape = [4, 3, 2];
281 let dst = [24, 8, 2];
282 let src = [48, 8, 16];
283 let dst = ArrayLayout::<3>::new(&shape, &dst, 0);
284 let src = ArrayLayout::<3>::new(&shape, &src, 0);
285 let scheme = Rearranging::new(&dst, &src, 2).unwrap();
286
287 let candidates = vec![2];
289 let new_scheme = scheme.distribute_unit(candidates).unwrap();
290 assert_eq!(new_scheme.unit(), 2);
291 assert_eq!(new_scheme.count(), scheme.count());
292 assert_eq!(new_scheme.idx_strides(), scheme.idx_strides());
293 assert_eq!(new_scheme.dst_strides(), scheme.dst_strides());
294 assert_eq!(new_scheme.src_strides(), scheme.src_strides());
295
296 let candidates = vec![1];
298 let new_scheme = scheme.distribute_unit(candidates).unwrap();
299 assert_eq!(new_scheme.unit(), 1);
300 assert_eq!(new_scheme.count(), scheme.count() * 2);
301 assert_eq!(
302 new_scheme
303 .idx_strides()
304 .iter()
305 .take(scheme.idx_strides().len())
306 .map(|&x| x / 2)
307 .collect::<Vec<_>>(),
308 scheme.idx_strides()
309 );
310 assert_eq!(
311 new_scheme
312 .dst_strides()
313 .iter()
314 .take(scheme.idx_strides().len())
315 .cloned()
316 .collect::<Vec<_>>(),
317 scheme.dst_strides()
318 );
319 assert_eq!(
320 new_scheme
321 .src_strides()
322 .iter()
323 .take(scheme.idx_strides().len())
324 .cloned()
325 .collect::<Vec<_>>(),
326 scheme.src_strides()
327 );
328
329 let candidates = vec![4, 2, 1];
331 let new_scheme = scheme.distribute_unit(candidates).unwrap();
332 assert_eq!(new_scheme.unit(), 2); }