cubecl_matmul/components/global/write/
plane.rs

1use crate::components::Ident;
2use crate::components::global::GlobalConfig;
3use crate::components::global::global_memory::TensorWriter;
4use cubecl_core as cubecl;
5use cubecl_core::prelude::*;
6use cubecl_std::div_ceil;
7use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
8
9use super::GlobalWriter;
10
11#[derive(CubeType)]
12/// Writes tiles from out shared memory to output global memory
13/// using a plane for each tile
14pub struct PlaneWriter<EG: Numeric> {
15    pub tensor_writer: TensorWriter<EG>,
16}
17
18#[cube]
19impl<EG: Numeric> PlaneWriter<EG> {
20    pub fn new(
21        tensor: VirtualTensor<EG, ReadWrite>,
22        x_offset: u32,
23        y_offset: u32,
24        batch_offset: u32,
25    ) -> Self {
26        PlaneWriter::<EG> {
27            tensor_writer: TensorWriter::new(tensor, x_offset, y_offset, batch_offset),
28        }
29    }
30}
31
32#[cube]
33impl<EG: Numeric> GlobalWriter<EG> for PlaneWriter<EG> {
34    fn write<G: GlobalConfig>(
35        this: &mut Self,
36        out_smem_slice: Slice<Line<EG>>,
37        tile_row: u32,
38        tile_col: u32,
39        #[comptime] config: G,
40    ) {
41        let tile_size = config.tiling_scheme().elements_in_tile_mn();
42        let output_line_size = config.global_line_size(Ident::Out);
43        let out_smem_slice = out_smem_slice.with_line_size(output_line_size);
44
45        let unit_step = config.plane_dim() * output_line_size;
46        let num_unit_writes = comptime!(div_ceil(tile_size, unit_step));
47        let balanced_workload = comptime!(tile_size % unit_step == 0);
48
49        #[unroll(num_unit_writes == 1)]
50        for i in 0..num_unit_writes {
51            let unit_write = UNIT_POS_X * output_line_size + i * unit_step;
52
53            #[allow(clippy::collapsible_else_if)]
54            if comptime!(balanced_workload) {
55                let value = out_smem_slice[unit_write / output_line_size];
56                this.tensor_writer
57                    .write_coalesced::<G>(tile_row, tile_col, unit_write, value, config);
58            } else {
59                if unit_write < tile_size {
60                    let value = out_smem_slice[unit_write / output_line_size];
61                    this.tensor_writer
62                        .write_coalesced::<G>(tile_row, tile_col, unit_write, value, config);
63                }
64            }
65        }
66    }
67}