use crate::error::{ModelError, ModelResult};
pub const ATTENTION_BLOCK_SIZE: usize = 32;
struct OnlineSoftmaxState {
max_val: f32,
sum_exp: f32,
output: Vec<f32>,
}
impl OnlineSoftmaxState {
fn new(head_dim: usize) -> Self {
Self {
max_val: f32::NEG_INFINITY,
sum_exp: 0.0,
output: vec![0.0f32; head_dim],
}
}
fn update(&mut self, scores: &[f32], values: &[&[f32]], head_dim: usize) {
debug_assert_eq!(scores.len(), values.len());
for (idx, &score) in scores.iter().enumerate() {
let v = values[idx];
debug_assert_eq!(v.len(), head_dim);
if score > self.max_val {
let rescale = if self.max_val == f32::NEG_INFINITY {
0.0 } else {
(self.max_val - score).exp()
};
self.sum_exp *= rescale;
for d in 0..head_dim {
self.output[d] *= rescale;
}
self.max_val = score;
}
let exp_score = (score - self.max_val).exp();
self.sum_exp += exp_score;
for (out_d, &v_d) in self.output[..head_dim].iter_mut().zip(v.iter()) {
*out_d += exp_score * v_d;
}
}
}
fn finalize(&mut self) {
if self.sum_exp > 0.0 {
let inv_sum = 1.0 / self.sum_exp;
for d in self.output.iter_mut() {
*d *= inv_sum;
}
}
}
}
pub fn fused_attention_head(
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
head_dim: usize,
output: &mut [f32],
) -> ModelResult<()> {
if query.len() < head_dim {
return Err(ModelError::ShapeMismatch {
name: "query".to_string(),
expected: vec![head_dim],
actual: vec![query.len()],
});
}
if output.len() < head_dim {
return Err(ModelError::ShapeMismatch {
name: "output".to_string(),
expected: vec![head_dim],
actual: vec![output.len()],
});
}
if keys.len() != values.len() {
return Err(ModelError::ShapeMismatch {
name: "keys/values length".to_string(),
expected: vec![keys.len()],
actual: vec![values.len()],
});
}
let seq_len = keys.len();
if seq_len == 0 {
for d in output.iter_mut() {
*d = 0.0;
}
return Ok(());
}
let scale = 1.0 / (head_dim as f32).sqrt();
let mut state = OnlineSoftmaxState::new(head_dim);
let mut pos = 0;
while pos < seq_len {
let block_end = (pos + ATTENTION_BLOCK_SIZE).min(seq_len);
let block_len = block_end - pos;
let mut block_scores = Vec::with_capacity(block_len);
let mut block_values = Vec::with_capacity(block_len);
for t in pos..block_end {
let score = scaled_dot_product(query, keys[t], scale);
block_scores.push(score);
block_values.push(values[t]);
}
state.update(&block_scores, &block_values, head_dim);
pos = block_end;
}
state.finalize();
output[..head_dim].copy_from_slice(&state.output[..head_dim]);
Ok(())
}
pub fn fused_attention_head_contiguous(
query: &[f32],
keys: &[f32],
values: &[f32],
output: &mut [f32],
seq_len: usize,
head_dim: usize,
) -> ModelResult<()> {
if query.len() < head_dim {
return Err(ModelError::ShapeMismatch {
name: "query".to_string(),
expected: vec![head_dim],
actual: vec![query.len()],
});
}
if keys.len() < seq_len * head_dim {
return Err(ModelError::ShapeMismatch {
name: "keys".to_string(),
expected: vec![seq_len * head_dim],
actual: vec![keys.len()],
});
}
if values.len() < seq_len * head_dim {
return Err(ModelError::ShapeMismatch {
name: "values".to_string(),
expected: vec![seq_len * head_dim],
actual: vec![values.len()],
});
}
if output.len() < head_dim {
return Err(ModelError::ShapeMismatch {
name: "output".to_string(),
expected: vec![head_dim],
actual: vec![output.len()],
});
}
if seq_len == 0 {
for d in output.iter_mut() {
*d = 0.0;
}
return Ok(());
}
let scale = 1.0 / (head_dim as f32).sqrt();
let mut state = OnlineSoftmaxState::new(head_dim);
let mut pos = 0;
while pos < seq_len {
let block_end = (pos + ATTENTION_BLOCK_SIZE).min(seq_len);
let block_len = block_end - pos;
let mut block_scores = Vec::with_capacity(block_len);
let mut block_values: Vec<&[f32]> = Vec::with_capacity(block_len);
for t in pos..block_end {
let k_slice = &keys[t * head_dim..(t + 1) * head_dim];
let score = scaled_dot_product(query, k_slice, scale);
block_scores.push(score);
block_values.push(&values[t * head_dim..(t + 1) * head_dim]);
}
state.update(&block_scores, &block_values, head_dim);
pos = block_end;
}
state.finalize();
output[..head_dim].copy_from_slice(&state.output[..head_dim]);
Ok(())
}
pub fn softmax_inplace(logits: &mut [f32]) {
if logits.is_empty() {
return;
}
let max_val = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for v in logits.iter_mut() {
*v = (*v - max_val).exp();
sum += *v;
}
if sum > 0.0 {
let inv_sum = 1.0 / sum;
for v in logits.iter_mut() {
*v *= inv_sum;
}
}
}
#[inline]
pub fn scaled_dot_product(q: &[f32], k: &[f32], scale: f32) -> f32 {
debug_assert_eq!(q.len(), k.len());
let len = q.len().min(k.len());
let chunks = len / 4;
let remainder = len % 4;
let mut acc0 = 0.0f32;
let mut acc1 = 0.0f32;
let mut acc2 = 0.0f32;
let mut acc3 = 0.0f32;
for i in 0..chunks {
let base = i * 4;
acc0 += q[base] * k[base];
acc1 += q[base + 1] * k[base + 1];
acc2 += q[base + 2] * k[base + 2];
acc3 += q[base + 3] * k[base + 3];
}
let mut sum = (acc0 + acc1) + (acc2 + acc3);
for i in (len - remainder)..len {
sum += q[i] * k[i];
}
sum * scale
}
#[cfg(test)]
mod tests {
use super::*;
fn reference_attention(
query: &[f32],
keys: &[f32],
values: &[f32],
output: &mut [f32],
seq_len: usize,
head_dim: usize,
) {
use super::super::attention::{attention_head, dot};
let _ = dot; attention_head(query, keys, values, output, seq_len, head_dim)
.expect("reference attention should succeed");
}
#[test]
fn fused_matches_standard_single_token() {
let head_dim = 4;
let query = vec![1.0, 0.0, 0.0, 0.0];
let keys = vec![1.0, 0.0, 0.0, 0.0];
let values = vec![0.0, 1.0, 2.0, 3.0];
let mut out_std = vec![0.0f32; head_dim];
let mut out_fused = vec![0.0f32; head_dim];
reference_attention(&query, &keys, &values, &mut out_std, 1, head_dim);
fused_attention_head_contiguous(&query, &keys, &values, &mut out_fused, 1, head_dim)
.expect("fused attention should succeed");
for i in 0..head_dim {
assert!(
(out_std[i] - out_fused[i]).abs() < 1e-5,
"dim {i}: std={}, fused={}",
out_std[i],
out_fused[i]
);
}
}
#[test]
fn fused_matches_standard_multiple_tokens() {
let head_dim = 8;
let seq_len = 10;
let query: Vec<f32> = (0..head_dim).map(|i| (i as f32 + 1.0) * 0.1).collect();
let keys: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i % 17) as f32 - 8.0) * 0.05)
.collect();
let values: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i % 13) as f32 - 6.0) * 0.1)
.collect();
let mut out_std = vec![0.0f32; head_dim];
let mut out_fused = vec![0.0f32; head_dim];
reference_attention(&query, &keys, &values, &mut out_std, seq_len, head_dim);
fused_attention_head_contiguous(&query, &keys, &values, &mut out_fused, seq_len, head_dim)
.expect("fused attention should succeed");
for i in 0..head_dim {
assert!(
(out_std[i] - out_fused[i]).abs() < 1e-4,
"dim {i}: std={}, fused={}",
out_std[i],
out_fused[i]
);
}
}
#[test]
fn fused_matches_standard_large_seq() {
let head_dim = 16;
let seq_len = 100;
let query: Vec<f32> = (0..head_dim).map(|i| (i as f32 * 0.2) - 1.0).collect();
let keys: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i * 7 + 3) % 23) as f32 * 0.04 - 0.5)
.collect();
let values: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i * 11 + 5) % 19) as f32 * 0.06 - 0.6)
.collect();
let mut out_std = vec![0.0f32; head_dim];
let mut out_fused = vec![0.0f32; head_dim];
reference_attention(&query, &keys, &values, &mut out_std, seq_len, head_dim);
fused_attention_head_contiguous(&query, &keys, &values, &mut out_fused, seq_len, head_dim)
.expect("fused attention should succeed");
for i in 0..head_dim {
assert!(
(out_std[i] - out_fused[i]).abs() < 1e-3,
"dim {i}: std={}, fused={}",
out_std[i],
out_fused[i]
);
}
}
#[test]
fn fused_with_slice_api() {
let head_dim = 4;
let seq_len = 3;
let query = vec![1.0, 0.5, -0.5, 0.0];
let k0 = vec![0.5, 0.5, 0.0, 0.0];
let k1 = vec![0.0, 1.0, 0.0, 0.0];
let k2 = vec![-0.5, 0.0, 1.0, 0.0];
let v0 = vec![1.0, 0.0, 0.0, 0.0];
let v1 = vec![0.0, 1.0, 0.0, 0.0];
let v2 = vec![0.0, 0.0, 1.0, 0.0];
let keys_refs: Vec<&[f32]> = vec![&k0, &k1, &k2];
let values_refs: Vec<&[f32]> = vec![&v0, &v1, &v2];
let mut output = vec![0.0f32; head_dim];
fused_attention_head(&query, &keys_refs, &values_refs, head_dim, &mut output)
.expect("fused attention should succeed");
let mut keys_flat = vec![0.0f32; seq_len * head_dim];
let mut values_flat = vec![0.0f32; seq_len * head_dim];
for (t, (k, v)) in keys_refs.iter().zip(values_refs.iter()).enumerate() {
keys_flat[t * head_dim..(t + 1) * head_dim].copy_from_slice(k);
values_flat[t * head_dim..(t + 1) * head_dim].copy_from_slice(v);
}
let mut out_std = vec![0.0f32; head_dim];
reference_attention(
&query,
&keys_flat,
&values_flat,
&mut out_std,
seq_len,
head_dim,
);
for i in 0..head_dim {
assert!(
(out_std[i] - output[i]).abs() < 1e-4,
"dim {i}: std={}, fused={}",
out_std[i],
output[i]
);
}
}
#[test]
fn fused_empty_sequence() {
let head_dim = 4;
let keys: Vec<&[f32]> = vec![];
let values: Vec<&[f32]> = vec![];
let query = vec![1.0; head_dim];
let mut output = vec![99.0f32; head_dim];
fused_attention_head(&query, &keys, &values, head_dim, &mut output)
.expect("fused attention should handle empty seq");
for &v in &output {
assert!((v - 0.0).abs() < f32::EPSILON);
}
}
#[test]
fn softmax_inplace_basic() {
let mut vals = vec![1.0, 2.0, 3.0];
softmax_inplace(&mut vals);
let sum: f32 = vals.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
assert!(vals[0] < vals[1]);
assert!(vals[1] < vals[2]);
}
#[test]
fn softmax_inplace_single() {
let mut vals = vec![5.0];
softmax_inplace(&mut vals);
assert!((vals[0] - 1.0).abs() < 1e-5);
}
#[test]
fn softmax_inplace_empty() {
let mut vals: Vec<f32> = vec![];
softmax_inplace(&mut vals); }
#[test]
fn scaled_dot_product_basic() {
let q = vec![1.0, 2.0, 3.0, 4.0];
let k = vec![4.0, 3.0, 2.0, 1.0];
let scale = 0.5;
let result = scaled_dot_product(&q, &k, scale);
assert!((result - 10.0).abs() < 1e-5);
}
#[test]
fn scaled_dot_product_non_multiple_of_4() {
let q = vec![1.0, 2.0, 3.0];
let k = vec![4.0, 5.0, 6.0];
let scale = 1.0;
let result = scaled_dot_product(&q, &k, scale);
assert!((result - 32.0).abs() < 1e-5);
}
#[test]
fn fused_validation_errors() {
let head_dim = 4;
let query = vec![1.0; 2]; let keys: Vec<&[f32]> = vec![];
let values: Vec<&[f32]> = vec![];
let mut output = vec![0.0f32; head_dim];
let result = fused_attention_head(&query, &keys, &values, head_dim, &mut output);
assert!(result.is_err());
}
#[test]
fn fused_contiguous_validation_errors() {
let head_dim = 4;
let query = vec![1.0; head_dim];
let keys = vec![1.0; 4]; let values = vec![1.0; 2]; let mut output = vec![0.0f32; head_dim];
let result =
fused_attention_head_contiguous(&query, &keys, &values, &mut output, 1, head_dim);
assert!(result.is_err());
}
#[test]
fn fused_head_dim_128() {
let head_dim = 128;
let seq_len = 50;
let query: Vec<f32> = (0..head_dim).map(|i| (i as f32 * 0.03) - 2.0).collect();
let keys: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i * 7 + 3) % 31) as f32 * 0.02 - 0.3)
.collect();
let values: Vec<f32> = (0..seq_len * head_dim)
.map(|i| ((i * 13 + 7) % 23) as f32 * 0.04 - 0.5)
.collect();
let mut out_std = vec![0.0f32; head_dim];
let mut out_fused = vec![0.0f32; head_dim];
reference_attention(&query, &keys, &values, &mut out_std, seq_len, head_dim);
fused_attention_head_contiguous(&query, &keys, &values, &mut out_fused, seq_len, head_dim)
.expect("fused attention should succeed");
let max_diff = out_std
.iter()
.zip(out_fused.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_diff < 1e-3,
"max difference between standard and fused: {max_diff}"
);
}
}