#[derive(Debug)]
pub(super) struct TileGeometry {
dims: Vec<usize>,
tiles: Vec<usize>,
stride: Vec<usize>,
ntiles_axis: Vec<usize>,
}
#[derive(Debug, Default)]
pub(super) struct TileScratch {
origin: Vec<usize>,
pub(super) tdims: Vec<usize>,
pub(super) row_bases: Vec<usize>,
pub(super) row_len: usize,
coord: Vec<usize>,
}
impl TileScratch {
pub(super) fn nelem(&self) -> usize {
self.row_len * self.row_bases.len()
}
}
impl TileGeometry {
pub(super) fn new(dims: &[usize], tiles: &[usize]) -> TileGeometry {
let n = dims.len();
let ntiles_axis = dims
.iter()
.zip(tiles)
.map(|(&d, &t)| d.div_ceil(t))
.collect();
let mut stride = vec![1usize; n];
for i in 1..n {
stride[i] = stride[i - 1] * dims[i - 1];
}
TileGeometry {
dims: dims.to_vec(),
tiles: tiles.to_vec(),
stride,
ntiles_axis,
}
}
pub(super) fn ntiles(&self) -> usize {
self.ntiles_axis.iter().product()
}
pub(super) fn tile_into(&self, t: usize, s: &mut TileScratch) {
let n = self.dims.len();
s.origin.clear();
s.tdims.clear();
let mut rem = t;
for i in 0..n {
let ti = rem % self.ntiles_axis[i];
rem /= self.ntiles_axis[i];
let origin = ti * self.tiles[i];
s.origin.push(origin);
s.tdims.push(self.tiles[i].min(self.dims[i] - origin));
}
s.row_len = if n == 0 { 1 } else { s.tdims[0] };
let nrows: usize = if n <= 1 {
1
} else {
s.tdims[1..].iter().product()
};
let mut flat: usize = (0..n).map(|i| s.origin[i] * self.stride[i]).sum();
s.row_bases.clear();
s.row_bases.reserve(nrows);
s.coord.clear();
s.coord.resize(n, 0);
for _ in 0..nrows {
s.row_bases.push(flat);
for i in 1..n {
s.coord[i] += 1;
flat += self.stride[i];
if s.coord[i] < s.tdims[i] {
break;
}
s.coord[i] = 0;
flat -= s.tdims[i] * self.stride[i];
}
}
}
}