use cubek_std::TileSize;
use cubek_std::cube_count::{Count3d, CubeCountPlan};
use crate::definition::{AttentionDims, AttentionVectorSizes, HypercubeBlueprint};
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct AttentionBlueprint {
pub hypercube_blueprint: HypercubeBlueprint,
pub tiling_scheme: AttentionTilingScheme,
pub plane_dim: u32,
pub two_rows_in_array_tile: bool,
pub vector_sizes: AttentionVectorSizes,
pub masked: bool,
pub causal: bool,
pub check_bounds: AttentionCheckBounds,
}
impl AttentionBlueprint {
pub fn cube_count_plan(
&self,
dims: &AttentionDims,
max_cube_count: &(u32, u32, u32),
) -> CubeCountPlan {
let seq_q_cubes = (dims.seq_q as u32).div_ceil(
self.tiling_scheme.tile_size.seq_q
* self.tiling_scheme.partition_size.seq_q
* self.tiling_scheme.stage_size.seq_q,
);
let batch_heads = (dims.batch * dims.num_heads) as u32;
let target_count = Count3d {
x: seq_q_cubes,
y: batch_heads,
z: 1,
};
CubeCountPlan::from_blueprint(&self.hypercube_blueprint, target_count, max_cube_count)
}
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct AttentionTilingScheme {
pub tile_size: AttentionTileSize,
pub partition_size: AttentionPartitionSize,
pub stage_size: AttentionStageSize,
}
impl AttentionTilingScheme {
pub fn elements_in_tile_seq_q(&self) -> u32 {
self.tile_size.seq_q
}
pub fn elements_in_tile_seq_kv(&self) -> u32 {
self.tile_size.seq_kv
}
pub fn elements_in_partition_seq_q(&self) -> u32 {
self.partition_size.seq_q * self.elements_in_tile_seq_q()
}
pub fn elements_in_partition_seq_kv(&self) -> u32 {
self.partition_size.seq_kv * self.elements_in_tile_seq_kv()
}
pub fn elements_in_partition_head_dim(&self) -> u32 {
self.partition_size.head_dim * self.tile_size.head_dim
}
pub fn elements_in_partition_val_dim(&self) -> u32 {
self.partition_size.val_dim * self.tile_size.val_dim
}
pub fn elements_in_stage_seq_q(&self) -> u32 {
self.stage_size.seq_q * self.elements_in_partition_seq_q()
}
pub fn check_bounds(&self, problem: &AttentionDims) -> AttentionCheckBounds {
AttentionCheckBounds {
seq_q: !self
.elements_in_stage_seq_q()
.is_multiple_of(problem.seq_q as u32),
seq_kv: !self
.elements_in_partition_seq_kv()
.is_multiple_of(problem.seq_kv as u32),
head_dim: !self
.elements_in_partition_head_dim()
.is_multiple_of(problem.head_dim as u32),
val_dim: !self
.elements_in_partition_val_dim()
.is_multiple_of(problem.val_dim as u32),
}
}
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct AttentionTileSize {
pub seq_q: u32, pub head_dim: u32, pub seq_kv: u32, pub val_dim: u32, }
impl AttentionTileSize {
pub fn from_max_vector_sizes(vector_sizes: &AttentionVectorSizes) -> Self {
fn lcm(a: usize, b: usize) -> usize {
a / gcd(a, b) * b
}
fn gcd(mut a: usize, mut b: usize) -> usize {
while b != 0 {
let tmp = b;
b = a % b;
a = tmp;
}
a
}
let head_dim = lcm(vector_sizes.query, vector_sizes.key);
let val_dim = lcm(vector_sizes.value, vector_sizes.out);
let seq_kv = lcm(vector_sizes.key, vector_sizes.mask);
let seq_q = 8;
AttentionTileSize {
seq_q,
head_dim: head_dim as u32,
seq_kv: seq_kv as u32,
val_dim: val_dim as u32,
}
}
pub fn to_score_matmul_tile_size(&self) -> TileSize {
TileSize {
m: self.seq_q,
n: self.seq_kv,
k: self.head_dim,
}
}
pub fn to_value_matmul_tile_size(&self) -> TileSize {
TileSize {
m: self.seq_q,
n: self.val_dim,
k: self.seq_kv,
}
}
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct AttentionPartitionSize {
pub seq_q: u32,
pub head_dim: u32,
pub seq_kv: u32,
pub val_dim: u32,
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct AttentionStageSize {
pub seq_q: u32,
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct AttentionCheckBounds {
pub seq_q: bool,
pub seq_kv: bool,
pub head_dim: bool,
pub val_dim: bool,
}