cubek-attention 0.2.0

CubeK: Attention Kernels
Documentation
use std::marker::PhantomData;

use cubecl::{
    ir::{AddressType, DeviceProperties},
    server::LaunchError,
};

use crate::{
    components::{
        batch::{
            BatchAttentionFamily,
            entry_point::attention,
            simple::{SimpleBatchAttention, config::SimpleBatchConfig},
        },
        global::GlobalAttentionFamily,
    },
    definition::{
        AttentionBlueprint, AttentionElems, AttentionPrecision, AttentionSetupError,
        AttentionVectorSizes, CubeMappingLaunch, InputRuntimeArg, OutputRuntimeArg,
        launch_types::*,
    },
    launch::AttentionArgs,
};

pub struct SimpleBatchAttentionFamily<GA: GlobalAttentionFamily> {
    _phantom: PhantomData<GA>,
}

impl<GA: GlobalAttentionFamily> BatchAttentionFamily for SimpleBatchAttentionFamily<GA> {
    type Attention<AP: AttentionPrecision> = SimpleBatchAttention<AP, GA::Attention<AP>>;
    type Config = SimpleBatchConfig<GA::Config>;
    type Blueprint = AttentionBlueprint;

    unsafe fn launch_unchecked<'a, AA: AttentionArgs, R: cubecl::Runtime>(
        client: &cubecl::prelude::ComputeClient<R>,
        cube_dim: cubecl::CubeDim,
        cube_count: cubecl::CubeCount,
        address_type: AddressType,
        input: InputRuntimeArg<AA, R>,
        output: OutputRuntimeArg<AA, R>,
        cube_mapping: CubeMappingLaunch<R>,
        dtypes: &AttentionElems,
        vector_sizes: &AttentionVectorSizes,
        blueprint: Self::Blueprint,
    ) -> Result<(), LaunchError> {
        unsafe {
            attention::launch_unchecked::<AA, QG, QGS, KG, KGS, VG, VGS, MSK, MSKS, OG, OGS, Self, R>(
                client,
                cube_count,
                cube_dim,
                address_type,
                input,
                output,
                cube_mapping,
                blueprint,
                dtypes.clone(),
                dtypes.into(),
                vector_sizes.into(),
            )
        };

        Ok(())
    }

    fn expand_config(
        device_props: &DeviceProperties,
        blueprint: Self::Blueprint,
        dtypes: &AttentionElems,
    ) -> Result<Self::Config, AttentionSetupError> {
        let global_config = GA::expand_config(device_props, &blueprint, dtypes)?;

        Ok(SimpleBatchConfig::new(global_config))
    }
}