cubek-attention 0.2.0

CubeK: Attention Kernels
Documentation
use cubecl::CubeDim;

use crate::components::{batch::BatchAttentionConfig, global::GlobalAttentionConfig};

#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct SimpleBatchConfig<G: GlobalAttentionConfig> {
    global_config: G,
}

impl<G: GlobalAttentionConfig> BatchAttentionConfig for SimpleBatchConfig<G> {
    type GlobalConfig = G;

    fn global_config(&self) -> Self::GlobalConfig {
        self.global_config
    }

    fn cube_dim(&self) -> CubeDim {
        self.global_config.cube_dim()
    }
}

impl<G: GlobalAttentionConfig> SimpleBatchConfig<G> {
    pub fn new(global_config: G) -> Self {
        Self { global_config }
    }
}