cubecl_matmul/components/global/write/
plane.rs

1use crate::components::{
2    MatrixPrecision,
3    global::{
4        GlobalWriter, GlobalWriterConfig, GlobalWriterFamily, PartitionedStage,
5        PartitionedStageFamily, WriteEvent, WriteEventExpand, WriteEventListener,
6        read::tiled::{TiledCoords, TiledLayout},
7    },
8    stage::{PlanePartitioner, StageMemoryConfig, StagePartitioner},
9    tile::StridedTile,
10};
11use cubecl_core as cubecl;
12use cubecl_core::prelude::*;
13use cubecl_std::tensor::View;
14use cubecl_std::tensor::layout::Coords2d;
15
16#[derive(CubeType)]
17/// Writes tiles from out shared memory to output global memory
18/// using a plane for each tile
19pub struct PlaneWriter<IP: MatrixPrecision> {
20    global: View<Line<IP::Global>, TiledCoords, ReadWrite>,
21    stage: PartitionedStage<IP::Stage>,
22
23    #[cube(comptime)]
24    plane_dim: u32,
25    #[cube(comptime)]
26    smem_config: StageMemoryConfig,
27}
28
29#[cube]
30impl<IP: MatrixPrecision> PlaneWriter<IP> {
31    pub fn new(
32        global: View<Line<IP::Global>, Coords2d, ReadWrite>,
33        #[comptime] config: GlobalWriterConfig,
34    ) -> Self {
35        let stage = PartitionedStage::new(
36            PlanePartitioner::coordinates(
37                config.role_rule_config,
38                config.plane_dim,
39                config.smem_config.partitions_per_stage_along_col,
40            ),
41            config.smem_config,
42        );
43
44        PlaneWriter::<IP> {
45            global: global.view_mut(TiledLayout::new(config.smem_config)),
46            stage,
47            plane_dim: config.plane_dim,
48            smem_config: config.smem_config,
49        }
50    }
51
52    fn write(&mut self, tile_pos: Coords2d) {
53        plane_write::<IP::Stage, IP::Global>(
54            &mut self.global,
55            &self.stage.unit_tile,
56            tile_pos,
57            comptime!(self.plane_dim),
58            comptime!(self.smem_config.elements_per_tile()),
59        )
60    }
61}
62
63#[cube]
64impl<IP: MatrixPrecision> WriteEventListener for PlaneWriter<IP> {
65    fn on_event(this: &mut Self, event: super::WriteEvent) {
66        #[allow(clippy::single_match)]
67        match event {
68            WriteEvent::TileStored { tile } => {
69                this.write(tile);
70            }
71            _ => {}
72        }
73    }
74}
75
76#[cube]
77impl<IP: MatrixPrecision> GlobalWriter<IP> for PlaneWriter<IP> {
78    type Stage = PartitionedStage<IP::Stage>;
79
80    fn init(
81        tensor: View<Line<IP::Global>, Coords2d, ReadWrite>,
82        #[comptime] config: GlobalWriterConfig,
83    ) -> Self {
84        Self::new(tensor, config)
85    }
86
87    fn stage(this: &Self) -> Self::Stage {
88        this.stage
89    }
90}
91
92#[cube]
93pub fn plane_write<ES: Numeric, EG: Numeric>(
94    global: &mut View<Line<EG>, TiledCoords, ReadWrite>,
95    smem_tile: &StridedTile<ES, ReadWrite>,
96    tile_pos: Coords2d,
97    #[comptime] plane_dim: u32,
98    #[comptime] elements_in_tile: u32,
99) {
100    let output_line_size = global.line_size();
101
102    let unit_step = comptime![plane_dim * output_line_size];
103    let num_unit_writes = comptime!(elements_in_tile.div_ceil(unit_step));
104    let balanced_workload = comptime!(elements_in_tile.is_multiple_of(unit_step));
105
106    #[unroll(num_unit_writes == 1)]
107    for i in 0..num_unit_writes {
108        let unit_write = UNIT_POS_X * output_line_size + i * unit_step;
109
110        #[allow(clippy::collapsible_else_if)]
111        if comptime!(balanced_workload) {
112            write_line(global, smem_tile, unit_write, tile_pos);
113        } else {
114            if unit_write < elements_in_tile {
115                write_line(global, smem_tile, unit_write, tile_pos);
116            }
117        }
118    }
119}
120
121#[cube]
122fn write_line<ES: Numeric, EG: Numeric>(
123    view: &mut View<Line<EG>, TiledCoords, ReadWrite>,
124    out_smem_tile: &StridedTile<ES, ReadWrite>,
125    unit_write: u32,
126    tile: Coords2d,
127) {
128    let output_line_size = view.line_size();
129    let out_smem_line_size = out_smem_tile.stage.line_size();
130
131    let value = if comptime!(output_line_size == out_smem_line_size) {
132        out_smem_tile.stage[out_smem_tile.stage_offset(unit_write / output_line_size)]
133    } else if comptime!(
134        out_smem_line_size < output_line_size
135            && output_line_size.is_multiple_of(out_smem_line_size)
136    ) {
137        let mut value = Line::empty(output_line_size);
138        #[unroll]
139        for i in 0..comptime!(output_line_size / out_smem_line_size) {
140            let offs = out_smem_tile.stage_offset(unit_write + i);
141            #[unroll]
142            for j in 0..out_smem_line_size {
143                value[i * out_smem_line_size + j] = out_smem_tile.stage[offs][j];
144            }
145        }
146        value
147    } else {
148        unimplemented!()
149    };
150
151    view.write_checked((tile, unit_write), Line::cast_from(value));
152}
153
154pub struct PlaneWriterFamily;
155
156impl GlobalWriterFamily for PlaneWriterFamily {
157    type Stage = PartitionedStageFamily;
158    type Writer<IP: MatrixPrecision> = PlaneWriter<IP>;
159}