use cubecl;
use cubecl::{ir::DeviceProperties, prelude::*, std::tensor::r#virtual::VirtualTensor};
use crate::definition::{
AttentionElems, AttentionPrecision, AttentionSetupError, CubeMapping, CubeMappingLaunch,
InputRuntimeArg, OutputRuntimeArg,
};
use crate::{
definition::attention_types::*,
launch::AttentionArgs,
{components::global::GlobalAttentionConfig, definition::AttentionVectorSizes},
};
use std::{fmt::Debug, hash::Hash};
pub trait BatchAttentionFamily: Send + Sync + 'static {
type Attention<AP: AttentionPrecision>: BatchAttention<AP, Config = Self::Config>;
type Config: BatchAttentionConfig;
type Blueprint;
#[allow(clippy::too_many_arguments)]
unsafe fn launch_unchecked<AA: AttentionArgs, R: Runtime>(
client: &ComputeClient<R>,
cube_dim: CubeDim,
cube_count: CubeCount,
address_type: AddressType,
input: InputRuntimeArg<AA, R>,
output: OutputRuntimeArg<AA, R>,
cube_mapping: CubeMappingLaunch<R>,
dtypes: &AttentionElems,
vector_sizes: &AttentionVectorSizes,
attention_blueprint: Self::Blueprint,
) -> Result<(), LaunchError>;
fn expand_config(
device_props: &DeviceProperties,
blueprint: Self::Blueprint,
dtypes: &AttentionElems,
) -> Result<Self::Config, AttentionSetupError>;
}
#[cube]
pub trait BatchAttention<AP: AttentionPrecision>: 'static + Send + Sync {
type Config: BatchAttentionConfig;
fn execute(
query: VirtualTensor<QG<AP>, QGS<AP>>,
key: VirtualTensor<KG<AP>, KGS<AP>>,
value: VirtualTensor<VG<AP>, VGS<AP>>,
mask: ComptimeOption<VirtualTensor<MSK<AP>, MSKS<AP>>>,
out: VirtualTensor<OG<AP>, OGS<AP>, ReadWrite>,
cube_mapping: CubeMapping,
#[comptime] config: Self::Config,
);
}
pub trait BatchAttentionConfig:
Copy + Clone + Eq + PartialEq + Hash + Debug + Send + Sync + 'static
{
type GlobalConfig: GlobalAttentionConfig;
fn global_config(&self) -> Self::GlobalConfig;
fn cube_dim(&self) -> CubeDim;
}