use crate::autograd::Variable;
use crate::tensor::{Device, Result, Tensor};
use super::init;
use super::parameter::Parameter;
use super::Module;
pub struct MultiheadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
num_heads: i64,
head_dim: i64,
scale: f64,
}
struct Linear {
weight: Parameter,
bias: Parameter,
}
impl Linear {
fn on_device(in_features: i64, out_features: i64, device: Device) -> Result<Self> {
let w = init::xavier_uniform(
&[out_features, in_features], in_features, out_features, device,
)?;
let b = Tensor::zeros(
&[out_features],
crate::tensor::TensorOptions { dtype: crate::tensor::DType::Float32, device },
)?;
Ok(Linear {
weight: Parameter::new(w, "weight"),
bias: Parameter::new(b, "bias"),
})
}
fn forward(&self, input: &Variable) -> Result<Variable> {
crate::autograd::linear(
input,
&self.weight.variable,
Some(&self.bias.variable),
)
}
fn parameters(&self, prefix: &str) -> Vec<Parameter> {
vec![
Parameter {
variable: self.weight.variable.clone(),
name: format!("{prefix}.weight"),
},
Parameter {
variable: self.bias.variable.clone(),
name: format!("{prefix}.bias"),
},
]
}
}
impl MultiheadAttention {
pub fn new(embed_dim: i64, num_heads: i64) -> Result<Self> {
Self::on_device(embed_dim, num_heads, Device::CPU)
}
pub fn on_device(embed_dim: i64, num_heads: i64, device: Device) -> Result<Self> {
assert!(
embed_dim % num_heads == 0,
"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
);
let head_dim = embed_dim / num_heads;
Ok(MultiheadAttention {
q_proj: Linear::on_device(embed_dim, embed_dim, device)?,
k_proj: Linear::on_device(embed_dim, embed_dim, device)?,
v_proj: Linear::on_device(embed_dim, embed_dim, device)?,
out_proj: Linear::on_device(embed_dim, embed_dim, device)?,
num_heads,
head_dim,
scale: 1.0 / (head_dim as f64).sqrt(),
})
}
pub fn forward_ext(
&self,
query: &Variable,
key: &Variable,
value: &Variable,
mask: Option<&Tensor>,
) -> Result<Variable> {
let batch = query.shape()[0];
let seq_q = query.shape()[1];
let seq_k = key.shape()[1];
let q = self.q_proj.forward(query)?;
let k = self.k_proj.forward(key)?;
let v = self.v_proj.forward(value)?;
let q = q.reshape(&[batch, seq_q, self.num_heads, self.head_dim])?
.transpose(1, 2)?;
let k = k.reshape(&[batch, seq_k, self.num_heads, self.head_dim])?
.transpose(1, 2)?;
let v = v.reshape(&[batch, seq_k, self.num_heads, self.head_dim])?
.transpose(1, 2)?;
let k_t = k.transpose(2, 3)?;
let mut scores = q.matmul(&k_t)?.mul_scalar(self.scale)?;
if let Some(m) = mask {
scores = scores.masked_fill(m, f64::NEG_INFINITY)?;
}
let attn = scores.softmax(-1)?;
let out = attn.matmul(&v)?;
let out = out.transpose(1, 2)?
.reshape(&[batch, seq_q, self.num_heads * self.head_dim])?;
self.out_proj.forward(&out)
}
}
impl Module for MultiheadAttention {
fn name(&self) -> &str { "multihead_attention" }
fn forward(&self, input: &Variable) -> Result<Variable> {
self.forward_ext(input, input, input, None)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.q_proj.parameters("q_proj"));
params.extend(self.k_proj.parameters("k_proj"));
params.extend(self.v_proj.parameters("v_proj"));
params.extend(self.out_proj.parameters("out_proj"));
params
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::test_device;
#[test]
fn test_mha_self_attention() {
let device = test_device();
let mha = MultiheadAttention::on_device(8, 2, device).unwrap();
let opts = crate::tensor::test_opts();
let x = Variable::new(
Tensor::randn(&[2, 4, 8], opts).unwrap(), false,
);
let y = mha.forward(&x).unwrap();
assert_eq!(y.shape(), vec![2, 4, 8]);
}
#[test]
fn test_mha_cross_attention() {
let device = test_device();
let mha = MultiheadAttention::on_device(8, 2, device).unwrap();
let opts = crate::tensor::test_opts();
let q = Variable::new(Tensor::randn(&[1, 3, 8], opts).unwrap(), false);
let kv = Variable::new(Tensor::randn(&[1, 5, 8], opts).unwrap(), false);
let y = mha.forward_ext(&q, &kv, &kv, None).unwrap();
assert_eq!(y.shape(), vec![1, 3, 8]); }
#[test]
fn test_mha_causal_mask() {
let device = test_device();
let mha = MultiheadAttention::on_device(8, 2, device).unwrap();
let opts = crate::tensor::test_opts();
let x = Variable::new(Tensor::randn(&[1, 4, 8], opts).unwrap(), false);
let mask = Tensor::ones(&[4, 4], opts).unwrap().triu(1).unwrap();
let y = mha.forward_ext(&x, &x, &x, Some(&mask)).unwrap();
assert_eq!(y.shape(), vec![1, 4, 8]);
}
#[test]
fn test_mha_gradient() {
let device = test_device();
let mha = MultiheadAttention::on_device(8, 2, device).unwrap();
let opts = crate::tensor::test_opts();
let x = Variable::new(Tensor::randn(&[1, 3, 8], opts).unwrap(), true);
let y = mha.forward(&x).unwrap();
let loss = y.sum().unwrap();
loss.backward().unwrap();
let grad = x.grad().unwrap();
assert_eq!(grad.shape(), vec![1, 3, 8]);
}
#[test]
fn test_mha_parameters() {
let mha = MultiheadAttention::new(16, 4).unwrap();
let params = mha.parameters();
assert_eq!(params.len(), 8);
}
}