ndarray_layout/transform/tile.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
use crate::{ArrayLayout, Endian};
use std::iter::zip;
/// 分块变换参数。
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct TileArg<'a> {
/// 分块的轴。
pub axis: usize,
/// 分块的顺序。
pub endian: Endian,
/// 分块的大小。
pub tiles: &'a [usize],
}
impl<const N: usize> ArrayLayout<N> {
/// 分块变换是将单个维度划分为多个分块的变换。
/// 大端分块使得分块后范围更大的维度在形状中更靠前的位置。
///
/// ```rust
/// # use ndarray_layout::ArrayLayout;
/// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]);
/// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
/// assert_eq!(layout.strides(), &[18, 6, 3, 1]);
/// assert_eq!(layout.offset(), 0);
/// ```
#[inline]
pub fn tile_be(&self, axis: usize, tiles: &[usize]) -> Self {
self.tile_many(&[TileArg {
axis,
endian: Endian::BigEndian,
tiles,
}])
}
/// 分块变换是将单个维度划分为多个分块的变换。
/// 小端分块使得分块后范围更小的维度在形状中更靠前的位置。
///
/// ```rust
/// # use ndarray_layout::ArrayLayout;
/// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]);
/// assert_eq!(layout.shape(), &[2, 3, 2, 3]);
/// assert_eq!(layout.strides(), &[18, 6, 1, 2]);
/// assert_eq!(layout.offset(), 0);
/// ```
#[inline]
pub fn tile_le(&self, axis: usize, tiles: &[usize]) -> Self {
self.tile_many(&[TileArg {
axis,
endian: Endian::LittleEndian,
tiles,
}])
}
/// 一次对多个阶进行分块变换。
pub fn tile_many(&self, mut args: &[TileArg]) -> Self {
let content = self.content();
let shape = content.shape();
let iter = zip(shape, content.strides()).enumerate();
let check = |&TileArg { axis, tiles, .. }| {
shape
.get(axis)
.filter(|&&d| d == tiles.iter().product())
.is_some()
};
let (mut new, mut last_axis) = match args {
[first, ..] => {
assert!(check(first));
(first.tiles.len(), first.axis)
}
[..] => return self.clone(),
};
for arg in &args[1..] {
assert!(check(arg));
assert!(arg.axis > last_axis);
new += arg.tiles.len();
last_axis = arg.axis;
}
let mut ans = Self::with_ndim(self.ndim + new - args.len());
let mut content = ans.content_mut();
content.set_offset(self.offset());
let mut j = 0;
let mut push = |t, s| {
content.set_shape(j, t);
content.set_stride(j, s);
j += 1;
};
for (i, (&d, &s)) in iter {
match *args {
[TileArg {
axis,
endian,
tiles,
}, ref tail @ ..]
if axis == i =>
{
match endian {
Endian::BigEndian => {
// tile : [a, b , c]
// strides: [s * c * b, s * c, s]
let mut s = s * d as isize;
for &t in tiles {
s /= t as isize;
push(t, s);
}
}
Endian::LittleEndian => {
// tile : [a, b , c ]
// strides: [s, s * a, s * a * b]
let mut s = s;
for &t in tiles {
push(t, s);
s *= t as isize;
}
}
}
args = tail;
}
[..] => push(d, s),
}
}
ans
}
}