cubecl_matmul/components/global/write/
unit.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::{View, layout::Coords2d};
4
5use crate::components::{
6    MatrixPrecision, StageIdent,
7    global::{
8        GlobalWriter, GlobalWriterFamily, PartitionedStage, PartitionedStageFamily, WriteEvent,
9        WriteEventExpand, WriteEventListener,
10        memory::GlobalMemoryConfig,
11        read::tiled::{TiledCoords, TiledLayout},
12    },
13    stage::{StageConfig, StageMemoryConfig, StagePartitioner, UnitPartitioner},
14};
15
16#[derive(CubeType)]
17/// Writes tiles from out shared memory to output global memory
18/// using a unit for each tile
19pub struct UnitWriter<IP: MatrixPrecision> {
20    global: View<Line<IP::Global>, TiledCoords, ReadWrite>,
21    stage: PartitionedStage<IP::Stage>,
22
23    #[cube(comptime)]
24    config: GlobalMemoryConfig,
25}
26
27#[cube]
28impl<IP: MatrixPrecision> UnitWriter<IP> {
29    pub fn new<S: StageConfig>(
30        global: View<Line<IP::Global>, Coords2d, ReadWrite>,
31        #[comptime] global_config: GlobalMemoryConfig,
32        #[comptime] stage_config: S,
33    ) -> Self {
34        let stage_mem_config = comptime![stage_memory_config(stage_config)];
35        let stage = PartitionedStage::new(tile_pos::<S>(stage_config), stage_mem_config);
36
37        UnitWriter::<IP> {
38            global: global.view_mut(TiledLayout::new(global_config)),
39            stage,
40            config: global_config,
41        }
42    }
43
44    fn write(&mut self, tile: Coords2d) {
45        let smem_tile = &self.stage.unit_tile;
46        let config = comptime![self.config];
47
48        let tile_size = config.elements_in_tile_row * config.elements_in_tile_col;
49        let output_line_size = self.global.line_size();
50        let out_smem_slice = smem_tile.slice.with_line_size(output_line_size);
51
52        let num_lines = tile_size / output_line_size;
53
54        for i in 0..num_lines {
55            let value = out_smem_slice[i];
56            self.global
57                .write_checked((tile, i * output_line_size), Line::cast_from(value));
58        }
59    }
60}
61
62#[cube]
63fn tile_pos<S: StageConfig>(#[comptime] config: S) -> (u32, u32) {
64    UnitPartitioner::coordinates::<S>(config)
65}
66
67fn stage_memory_config<S: StageConfig>(config: S) -> StageMemoryConfig {
68    let units = config.num_main_flow_planes() * config.plane_dim();
69    let size_n = config.tiling_scheme().stage_partitions_in_stage_n();
70    let base = config.stage_memory_config(StageIdent::Acc);
71    StageMemoryConfig {
72        tiles_in_stage_row: units / size_n,
73        tiles_in_stage_col: size_n,
74        ..base
75    }
76}
77
78#[cube]
79impl<IP: MatrixPrecision> WriteEventListener for UnitWriter<IP> {
80    fn on_event(this: &mut Self, event: super::WriteEvent) {
81        #[allow(clippy::single_match)]
82        match event {
83            WriteEvent::TileStored { tile } => this.write(tile),
84            _ => {}
85        }
86    }
87}
88
89#[cube]
90impl<IP: MatrixPrecision> GlobalWriter<IP> for UnitWriter<IP> {
91    type Stage = PartitionedStage<IP::Stage>;
92
93    fn init<S: StageConfig>(
94        tensor: View<Line<IP::Global>, Coords2d, ReadWrite>,
95        #[comptime] config: GlobalMemoryConfig,
96        #[comptime] stage_config: S,
97    ) -> Self {
98        Self::new::<S>(tensor, config, stage_config)
99    }
100
101    fn stage(this: &Self) -> Self::Stage {
102        this.stage
103    }
104}
105
106pub struct UnitWriterFamily;
107
108impl GlobalWriterFamily for UnitWriterFamily {
109    type Stage = PartitionedStageFamily;
110    type Writer<IP: MatrixPrecision> = UnitWriter<IP>;
111}