use cubecl::{
prelude::*,
{self, ir::DeviceProperties},
};
use crate::{
components::global::simple::AttentionWriter,
definition::{
AttentionBlueprint, AttentionElems, AttentionPrecision, AttentionSetupError,
attention_types::*,
},
};
use cubecl::std::tensor::r#virtual::VirtualTensor;
use crate::components::{global::simple::QueryReader, stage::StageAttentionConfig};
use std::{fmt::Debug, hash::Hash};
pub trait GlobalAttentionFamily: Send + Sync + 'static {
type Attention<AP: AttentionPrecision>: GlobalAttention<AP, Config = Self::Config>;
type Config: GlobalAttentionConfig;
fn expand_config(
device_props: &DeviceProperties,
blueprint: &AttentionBlueprint,
dtypes: &AttentionElems,
) -> Result<Self::Config, AttentionSetupError>;
}
#[cube]
pub trait GlobalAttention<AP: AttentionPrecision>: 'static + Send + Sync {
type Writer: AttentionWriter<OS<AP>, OSS<AP>, OG<AP>, OGS<AP>>;
type KeyReader: CubeType;
type ValueReader: CubeType;
type MaskReader: CubeType;
type Config: GlobalAttentionConfig;
fn execute(
query_reader: QueryReader<AP>,
key_reader: Self::KeyReader,
value_reader: Self::ValueReader,
mask_reader: Self::MaskReader,
writer: Self::Writer,
seq_q: u32,
seq_kv: u32,
#[comptime] config: Self::Config,
);
fn init_query_reader(
batch_index: u32,
stage_q_offset: u32,
query: VirtualTensor<QG<AP>, QGS<AP>>,
#[comptime] config: Self::Config,
) -> QueryReader<AP>;
fn init_key_reader(
batch_index: u32,
key: VirtualTensor<KG<AP>, KGS<AP>>,
#[comptime] config: Self::Config,
) -> Self::KeyReader;
fn init_value_reader(
batch_index: u32,
value: VirtualTensor<VG<AP>, VGS<AP>>,
#[comptime] config: Self::Config,
) -> Self::ValueReader;
fn init_mask_reader(
batch_index: u32,
stage_q_offset: u32,
mask: ComptimeOption<VirtualTensor<MSK<AP>, MSKS<AP>>>,
seq_kv_shape: u32,
#[comptime] config: Self::Config,
) -> Self::MaskReader;
fn init_writer(
batch_index: u32,
stage_q_offset: u32,
out: VirtualTensor<OG<AP>, OGS<AP>, ReadWrite>,
#[comptime] config: Self::Config,
) -> Self::Writer;
}
pub trait GlobalAttentionConfig:
Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
{
type StageConfig: StageAttentionConfig;
fn stage_config(&self) -> Self::StageConfig;
fn cube_dim(&self) -> CubeDim;
}