use crate::{
components::{
global::{
PlaneFlowPartitionRule, WriteEventListener, WriteTiling, memory::GlobalMemoryConfig,
},
stage::{Stage, StageFamily},
},
definition::MatrixTypes,
};
use cubecl::{
prelude::*,
std::tensor::{View, layout::Coords2d},
};
use cubek_std::stage::StageMemoryConfig;
pub trait GlobalWriterFamily: 'static + Send + Sync {
type Stage: StageFamily<ReadWrite>;
type Writer<IP: MatrixTypes>: GlobalWriter<
IP,
Stage = <Self::Stage as StageFamily<ReadWrite>>::Stage<
IP::Stage,
IP::StageSize,
WriteTiling,
>,
>;
}
#[cube]
pub trait GlobalWriter<IP: MatrixTypes>:
WriteEventListener + CubeType + 'static + Send + Sync
{
type Stage: Stage<IP::Stage, ReadWrite>;
fn init(
tensor: View<Vector<IP::Global, IP::GlobalSize>, Coords2d, ReadWrite>,
#[comptime] config: GlobalWriterConfig,
) -> Self;
fn stage(this: &Self) -> Self::Stage;
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct GlobalWriterConfig {
pub gmem_config: GlobalMemoryConfig,
pub smem_config: StageMemoryConfig,
pub plane_flow_partition_rule: PlaneFlowPartitionRule,
pub plane_dim: u32,
}