use std::collections::HashMap;
use axonml_autograd::Variable;
#[cfg(feature = "cuda")]
use axonml_autograd::functions::FusedAttentionBackward;
#[cfg(feature = "cuda")]
use axonml_autograd::grad_fn::GradFn;
use axonml_tensor::Tensor;
use crate::layers::Linear;
use crate::module::Module;
use crate::parameter::Parameter;
pub struct MultiHeadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
embed_dim: usize,
num_heads: usize,
head_dim: usize,
scale: f32,
batch_first: bool,
}
impl MultiHeadAttention {
pub fn new(embed_dim: usize, num_heads: usize) -> Self {
Self::with_options(embed_dim, num_heads, 0.0, true)
}
pub fn with_options(
embed_dim: usize,
num_heads: usize,
_dropout: f32,
batch_first: bool,
) -> Self {
assert!(
embed_dim % num_heads == 0,
"embed_dim must be divisible by num_heads"
);
let head_dim = embed_dim / num_heads;
let scale = (head_dim as f32).sqrt().recip();
Self {
q_proj: Linear::new(embed_dim, embed_dim),
k_proj: Linear::new(embed_dim, embed_dim),
v_proj: Linear::new(embed_dim, embed_dim),
out_proj: Linear::new(embed_dim, embed_dim),
embed_dim,
num_heads,
head_dim,
scale,
batch_first,
}
}
#[allow(dead_code)]
fn expand_mask(
mask: &Variable,
batch_size: usize,
num_heads: usize,
tgt_len: usize,
src_len: usize,
) -> Variable {
let mask_shape = mask.shape();
let target = [batch_size, num_heads, tgt_len, src_len];
if mask_shape == target {
return mask.clone();
}
if mask_shape.len() == 2 {
let reshaped = mask.reshape(&[1, 1, tgt_len, src_len]);
return reshaped.expand(&target);
}
if mask_shape.len() == 4 && mask_shape[1] == 1 {
return mask.expand(&target);
}
if mask_shape.len() == 4 && mask_shape[0] == 1 && mask_shape[1] == 1 {
return mask.expand(&target);
}
mask.clone()
}
pub fn attention(
&self,
query: &Variable,
key: &Variable,
value: &Variable,
attn_mask: Option<&Variable>,
) -> Variable {
let q_shape = query.shape();
let (batch_size, tgt_len, _) = if self.batch_first {
(q_shape[0], q_shape[1], q_shape[2])
} else {
(q_shape[1], q_shape[0], q_shape[2])
};
let src_len = if self.batch_first {
key.shape()[1]
} else {
key.shape()[0]
};
let q = self.q_proj.forward(query);
let k = self.k_proj.forward(key);
let v = self.v_proj.forward(value);
let q = q
.reshape(&[batch_size, tgt_len, self.num_heads, self.head_dim])
.transpose(1, 2);
let k = k
.reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
.transpose(1, 2);
let v = v
.reshape(&[batch_size, src_len, self.num_heads, self.head_dim])
.transpose(1, 2);
#[cfg(feature = "cuda")]
if q.data().device().is_gpu() && attn_mask.is_none() {
let is_training = axonml_autograd::no_grad::is_grad_enabled();
let q_tensor = q.data();
let k_tensor = k.data();
let v_tensor = v.data();
if let Some(attn_out) = q_tensor.fused_attention_cuda(
&k_tensor, &v_tensor, self.scale,
false, ) {
let attn_output = if is_training
&& (q.requires_grad() || k.requires_grad() || v.requires_grad())
{
let backward = FusedAttentionBackward::new(
q.grad_fn().cloned(),
k.grad_fn().cloned(),
v.grad_fn().cloned(),
q_tensor,
k_tensor,
v_tensor,
attn_out.clone(),
self.scale,
false,
);
Variable::from_operation(attn_out, GradFn::new(backward), true)
} else {
Variable::new(attn_out, false)
};
let attn_output =
attn_output
.transpose(1, 2)
.reshape(&[batch_size, tgt_len, self.embed_dim]);
return self.out_proj.forward(&attn_output);
}
}
let k_t = k.transpose(2, 3);
let scores = q.matmul(&k_t).mul_scalar(self.scale);
let scores = if let Some(mask) = attn_mask {
let mask_shape = mask.shape();
let mask_data = mask.data();
let scores_shape = scores.shape();
let total = scores_shape.iter().product::<usize>();
#[cfg(feature = "cuda")]
if scores.data().device().is_gpu() {
let mask_gpu = if mask_data.device().is_gpu() {
mask_data.clone()
} else {
mask_data.to_device(scores.data().device()).unwrap()
};
if let Some(expanded_tensor) = mask_gpu.mask_expand_cuda(
&scores_shape,
batch_size,
self.num_heads,
tgt_len,
src_len,
) {
let additive_mask = Variable::new(expanded_tensor, false);
return self.finish_attention(
scores.add_var(&additive_mask),
&v,
batch_size,
tgt_len,
);
}
}
let mask_vec = mask_data.to_vec();
let additive: Vec<f32> = mask_vec
.iter()
.map(|&v| if v == 0.0 { -1e9 } else { 0.0 })
.collect();
let mut expanded = vec![0.0f32; total];
if mask_shape.len() == 2 && mask_shape[0] == tgt_len && mask_shape[1] == src_len {
for b in 0..batch_size {
for h in 0..self.num_heads {
for i in 0..tgt_len {
for j in 0..src_len {
let idx = b * self.num_heads * tgt_len * src_len
+ h * tgt_len * src_len
+ i * src_len
+ j;
expanded[idx] = additive[i * src_len + j];
}
}
}
}
} else if mask_shape.len() == 2
&& mask_shape[0] == batch_size
&& mask_shape[1] == src_len
{
for b in 0..batch_size {
for h in 0..self.num_heads {
for i in 0..tgt_len {
for j in 0..src_len {
let idx = b * self.num_heads * tgt_len * src_len
+ h * tgt_len * src_len
+ i * src_len
+ j;
expanded[idx] = additive[b * src_len + j];
}
}
}
}
} else {
for (i, val) in expanded.iter_mut().enumerate() {
*val = additive[i % additive.len()];
}
}
let mut additive_tensor =
Tensor::from_vec(expanded, &scores_shape).expect("tensor creation failed");
let scores_device = scores.data().device();
if scores_device.is_gpu() {
additive_tensor = additive_tensor
.to_device(scores_device)
.expect("device transfer failed");
}
let additive_mask = Variable::new(additive_tensor, false);
scores.add_var(&additive_mask)
} else {
scores
};
self.finish_attention(scores, &v, batch_size, tgt_len)
}
fn finish_attention(
&self,
scores: Variable,
v: &Variable,
batch_size: usize,
tgt_len: usize,
) -> Variable {
let attn_weights = scores.softmax(-1);
let attn_output = attn_weights.matmul(v);
let attn_output =
attn_output
.transpose(1, 2)
.reshape(&[batch_size, tgt_len, self.embed_dim]);
self.out_proj.forward(&attn_output)
}
}
impl Module for MultiHeadAttention {
fn forward(&self, input: &Variable) -> Variable {
self.attention(input, input, input, None)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.q_proj.parameters());
params.extend(self.k_proj.parameters());
params.extend(self.v_proj.parameters());
params.extend(self.out_proj.parameters());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (name, param) in self.q_proj.named_parameters() {
params.insert(format!("q_proj.{name}"), param);
}
for (name, param) in self.k_proj.named_parameters() {
params.insert(format!("k_proj.{name}"), param);
}
for (name, param) in self.v_proj.named_parameters() {
params.insert(format!("v_proj.{name}"), param);
}
for (name, param) in self.out_proj.named_parameters() {
params.insert(format!("out_proj.{name}"), param);
}
params
}
fn name(&self) -> &'static str {
"MultiHeadAttention"
}
}
pub struct CrossAttention {
mha: MultiHeadAttention,
}
impl CrossAttention {
pub fn new(embed_dim: usize, num_heads: usize) -> Self {
Self {
mha: MultiHeadAttention::new(embed_dim, num_heads),
}
}
pub fn with_options(
embed_dim: usize,
num_heads: usize,
dropout: f32,
batch_first: bool,
) -> Self {
Self {
mha: MultiHeadAttention::with_options(embed_dim, num_heads, dropout, batch_first),
}
}
pub fn cross_attention(
&self,
query: &Variable,
memory: &Variable,
attn_mask: Option<&Variable>,
) -> Variable {
self.mha.attention(query, memory, memory, attn_mask)
}
pub fn embed_dim(&self) -> usize {
self.mha.embed_dim
}
pub fn num_heads(&self) -> usize {
self.mha.num_heads
}
}
impl Module for CrossAttention {
fn forward(&self, input: &Variable) -> Variable {
self.mha.forward(input)
}
fn parameters(&self) -> Vec<Parameter> {
self.mha.parameters()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (name, param) in self.mha.named_parameters() {
params.insert(format!("mha.{name}"), param);
}
params
}
fn name(&self) -> &'static str {
"CrossAttention"
}
}
pub fn scaled_dot_product_attention_fused(
q: &Tensor<f32>,
k: &Tensor<f32>,
v: &Tensor<f32>,
scale: f32,
is_causal: bool,
) -> Tensor<f32> {
#[cfg(feature = "cuda")]
if q.device().is_gpu() {
if let Some(result) = q.fused_attention_cuda(k, v, scale, is_causal) {
return result;
}
}
let shape = q.shape();
let batch_size = shape[0];
let num_heads = shape[1];
let tgt_len = shape[2];
let head_dim = shape[3];
let src_len = k.shape()[2];
let q_data = q.to_vec();
let k_data = k.to_vec();
let v_data = v.to_vec();
let mut output = vec![0.0f32; batch_size * num_heads * tgt_len * head_dim];
for b in 0..batch_size {
for h in 0..num_heads {
for i in 0..tgt_len {
let mut scores = vec![0.0f32; src_len];
let mut max_score = f32::NEG_INFINITY;
for j in 0..src_len {
if is_causal && j > i {
scores[j] = f32::NEG_INFINITY;
continue;
}
let mut score = 0.0f32;
for d in 0..head_dim {
let q_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
let k_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
score += q_data[q_idx] * k_data[k_idx];
}
score *= scale;
scores[j] = score;
if score > max_score {
max_score = score;
}
}
let mut sum_exp = 0.0f32;
for s in &mut scores {
if *s > f32::NEG_INFINITY {
*s = (*s - max_score).exp();
sum_exp += *s;
} else {
*s = 0.0;
}
}
let inv_sum = if sum_exp > 0.0 { 1.0 / sum_exp } else { 0.0 };
for d in 0..head_dim {
let mut val = 0.0f32;
for j in 0..src_len {
let v_idx = ((b * num_heads + h) * src_len + j) * head_dim + d;
val += scores[j] * v_data[v_idx];
}
let out_idx = ((b * num_heads + h) * tgt_len + i) * head_dim + d;
output[out_idx] = val * inv_sum;
}
}
}
}
Tensor::from_vec(output, &[batch_size, num_heads, tgt_len, head_dim])
.expect("tensor creation failed")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multihead_attention_creation() {
let mha = MultiHeadAttention::new(512, 8);
assert_eq!(mha.embed_dim, 512);
assert_eq!(mha.num_heads, 8);
assert_eq!(mha.head_dim, 64);
}
#[test]
fn test_multihead_attention_forward() {
let mha = MultiHeadAttention::new(64, 4);
let input = Variable::new(
Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
false,
);
let output = mha.forward(&input);
assert_eq!(output.shape(), vec![2, 10, 64]);
}
#[test]
fn test_cross_attention() {
let mha = MultiHeadAttention::new(64, 4);
let query = Variable::new(
Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
false,
);
let key_value = Variable::new(
Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
false,
);
let output = mha.attention(&query, &key_value, &key_value, None);
assert_eq!(output.shape(), vec![2, 5, 64]);
}
#[test]
fn test_multihead_attention_parameters() {
let mha = MultiHeadAttention::new(64, 4);
let params = mha.parameters();
assert_eq!(params.len(), 8);
}
#[test]
fn test_cross_attention_creation() {
let ca = CrossAttention::new(256, 8);
assert_eq!(ca.embed_dim(), 256);
assert_eq!(ca.num_heads(), 8);
}
#[test]
fn test_cross_attention_forward() {
let ca = CrossAttention::new(64, 4);
let query = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
false,
);
let memory = Variable::new(
Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
false,
);
let output = ca.cross_attention(&query, &memory, None);
assert_eq!(output.shape(), vec![2, 5, 64]);
}
#[test]
fn test_cross_attention_self_attention_fallback() {
let ca = CrossAttention::new(64, 4);
let input = Variable::new(
Tensor::from_vec(vec![1.0; 2 * 8 * 64], &[2, 8, 64]).expect("tensor creation failed"),
false,
);
let output = ca.forward(&input);
assert_eq!(output.shape(), vec![2, 8, 64]);
}
#[test]
fn test_cross_attention_parameters() {
let ca = CrossAttention::new(64, 4);
let params = ca.parameters();
assert_eq!(params.len(), 8); let named = ca.named_parameters();
assert!(named.contains_key("mha.q_proj.weight"));
assert!(named.contains_key("mha.out_proj.bias"));
}
#[test]
fn test_fused_attention_cpu() {
let batch = 2;
let heads = 4;
let seq = 8;
let dim = 16;
let scale = 1.0 / (dim as f32).sqrt();
let q = Tensor::from_vec(
vec![0.1; batch * heads * seq * dim],
&[batch, heads, seq, dim],
)
.unwrap();
let k = Tensor::from_vec(
vec![0.1; batch * heads * seq * dim],
&[batch, heads, seq, dim],
)
.unwrap();
let v = Tensor::from_vec(
vec![0.5; batch * heads * seq * dim],
&[batch, heads, seq, dim],
)
.unwrap();
let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
assert_eq!(out.shape(), &[batch, heads, seq, dim]);
let out_vec = out.to_vec();
for val in &out_vec {
assert!((*val - 0.5).abs() < 0.01, "Expected ~0.5, got {}", val);
}
}
#[test]
fn test_fused_attention_causal() {
let batch = 1;
let heads = 1;
let seq = 4;
let dim = 4;
let scale = 1.0 / (dim as f32).sqrt();
let q = Tensor::from_vec(
vec![0.1; batch * heads * seq * dim],
&[batch, heads, seq, dim],
)
.unwrap();
let k = Tensor::from_vec(
vec![0.1; batch * heads * seq * dim],
&[batch, heads, seq, dim],
)
.unwrap();
let v = Tensor::from_vec(
vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
],
&[batch, heads, seq, dim],
)
.unwrap();
let out = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
assert_eq!(out.shape(), &[batch, heads, seq, dim]);
let out_vec = out.to_vec();
assert!(
(out_vec[0] - 1.0).abs() < 1e-5,
"row 0, col 0 should be 1.0"
);
assert!((out_vec[1]).abs() < 1e-5, "row 0, col 1 should be 0.0");
}
#[test]
fn test_multihead_attention_backward_cpu() {
use axonml_autograd::backward;
let mha = MultiHeadAttention::new(32, 4);
let input = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 4 * 32], &[2, 4, 32]).expect("tensor creation failed"),
true,
);
let output = mha.forward(&input);
assert_eq!(output.shape(), vec![2, 4, 32]);
let loss = output.sum();
let ones = Tensor::from_vec(vec![1.0f32], &[1]).expect("tensor creation failed");
backward(&loss, &ones);
let grad = input.grad();
assert!(grad.is_some(), "Input gradient should exist");
let grad_data = grad.unwrap();
assert_eq!(grad_data.shape(), &[2, 4, 32]);
let grad_vec = grad_data.to_vec();
let non_zero = grad_vec.iter().any(|&v| v.abs() > 1e-10);
assert!(non_zero, "Gradients should be non-zero");
}
#[test]
fn test_fused_attention_backward_cpu() {
use axonml_autograd::functions::FusedAttentionBackward;
use axonml_autograd::grad_fn::GradientFunction;
let batch = 1;
let heads = 2;
let seq = 4;
let dim = 8;
let scale = 1.0 / (dim as f32).sqrt();
let q_data: Vec<f32> = (0..batch * heads * seq * dim)
.map(|i| ((i as f32) * 0.01).sin())
.collect();
let k_data: Vec<f32> = (0..batch * heads * seq * dim)
.map(|i| ((i as f32) * 0.02).cos())
.collect();
let v_data: Vec<f32> = (0..batch * heads * seq * dim)
.map(|i| ((i as f32) * 0.03).sin() + 0.5)
.collect();
let q =
Tensor::from_vec(q_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
let k =
Tensor::from_vec(k_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
let v =
Tensor::from_vec(v_data, &[batch, heads, seq, dim]).expect("tensor creation failed");
let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, false);
assert_eq!(output.shape(), &[batch, heads, seq, dim]);
let backward_fn = FusedAttentionBackward::new(
None,
None,
None,
q.clone(),
k.clone(),
v.clone(),
output.clone(),
scale,
false,
);
let grad_output = Tensor::from_vec(
vec![1.0f32; batch * heads * seq * dim],
&[batch, heads, seq, dim],
)
.unwrap();
let grads = backward_fn.apply(&grad_output);
assert_eq!(grads.len(), 3);
let gq = grads[0].as_ref().expect("grad_Q should exist");
let gk = grads[1].as_ref().expect("grad_K should exist");
let gv = grads[2].as_ref().expect("grad_V should exist");
assert_eq!(gq.shape(), &[batch, heads, seq, dim]);
assert_eq!(gk.shape(), &[batch, heads, seq, dim]);
assert_eq!(gv.shape(), &[batch, heads, seq, dim]);
for val in gq
.to_vec()
.iter()
.chain(gk.to_vec().iter())
.chain(gv.to_vec().iter())
{
assert!(val.is_finite(), "Gradient should be finite, got {}", val);
}
let gv_nonzero = gv.to_vec().iter().any(|&v| v.abs() > 1e-10);
assert!(gv_nonzero, "grad_V should be non-zero");
}
#[test]
fn test_fused_attention_backward_causal_cpu() {
use axonml_autograd::functions::FusedAttentionBackward;
use axonml_autograd::grad_fn::GradientFunction;
let batch = 1;
let heads = 1;
let seq = 4;
let dim = 4;
let scale = 1.0 / (dim as f32).sqrt();
let q = Tensor::from_vec(
vec![0.1f32; batch * heads * seq * dim],
&[batch, heads, seq, dim],
)
.unwrap();
let k = Tensor::from_vec(
vec![0.2f32; batch * heads * seq * dim],
&[batch, heads, seq, dim],
)
.unwrap();
let v = Tensor::from_vec(
vec![0.5f32; batch * heads * seq * dim],
&[batch, heads, seq, dim],
)
.unwrap();
let output = scaled_dot_product_attention_fused(&q, &k, &v, scale, true);
let backward_fn = FusedAttentionBackward::new(
None,
None,
None,
q.clone(),
k.clone(),
v.clone(),
output.clone(),
scale,
true,
);
let grad_output = Tensor::from_vec(
vec![1.0f32; batch * heads * seq * dim],
&[batch, heads, seq, dim],
)
.unwrap();
let grads = backward_fn.apply(&grad_output);
assert_eq!(grads.len(), 3);
let gq = grads[0].as_ref().unwrap();
let gk = grads[1].as_ref().unwrap();
let gv = grads[2].as_ref().unwrap();
for val in gq
.to_vec()
.iter()
.chain(gk.to_vec().iter())
.chain(gv.to_vec().iter())
{
assert!(val.is_finite(), "Gradient should be finite, got {}", val);
}
}
}