use cubecl::{
prelude::*,
std::tensor::{View, layout::Coords2d},
};
#[derive(Clone, CubeType)]
pub struct GlobalIterator<EI: CubePrimitive> {
global_view: View<EI, Coords2d>,
offset: RuntimeCell<u32>,
step: u32,
view_size: Coords2d,
#[cube(comptime)]
view_direction: ViewDirection,
#[cube(comptime)]
checked: bool,
}
unsafe impl<EG: CubePrimitive> Sync for GlobalIterator<EG> {}
unsafe impl<EG: CubePrimitive> Send for GlobalIterator<EG> {}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Default)]
pub enum ViewDirection {
Row,
Col,
#[default]
None,
}
#[cube]
impl<EG: CubePrimitive> GlobalIterator<EG> {
pub fn new(
global_view: View<EG, Coords2d>,
step: u32,
#[comptime] view_direction: ViewDirection,
#[comptime] checked: bool,
) -> Self {
let (size_row, size_col) = global_view.shape();
let view_size = match view_direction {
ViewDirection::Row => (step, size_col),
ViewDirection::Col => (size_row, step),
ViewDirection::None => (size_row, size_col),
};
GlobalIterator::<EG> {
global_view,
offset: RuntimeCell::new(0),
step,
view_size,
view_direction,
checked,
}
}
pub fn advance(&self) {
self.offset.store(self.offset.read() + self.step);
}
pub fn view(&self) -> View<EG, Coords2d> {
let offset = match self.view_direction.comptime() {
ViewDirection::Row => (self.offset.read(), 0u32),
ViewDirection::Col => (0u32, self.offset.read()),
ViewDirection::None => (0u32, 0u32).runtime(),
};
if self.checked.comptime() {
self.global_view.slice(offset, self.view_size)
} else {
self.global_view.slice_unchecked(offset, self.view_size)
}
}
pub fn vector_size(&self) -> comptime_type!(VectorSize) {
self.global_view.vector_size()
}
pub fn offset(&self) -> u32 {
self.offset.read()
}
}