cubecl_matmul/components/global/write/
unit.rs

1use crate::components::Ident;
2use crate::components::global::GlobalConfig;
3use crate::components::global::global_memory::TensorWriter;
4use cubecl_core as cubecl;
5use cubecl_core::prelude::*;
6use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
7
8use super::GlobalWriter;
9
10#[derive(CubeType)]
11/// Writes tiles from out shared memory to output global memory
12/// using a unit for each tile
13pub struct UnitWriter<EG: Numeric> {
14    pub tensor_view: TensorWriter<EG>,
15}
16
17#[cube]
18impl<EG: Numeric> UnitWriter<EG> {
19    pub fn new(
20        tensor: VirtualTensor<EG, ReadWrite>,
21        x_offset: u32,
22        y_offset: u32,
23        batch_offset: u32,
24    ) -> Self {
25        UnitWriter::<EG> {
26            tensor_view: TensorWriter::new(tensor, x_offset, y_offset, batch_offset),
27        }
28    }
29}
30
31#[cube]
32impl<EG: Numeric> GlobalWriter<EG> for UnitWriter<EG> {
33    fn write<G: GlobalConfig>(
34        this: &mut Self,
35        out_smem_slice: Slice<Line<EG>>,
36        tile_row: u32,
37        tile_col: u32,
38        #[comptime] config: G,
39    ) {
40        let tile_size = config.tiling_scheme().elements_in_tile_mn();
41        let output_line_size = config.global_line_size(Ident::Out);
42        let out_smem_slice = out_smem_slice.with_line_size(output_line_size);
43
44        let num_lines = tile_size / output_line_size;
45
46        for i in 0..num_lines {
47            let value = out_smem_slice[i];
48            this.tensor_view.write_coalesced::<G>(
49                tile_row,
50                tile_col,
51                i * output_line_size,
52                value,
53                config,
54            );
55        }
56    }
57}