ndarray_layout/transform/tile.rs
1use crate::{ArrayLayout, Endian};
2use std::iter::zip;
3
4/// 分块变换参数。
5#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct TileArg<'a> {
7 /// 分块的轴。
8 pub axis: usize,
9 /// 分块的顺序。
10 pub endian: Endian,
11 /// 分块的大小。
12 pub tiles: &'a [usize],
13}
14
15impl<const N: usize> ArrayLayout<N> {
16 /// 分块变换是将单个维度划分为多个分块的变换。
17 /// 大端分块使得分块后范围更大的维度在形状中更靠前的位置。
18 ///
19 /// ```rust
20 /// # use ndarray_layout::ArrayLayout;
21 /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]);
22 /// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
23 /// assert_eq!(layout.strides(), &[18, 6, 3, 1]);
24 /// assert_eq!(layout.offset(), 0);
25 /// ```
26 #[inline]
27 pub fn tile_be(&self, axis: usize, tiles: &[usize]) -> Self {
28 self.tile_many(&[TileArg {
29 axis,
30 endian: Endian::BigEndian,
31 tiles,
32 }])
33 }
34
35 /// 分块变换是将单个维度划分为多个分块的变换。
36 /// 小端分块使得分块后范围更小的维度在形状中更靠前的位置。
37 ///
38 /// ```rust
39 /// # use ndarray_layout::ArrayLayout;
40 /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]);
41 /// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
42 /// assert_eq!(layout.strides(), &[18, 6, 1, 2]);
43 /// assert_eq!(layout.offset(), 0);
44 /// ```
45 #[inline]
46 pub fn tile_le(&self, axis: usize, tiles: &[usize]) -> Self {
47 self.tile_many(&[TileArg {
48 axis,
49 endian: Endian::LittleEndian,
50 tiles,
51 }])
52 }
53
54 /// 一次对多个阶进行分块变换。
55 pub fn tile_many(&self, mut args: &[TileArg]) -> Self {
56 let content = self.content();
57 let shape = content.shape();
58 let iter = zip(shape, content.strides()).enumerate();
59
60 let check = |&TileArg { axis, tiles, .. }| {
61 shape
62 .get(axis)
63 .filter(|&&d| d == tiles.iter().product())
64 .is_some()
65 };
66
67 let (mut new, mut last_axis) = match args {
68 [first, ..] => {
69 assert!(check(first));
70 (first.tiles.len(), first.axis)
71 }
72 [..] => return self.clone(),
73 };
74 for arg in &args[1..] {
75 assert!(check(arg));
76 assert!(arg.axis > last_axis);
77 new += arg.tiles.len();
78 last_axis = arg.axis;
79 }
80
81 let mut ans = Self::with_ndim(self.ndim + new - args.len());
82
83 let mut content = ans.content_mut();
84 content.set_offset(self.offset());
85 let mut j = 0;
86 let mut push = |t, s| {
87 content.set_shape(j, t);
88 content.set_stride(j, s);
89 j += 1;
90 };
91
92 for (i, (&d, &s)) in iter {
93 match *args {
94 [TileArg {
95 axis,
96 endian,
97 tiles,
98 }, ref tail @ ..]
99 if axis == i =>
100 {
101 match endian {
102 Endian::BigEndian => {
103 // tile : [a, b , c]
104 // strides: [s * c * b, s * c, s]
105 let mut s = s * d as isize;
106 for &t in tiles {
107 s /= t as isize;
108 push(t, s);
109 }
110 }
111 Endian::LittleEndian => {
112 // tile : [a, b , c ]
113 // strides: [s, s * a, s * a * b]
114 let mut s = s;
115 for &t in tiles {
116 push(t, s);
117 s *= t as isize;
118 }
119 }
120 }
121 args = tail;
122 }
123 [..] => push(d, s),
124 }
125 }
126 ans
127 }
128}