ndarray_layout/transform/
tile.rs1use crate::{ArrayLayout, Endian};
2use std::iter::zip;
3
4#[derive(Clone, PartialEq, Eq, Debug)]
6pub struct TileArg<'a> {
7 pub axis: usize,
9 pub endian: Endian,
11 pub tiles: &'a [usize],
13}
14
15impl<const N: usize> ArrayLayout<N> {
16 #[inline]
27 pub fn tile_be(&self, axis: usize, tiles: &[usize]) -> Self {
28 self.tile_many(&[TileArg {
29 axis,
30 endian: Endian::BigEndian,
31 tiles,
32 }])
33 }
34
35 #[inline]
46 pub fn tile_le(&self, axis: usize, tiles: &[usize]) -> Self {
47 self.tile_many(&[TileArg {
48 axis,
49 endian: Endian::LittleEndian,
50 tiles,
51 }])
52 }
53
54 pub fn tile_many(&self, mut args: &[TileArg]) -> Self {
56 let content = self.content();
57 let shape = content.shape();
58 let iter = zip(shape, content.strides()).enumerate();
59
60 let check = |&TileArg { axis, tiles, .. }| {
61 shape
62 .get(axis)
63 .filter(|&&d| d == tiles.iter().product())
64 .is_some()
65 };
66
67 let (mut new, mut last_axis) = match args {
68 [first, ..] => {
69 assert!(check(first));
70 (first.tiles.len(), first.axis)
71 }
72 [..] => return self.clone(),
73 };
74 for arg in &args[1..] {
75 assert!(check(arg));
76 assert!(arg.axis > last_axis);
77 new += arg.tiles.len();
78 last_axis = arg.axis;
79 }
80
81 let mut ans = Self::with_ndim(self.ndim + new - args.len());
82
83 let mut content = ans.content_mut();
84 content.set_offset(self.offset());
85 let mut j = 0;
86 let mut push = |t, s| {
87 content.set_shape(j, t);
88 content.set_stride(j, s);
89 j += 1;
90 };
91
92 for (i, (&d, &s)) in iter {
93 match *args {
94 [
95 TileArg {
96 axis,
97 endian,
98 tiles,
99 },
100 ref tail @ ..,
101 ] if axis == i => {
102 match endian {
103 Endian::BigEndian => {
104 let mut s = s * d as isize;
107 for &t in tiles {
108 s /= t as isize;
109 push(t, s);
110 }
111 }
112 Endian::LittleEndian => {
113 let mut s = s;
116 for &t in tiles {
117 push(t, s);
118 s *= t as isize;
119 }
120 }
121 }
122 args = tail;
123 }
124 [..] => push(d, s),
125 }
126 }
127 ans
128 }
129}
130
131#[test]
132fn test_tile_be() {
133 let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]);
134 assert_eq!(layout.shape(), &[2, 3, 2, 3]);
135 assert_eq!(layout.strides(), &[18, 6, 3, 1]);
136 assert_eq!(layout.offset(), 0);
137}
138
139#[test]
140fn test_tile_le() {
141 let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]);
142 assert_eq!(layout.shape(), &[2, 3, 2, 3]);
143 assert_eq!(layout.strides(), &[18, 6, 1, 2]);
144 assert_eq!(layout.offset(), 0);
145}
146
147#[test]
148fn test_empty_tile() {
149 let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_many(&[]);
150 assert_eq!(layout.shape(), &[2, 3, 6]);
151 assert_eq!(layout.strides(), &[18, 6, 1]);
152 assert_eq!(layout.offset(), 0);
153}
154
155#[test]
156fn test_multiple_tiles() {
157 let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_many(&[
158 TileArg {
159 axis: 0,
160 endian: Endian::BigEndian,
161 tiles: &[2, 1],
162 },
163 TileArg {
164 axis: 2,
165 endian: Endian::BigEndian,
166 tiles: &[2, 3],
167 },
168 ]);
169 assert_eq!(layout.shape(), &[2, 1, 3, 2, 3]);
170 assert_eq!(layout.strides(), &[18, 18, 6, 3, 1]);
171 assert_eq!(layout.offset(), 0);
172}