use cubecl::{prelude::*, std::tensor::layout::Coords2d};
use cubek_std::tile::{Tile, TileScope, Value};
use crate::components::stage::{Stage, StageFamily, TilingLayout};
pub struct FilledStageFamily;
impl StageFamily for FilledStageFamily {
type Stage<ES: Numeric, NS: Size, T: TilingLayout> = FilledStage<ES>;
}
#[derive(CubeType, Clone)]
pub struct FilledStage<ES: Numeric> {
value: ES,
}
#[cube]
impl<ES: Numeric> FilledStage<ES> {
pub fn new(value: ES) -> Self {
FilledStage::<ES> { value }
}
}
#[cube]
impl<ES: Numeric> Stage<ES, ReadOnly> for FilledStage<ES> {
fn tile<Sc: TileScope>(this: &Self, _tile: Coords2d) -> Tile<ES, Sc, ReadOnly> {
Tile::new_Broadcasted(Value::<ES> { val: this.value })
}
}