use crate::error::{LmError, LmResult};
use crate::layer::embedding::RotaryEmbedding;
use crate::layer::ffn::linear_batch;
use crate::weights::WeightTensor;
#[derive(Debug, Clone)]
pub struct LayerKvCache {
pub keys: Vec<f32>,
pub values: Vec<f32>,
pub past_len: usize,
pub n_kv_heads: usize,
pub head_dim: usize,
}
impl LayerKvCache {
pub fn new(n_kv_heads: usize, head_dim: usize) -> Self {
Self {
keys: vec![],
values: vec![],
past_len: 0,
n_kv_heads,
head_dim,
}
}
pub fn append(&mut self, new_k: &[f32], new_v: &[f32], seq_len: usize) {
self.keys.extend_from_slice(new_k);
self.values.extend_from_slice(new_v);
self.past_len += seq_len;
}
pub fn total_len(&self) -> usize {
self.past_len
}
}
#[derive(Debug, Clone)]
pub struct MultiHeadAttention {
pub n_heads: usize,
pub n_kv_heads: usize,
pub head_dim: usize,
pub hidden_dim: usize,
pub w_q: WeightTensor,
pub w_k: WeightTensor,
pub w_v: WeightTensor,
pub w_o: WeightTensor,
pub b_q: Option<Vec<f32>>,
pub b_k: Option<Vec<f32>>,
pub b_v: Option<Vec<f32>>,
pub b_o: Option<Vec<f32>>,
pub causal: bool,
}
impl MultiHeadAttention {
pub fn new(
n_heads: usize,
n_kv_heads: usize,
hidden_dim: usize,
causal: bool,
) -> LmResult<Self> {
if n_heads == 0 || hidden_dim == 0 {
return Err(LmError::InvalidConfig {
msg: "n_heads and hidden_dim must be > 0".into(),
});
}
if n_kv_heads == 0 || n_heads % n_kv_heads != 0 {
return Err(LmError::GqaHeadMismatch {
n_heads,
n_kv_heads,
});
}
if hidden_dim % n_heads != 0 {
return Err(LmError::HeadDimMismatch {
hidden_dim,
n_heads,
});
}
let head_dim = hidden_dim / n_heads;
let kv_proj_dim = n_kv_heads * head_dim;
Ok(Self {
n_heads,
n_kv_heads,
head_dim,
hidden_dim,
w_q: WeightTensor::zeros(&[hidden_dim, hidden_dim]),
w_k: WeightTensor::zeros(&[kv_proj_dim, hidden_dim]),
w_v: WeightTensor::zeros(&[kv_proj_dim, hidden_dim]),
w_o: WeightTensor::zeros(&[hidden_dim, hidden_dim]),
b_q: None,
b_k: None,
b_v: None,
b_o: None,
causal,
})
}
pub fn forward(
&self,
x: &[f32],
seq_len: usize,
past_kv: Option<&LayerKvCache>,
rope: Option<&RotaryEmbedding>,
) -> LmResult<(Vec<f32>, LayerKvCache)> {
if seq_len == 0 {
return Err(LmError::EmptyInput {
context: "MultiHeadAttention::forward seq_len",
});
}
let expected_len = seq_len * self.hidden_dim;
if x.len() != expected_len {
return Err(LmError::DimensionMismatch {
expected: expected_len,
got: x.len(),
});
}
let past_len = past_kv.map_or(0, |c| c.past_len);
let kv_proj_dim = self.n_kv_heads * self.head_dim;
let mut q = linear_batch(&self.w_q, self.b_q.as_deref(), x, seq_len)?;
let mut k_new = linear_batch(&self.w_k, self.b_k.as_deref(), x, seq_len)?;
let v_new = linear_batch(&self.w_v, self.b_v.as_deref(), x, seq_len)?;
if let Some(r) = rope {
r.apply(&mut q, self.n_heads, seq_len, past_len)?;
r.apply(&mut k_new, self.n_kv_heads, seq_len, past_len)?;
}
let (full_k, full_v) = if let Some(cache) = past_kv {
let mut fk = cache.keys.clone();
fk.extend_from_slice(&k_new);
let mut fv = cache.values.clone();
fv.extend_from_slice(&v_new);
(fk, fv)
} else {
(k_new.clone(), v_new.clone())
};
let total_len = past_len + seq_len;
let scale = 1.0 / (self.head_dim as f32).sqrt();
let gqa_factor = self.n_heads / self.n_kv_heads;
let mut out = vec![0.0_f32; seq_len * self.hidden_dim];
for t in 0..seq_len {
let abs_q_pos = past_len + t;
for h in 0..self.n_heads {
let kv_h = h / gqa_factor;
let q_off = t * self.hidden_dim + h * self.head_dim;
let q_vec = &q[q_off..q_off + self.head_dim];
let mut scores = vec![0.0_f32; total_len];
for (kpos, sc) in scores.iter_mut().enumerate() {
if self.causal && kpos > abs_q_pos {
continue;
}
let k_off = kpos * kv_proj_dim + kv_h * self.head_dim;
let k_vec = &full_k[k_off..k_off + self.head_dim];
let dot: f32 = q_vec
.iter()
.zip(k_vec.iter())
.map(|(&qi, &ki)| qi * ki)
.sum();
*sc = dot * scale;
}
if self.causal {
for (kpos, sc) in scores.iter_mut().enumerate() {
if kpos > abs_q_pos {
*sc = f32::NEG_INFINITY;
}
}
}
let max_s = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum_exp = 0.0_f32;
let mut attn: Vec<f32> = scores
.iter()
.map(|&s| {
let e = (s - max_s).exp();
sum_exp += e;
e
})
.collect();
if sum_exp > 0.0 {
for a in &mut attn {
*a /= sum_exp;
}
}
let out_off = t * self.hidden_dim + h * self.head_dim;
for (kpos, &aw) in attn.iter().enumerate() {
if self.causal && kpos > abs_q_pos {
continue;
}
let v_off = kpos * kv_proj_dim + kv_h * self.head_dim;
let v_vec = &full_v[v_off..v_off + self.head_dim];
for (d, &vi) in v_vec.iter().enumerate() {
out[out_off + d] += aw * vi;
}
}
}
}
let out = linear_batch(&self.w_o, self.b_o.as_deref(), &out, seq_len)?;
let mut new_cache = LayerKvCache::new(self.n_kv_heads, self.head_dim);
new_cache.keys = full_k;
new_cache.values = full_v;
new_cache.past_len = total_len;
Ok((out, new_cache))
}
pub fn forward_residual(
&self,
x: &[f32],
seq_len: usize,
past_kv: Option<&LayerKvCache>,
rope: Option<&RotaryEmbedding>,
) -> LmResult<(Vec<f32>, LayerKvCache)> {
let (attn_out, cache) = self.forward(x, seq_len, past_kv, rope)?;
let mut out = x.to_vec();
for (o, &a) in out.iter_mut().zip(attn_out.iter()) {
*o += a;
}
Ok((out, cache))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_mha(n_heads: usize, hidden_dim: usize) -> MultiHeadAttention {
MultiHeadAttention::new(n_heads, n_heads, hidden_dim, true).unwrap()
}
#[test]
fn kv_cache_empty_on_init() {
let c = LayerKvCache::new(4, 16);
assert_eq!(c.past_len, 0);
assert!(c.keys.is_empty());
}
#[test]
fn kv_cache_append() {
let mut c = LayerKvCache::new(2, 4);
let k = vec![1.0_f32; 2 * 4]; c.append(&k, &k, 1);
assert_eq!(c.past_len, 1);
assert_eq!(c.keys.len(), 8);
}
#[test]
fn mha_zero_weights_zero_output() {
let mha = make_mha(2, 4);
let x = vec![1.0_f32; 4]; let (out, _) = mha.forward(&x, 1, None, None).unwrap();
assert!(out.iter().all(|&v| v.abs() < 1e-6), "out={out:?}");
}
#[test]
fn mha_output_shape_single_token() {
let mha = make_mha(2, 4);
let x = vec![0.0_f32; 4];
let (out, cache) = mha.forward(&x, 1, None, None).unwrap();
assert_eq!(out.len(), 4);
assert_eq!(cache.past_len, 1);
}
#[test]
fn mha_output_shape_multi_token() {
let mha = make_mha(2, 4);
let x = vec![0.0_f32; 3 * 4];
let (out, cache) = mha.forward(&x, 3, None, None).unwrap();
assert_eq!(out.len(), 3 * 4);
assert_eq!(cache.past_len, 3);
}
#[test]
fn mha_kv_cache_extends() {
let mha = make_mha(2, 4);
let x1 = vec![0.0_f32; 4];
let (_, cache1) = mha.forward(&x1, 1, None, None).unwrap();
assert_eq!(cache1.past_len, 1);
let x2 = vec![0.0_f32; 4];
let (_, cache2) = mha.forward(&x2, 1, Some(&cache1), None).unwrap();
assert_eq!(cache2.past_len, 2);
}
#[test]
fn mha_gqa_forward_shape() {
let mha = MultiHeadAttention::new(4, 2, 8, true).unwrap();
let x = vec![0.0_f32; 3 * 8]; let (out, _) = mha.forward(&x, 3, None, None).unwrap();
assert_eq!(out.len(), 3 * 8);
}
#[test]
fn mha_gqa_head_mismatch_error() {
assert!(MultiHeadAttention::new(4, 5, 8, true).is_err());
}
#[test]
fn mha_invalid_config_error() {
assert!(MultiHeadAttention::new(0, 1, 4, true).is_err());
}
#[test]
fn mha_w_o_identity_propagates_value() {
let mut mha = MultiHeadAttention::new(1, 1, 4, false).unwrap();
mha.w_q = WeightTensor::eye(4, 4);
mha.w_k = WeightTensor::eye(4, 4);
mha.w_v = WeightTensor::eye(4, 4);
mha.w_o = WeightTensor::eye(4, 4);
let x = vec![1.0_f32, 2.0, 3.0, 4.0];
let (out, _) = mha.forward(&x, 1, None, None).unwrap();
for (&o, &xi) in out.iter().zip(x.iter()) {
assert!((o - xi).abs() < 1e-5, "out={out:?} x={x:?}");
}
}
#[test]
fn mha_causal_mask_applied() {
let mut mha = MultiHeadAttention::new(1, 1, 4, true).unwrap();
mha.w_q = WeightTensor::eye(4, 4);
mha.w_k = WeightTensor::eye(4, 4);
mha.w_v = WeightTensor::eye(4, 4);
mha.w_o = WeightTensor::eye(4, 4);
let x = vec![1.0_f32, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0];
let (out, _) = mha.forward(&x, 2, None, None).unwrap();
assert!((out[0] - 1.0).abs() < 1e-5, "out[0]={}", out[0]);
assert!(out[1].abs() < 1e-5);
}
#[test]
fn mha_empty_seq_error() {
let mha = make_mha(2, 4);
assert!(mha.forward(&[], 0, None, None).is_err());
}
#[test]
fn mha_rope_applied_no_error() {
let mha = make_mha(2, 4);
let rope = RotaryEmbedding::new(2, 32, 10_000.0).unwrap();
let x = vec![0.5_f32; 4];
let result = mha.forward(&x, 1, None, Some(&rope));
assert!(result.is_ok());
}
#[test]
fn mha_incremental_vs_full_consistency() {
let mut mha = MultiHeadAttention::new(1, 1, 4, false).unwrap();
mha.w_q = WeightTensor::eye(4, 4);
mha.w_k = WeightTensor::eye(4, 4);
mha.w_v = WeightTensor::eye(4, 4);
mha.w_o = WeightTensor::eye(4, 4);
let x0 = vec![1.0_f32, 0.0, 0.0, 0.0];
let x1 = vec![0.0_f32, 1.0, 0.0, 0.0];
let full_x = [x0.clone(), x1.clone()].concat();
let (out_full, _) = mha.forward(&full_x, 2, None, None).unwrap();
let (_, cache0) = mha.forward(&x0, 1, None, None).unwrap();
let (out_incr_1, _) = mha.forward(&x1, 1, Some(&cache0), None).unwrap();
for (&full_v, &incr_v) in out_full[4..].iter().zip(out_incr_1.iter()) {
assert!(
(full_v - incr_v).abs() < 1e-4,
"full={full_v} incr={incr_v}"
);
}
}
}