ndarray_layout/transform/tile.rs
1use crate::{ArrayLayout, Endian};
2use std::iter::zip;
3
4/// 分块变换参数。
5#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct TileArg<'a> {
7 /// 分块的轴。
8 pub axis: usize,
9 /// 分块的顺序。
10 pub endian: Endian,
11 /// 分块的大小。
12 pub tiles: &'a [usize],
13}
14
15impl<const N: usize> ArrayLayout<N> {
16 /// 分块变换是将单个维度划分为多个分块的变换。
17 /// 大端分块使得分块后范围更大的维度在形状中更靠前的位置。
18 ///
19 /// ```rust
20 /// # use ndarray_layout::ArrayLayout;
21 /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]);
22 /// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
23 /// assert_eq!(layout.strides(), &[18, 6, 3, 1]);
24 /// assert_eq!(layout.offset(), 0);
25 /// ```
26 #[inline]
27 pub fn tile_be(&self, axis: usize, tiles: &[usize]) -> Self {
28 self.tile_many(&[TileArg {
29 axis,
30 endian: Endian::BigEndian,
31 tiles,
32 }])
33 }
34
35 /// 分块变换是将单个维度划分为多个分块的变换。
36 /// 小端分块使得分块后范围更小的维度在形状中更靠前的位置。
37 ///
38 /// ```rust
39 /// # use ndarray_layout::ArrayLayout;
40 /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]);
41 /// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
42 /// assert_eq!(layout.strides(), &[18, 6, 1, 2]);
43 /// assert_eq!(layout.offset(), 0);
44 /// ```
45 #[inline]
46 pub fn tile_le(&self, axis: usize, tiles: &[usize]) -> Self {
47 self.tile_many(&[TileArg {
48 axis,
49 endian: Endian::LittleEndian,
50 tiles,
51 }])
52 }
53
54 /// 一次对多个阶进行分块变换。
55 pub fn tile_many(&self, mut args: &[TileArg]) -> Self {
56 let content = self.content();
57 let shape = content.shape();
58 let iter = zip(shape, content.strides()).enumerate();
59
60 let check = |&TileArg { axis, tiles, .. }| {
61 shape
62 .get(axis)
63 .filter(|&&d| d == tiles.iter().product())
64 .is_some()
65 };
66
67 let (mut new, mut last_axis) = match args {
68 [first, ..] => {
69 assert!(check(first));
70 (first.tiles.len(), first.axis)
71 }
72 [..] => return self.clone(),
73 };
74 for arg in &args[1..] {
75 assert!(check(arg));
76 assert!(arg.axis > last_axis);
77 new += arg.tiles.len();
78 last_axis = arg.axis;
79 }
80
81 let mut ans = Self::with_ndim(self.ndim + new - args.len());
82
83 let mut content = ans.content_mut();
84 content.set_offset(self.offset());
85 let mut j = 0;
86 let mut push = |t, s| {
87 content.set_shape(j, t);
88 content.set_stride(j, s);
89 j += 1;
90 };
91
92 for (i, (&d, &s)) in iter {
93 match *args {
94 [
95 TileArg {
96 axis,
97 endian,
98 tiles,
99 },
100 ref tail @ ..,
101 ] if axis == i => {
102 match endian {
103 Endian::BigEndian => {
104 // tile : [a, b , c]
105 // strides: [s * c * b, s * c, s]
106 let mut s = s * d as isize;
107 for &t in tiles {
108 s /= t as isize;
109 push(t, s);
110 }
111 }
112 Endian::LittleEndian => {
113 // tile : [a, b , c ]
114 // strides: [s, s * a, s * a * b]
115 let mut s = s;
116 for &t in tiles {
117 push(t, s);
118 s *= t as isize;
119 }
120 }
121 }
122 args = tail;
123 }
124 [..] => push(d, s),
125 }
126 }
127 ans
128 }
129}