cubecl_matmul/components/stage/
filled.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::tensor::layout::Coords2d;
4
5use crate::components::{
6    stage::{Stage, StageFamily, TilingLayout},
7    tile::io::Filled,
8};
9
10pub struct FilledStageFamily;
11
12impl StageFamily for FilledStageFamily {
13    type TileKind = Filled;
14
15    type Stage<ES: Numeric, T: TilingLayout> = FilledStage<ES>;
16}
17
18#[derive(CubeType)]
19pub struct FilledStage<ES: Numeric> {
20    value: ES,
21}
22
23#[cube]
24impl<ES: Numeric> FilledStage<ES> {
25    pub fn new(value: ES) -> Self {
26        FilledStage::<ES> { value }
27    }
28}
29
30#[cube]
31impl<ES: Numeric> Stage<ES, ReadOnly> for FilledStage<ES> {
32    type TileKind = Filled;
33
34    fn tile(this: &Self, _tile: Coords2d) -> ES {
35        this.value
36    }
37}