use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
use crate::layers::Linear;
use crate::module::Module;
use crate::parameter::Parameter;
pub struct DifferentialAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
lambda: Parameter,
embed_dim: usize,
num_heads: usize,
head_dim: usize,
half_head_dim: usize,
scale: f32,
}
impl DifferentialAttention {
pub fn new(embed_dim: usize, num_heads: usize) -> Self {
Self::with_lambda(embed_dim, num_heads, 0.05)
}
pub fn with_lambda(embed_dim: usize, num_heads: usize, lambda_init: f32) -> Self {
assert!(
embed_dim % num_heads == 0,
"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
);
let head_dim = embed_dim / num_heads;
assert!(
head_dim % 2 == 0,
"head_dim ({head_dim}) must be even for Q/K splitting"
);
let half_head_dim = head_dim / 2;
let scale = (half_head_dim as f32).sqrt().recip();
let lambda_tensor =
Tensor::from_vec(vec![lambda_init], &[1]).expect("tensor creation failed");
Self {
q_proj: Linear::new(embed_dim, embed_dim),
k_proj: Linear::new(embed_dim, embed_dim),
v_proj: Linear::new(embed_dim, embed_dim),
out_proj: Linear::new(embed_dim, embed_dim),
lambda: Parameter::named("lambda", lambda_tensor, true),
embed_dim,
num_heads,
head_dim,
half_head_dim,
scale,
}
}
pub fn attention(
&self,
query: &Variable,
key: &Variable,
value: &Variable,
_attn_mask: Option<&Variable>,
) -> Variable {
let q_shape = query.shape();
let batch_size = q_shape[0];
let tgt_len = q_shape[1];
let src_len = key.shape()[1];
let q = self.q_proj.forward(query);
let k = self.k_proj.forward(key);
let v = self.v_proj.forward(value);
let q = q
.reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
.transpose(1, 2);
let k = k
.reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
.transpose(1, 2);
let v = v
.reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
.transpose(1, 2);
let q1 = q.narrow(3, 0, self.half_head_dim);
let q2 = q.narrow(3, self.half_head_dim, self.half_head_dim);
let k1 = k.narrow(3, 0, self.half_head_dim);
let k2 = k.narrow(3, self.half_head_dim, self.half_head_dim);
let k1_t = k1.transpose(2, 3);
let scores1 = q1.matmul(&k1_t).mul_scalar(self.scale);
let attn1 = scores1.softmax(-1);
let k2_t = k2.transpose(2, 3);
let scores2 = q2.matmul(&k2_t).mul_scalar(self.scale);
let attn2 = scores2.softmax(-1);
let lambda_var = self.lambda.variable();
let attn2_scaled = self.broadcast_mul_scalar(&attn2, &lambda_var);
let neg_attn2 = attn2_scaled.mul_scalar(-1.0);
let diff_attn = attn1.add_var(&neg_attn2);
let attn_output = diff_attn.matmul(&v);
let attn_output =
attn_output
.transpose(1, 2)
.reshape(&[batch_size, tgt_len, self.embed_dim]);
self.out_proj.forward(&attn_output)
}
fn broadcast_mul_scalar(&self, attn: &Variable, lambda: &Variable) -> Variable {
let lambda_val = lambda.data().to_vec()[0];
let attn_shape = attn.shape();
let total = attn_shape.iter().product::<usize>();
let lambda_expanded =
Tensor::from_vec(vec![lambda_val; total], &attn_shape).expect("tensor creation failed");
let lambda_var = Variable::new(lambda_expanded, false);
attn.mul_var(&lambda_var)
}
pub fn lambda_value(&self) -> f32 {
self.lambda.data().to_vec()[0]
}
pub fn embed_dim(&self) -> usize {
self.embed_dim
}
pub fn num_heads(&self) -> usize {
self.num_heads
}
}
impl Module for DifferentialAttention {
fn forward(&self, input: &Variable) -> Variable {
self.attention(input, input, input, None)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.q_proj.parameters());
params.extend(self.k_proj.parameters());
params.extend(self.v_proj.parameters());
params.extend(self.out_proj.parameters());
params.push(self.lambda.clone());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (name, param) in self.q_proj.named_parameters() {
params.insert(format!("q_proj.{name}"), param);
}
for (name, param) in self.k_proj.named_parameters() {
params.insert(format!("k_proj.{name}"), param);
}
for (name, param) in self.v_proj.named_parameters() {
params.insert(format!("v_proj.{name}"), param);
}
for (name, param) in self.out_proj.named_parameters() {
params.insert(format!("out_proj.{name}"), param);
}
params.insert("lambda".to_string(), self.lambda.clone());
params
}
fn name(&self) -> &'static str {
"DifferentialAttention"
}
}
impl std::fmt::Debug for DifferentialAttention {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DifferentialAttention")
.field("embed_dim", &self.embed_dim)
.field("num_heads", &self.num_heads)
.field("head_dim", &self.head_dim)
.field("half_head_dim", &self.half_head_dim)
.field("lambda", &self.lambda_value())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diff_attention_creation() {
let attn = DifferentialAttention::new(64, 4);
assert_eq!(attn.embed_dim(), 64);
assert_eq!(attn.num_heads(), 4);
assert_eq!(attn.head_dim, 16);
assert_eq!(attn.half_head_dim, 8);
assert!((attn.lambda_value() - 0.05).abs() < 1e-6);
}
#[test]
fn test_diff_attention_forward() {
let attn = DifferentialAttention::new(64, 4);
let input = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
false,
);
let output = attn.forward(&input);
assert_eq!(output.shape(), vec![2, 10, 64]);
}
#[test]
fn test_diff_attention_cross() {
let attn = DifferentialAttention::new(64, 4);
let query = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
false,
);
let kv = Variable::new(
Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
false,
);
let output = attn.attention(&query, &kv, &kv, None);
assert_eq!(output.shape(), vec![2, 5, 64]);
}
#[test]
fn test_diff_attention_parameters() {
let attn = DifferentialAttention::new(64, 4);
let params = attn.parameters();
assert_eq!(params.len(), 9);
}
#[test]
fn test_diff_attention_lambda_in_named_params() {
let attn = DifferentialAttention::new(64, 4);
let named = attn.named_parameters();
assert!(named.contains_key("lambda"));
assert!(named.contains_key("q_proj.weight"));
assert!(named.contains_key("out_proj.bias"));
}
#[test]
fn test_diff_attention_backward() {
use axonml_autograd::backward;
let attn = DifferentialAttention::new(32, 2);
let input = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
true,
);
let output = attn.forward(&input);
assert_eq!(output.shape(), vec![2, 4, 32]);
let loss = output.sum();
let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
backward(&loss, &ones);
let grad = input.grad();
assert!(grad.is_some(), "Input gradient should exist");
let grad_data = grad.unwrap();
assert_eq!(grad_data.shape(), &[2, 4, 32]);
let grad_vec = grad_data.to_vec();
let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
assert!(non_zero, "Gradients should be non-zero");
}
#[test]
fn test_diff_attention_custom_lambda() {
let attn = DifferentialAttention::with_lambda(64, 4, 0.1);
assert!((attn.lambda_value() - 0.1).abs() < 1e-6);
}
}