cubecl_matmul/components/global/write/
plane.rs

1use crate::components::{
2    MatrixPrecision, StageIdent,
3    global::{
4        GlobalWriter, GlobalWriterFamily, PartitionedStage, PartitionedStageFamily, WriteEvent,
5        WriteEventExpand, WriteEventListener,
6        memory::GlobalMemoryConfig,
7        read::tiled::{TiledCoords, TiledLayout},
8    },
9    stage::{PlanePartitioner, StageConfig, StageMemoryConfig, StagePartitioner},
10    tile::StridedTile,
11};
12use cubecl_core as cubecl;
13use cubecl_core::prelude::*;
14use cubecl_std::tensor::View;
15use cubecl_std::tensor::layout::Coords2d;
16
17#[derive(CubeType)]
18/// Writes tiles from out shared memory to output global memory
19/// using a plane for each tile
20pub struct PlaneWriter<IP: MatrixPrecision> {
21    global: View<Line<IP::Global>, TiledCoords, ReadWrite>,
22    stage: PartitionedStage<IP::Stage>,
23
24    #[cube(comptime)]
25    plane_dim: u32,
26    #[cube(comptime)]
27    config: GlobalMemoryConfig,
28}
29
30#[cube]
31impl<IP: MatrixPrecision> PlaneWriter<IP> {
32    pub fn new<S: StageConfig>(
33        global: View<Line<IP::Global>, Coords2d, ReadWrite>,
34        #[comptime] global_config: GlobalMemoryConfig,
35        #[comptime] stage_config: S,
36    ) -> Self {
37        let stage_mem_config = comptime![stage_memory_config(stage_config)];
38        let stage = PartitionedStage::new(tile_pos::<S>(stage_config), stage_mem_config);
39
40        PlaneWriter::<IP> {
41            global: global.view_mut(TiledLayout::new(global_config)),
42            stage,
43            plane_dim: stage_config.plane_dim(),
44            config: global_config,
45        }
46    }
47
48    fn write(&mut self, tile_pos: Coords2d) {
49        plane_write::<IP::Stage, IP::Global>(
50            &mut self.global,
51            &self.stage.unit_tile,
52            tile_pos,
53            comptime![self.plane_dim],
54            comptime![self.config],
55        )
56    }
57}
58
59#[cube]
60fn tile_pos<S: StageConfig>(#[comptime] config: S) -> (u32, u32) {
61    PlanePartitioner::coordinates::<S>(config)
62}
63
64fn stage_memory_config<S: StageConfig>(config: S) -> StageMemoryConfig {
65    let planes = config.num_main_flow_planes();
66    let size_n = config.tiling_scheme().stage_partitions_in_stage_n();
67    let base = config.stage_memory_config(StageIdent::Acc);
68    StageMemoryConfig {
69        tiles_in_stage_row: planes / size_n,
70        tiles_in_stage_col: size_n,
71        ..base
72    }
73}
74
75#[cube]
76impl<IP: MatrixPrecision> WriteEventListener for PlaneWriter<IP> {
77    fn on_event(this: &mut Self, event: super::WriteEvent) {
78        #[allow(clippy::single_match)]
79        match event {
80            WriteEvent::TileStored { tile } => {
81                this.write(tile);
82            }
83            _ => {}
84        }
85    }
86}
87
88#[cube]
89impl<IP: MatrixPrecision> GlobalWriter<IP> for PlaneWriter<IP> {
90    type Stage = PartitionedStage<IP::Stage>;
91
92    fn init<S: StageConfig>(
93        tensor: View<Line<IP::Global>, Coords2d, ReadWrite>,
94        #[comptime] config: GlobalMemoryConfig,
95        #[comptime] stage_config: S,
96    ) -> Self {
97        Self::new::<S>(tensor, config, stage_config)
98    }
99
100    fn stage(this: &Self) -> Self::Stage {
101        this.stage
102    }
103}
104
105#[cube]
106pub fn plane_write<ES: Numeric, EG: Numeric>(
107    global: &mut View<Line<EG>, TiledCoords, ReadWrite>,
108    smem_tile: &StridedTile<ES, ReadWrite>,
109    tile_pos: Coords2d,
110    #[comptime] plane_dim: u32,
111    #[comptime] config: GlobalMemoryConfig,
112) {
113    let tile_size = config.elements_in_tile_row * config.elements_in_tile_col;
114    let output_line_size = global.line_size();
115
116    let unit_step = comptime![plane_dim * output_line_size];
117    let num_unit_writes = comptime!(tile_size.div_ceil(unit_step));
118    let balanced_workload = comptime!(tile_size.is_multiple_of(unit_step));
119
120    #[unroll(num_unit_writes == 1)]
121    for i in 0..num_unit_writes {
122        let unit_write = UNIT_POS_X * output_line_size + i * unit_step;
123
124        #[allow(clippy::collapsible_else_if)]
125        if comptime!(balanced_workload) {
126            write_line(global, &smem_tile.slice, unit_write, tile_pos);
127        } else {
128            if unit_write < tile_size {
129                write_line(global, &smem_tile.slice, unit_write, tile_pos);
130            }
131        }
132    }
133}
134
135#[cube]
136fn write_line<ES: Numeric, EG: Numeric>(
137    view: &mut View<Line<EG>, TiledCoords, ReadWrite>,
138    out_smem_slice: &Slice<Line<ES>, ReadWrite>,
139    unit_write: u32,
140    tile: Coords2d,
141) {
142    let output_line_size = view.line_size();
143    let out_smem_line_size = out_smem_slice.line_size();
144
145    let value = if comptime!(output_line_size == out_smem_line_size) {
146        out_smem_slice[unit_write / output_line_size]
147    } else if comptime!(
148        out_smem_line_size < output_line_size
149            && output_line_size.is_multiple_of(out_smem_line_size)
150    ) {
151        let mut value = Line::empty(output_line_size);
152        #[unroll]
153        for i in 0..comptime!(output_line_size / out_smem_line_size) {
154            #[unroll]
155            for j in 0..out_smem_line_size {
156                value[i * out_smem_line_size + j] = out_smem_slice[unit_write + i][j];
157            }
158        }
159        value
160    } else {
161        unimplemented!()
162    };
163
164    view.write_checked((tile, unit_write), Line::cast_from(value));
165}
166
167pub struct PlaneWriterFamily;
168
169impl GlobalWriterFamily for PlaneWriterFamily {
170    type Stage = PartitionedStageFamily;
171    type Writer<IP: MatrixPrecision> = PlaneWriter<IP>;
172}