cubecl_matmul/components/global/write/
plane.rs1use crate::components::Ident;
2use crate::components::global::GlobalConfig;
3use crate::components::global::global_memory::TensorWriter;
4use cubecl_core as cubecl;
5use cubecl_core::prelude::*;
6use cubecl_std::div_ceil;
7use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
8
9use super::GlobalWriter;
10
11#[derive(CubeType)]
12pub struct PlaneWriter<EG: Numeric> {
15 pub tensor_writer: TensorWriter<EG>,
16}
17
18#[cube]
19impl<EG: Numeric> PlaneWriter<EG> {
20 pub fn new(
21 tensor: VirtualTensor<EG, ReadWrite>,
22 x_offset: u32,
23 y_offset: u32,
24 batch_offset: u32,
25 ) -> Self {
26 PlaneWriter::<EG> {
27 tensor_writer: TensorWriter::new(tensor, x_offset, y_offset, batch_offset),
28 }
29 }
30}
31
32#[cube]
33impl<EG: Numeric> GlobalWriter<EG> for PlaneWriter<EG> {
34 fn write<G: GlobalConfig>(
35 this: &mut Self,
36 out_smem_slice: Slice<Line<EG>>,
37 tile_row: u32,
38 tile_col: u32,
39 #[comptime] config: G,
40 ) {
41 let tile_size = config.tiling_scheme().elements_in_tile_mn();
42 let output_line_size = config.global_line_size(Ident::Out);
43 let out_smem_slice = out_smem_slice.with_line_size(output_line_size);
44
45 let unit_step = config.plane_dim() * output_line_size;
46 let num_unit_writes = comptime!(div_ceil(tile_size, unit_step));
47 let balanced_workload = comptime!(tile_size % unit_step == 0);
48
49 #[unroll(num_unit_writes == 1)]
50 for i in 0..num_unit_writes {
51 let unit_write = UNIT_POS_X * output_line_size + i * unit_step;
52
53 #[allow(clippy::collapsible_else_if)]
54 if comptime!(balanced_workload) {
55 let value = out_smem_slice[unit_write / output_line_size];
56 this.tensor_writer
57 .write_coalesced::<G>(tile_row, tile_col, unit_write, value, config);
58 } else {
59 if unit_write < tile_size {
60 let value = out_smem_slice[unit_write / output_line_size];
61 this.tensor_writer
62 .write_coalesced::<G>(tile_row, tile_col, unit_write, value, config);
63 }
64 }
65 }
66 }
67}