impl FlashAttentionBrick {
#[must_use]
pub fn new(num_heads: usize, num_kv_heads: usize, head_dim: usize) -> Self {
Self {
num_heads,
num_kv_heads,
head_dim,
tile_size: 128, budget: TokenBudget::from_latency(5.0), use_online_softmax: true,
}
}
#[must_use]
pub fn with_tile_size(
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
tile_size: usize,
) -> Self {
Self {
num_heads,
num_kv_heads,
head_dim,
tile_size,
budget: TokenBudget::from_latency(5.0),
use_online_softmax: true,
}
}
#[must_use]
pub fn with_budget(mut self, budget: TokenBudget) -> Self {
self.budget = budget;
self
}
#[must_use]
pub fn group_size(&self) -> usize {
self.num_heads / self.num_kv_heads.max(1)
}
#[must_use]
pub fn flops(&self, seq_len: usize) -> u64 {
4 * self.num_heads as u64 * self.head_dim as u64 * seq_len as u64
}
#[must_use]
pub fn memory_bytes(&self, seq_len: usize) -> (u64, u64) {
let kv_bytes = 2 * self.num_kv_heads as u64 * self.head_dim as u64 * seq_len as u64 * 4; let naive = kv_bytes + self.num_heads as u64 * seq_len as u64 * 4; let flash = kv_bytes; (naive, flash)
}
#[must_use]
pub fn arithmetic_intensity(&self, seq_len: usize) -> f64 {
let (_, flash_bytes) = self.memory_bytes(seq_len);
self.flops(seq_len) as f64 / flash_bytes as f64
}
#[must_use]
pub fn num_tiles(&self, seq_len: usize) -> usize {
seq_len.div_ceil(self.tile_size)
}
pub fn forward(
&self,
query: &[f32], keys: &[f32], values: &[f32], seq_len: usize,
) -> Result<Vec<f32>, BrickError> {
if self.num_heads == 0 || self.head_dim == 0 {
return Err(BrickError::InvalidInput("Zero dimension".to_string()));
}
if query.len() != self.num_heads * self.head_dim {
return Err(BrickError::InvalidInput(format!(
"Query length {} != num_heads * head_dim = {}",
query.len(),
self.num_heads * self.head_dim
)));
}
let expected_kv_len = seq_len * self.num_kv_heads * self.head_dim;
if keys.len() != expected_kv_len || values.len() != expected_kv_len {
return Err(BrickError::InvalidInput(format!(
"KV length {} != seq_len * num_kv_heads * head_dim = {}",
keys.len(),
expected_kv_len
)));
}
let scale = 1.0 / (self.head_dim as f32).sqrt();
let group_size = self.group_size();
let mut output = vec![0.0f32; self.num_heads * self.head_dim];
for h in 0..self.num_heads {
let kv_head = h / group_size; let q_start = h * self.head_dim;
let mut m = f32::NEG_INFINITY; let mut l = 0.0f32; let mut o = vec![0.0f32; self.head_dim];
for tile_start in (0..seq_len).step_by(self.tile_size) {
let tile_end = (tile_start + self.tile_size).min(seq_len);
for s in tile_start..tile_end {
let k_start = (s * self.num_kv_heads + kv_head) * self.head_dim;
let mut score = 0.0f32;
for d in 0..self.head_dim {
score += query[q_start + d] * keys[k_start + d];
}
score *= scale;
let m_new = m.max(score);
let exp_old = (m - m_new).exp();
let exp_score = (score - m_new).exp();
l = l * exp_old + exp_score;
let v_start = (s * self.num_kv_heads + kv_head) * self.head_dim;
for d in 0..self.head_dim {
o[d] = o[d] * exp_old + exp_score * values[v_start + d];
}
m = m_new;
}
}
if l > 0.0 {
for d in 0..self.head_dim {
output[h * self.head_dim + d] = o[d] / l;
}
}
}
Ok(output)
}
pub fn forward_timed(
&self,
query: &[f32],
keys: &[f32],
values: &[f32],
seq_len: usize,
) -> Result<TokenResult<Vec<f32>>, BrickError> {
let start = Instant::now();
let output = self.forward(query, keys, values, seq_len)?;
let elapsed_us = start.elapsed().as_secs_f64() * 1_000_000.0;
Ok(TokenResult {
output,
tokens_processed: 1,
us_per_token: elapsed_us,
tokens_per_sec: 1_000_000.0 / elapsed_us,
budget_met: elapsed_us <= self.budget.us_per_token,
})
}
#[deprecated(note = "Use forward() for real implementation")]
pub fn execute(&self, _seq_len: usize) -> Result<Vec<f32>, BrickError> {
if self.num_heads == 0 || self.head_dim == 0 {
return Err(BrickError::InvalidInput("Zero dimension".to_string()));
}
Ok(vec![0.0; self.num_heads * self.head_dim])
}
}
impl ComputeBrick for FlashAttentionBrick {
type Output = Vec<f32>;
fn name(&self) -> &'static str {
"flash_attention"
}
fn budget(&self) -> TokenBudget {
self.budget
}
fn assertions(&self) -> Vec<BrickAssertion> {
vec![
BrickAssertion::no_nan(),
BrickAssertion::no_inf(),
BrickAssertion::budget_met(),
BrickAssertion::bounds(-100.0, 100.0),
BrickAssertion {
name: "online_softmax".to_string(),
description: "Uses online softmax (no full attention matrix)".to_string(),
kind: AssertionKind::Custom {
check_name: "online_softmax".to_string(),
},
},
BrickAssertion {
name: "tiled_kv_access".to_string(),
description: "KV cache accessed in tiles for cache locality".to_string(),
kind: AssertionKind::Custom {
check_name: "tiled_kv_access".to_string(),
},
},
]
}
fn can_run(&self) -> bool {
self.num_heads > 0 && self.head_dim > 0 && self.tile_size > 0
}
}
#[derive(Debug)]
pub struct FfnBrick {
pub hidden_dim: usize,
pub intermediate_dim: usize,
budget: TokenBudget,
}
impl FfnBrick {
pub fn new(hidden_dim: usize, intermediate_dim: usize) -> Self {
Self {
hidden_dim,
intermediate_dim,
budget: TokenBudget::from_latency(12.2), }
}
#[must_use]
pub fn with_budget(mut self, budget: TokenBudget) -> Self {
self.budget = budget;
self
}
}
impl ComputeBrick for FfnBrick {
type Output = Vec<f32>;
fn name(&self) -> &'static str {
"ffn"
}
fn budget(&self) -> TokenBudget {
self.budget
}
fn assertions(&self) -> Vec<BrickAssertion> {
vec![
BrickAssertion::no_nan(),
BrickAssertion::no_inf(),
BrickAssertion::budget_met(),
]
}
}
#[derive(Debug)]
pub struct OProjBrick {
pub in_dim: usize,
pub out_dim: usize,
budget: TokenBudget,
}
impl OProjBrick {
pub fn new(in_dim: usize, out_dim: usize) -> Self {
Self {
in_dim,
out_dim,
budget: TokenBudget::from_latency(3.5), }
}
#[must_use]
pub fn with_budget(mut self, budget: TokenBudget) -> Self {
self.budget = budget;
self
}
}
impl ComputeBrick for OProjBrick {
type Output = Vec<f32>;
fn name(&self) -> &'static str {
"o_proj"
}
fn budget(&self) -> TokenBudget {
self.budget
}
fn assertions(&self) -> Vec<BrickAssertion> {
vec![
BrickAssertion::no_nan(),
BrickAssertion::no_inf(),
BrickAssertion::budget_met(),
]
}
}
#[derive(Debug, Clone)]
pub struct ActivationQuantBrick {
pub dim: usize,
budget: TokenBudget,
pub per_channel: bool,
}