use cubecl::prelude::CubeType;
use crate::components::{
global::read::{FullLoadingStrategy, PartialLoadingStrategy},
stage::StageFamily,
};
#[derive(Copy, Clone, CubeType)]
pub enum StageBuffer {
A,
B,
}
impl StageBuffer {
pub fn to_index(&self) -> u32 {
match self {
StageBuffer::A => 0,
StageBuffer::B => 1,
}
}
}
#[derive(CubeType, Clone)]
pub struct TaskCounter {
#[cube(comptime)]
pub counter: u32,
}
pub type FullLoaderStage<RC, L, E, N> =
<<L as FullLoadingStrategy<RC>>::Stage as StageFamily>::Stage<
E,
N,
<L as FullLoadingStrategy<RC>>::TilingLayout,
>;
pub type PartialLoaderStage<RC, L, E, N> =
<<L as PartialLoadingStrategy<RC>>::Stage as StageFamily>::Stage<
E,
N,
<L as PartialLoadingStrategy<RC>>::TilingLayout,
>;