use cubecl::{
Runtime,
client::ComputeClient,
frontend::CubePrimitive,
ir::{AddressType, ElemType, FloatKind, StorageType, Type},
};
#[derive(Clone, Debug)]
pub struct AttentionProblem {
pub dims: AttentionDims,
pub masked: bool,
pub global_dtypes: AttentionGlobalTypes,
pub options: AttentionOptions,
pub address_type: AddressType,
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum AttentionIdent {
Query,
Key,
Softmax,
Value,
Mask,
Out,
}
#[derive(Clone, Debug, Default)]
pub struct AttentionOptions {
pub causal: bool,
pub accumulator_precision: AccumulatorPrecision,
}
impl AttentionProblem {
pub fn shape(&self, ident: AttentionIdent) -> [usize; 4] {
self.dims.shape(ident)
}
}
#[derive(Clone, Debug)]
pub struct AttentionDims {
pub batch: usize,
pub num_heads: usize,
pub seq_q: usize,
pub seq_kv: usize,
pub head_dim: usize,
pub val_dim: usize,
}
impl AttentionDims {
pub fn shape(&self, ident: AttentionIdent) -> [usize; 4] {
match ident {
AttentionIdent::Query => [self.batch, self.num_heads, self.seq_q, self.head_dim],
AttentionIdent::Key => [self.batch, self.num_heads, self.seq_kv, self.head_dim],
AttentionIdent::Value => [self.batch, self.num_heads, self.seq_kv, self.val_dim],
AttentionIdent::Mask => [self.batch, self.num_heads, self.seq_q, self.seq_kv],
AttentionIdent::Out => [self.batch, self.num_heads, self.seq_q, self.val_dim],
AttentionIdent::Softmax => unreachable!("Not a materialized tensor"),
}
}
}
#[derive(Clone, Debug)]
pub struct AttentionGlobalTypes {
pub query: StorageType,
pub key: StorageType,
pub value: StorageType,
pub mask: StorageType,
pub out: StorageType,
}
impl AttentionGlobalTypes {
pub fn from_single_float_dtype(
float_dtype: Type,
mask_dtype: StorageType,
) -> AttentionGlobalTypes {
let float_dtype = float_dtype.storage_type();
Self {
query: float_dtype,
key: float_dtype,
value: float_dtype,
mask: mask_dtype,
out: float_dtype,
}
}
pub fn mask_dtype<R: Runtime>(client: &ComputeClient<R>) -> StorageType {
let props = client.properties();
let u8_ty = u8::as_type_native_unchecked().storage_type();
let u32_ty = u32::as_type_native_unchecked().storage_type();
if props.supports_type(u8_ty) {
u8_ty
} else if props.supports_type(u32_ty) {
u32_ty
} else {
panic!("Client does not support u8 or u32 native types");
}
}
}
#[derive(Copy, Clone, Debug)]
pub enum AccumulatorPrecision {
Strict(StorageType),
Loose,
}
impl AccumulatorPrecision {
pub fn default_accumulator_type() -> StorageType {
StorageType::Scalar(ElemType::Float(FloatKind::F32))
}
}
impl Default for AccumulatorPrecision {
fn default() -> Self {
Self::Strict(Self::default_accumulator_type())
}
}