cubecl_matmul/components/global/write/
unit.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::{View, layout::Coords2d};
4
5use crate::components::{
6    MatrixPrecision,
7    global::{
8        GlobalWriter, GlobalWriterConfig, GlobalWriterFamily, PartitionedStage,
9        PartitionedStageFamily, WriteEvent, WriteEventExpand, WriteEventListener,
10        read::tiled::{TiledCoords, TiledLayout},
11    },
12    stage::{StageMemoryConfig, StagePartitioner, UnitPartitioner},
13    tile::StridedTile,
14};
15
16#[derive(CubeType)]
17/// Writes tiles from out shared memory to output global memory
18/// using a unit for each tile
19pub struct UnitWriter<IP: MatrixPrecision> {
20    global: View<Line<IP::Global>, TiledCoords, ReadWrite>,
21    stage: PartitionedStage<IP::Stage>,
22
23    #[cube(comptime)]
24    smem_config: StageMemoryConfig,
25}
26
27#[cube]
28impl<IP: MatrixPrecision> UnitWriter<IP> {
29    pub fn new(
30        global: View<Line<IP::Global>, Coords2d, ReadWrite>,
31        #[comptime] config: GlobalWriterConfig,
32    ) -> Self {
33        let smem_config = config.smem_config;
34        let stage = PartitionedStage::new(
35            UnitPartitioner::coordinates(
36                config.role_rule_config,
37                config.plane_dim,
38                smem_config.partitions_per_stage_along_col,
39            ),
40            smem_config,
41        );
42
43        UnitWriter::<IP> {
44            global: global.view_mut(TiledLayout::new(smem_config)),
45            stage,
46            smem_config,
47        }
48    }
49
50    fn write(&mut self, tile: Coords2d) {
51        unit_write(
52            &mut self.global,
53            &self.stage.unit_tile,
54            tile,
55            comptime!(self.smem_config.elements_per_tile()),
56        )
57    }
58}
59
60#[cube]
61pub fn unit_write<ES: Numeric, EG: Numeric>(
62    global: &mut View<Line<EG>, TiledCoords, ReadWrite>,
63    smem_tile: &StridedTile<ES, ReadWrite>,
64    tile_pos: Coords2d,
65    #[comptime] elements_in_tile: u32,
66) {
67    let output_line_size = global.line_size();
68    let out_smem_stage = smem_tile.stage.with_line_size(output_line_size);
69
70    let num_lines = elements_in_tile / output_line_size;
71
72    for i in 0..num_lines {
73        let value = out_smem_stage[smem_tile.stage_offset(i)];
74        global.write_checked((tile_pos, i * output_line_size), Line::cast_from(value));
75    }
76}
77
78#[cube]
79impl<IP: MatrixPrecision> WriteEventListener for UnitWriter<IP> {
80    fn on_event(this: &mut Self, event: super::WriteEvent) {
81        #[allow(clippy::single_match)]
82        match event {
83            WriteEvent::TileStored { tile } => this.write(tile),
84            _ => {}
85        }
86    }
87}
88
89#[cube]
90impl<IP: MatrixPrecision> GlobalWriter<IP> for UnitWriter<IP> {
91    type Stage = PartitionedStage<IP::Stage>;
92
93    fn init(
94        tensor: View<Line<IP::Global>, Coords2d, ReadWrite>,
95        #[comptime] config: GlobalWriterConfig,
96    ) -> Self {
97        Self::new(tensor, config)
98    }
99
100    fn stage(this: &Self) -> Self::Stage {
101        this.stage
102    }
103}
104
105pub struct UnitWriterFamily;
106
107impl GlobalWriterFamily for UnitWriterFamily {
108    type Stage = PartitionedStageFamily;
109    type Writer<IP: MatrixPrecision> = UnitWriter<IP>;
110}