use super::tensor_view::TensorView;
use std::marker::PhantomData;
#[derive(Debug)]
pub struct PartitionView<T> {
tensor: TensorView<T>,
tile_shape: [usize; 4],
_marker: PhantomData<T>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TileInfo {
pub tile_idx: [usize; 4],
pub start: [usize; 4],
pub size: [usize; 4],
pub is_edge: bool,
}
impl<T> PartitionView<T> {
pub fn new(tensor: TensorView<T>, tile_shape: [usize; 4]) -> Self {
assert!(tile_shape.iter().all(|&d| d > 0), "Tile dimensions must be non-zero");
Self { tensor, tile_shape, _marker: PhantomData }
}
pub fn new_power_of_two(tensor: TensorView<T>, tile_log2: [usize; 4]) -> Self {
let tile_shape =
[1 << tile_log2[0], 1 << tile_log2[1], 1 << tile_log2[2], 1 << tile_log2[3]];
Self::new(tensor, tile_shape)
}
pub fn new_2d(tensor: TensorView<T>, tile_rows: usize, tile_cols: usize) -> Self {
Self::new(tensor, [tile_rows, tile_cols, 1, 1])
}
pub fn tensor(&self) -> &TensorView<T> {
&self.tensor
}
pub fn tile_shape(&self) -> &[usize; 4] {
&self.tile_shape
}
pub fn tile_count(&self) -> [usize; 4] {
let tensor_shape = self.tensor.shape();
[
tensor_shape[0].div_ceil(self.tile_shape[0]),
tensor_shape[1].div_ceil(self.tile_shape[1]),
tensor_shape[2].div_ceil(self.tile_shape[2]),
tensor_shape[3].div_ceil(self.tile_shape[3]),
]
}
pub fn total_tiles(&self) -> usize {
let count = self.tile_count();
count.iter().product()
}
pub fn get_tile(&self, tile_idx: [usize; 4]) -> Option<TileInfo> {
let tile_count = self.tile_count();
for i in 0..4 {
if tile_idx[i] >= tile_count[i] {
return None;
}
}
let tensor_shape = self.tensor.shape();
let mut start = [0usize; 4];
let mut size = [0usize; 4];
let mut is_edge = false;
for i in 0..4 {
start[i] = tile_idx[i] * self.tile_shape[i];
let remaining = tensor_shape[i] - start[i];
size[i] = remaining.min(self.tile_shape[i]);
if size[i] < self.tile_shape[i] {
is_edge = true;
}
}
Some(TileInfo { tile_idx, start, size, is_edge })
}
pub fn get_tile_view(&self, tile_idx: [usize; 4]) -> Option<TensorView<T>> {
let info = self.get_tile(tile_idx)?;
let mut view = self.tensor.clone();
for i in 0..4 {
if self.tensor.shape()[i] > 1 {
view = view.slice_dim(i, info.start[i]..info.start[i] + info.size[i]);
}
}
Some(view)
}
pub fn iter_tiles(&self) -> TileIterator<'_, T> {
TileIterator { partition: self, current: [0, 0, 0, 0], done: false }
}
pub fn is_power_of_two_tiles(&self) -> bool {
self.tile_shape.iter().all(|&d| d.is_power_of_two())
}
pub fn elements_per_tile(&self) -> usize {
self.tile_shape.iter().product()
}
pub fn recommended_workgroup_size(&self) -> (u32, u32, u32) {
const MAX_WORKGROUP_SIZE: usize = 256;
const MAX_DIM: usize = 16;
let tile_2d = [self.tile_shape[0], self.tile_shape[1]];
if tile_2d[0] > 1 && tile_2d[1] > 1 {
let x = tile_2d[1].min(MAX_DIM) as u32;
let y = tile_2d[0].min(MAX_DIM) as u32;
let z = 1;
(x, y, z)
} else {
let size = self.elements_per_tile().min(MAX_WORKGROUP_SIZE);
(size as u32, 1, 1)
}
}
}
impl<T> Clone for PartitionView<T> {
fn clone(&self) -> Self {
Self { tensor: self.tensor.clone(), tile_shape: self.tile_shape, _marker: PhantomData }
}
}
pub struct TileIterator<'a, T> {
partition: &'a PartitionView<T>,
current: [usize; 4],
done: bool,
}
impl<T> Iterator for TileIterator<'_, T> {
type Item = TileInfo;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
let tile = self.partition.get_tile(self.current)?;
let tile_count = self.partition.tile_count();
self.current[3] += 1;
for i in (0..4).rev() {
if self.current[i] >= tile_count[i] {
self.current[i] = 0;
if i > 0 {
self.current[i - 1] += 1;
} else {
self.done = true;
}
} else {
break;
}
}
Some(tile)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let total = self.partition.total_tiles();
(total, Some(total))
}
}
impl<T> ExactSizeIterator for TileIterator<'_, T> {}
#[cfg(test)]
mod tests;