cubecl_matmul/components/global/write/
base.rs

1use crate::components::{
2    MatrixPrecision,
3    global::{WriteEventListener, WriteTiling, memory::GlobalMemoryConfig},
4    stage::{Stage, StageConfig, StageFamily},
5};
6use cubecl_core as cubecl;
7use cubecl_core::prelude::*;
8use cubecl_std::tensor::{View, layout::Coords2d};
9
10pub trait GlobalWriterFamily: 'static + Send + Sync {
11    type Stage: StageFamily<ReadWrite>;
12    type Writer<IP: MatrixPrecision>: GlobalWriter<
13            IP,
14            Stage = <Self::Stage as StageFamily<ReadWrite>>::Stage<IP::Stage, WriteTiling>,
15        >;
16}
17
18#[cube]
19/// Responsible of writing the accumulated stage matmul output
20/// to global memory
21pub trait GlobalWriter<IP: MatrixPrecision>:
22    WriteEventListener + CubeType + 'static + Send + Sync
23{
24    /// Tile stage that stores the data for this writer
25    type Stage: Stage<IP::Stage, ReadWrite>;
26
27    /// Init this writer from a global tensor and config
28    fn init<S: StageConfig>(
29        tensor: View<Line<IP::Global>, Coords2d, ReadWrite>,
30        #[comptime] config: GlobalMemoryConfig,
31        #[comptime] stage_config: S,
32    ) -> Self;
33
34    /// Stage used by this writer
35    fn stage(this: &Self) -> Self::Stage;
36}