use crate::ops::elementwise::scale;
use crate::ops::linear::linear;
use crate::ops::matmul::{matmul, matmul_t_b};
use crate::ops::softmax::softmax_last_dim;
use crate::tensor::Tensor;
fn extract_head(x: &Tensor, head_idx: usize, num_heads: usize) -> Tensor {
let shape = x.shape().as_slice();
assert_eq!(shape.len(), 2, "extract_head: x must be 2D");
let seq_len = shape[0];
let hidden = shape[1];
assert!(
hidden.is_multiple_of(num_heads),
"extract_head: hidden ({}) must be divisible by num_heads ({})",
hidden,
num_heads
);
let head_dim = hidden / num_heads;
let src = x.data();
let mut out = vec![0.0f32; seq_len * head_dim];
for t in 0..seq_len {
let src_off = t * hidden + head_idx * head_dim;
let dst_off = t * head_dim;
out[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
}
Tensor::from_vec(out, &[seq_len, head_dim])
}
fn write_head_back(dst: &mut [f32], head_result: &Tensor, head_idx: usize, num_heads: usize) {
let shape = head_result.shape().as_slice();
assert_eq!(shape.len(), 2, "write_head_back: head must be 2D");
let seq_len = shape[0];
let head_dim = shape[1];
let hidden = num_heads * head_dim;
let src = head_result.data();
for t in 0..seq_len {
let src_off = t * head_dim;
let dst_off = t * hidden + head_idx * head_dim;
dst[dst_off..dst_off + head_dim].copy_from_slice(&src[src_off..src_off + head_dim]);
}
}
#[allow(clippy::too_many_arguments)]
pub fn multi_head_attention(
x: &Tensor,
wq: &Tensor,
bq: Option<&Tensor>,
wk: &Tensor,
bk: Option<&Tensor>,
wv: &Tensor,
bv: Option<&Tensor>,
wo: &Tensor,
bo: Option<&Tensor>,
num_heads: usize,
) -> Tensor {
let shape = x.shape().as_slice();
assert_eq!(shape.len(), 2, "multi_head_attention: x must be 2D");
let hidden = shape[1];
assert!(
hidden.is_multiple_of(num_heads),
"multi_head_attention: hidden ({}) must be divisible by num_heads ({})",
hidden,
num_heads
);
let q = linear(x, wq, bq);
let k = linear(x, wk, bk);
let v = linear(x, wv, bv);
let concat = multi_head_attention_from_qkv(&q, &k, &v, num_heads);
linear(&concat, wo, bo)
}
pub fn multi_head_attention_from_qkv(
q: &Tensor,
k: &Tensor,
v: &Tensor,
num_heads: usize,
) -> Tensor {
let q_shape = q.shape().as_slice();
let k_shape = k.shape().as_slice();
let v_shape = v.shape().as_slice();
assert_eq!(
q_shape.len(),
2,
"multi_head_attention_from_qkv: q must be 2D"
);
assert_eq!(
k_shape, q_shape,
"multi_head_attention_from_qkv: k shape mismatch"
);
assert_eq!(
v_shape, q_shape,
"multi_head_attention_from_qkv: v shape mismatch"
);
let seq_len = q_shape[0];
let hidden = q_shape[1];
assert!(
hidden.is_multiple_of(num_heads),
"multi_head_attention_from_qkv: hidden ({}) must be divisible by num_heads ({})",
hidden,
num_heads
);
let head_dim = hidden / num_heads;
let scale_factor = 1.0 / (head_dim as f32).sqrt();
let mut concat = vec![0.0f32; seq_len * hidden];
for h in 0..num_heads {
let q_h = extract_head(q, h, num_heads);
let k_h = extract_head(k, h, num_heads);
let v_h = extract_head(v, h, num_heads);
let scores = matmul_t_b(&q_h, &k_h);
let scores = scale(&scores, scale_factor);
let attn = softmax_last_dim(&scores);
let head_out = matmul(&attn, &v_h);
write_head_back(&mut concat, &head_out, h, num_heads);
}
Tensor::from_vec(concat, &[seq_len, hidden])
}
pub fn mean_pool(x: &Tensor, attention_mask: Option<&[u32]>) -> Tensor {
let shape = x.shape().as_slice();
assert_eq!(shape.len(), 2, "mean_pool: x must be 2D");
let seq_len = shape[0];
let hidden = shape[1];
if let Some(m) = attention_mask {
assert_eq!(
m.len(),
seq_len,
"mean_pool: mask length must equal sequence length"
);
}
let data = x.data();
let mut out = vec![0.0f32; hidden];
let mut count = 0.0f32;
for t in 0..seq_len {
let valid = match attention_mask {
Some(m) => m[t] != 0,
None => true,
};
if valid {
let row = &data[t * hidden..(t + 1) * hidden];
for (o, &v) in out.iter_mut().zip(row) {
*o += v;
}
count += 1.0;
}
}
if count > 0.0 {
let inv = 1.0 / count;
for o in out.iter_mut() {
*o *= inv;
}
}
Tensor::from_vec(out, &[hidden])
}
pub fn causal_multi_head_attention_from_qkv(
q: &Tensor,
k: &Tensor,
v: &Tensor,
num_heads: usize,
) -> Tensor {
let q_shape = q.shape().as_slice();
let seq_len = q_shape[0];
let hidden = q_shape[1];
assert!(
hidden.is_multiple_of(num_heads),
"causal_attention: hidden must be divisible by num_heads"
);
let head_dim = hidden / num_heads;
let scale_factor = 1.0 / (head_dim as f32).sqrt();
let mut concat = vec![0.0f32; seq_len * hidden];
for h in 0..num_heads {
let q_h = extract_head(q, h, num_heads);
let k_h = extract_head(k, h, num_heads);
let v_h = extract_head(v, h, num_heads);
let raw = matmul_t_b(&q_h, &k_h);
let raw = scale(&raw, scale_factor);
let masked = apply_causal_mask(&raw);
let attn = softmax_last_dim(&masked);
let head_out = matmul(&attn, &v_h);
write_head_back(&mut concat, &head_out, h, num_heads);
}
Tensor::from_vec(concat, &[seq_len, hidden])
}
pub fn cross_attention_from_qkv(
q: &Tensor,
k: &Tensor,
v: &Tensor,
num_heads: usize,
) -> Tensor {
let q_shape = q.shape().as_slice();
let q_len = q_shape[0];
let hidden = q_shape[1];
assert!(
hidden.is_multiple_of(num_heads),
"cross_attention: hidden must be divisible by num_heads"
);
let head_dim = hidden / num_heads;
let scale_factor = 1.0 / (head_dim as f32).sqrt();
let mut concat = vec![0.0f32; q_len * hidden];
for h in 0..num_heads {
let q_h = extract_head(q, h, num_heads);
let k_h = extract_head(k, h, num_heads);
let v_h = extract_head(v, h, num_heads);
let raw = matmul_t_b(&q_h, &k_h);
let raw = scale(&raw, scale_factor);
let attn = softmax_last_dim(&raw);
let head_out = matmul(&attn, &v_h);
write_head_back(&mut concat, &head_out, h, num_heads);
}
Tensor::from_vec(concat, &[q_len, hidden])
}
pub fn multi_head_attention_with_bias(
q: &Tensor,
k: &Tensor,
v: &Tensor,
num_heads: usize,
bias: Option<&Tensor>,
causal: bool,
) -> Tensor {
let q_shape = q.shape().as_slice();
let k_shape = k.shape().as_slice();
let q_len = q_shape[0];
let hidden = q_shape[1];
let kv_len = k_shape[0];
assert!(
hidden.is_multiple_of(num_heads),
"attention_with_bias: hidden must be divisible by num_heads"
);
let head_dim = hidden / num_heads;
let scale_factor = 1.0 / (head_dim as f32).sqrt();
let mut concat = vec![0.0f32; q_len * hidden];
let bias_data = bias.map(|b| b.data());
for h in 0..num_heads {
let q_h = extract_head(q, h, num_heads);
let k_h = extract_head(k, h, num_heads);
let v_h = extract_head(v, h, num_heads);
let raw = matmul_t_b(&q_h, &k_h);
let mut scores = raw.data().to_vec();
if let Some(bd) = bias_data {
let bias_off = h * q_len * kv_len;
for i in 0..q_len {
for j in 0..kv_len {
scores[i * kv_len + j] =
scores[i * kv_len + j] * scale_factor + bd[bias_off + i * kv_len + j];
}
}
} else {
for s in &mut scores {
*s *= scale_factor;
}
}
if causal {
for i in 0..q_len {
for j in (i + 1)..kv_len {
scores[i * kv_len + j] = f32::NEG_INFINITY;
}
}
}
let scores_t = Tensor::from_vec(scores, &[q_len, kv_len]);
let attn = softmax_last_dim(&scores_t);
let head_out = matmul(&attn, &v_h);
write_head_back(&mut concat, &head_out, h, num_heads);
}
Tensor::from_vec(concat, &[q_len, hidden])
}
fn apply_causal_mask(scores: &Tensor) -> Tensor {
let shape = scores.shape().as_slice();
let q_len = shape[0];
let k_len = shape[1];
let mut out = scores.data().to_vec();
for i in 0..q_len {
for j in (i + 1)..k_len {
out[i * k_len + j] = f32::NEG_INFINITY;
}
}
Tensor::from_vec(out, shape)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_then_write_back_is_identity() {
let x = Tensor::from_vec((0..12).map(|v| v as f32).collect(), &[3, 4]);
let mut reconstructed = vec![0.0f32; 12];
for h in 0..2 {
let head = extract_head(&x, h, 2);
write_head_back(&mut reconstructed, &head, h, 2);
}
assert_eq!(reconstructed, x.data());
}
#[test]
fn attention_output_shape_matches_input() {
let seq_len = 3;
let hidden = 8;
let x = Tensor::from_vec(vec![0.1f32; seq_len * hidden], &[seq_len, hidden]);
let identity = identity_matrix(hidden);
let zero_bias = Tensor::from_vec(vec![0.0; hidden], &[hidden]);
let y = multi_head_attention(
&x,
&identity,
Some(&zero_bias),
&identity,
Some(&zero_bias),
&identity,
Some(&zero_bias),
&identity,
Some(&zero_bias),
2,
);
assert_eq!(y.shape().as_slice(), &[seq_len, hidden]);
}
#[test]
fn mean_pool_basic() {
let x = Tensor::from_vec(vec![1., 2., 3., 4., 5., 6.], &[3, 2]);
let y = mean_pool(&x, None);
assert_eq!(y.shape().as_slice(), &[2]);
assert_eq!(y.data(), &[3.0, 4.0]);
}
#[test]
fn mean_pool_respects_mask() {
let x = Tensor::from_vec(vec![1., 2., 99., 99., 5., 6.], &[3, 2]);
let mask = [1u32, 0, 1];
let y = mean_pool(&x, Some(&mask));
assert_eq!(y.data(), &[3.0, 4.0]);
}
fn identity_matrix(n: usize) -> Tensor {
let mut data = vec![0.0f32; n * n];
for i in 0..n {
data[i * n + i] = 1.0;
}
Tensor::from_vec(data, &[n, n])
}
}