use crate::{BoundChecks, ReduceInstruction, ReducePrecision};
use cubecl::{
prelude::*,
std::tensor::{View, layout::Coords1d},
};
#[derive(CubeType)]
#[allow(unused)]
pub enum ReaderBoundChecks<P: ReducePrecision> {
NotRequired,
Required(RequiredReaderBoundChecks<P>),
}
#[derive(CubeType)]
pub struct RequiredReaderBoundChecks<P: ReducePrecision> {
#[cube(comptime)]
bound_checks: BoundChecks,
pos_max: usize,
null_input: Vector<P::EI, P::SI>,
}
#[cube]
impl<P: ReducePrecision> ReaderBoundChecks<P> {
pub fn new<I: ReduceInstruction<P>>(
inst: &I,
pos_max: usize,
idle: ComptimeOption<bool>,
#[comptime] bound_checks: BoundChecks,
) -> ReaderBoundChecks<P> {
#[comptime]
let pos_max = match idle {
ComptimeOption::Some(idle) => pos_max * usize::cast_from(!idle),
ComptimeOption::None => pos_max,
};
let bound_checks = comptime!(match idle.is_some() {
true => BoundChecks::Mask,
false => bound_checks,
});
match bound_checks {
BoundChecks::None => ReaderBoundChecks::new_NotRequired(),
BoundChecks::Mask | BoundChecks::Branch => {
ReaderBoundChecks::new_Required(RequiredReaderBoundChecks::<P> {
bound_checks,
pos_max,
null_input: I::null_input(inst),
})
}
}
}
pub fn read(
&self,
pos: usize,
offset: usize,
view: &View<Vector<P::EI, P::SI>, Coords1d>,
) -> Vector<P::EI, P::SI> {
#[comptime]
match self {
ReaderBoundChecks::NotRequired => view[offset],
ReaderBoundChecks::Required(checks) => match checks.bound_checks.comptime() {
BoundChecks::None => view[offset],
BoundChecks::Mask => {
let mask = pos < checks.pos_max;
let index = offset * usize::cast_from(mask);
select(mask, view[index], checks.null_input)
}
BoundChecks::Branch => {
if pos < checks.pos_max {
view[offset]
} else {
checks.null_input
}
}
},
}
}
}