cubek-attention 0.2.0

CubeK: Attention Kernels
Documentation
use cubecl::{
    prelude::*,
    std::tensor::{View, layout::Coords2d},
    {self as cubecl},
};
use cubek_matmul::components::global::{
    GlobalWriterConfig, PartitionedStage, WriteEvent, WriteEventExpand, WriteEventListener,
    plane_write,
    read::tiled::{TiledCoords, TiledLayout},
};
use cubek_std::StageIdent;

use crate::components::{
    global::simple::{AttentionWriter, AttentionWriterExpand},
    stage::{AttentionPartitioner, StageAttentionConfig, plane::PlanePartitioner},
};

#[derive(CubeType)]
pub struct PlaneAttentionWriter<ES: Numeric, ESS: Size, EO: Numeric, EOS: Size> {
    global: View<Vector<EO, EOS>, TiledCoords, ReadWrite>,
    stage: PartitionedStage<ES, ESS>,

    #[cube(comptime)]
    config: GlobalWriterConfig,
}

#[cube]
impl<ES: Numeric, ESS: Size, EG: Numeric, EGS: Size> PlaneAttentionWriter<ES, ESS, EG, EGS> {}

#[cube]
impl<ES: Numeric, ESS: Size, EG: Numeric, EGS: Size> WriteEventListener
    for PlaneAttentionWriter<ES, ESS, EG, EGS>
{
    fn on_event(this: &mut Self, event: WriteEvent) {
        #[allow(clippy::single_match)]
        match event {
            WriteEvent::TileStored { tile } => plane_write::<ES, ESS, EG, EGS>(
                &mut this.global,
                &this.stage.unit_tile,
                tile,
                this.config.plane_dim,
                this.config.comptime().smem_config.elements_per_tile(),
            ),
            _ => {}
        }
    }
}

#[cube]
impl<ES: Numeric, ESS: Size, EG: Numeric, EGS: Size> AttentionWriter<ES, ESS, EG, EGS>
    for PlaneAttentionWriter<ES, ESS, EG, EGS>
{
    fn init<S: StageAttentionConfig>(
        global: View<Vector<EG, EGS>, Coords2d, ReadWrite>,
        #[comptime] config: GlobalWriterConfig,
    ) -> Self {
        let stage =
            PartitionedStage::new((PlanePartitioner::seq_q_index(), 0u32), config.smem_config);

        PlaneAttentionWriter::<ES, ESS, EG, EGS> {
            global: global.view_mut(TiledLayout::new(StageIdent::Out, config.smem_config)),
            stage,
            config,
        }
    }

    fn stage(&mut self) -> PartitionedStage<ES, ESS> {
        self.stage
    }
}