cubek-attention 0.2.0

CubeK: Attention Kernels
Documentation
use std::fmt::Debug;

use cubecl::{
    {CubeDim, Runtime},
    {client::ComputeClient, ir::AddressType},
};

use crate::components::tile::TileAttentionKind;
use crate::components::{
    batch::BatchAttentionFamily, global::GlobalAttentionFamily, stage::StageAttentionFamily,
};
use crate::definition::{
    AttentionElems, AttentionProblem, AttentionSetupError, AttentionVectorSizes, CubeCountPlan,
};
use crate::launch::BlueprintStrategy;

pub trait Routine: Debug + Clone {
    /// Tile-level strategy this routine selects.
    const TILE_KIND: TileAttentionKind;

    type StageAttention: StageAttentionFamily;
    type GlobalAttention: GlobalAttentionFamily;
    type BatchAttention: BatchAttentionFamily<Blueprint = Self::Blueprint>;

    type Strategy;
    type Blueprint: Clone;

    fn prepare<R: Runtime>(
        problem: &AttentionProblem,
        device_settings: &DeviceSettings<R>,
        strategy: BlueprintStrategy<Self>,
    ) -> Result<LaunchInfo<Self::Blueprint>, AttentionSetupError>;
}

pub struct LaunchInfo<B> {
    pub blueprint: B,
    pub dtypes: AttentionElems,
    pub cube_dim: CubeDim,
    pub cube_count_plan: CubeCountPlan,
    pub address_type: AddressType,
}

pub struct DeviceSettings<R: Runtime> {
    pub plane_dim: u32,
    pub max_cube_count: (u32, u32, u32),
    pub vector_sizes: AttentionVectorSizes,
    pub client: ComputeClient<R>,
}

impl<R: Runtime> core::fmt::Debug for DeviceSettings<R> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("DeviceSettings")
            .field("plane_dim", &self.plane_dim)
            .field("max_cube_count", &self.max_cube_count)
            .field("vector_sizes", &self.vector_sizes)
            .finish()
    }
}

impl<R: Runtime> DeviceSettings<R> {
    pub fn new(client: &ComputeClient<R>, problem: &AttentionProblem) -> Self {
        DeviceSettings {
            plane_dim: client.properties().hardware.plane_size_max,
            max_cube_count: client.properties().hardware.max_cube_count,
            vector_sizes: AttentionVectorSizes::new_max_for_problem(client, problem),
            client: client.clone(),
        }
    }
}