cubecl_matmul/components/global/write/
unit.rs1use crate::components::Ident;
2use crate::components::global::GlobalConfig;
3use crate::components::global::global_memory::TensorWriter;
4use cubecl_core as cubecl;
5use cubecl_core::prelude::*;
6use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
7
8use super::GlobalWriter;
9
10#[derive(CubeType)]
11pub struct UnitWriter<EG: Numeric> {
14 pub tensor_view: TensorWriter<EG>,
15}
16
17#[cube]
18impl<EG: Numeric> UnitWriter<EG> {
19 pub fn new(
20 tensor: VirtualTensor<EG, ReadWrite>,
21 x_offset: u32,
22 y_offset: u32,
23 batch_offset: u32,
24 ) -> Self {
25 UnitWriter::<EG> {
26 tensor_view: TensorWriter::new(tensor, x_offset, y_offset, batch_offset),
27 }
28 }
29}
30
31#[cube]
32impl<EG: Numeric> GlobalWriter<EG> for UnitWriter<EG> {
33 fn write<G: GlobalConfig>(
34 this: &mut Self,
35 out_smem_slice: Slice<Line<EG>>,
36 tile_row: u32,
37 tile_col: u32,
38 #[comptime] config: G,
39 ) {
40 let tile_size = config.tiling_scheme().elements_in_tile_mn();
41 let output_line_size = config.global_line_size(Ident::Out);
42 let out_smem_slice = out_smem_slice.with_line_size(output_line_size);
43
44 let num_lines = tile_size / output_line_size;
45
46 for i in 0..num_lines {
47 let value = out_smem_slice[i];
48 this.tensor_view.write_coalesced::<G>(
49 tile_row,
50 tile_col,
51 i * output_line_size,
52 value,
53 config,
54 );
55 }
56 }
57}