mem_rearrange/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(warnings, missing_docs)]
3
4use itertools::izip;
5use ndarray_layout::ArrayLayout;
6use rayon::iter::{IntoParallelIterator, ParallelIterator};
7use std::{cmp::Ordering, ptr::copy_nonoverlapping};
8
9pub extern crate ndarray_layout;
10
11/// 存储重排任务对象。
12// Layout: | unit | dst offset | src offset | count | idx strides | dst strides | src strides |
13#[derive(Clone, Debug)]
14#[repr(transparent)]
15pub struct Rearranging(Box<[isize]>);
16
17/// 重排方案异常。
18#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
19#[repr(u8)]
20pub enum SchemeError {
21    /// 输入输出布局形状不一致。
22    ShapeMismatch,
23    /// 输出布局中含有广播维度,导致写规约。
24    DimReduce,
25}
26
27impl Rearranging {
28    /// 从输出布局、输入布局和单元规模构造重排方案。
29    ///
30    /// 单元规模 `unit` 是数组中单个元素的字节数。
31    pub fn new<const M: usize, const N: usize>(
32        dst: &ArrayLayout<M>,
33        src: &ArrayLayout<N>,
34        unit: usize,
35    ) -> Result<Self, SchemeError> {
36        // # 检查基本属性
37        let ndim = dst.ndim();
38        if src.ndim() != ndim {
39            return Err(SchemeError::ShapeMismatch);
40        }
41        // # 输入形状
42        #[derive(Clone, PartialEq, Eq, Debug)]
43        struct Dim {
44            len: usize,
45            dst: isize,
46            src: isize,
47        }
48        let mut dims = Vec::with_capacity(ndim);
49        for (&dd, &ds, &sd, &ss) in izip!(dst.shape(), src.shape(), dst.strides(), src.strides()) {
50            if dd != ds {
51                return Err(SchemeError::ShapeMismatch);
52            }
53            // 剔除初始的 1 长维度
54            if dd != 1 {
55                if sd == 0 {
56                    return Err(SchemeError::DimReduce);
57                }
58                dims.push(Dim {
59                    len: dd,
60                    dst: sd,
61                    src: ss,
62                })
63            }
64        }
65        // # 排序
66        dims.sort_unstable_by(|a, b| {
67            use Ordering::Equal as Eq;
68            match a.dst.abs().cmp(&b.dst.abs()) {
69                Eq => match a.src.abs().cmp(&b.src.abs()) {
70                    Eq => a.len.cmp(&b.len),
71                    ord => ord.reverse(),
72                },
73                ord => ord.reverse(),
74            }
75        });
76        // # 合并连续维度
77        let mut unit = unit as isize;
78        let mut ndim = dims.len();
79        // ## 合并末尾连续维度到 unit
80        for dim in dims.iter_mut().rev() {
81            if dim.dst == unit && dim.src == unit {
82                unit *= dim.len as isize;
83                ndim -= 1
84            } else {
85                break;
86            }
87        }
88        dims.truncate(ndim);
89        // ## 合并任意连续维度
90        for i in (1..dims.len()).rev() {
91            let (head, tail) = dims.split_at_mut(i);
92            let f = &mut head[i - 1]; // f for front
93            let b = &mut tail[0]; // b for back
94            let len = b.len as isize;
95            if b.dst * len == f.dst && b.src * len == f.src {
96                *f = Dim {
97                    len: b.len * f.len,
98                    dst: b.dst,
99                    src: b.src,
100                };
101                *b = Dim {
102                    len: 1,
103                    dst: 0,
104                    src: 0,
105                };
106                ndim -= 1
107            }
108        }
109        // # 合并空间
110        let mut ans = Self(vec![0isize; 4 + ndim * 3].into_boxed_slice());
111        ans.0[0] = unit as _;
112        ans.0[1] = dst.offset();
113        ans.0[2] = src.offset();
114        let layout = &mut ans.0[3..];
115        layout[ndim] = 1;
116        for (i, Dim { len, dst, src }) in dims.into_iter().filter(|d| d.len != 1).enumerate() {
117            layout[i] = len as _;
118            layout[i + 1 + ndim] = dst;
119            layout[i + 1 + ndim * 2] = src;
120        }
121        for i in (1..=ndim).rev() {
122            layout[i - 1] *= layout[i]
123        }
124        Ok(ans)
125    }
126
127    /// 从候选值中选择一个能整除当前单元规模的值作为新的单元规模
128    pub fn distribute_unit(&self, candidates: impl IntoIterator<Item = usize>) -> Option<Self> {
129        let unit = candidates.into_iter().find(|n| self.unit() % n == 0)?;
130        if unit == self.unit() {
131            return Some(self.clone());
132        }
133
134        let ndim = self.ndim();
135        let mut layout = vec![0isize; 4 + (ndim + 1) * 3].into_boxed_slice();
136        layout[0] = unit as _;
137        layout[1] = self.dst_offset();
138        layout[2] = self.src_offset();
139
140        let (_, tail) = layout.split_at_mut(3);
141        let (idx, tail) = tail.split_at_mut(ndim + 2);
142        let (dst, src) = tail.split_at_mut(ndim + 1);
143
144        let (_, tail) = self.0.split_at(3);
145        let (idx_, tail) = tail.split_at(ndim + 1);
146        let (dst_, src_) = tail.split_at(ndim);
147
148        idx[ndim + 1] = 1;
149        let extra = (self.unit() / unit) as isize;
150        for (new, old) in izip!(idx, idx_) {
151            *new = *old * extra;
152        }
153
154        fn copy_value(new: &mut [isize], old: &[isize], unit: usize) {
155            let [head @ .., tail] = new else {
156                unreachable!()
157            };
158            head.copy_from_slice(old);
159            *tail = unit as _;
160        }
161        copy_value(dst, dst_, unit);
162        copy_value(src, src_, unit);
163
164        Some(Self(layout))
165    }
166
167    /// 执行方案维数。
168    #[inline]
169    pub fn ndim(&self) -> usize {
170        (self.0.len() - 4) / 3
171    }
172
173    /// 读写单元规模。
174    #[inline]
175    pub fn unit(&self) -> usize {
176        self.0[0] as _
177    }
178
179    /// 输出基址偏移。
180    #[inline]
181    pub fn dst_offset(&self) -> isize {
182        self.0[1]
183    }
184
185    /// 输入基址偏移。
186    #[inline]
187    pub fn src_offset(&self) -> isize {
188        self.0[2]
189    }
190
191    /// 读写单元数量。
192    #[inline]
193    pub fn count(&self) -> usize {
194        self.0[3] as _
195    }
196
197    /// 索引步长。
198    #[inline]
199    pub fn idx_strides(&self) -> &[isize] {
200        let ndim = self.ndim();
201        &self.0[4..][..ndim]
202    }
203
204    /// 输出数据步长。
205    #[inline]
206    pub fn dst_strides(&self) -> &[isize] {
207        let ndim = self.ndim();
208        &self.0[4 + ndim..][..ndim]
209    }
210
211    /// 输入数据步长。
212    #[inline]
213    pub fn src_strides(&self) -> &[isize] {
214        let ndim = self.ndim();
215        &self.0[4 + ndim * 2..][..ndim]
216    }
217
218    /// 计算方案涉及的形状。
219    pub fn shape(&self) -> impl Iterator<Item = usize> + '_ {
220        let ndim = self.ndim();
221        self.0[3..][..ndim + 1]
222            .windows(2)
223            .map(|pair| (pair[0] / pair[1]) as usize)
224    }
225
226    /// 执行存储重排。
227    ///
228    /// # Safety
229    ///
230    /// `dst` and `src` must be valid pointers and must able to access with the scheme.
231    pub unsafe fn launch(&self, dst: *mut u8, src: *const u8) {
232        let dst = unsafe { dst.byte_offset(self.dst_offset()) };
233        let src = unsafe { src.byte_offset(self.src_offset()) };
234        match self.count() {
235            1 => unsafe { copy_nonoverlapping(src, dst, self.unit()) },
236            count => {
237                let dst = dst as isize;
238                let src = src as isize;
239                let idx_strides = self.idx_strides();
240                let dst_strides = self.dst_strides();
241                let src_strides = self.src_strides();
242                (0..count as isize).into_par_iter().for_each(|mut rem| {
243                    let mut dst = dst;
244                    let mut src = src;
245                    for (i, &s) in idx_strides.iter().enumerate() {
246                        let k = rem / s;
247                        dst += k * dst_strides[i];
248                        src += k * src_strides[i];
249                        rem %= s
250                    }
251                    unsafe { copy_nonoverlapping::<u8>(src as _, dst as _, self.unit()) }
252                })
253            }
254        }
255    }
256}
257
258#[test]
259fn test_scheme() {
260    let shape = [4, 3, 2, 1, 2, 3, 4];
261    let dst = [288, 96, 48, 48, 24, 8, 2];
262    let src = [576, 192, 96, 48, 8, 16, 2];
263    let dst = ArrayLayout::<7>::new(&shape, &dst, 0);
264    let src = ArrayLayout::<7>::new(&shape, &src, 0);
265    let scheme = Rearranging::new(&dst, &src, 2).unwrap();
266    assert_eq!(scheme.ndim(), 3);
267    assert_eq!(scheme.dst_offset(), 0);
268    assert_eq!(scheme.src_offset(), 0);
269    assert_eq!(scheme.unit(), 8);
270    assert_eq!(scheme.count(), 24 * 2 * 3);
271    assert_eq!(scheme.idx_strides(), [6, 3, 1]);
272    assert_eq!(scheme.dst_strides(), [48, 24, 8]);
273    assert_eq!(scheme.src_strides(), [96, 8, 16]);
274    assert_eq!(scheme.shape().collect::<Vec<_>>(), [24, 2, 3]);
275}
276
277#[test]
278fn test_distribute_unit() {
279    // 创建一个测试用的重排方案
280    let shape = [4, 3, 2];
281    let dst = [24, 8, 2];
282    let src = [48, 8, 16];
283    let dst = ArrayLayout::<3>::new(&shape, &dst, 0);
284    let src = ArrayLayout::<3>::new(&shape, &src, 0);
285    let scheme = Rearranging::new(&dst, &src, 2).unwrap();
286
287    // 测试1: 使用相同的单元大小
288    let candidates = vec![2];
289    let new_scheme = scheme.distribute_unit(candidates).unwrap();
290    assert_eq!(new_scheme.unit(), 2);
291    assert_eq!(new_scheme.count(), scheme.count());
292    assert_eq!(new_scheme.idx_strides(), scheme.idx_strides());
293    assert_eq!(new_scheme.dst_strides(), scheme.dst_strides());
294    assert_eq!(new_scheme.src_strides(), scheme.src_strides());
295
296    // 测试2: 使用更小的单元大小
297    let candidates = vec![1];
298    let new_scheme = scheme.distribute_unit(candidates).unwrap();
299    assert_eq!(new_scheme.unit(), 1);
300    assert_eq!(new_scheme.count(), scheme.count() * 2);
301    assert_eq!(
302        new_scheme
303            .idx_strides()
304            .iter()
305            .take(scheme.idx_strides().len())
306            .map(|&x| x / 2)
307            .collect::<Vec<_>>(),
308        scheme.idx_strides()
309    );
310    assert_eq!(
311        new_scheme
312            .dst_strides()
313            .iter()
314            .take(scheme.idx_strides().len())
315            .cloned()
316            .collect::<Vec<_>>(),
317        scheme.dst_strides()
318    );
319    assert_eq!(
320        new_scheme
321            .src_strides()
322            .iter()
323            .take(scheme.idx_strides().len())
324            .cloned()
325            .collect::<Vec<_>>(),
326        scheme.src_strides()
327    );
328
329    // 测试3: 使用多个候选值
330    let candidates = vec![4, 2, 1];
331    let new_scheme = scheme.distribute_unit(candidates).unwrap();
332    assert_eq!(new_scheme.unit(), 2); // 应该选择第一个能整除的值
333}