use crate::error::{CognitionError, Result};
use crate::mask;
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct AttentionConfig {
pub model_dim: usize,
pub num_heads: usize,
pub head_dim: usize,
}
impl AttentionConfig {
pub fn new(model_dim: usize, num_heads: usize) -> Result<Self> {
if model_dim % num_heads != 0 {
return Err(CognitionError::InvalidConfig(format!(
"model_dim ({model_dim}) must be divisible by num_heads ({num_heads})"
)));
}
Ok(Self {
model_dim,
num_heads,
head_dim: model_dim / num_heads,
})
}
}
#[derive(Debug, Clone)]
pub struct AttentionHead {
pub w_query: Tensor,
pub w_key: Tensor,
pub w_value: Tensor,
pub head_dim: usize,
}
impl AttentionHead {
pub fn new(model_dim: usize, head_dim: usize, rng: &mut impl rand::Rng) -> Result<Self> {
Ok(Self {
w_query: Tensor::xavier_uniform(&[model_dim, head_dim], rng)?,
w_key: Tensor::xavier_uniform(&[model_dim, head_dim], rng)?,
w_value: Tensor::xavier_uniform(&[model_dim, head_dim], rng)?,
head_dim,
})
}
pub fn forward(&self, x: &Tensor, causal: bool) -> Result<AttentionOutput> {
let seq_len = x.shape()[0];
let q = x.matmul(&self.w_query)?;
let k = x.matmul(&self.w_key)?;
let v = x.matmul(&self.w_value)?;
let k_t = k.transpose()?;
let scale = (self.head_dim as f64).sqrt();
let mut scores = q.matmul(&k_t)?.scale(1.0 / scale);
if causal {
let m = mask::causal_mask(seq_len);
scores = mask::apply_mask(&scores, &m)?;
}
let weights = scores.softmax()?;
let output = weights.matmul(&v)?;
Ok(AttentionOutput {
output,
weights: weights.clone(),
})
}
}
#[derive(Debug, Clone)]
pub struct AttentionOutput {
pub output: Tensor,
pub weights: Tensor,
}
#[derive(Debug, Clone)]
pub struct MultiHeadAttention {
pub heads: Vec<AttentionHead>,
pub w_output: Tensor,
pub config: AttentionConfig,
}
impl MultiHeadAttention {
pub fn new(config: AttentionConfig, rng: &mut impl rand::Rng) -> Result<Self> {
let mut heads = Vec::with_capacity(config.num_heads);
for _ in 0..config.num_heads {
heads.push(AttentionHead::new(config.model_dim, config.head_dim, rng)?);
}
let w_output = Tensor::xavier_uniform(&[config.model_dim, config.model_dim], rng)?;
Ok(Self {
heads,
w_output,
config,
})
}
pub fn forward(&self, x: &Tensor, causal: bool) -> Result<MultiHeadOutput> {
let head_outputs: Vec<AttentionOutput> = self
.heads
.iter()
.map(|head| head.forward(x, causal))
.collect::<Result<Vec<_>>>()?;
let seq_len = x.shape()[0];
let mut concat_data = Vec::with_capacity(seq_len * self.config.model_dim);
for row in 0..seq_len {
for head_out in &head_outputs {
let row_data = head_out.output.row(row)?;
concat_data.extend_from_slice(row_data.data());
}
}
let concatenated = Tensor::new(concat_data, vec![seq_len, self.config.model_dim])?;
let output = concatenated.matmul(&self.w_output)?;
let all_weights: Vec<Tensor> = head_outputs.into_iter().map(|ho| ho.weights).collect();
Ok(MultiHeadOutput {
output,
head_weights: all_weights,
})
}
}
#[derive(Debug, Clone)]
pub struct MultiHeadOutput {
pub output: Tensor,
pub head_weights: Vec<Tensor>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_single_head_shape() {
let mut rng = rand::rng();
let head = AttentionHead::new(8, 4, &mut rng).unwrap();
let x = Tensor::randn(&[3, 8], &mut rng);
let out = head.forward(&x, true).unwrap();
assert_eq!(out.output.shape(), &[3, 4]);
assert_eq!(out.weights.shape(), &[3, 3]);
}
#[test]
fn test_multi_head_shape() {
let mut rng = rand::rng();
let config = AttentionConfig::new(8, 2).unwrap();
let mha = MultiHeadAttention::new(config, &mut rng).unwrap();
let x = Tensor::randn(&[3, 8], &mut rng);
let out = mha.forward(&x, true).unwrap();
assert_eq!(out.output.shape(), &[3, 8]);
assert_eq!(out.head_weights.len(), 2);
}
#[test]
fn test_attention_weights_sum_to_one() {
let mut rng = rand::rng();
let head = AttentionHead::new(8, 4, &mut rng).unwrap();
let x = Tensor::randn(&[3, 8], &mut rng);
let out = head.forward(&x, false).unwrap();
for r in 0..3 {
let row = out.weights.row(r).unwrap();
let sum: f64 = row.data().iter().sum();
assert!(
(sum - 1.0).abs() < 1e-6,
"row {r} attention weights sum to {sum}"
);
}
}
#[test]
fn test_causal_mask_enforced() {
let mut rng = rand::rng();
let head = AttentionHead::new(8, 4, &mut rng).unwrap();
let x = Tensor::randn(&[4, 8], &mut rng);
let out = head.forward(&x, true).unwrap();
let row0 = out.weights.row(0).unwrap();
assert!(
row0.data()[1] < 1e-6,
"token 0 should not attend to token 1"
);
}
#[test]
fn test_config_validation() {
assert!(AttentionConfig::new(7, 3).is_err());
}
}