cubek-matmul 0.2.0

CubeK: Matrix Multiplication Kernels
Documentation
use cubecl::{prelude::*, std::tensor::layout::Coords2d};
use cubek_std::tile::{Tile, TileScope, Value};

use crate::components::stage::{Stage, StageFamily, TilingLayout};

pub struct FilledStageFamily;

impl StageFamily for FilledStageFamily {
    type Stage<ES: Numeric, NS: Size, T: TilingLayout> = FilledStage<ES>;
}

#[derive(CubeType, Clone)]
pub struct FilledStage<ES: Numeric> {
    value: ES,
}

#[cube]
impl<ES: Numeric> FilledStage<ES> {
    pub fn new(value: ES) -> Self {
        FilledStage::<ES> { value }
    }
}

#[cube]
impl<ES: Numeric> Stage<ES, ReadOnly> for FilledStage<ES> {
    fn tile<Sc: TileScope>(this: &Self, _tile: Coords2d) -> Tile<ES, Sc, ReadOnly> {
        Tile::new_Broadcasted(Value::<ES> { val: this.value })
    }
}