use ndarray::{Array2, Array3, Array4, ArrayView2, ArrayView3};
use crate::layers::linear3d;
use crate::rope::{RopeTables, apply_rotary_emb_ref};
use crate::state_dict::{StateDict, StateDictError};
#[derive(Debug, Clone)]
pub struct AttentionConfig {
pub embed_dim: usize,
pub num_heads: usize,
pub dropout: f32,
pub bias: bool,
}
impl AttentionConfig {
#[inline]
pub fn head_dim(&self) -> usize {
self.embed_dim / self.num_heads
}
}
#[derive(Debug, Clone)]
pub struct AttentionParams {
pub in_proj_weight: Array2<f32>,
pub in_proj_bias: Option<Vec<f32>>,
pub out_proj_weight: Array2<f32>,
pub out_proj_bias: Option<Vec<f32>>,
}
impl AttentionParams {
pub fn load_from(
&mut self,
sd: &StateDict,
prefix: &str,
embed_dim: usize,
) -> Result<(), StateDictError> {
self.in_proj_weight = sd.take_array2(
&format!("{prefix}.in_proj_weight"),
3 * embed_dim,
embed_dim,
)?;
let in_b_key = format!("{prefix}.in_proj_bias");
if sd.tensors.contains_key(&in_b_key) {
self.in_proj_bias = Some(sd.take_vec(&in_b_key, 3 * embed_dim)?);
}
self.out_proj_weight =
sd.take_array2(&format!("{prefix}.out_proj.weight"), embed_dim, embed_dim)?;
let out_b_key = format!("{prefix}.out_proj.bias");
if sd.tensors.contains_key(&out_b_key) {
self.out_proj_bias = Some(sd.take_vec(&out_b_key, embed_dim)?);
}
Ok(())
}
}
pub fn multi_head_attention_forward(
query: ArrayView3<f32>,
key: ArrayView3<f32>,
value: ArrayView3<f32>,
params: &AttentionParams,
cfg: &AttentionConfig,
rope: Option<&RopeTables>,
query_scale: Option<&Array4<f32>>,
) -> Array3<f32> {
multi_head_attention_forward_masked(query, key, value, params, cfg, rope, query_scale, None)
}
pub fn multi_head_attention_forward_with_ssmax(
query: ArrayView3<f32>,
key: ArrayView3<f32>,
value: ArrayView3<f32>,
params: &AttentionParams,
cfg: &AttentionConfig,
rope: Option<&RopeTables>,
attn_mask: Option<ArrayView2<f32>>,
ssmax: Option<&crate::encoders::MabSsmax>,
) -> Array3<f32> {
let scale = ssmax.map(|s| {
let (b, t, embed_dim) = (query.shape()[0], query.shape()[1], query.shape()[2]);
let head_dim = cfg.head_dim();
let nh = cfg.num_heads;
let w_q = params.in_proj_weight.slice(ndarray::s![..embed_dim, ..]);
let b_q = params.in_proj_bias.as_ref().map(|bb| &bb[..embed_dim]);
let q_proj = crate::layers::linear3d(query, w_q, b_q);
let mut q_h = Array4::<f32>::zeros((b, nh, t, head_dim));
for bi in 0..b {
for ti in 0..t {
for hi in 0..nh {
for di in 0..head_dim {
q_h[(bi, hi, ti, di)] = q_proj[(bi, ti, hi * head_dim + di)];
}
}
}
}
let q_h_rotated = if let Some(rt) = rope {
crate::rope::apply_rotary_emb_ref(&q_h.view(), rt)
} else {
q_h
};
let n_src = key.shape()[1];
crate::ssmax::compute_query_scale(&s.spec, &s.params, q_h_rotated.view(), n_src)
});
multi_head_attention_forward_masked(
query,
key,
value,
params,
cfg,
rope,
scale.as_ref(),
attn_mask,
)
}
pub fn multi_head_attention_forward_masked(
query: ArrayView3<f32>,
key: ArrayView3<f32>,
value: ArrayView3<f32>,
params: &AttentionParams,
cfg: &AttentionConfig,
rope: Option<&RopeTables>,
query_scale: Option<&Array4<f32>>,
attn_mask: Option<ArrayView2<f32>>,
) -> Array3<f32> {
let (b, tgt_len, embed_dim) = (query.shape()[0], query.shape()[1], query.shape()[2]);
assert_eq!(embed_dim, cfg.embed_dim);
let src_len = key.shape()[1];
assert_eq!(value.shape()[1], src_len);
let head_dim = cfg.head_dim();
let nh = cfg.num_heads;
let w = ¶ms.in_proj_weight;
let bias = params.in_proj_bias.as_deref();
assert_eq!(w.shape(), &[3 * embed_dim, embed_dim]);
let w_q = w.slice(ndarray::s![..embed_dim, ..]);
let w_k = w.slice(ndarray::s![embed_dim..2 * embed_dim, ..]);
let w_v = w.slice(ndarray::s![2 * embed_dim..3 * embed_dim, ..]);
let (b_q, b_k, b_v) = match bias {
Some(bb) => (
Some(&bb[..embed_dim]),
Some(&bb[embed_dim..2 * embed_dim]),
Some(&bb[2 * embed_dim..3 * embed_dim]),
),
None => (None, None, None),
};
let q = linear3d(query, w_q, b_q);
let k = linear3d(key, w_k, b_k);
let v = linear3d(value, w_v, b_v);
let q = reshape_heads(&q, nh, head_dim, tgt_len, b);
let k = reshape_heads(&k, nh, head_dim, src_len, b);
let v = reshape_heads(&v, nh, head_dim, src_len, b);
let (q, k) = match rope {
Some(tab) => (
apply_rotary_emb_ref(&q.view(), tab),
apply_rotary_emb_ref(&k.view(), tab),
),
None => (q, k),
};
let q = match query_scale {
Some(scale) => broadcast_mul_bhtd(&q, scale),
None => q,
};
let inv_sqrt = (head_dim as f64).sqrt().recip();
let mut attn = Array4::<f32>::zeros((b, nh, tgt_len, src_len));
for bi in 0..b {
for hi in 0..nh {
for ti in 0..tgt_len {
for si in 0..src_len {
let mut acc: f64 = 0.0;
for d in 0..head_dim {
acc += (q[(bi, hi, ti, d)] as f64) * (k[(bi, hi, si, d)] as f64);
}
attn[(bi, hi, ti, si)] = (acc * inv_sqrt) as f32;
}
}
}
}
if let Some(mask) = attn_mask {
assert_eq!(mask.shape(), &[tgt_len, src_len]);
for bi in 0..b {
for hi in 0..nh {
for ti in 0..tgt_len {
for si in 0..src_len {
attn[(bi, hi, ti, si)] += mask[(ti, si)];
}
}
}
}
}
for bi in 0..b {
for hi in 0..nh {
for ti in 0..tgt_len {
let mut m = f32::NEG_INFINITY;
for si in 0..src_len {
if attn[(bi, hi, ti, si)] > m {
m = attn[(bi, hi, ti, si)];
}
}
let mut s_acc: f64 = 0.0;
let mut exps = vec![0.0_f64; src_len];
for si in 0..src_len {
let e = ((attn[(bi, hi, ti, si)] - m) as f64).exp();
exps[si] = e;
s_acc += e;
}
let inv = s_acc.recip();
for si in 0..src_len {
attn[(bi, hi, ti, si)] = (exps[si] * inv) as f32;
}
}
}
}
let mut out_heads = Array4::<f32>::zeros((b, nh, tgt_len, head_dim));
for bi in 0..b {
for hi in 0..nh {
for ti in 0..tgt_len {
for d in 0..head_dim {
let mut acc: f64 = 0.0;
for si in 0..src_len {
acc += (attn[(bi, hi, ti, si)] as f64) * (v[(bi, hi, si, d)] as f64);
}
out_heads[(bi, hi, ti, d)] = acc as f32;
}
}
}
}
let mut concat = Array3::<f32>::zeros((b, tgt_len, embed_dim));
for bi in 0..b {
for ti in 0..tgt_len {
for hi in 0..nh {
for d in 0..head_dim {
concat[(bi, ti, hi * head_dim + d)] = out_heads[(bi, hi, ti, d)];
}
}
}
}
linear3d(
concat.view(),
params.out_proj_weight.view(),
params.out_proj_bias.as_deref(),
)
}
fn reshape_heads(x: &Array3<f32>, h: usize, d: usize, t: usize, b: usize) -> Array4<f32> {
let mut out = Array4::<f32>::zeros((b, h, t, d));
for bi in 0..b {
for ti in 0..t {
for hi in 0..h {
for di in 0..d {
out[(bi, hi, ti, di)] = x[(bi, ti, hi * d + di)];
}
}
}
}
out
}
fn broadcast_mul_bhtd(q: &Array4<f32>, scale: &Array4<f32>) -> Array4<f32> {
let (b, h, t, d) = (q.shape()[0], q.shape()[1], q.shape()[2], q.shape()[3]);
let s_dims = scale.shape();
let sb = s_dims[0];
let sh = s_dims[1];
let st = s_dims[2];
let sd = s_dims[3];
assert!(sb == 1 || sb == b);
assert!(sh == 1 || sh == h);
assert!(st == 1 || st == t);
assert!(sd == 1 || sd == d);
let mut out = Array4::<f32>::zeros((b, h, t, d));
for bi in 0..b {
for hi in 0..h {
for ti in 0..t {
for di in 0..d {
let s = scale[(
if sb == 1 { 0 } else { bi },
if sh == 1 { 0 } else { hi },
if st == 1 { 0 } else { ti },
if sd == 1 { 0 } else { di },
)];
out[(bi, hi, ti, di)] = q[(bi, hi, ti, di)] * s;
}
}
}
}
out
}
#[allow(dead_code)]
fn _silence(_a: ArrayView2<f32>) {}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array;
fn zero_params(embed_dim: usize) -> AttentionParams {
AttentionParams {
in_proj_weight: Array2::<f32>::zeros((3 * embed_dim, embed_dim)),
in_proj_bias: Some(vec![0.0; 3 * embed_dim]),
out_proj_weight: Array2::<f32>::zeros((embed_dim, embed_dim)),
out_proj_bias: Some(vec![0.0; embed_dim]),
}
}
#[test]
fn zero_projections_give_zero_output() {
let cfg = AttentionConfig {
embed_dim: 4,
num_heads: 2,
dropout: 0.0,
bias: true,
};
let params = zero_params(4);
let x = Array3::<f32>::from_shape_fn((1, 3, 4), |(_, t, d)| (t * 4 + d) as f32);
let y =
multi_head_attention_forward(x.view(), x.view(), x.view(), ¶ms, &cfg, None, None);
assert_eq!(y.shape(), &[1, 3, 4]);
for v in y.iter() {
assert!(v.abs() < 1e-6);
}
}
#[test]
fn single_token_attention_is_identity_with_identity_proj() {
let embed = 4;
let cfg = AttentionConfig {
embed_dim: embed,
num_heads: 2,
dropout: 0.0,
bias: true,
};
let mut w = Array2::<f32>::zeros((3 * embed, embed));
for i in 0..embed {
w[(i, i)] = 1.0;
w[(embed + i, i)] = 1.0;
w[(2 * embed + i, i)] = 1.0;
}
let mut out_w = Array2::<f32>::zeros((embed, embed));
for i in 0..embed {
out_w[(i, i)] = 1.0;
}
let params = AttentionParams {
in_proj_weight: w,
in_proj_bias: None,
out_proj_weight: out_w,
out_proj_bias: None,
};
let x = Array::from_shape_vec((1, 1, embed), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let y =
multi_head_attention_forward(x.view(), x.view(), x.view(), ¶ms, &cfg, None, None);
for d in 0..embed {
assert!(
(y[(0, 0, d)] - x[(0, 0, d)]).abs() < 1e-6,
"mismatch at {d}: {} vs {}",
y[(0, 0, d)],
x[(0, 0, d)]
);
}
}
#[test]
fn uniform_attention_averages_v() {
let embed = 4;
let cfg = AttentionConfig {
embed_dim: embed,
num_heads: 1,
dropout: 0.0,
bias: true,
};
let mut w = Array2::<f32>::zeros((3 * embed, embed));
for i in 0..embed {
w[(2 * embed + i, i)] = 1.0;
}
let mut out_w = Array2::<f32>::zeros((embed, embed));
for i in 0..embed {
out_w[(i, i)] = 1.0;
}
let params = AttentionParams {
in_proj_weight: w,
in_proj_bias: None,
out_proj_weight: out_w,
out_proj_bias: None,
};
let x = Array::from_shape_fn((1, 3, embed), |(_, t, d)| (t * embed + d) as f32);
let y =
multi_head_attention_forward(x.view(), x.view(), x.view(), ¶ms, &cfg, None, None);
let mean = Array::from_shape_fn(embed, |d| (0..3).map(|t| x[(0, t, d)]).sum::<f32>() / 3.0);
for t in 0..3 {
for d in 0..embed {
assert!(
(y[(0, t, d)] - mean[d]).abs() < 1e-5,
"t={t} d={d}: {} vs {}",
y[(0, t, d)],
mean[d]
);
}
}
}
#[test]
fn query_scale_broadcasts_to_per_head() {
let q = Array4::<f32>::from_shape_fn((2, 3, 4, 5), |(b, h, t, d)| {
((b * 100 + h * 10 + t) as f32) + (d as f32) * 0.1
});
let scale = Array4::<f32>::from_shape_fn((1, 3, 1, 1), |(_, h, _, _)| (h as f32) + 1.0);
let y = broadcast_mul_bhtd(&q, &scale);
for b in 0..2 {
for h in 0..3 {
for t in 0..4 {
for d in 0..5 {
let expect = q[(b, h, t, d)] * ((h as f32) + 1.0);
assert!((y[(b, h, t, d)] - expect).abs() < 1e-6);
}
}
}
}
}
}