use crate::nn::{Linear, Module};
use crate::tensor::{GraphContext, Tensor};
use ndarray::{arr0, ArrayD, IxDyn};
use std::cell::RefCell;
use std::rc::Rc;
#[derive(Debug, Clone)]
pub enum AttentionMask {
Causal,
Padding(Tensor),
Custom(Tensor),
}
#[derive(Debug, Clone)]
pub struct MultiHeadAttentionConfig {
pub embed_dim: usize,
pub num_heads: usize,
pub dropout: f32,
pub bias: bool,
}
impl Default for MultiHeadAttentionConfig {
fn default() -> Self {
Self {
embed_dim: 512,
num_heads: 8,
dropout: 0.0,
bias: true,
}
}
}
impl MultiHeadAttentionConfig {
pub fn new(embed_dim: usize, num_heads: usize) -> Self {
Self {
embed_dim,
num_heads,
..Default::default()
}
}
pub fn with_dropout(mut self, dropout: f32) -> Self {
self.dropout = dropout;
self
}
pub fn without_bias(mut self) -> Self {
self.bias = false;
self
}
}
pub struct MultiHeadAttention {
num_heads: usize,
head_dim: usize,
embed_dim: usize,
w_q: Linear,
w_k: Linear,
w_v: Linear,
w_o: Linear,
scale: f32,
context: Rc<RefCell<GraphContext>>,
}
impl MultiHeadAttention {
pub fn new(
context: &Rc<RefCell<GraphContext>>,
embed_dim: usize,
num_heads: usize,
name: &str,
) -> Self {
assert!(
embed_dim % num_heads == 0,
"embed_dim ({}) must be divisible by num_heads ({}) without remainder.",
embed_dim,
num_heads
);
let head_dim = embed_dim / num_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
Self {
num_heads,
head_dim,
embed_dim,
w_q: Linear::new(context, &format!("{}.w_q", name), embed_dim, embed_dim),
w_k: Linear::new(context, &format!("{}.w_k", name), embed_dim, embed_dim),
w_v: Linear::new(context, &format!("{}.w_v", name), embed_dim, embed_dim),
w_o: Linear::new(context, &format!("{}.w_o", name), embed_dim, embed_dim),
scale,
context: Rc::clone(context),
}
}
pub fn from_config(
context: &Rc<RefCell<GraphContext>>,
config: MultiHeadAttentionConfig,
name: &str,
) -> Self {
Self::new(context, config.embed_dim, config.num_heads, name)
}
pub fn scaled_dot_product_attention(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
mask: Option<&Tensor>,
) -> Tensor {
let k_transposed = key.transpose(2, 3);
let scores = query.dot(&k_transposed);
let scale_tensor = Tensor::new_literal(&self.context, arr0(self.scale).into_dyn(), "scale");
let scores_scaled = &scores * &scale_tensor;
let scores_masked = if let Some(m) = mask {
&scores_scaled + m
} else {
scores_scaled
};
let attention_weights = scores_masked.softmax();
attention_weights.dot(value)
}
pub fn forward_qkv(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attn_mask: Option<&Tensor>,
key_padding_mask: Option<&Tensor>,
) -> Tensor {
let q = self.w_q.forward(query);
let k = self.w_k.forward(key);
let v = self.w_v.forward(value);
let q_heads = self.split_heads_dynamic(&q);
let k_heads = self.split_heads_dynamic(&k);
let v_heads = self.split_heads_dynamic(&v);
let combined_mask = self.combine_masks(attn_mask, key_padding_mask);
let attention_output =
self.scaled_dot_product_attention(&q_heads, &k_heads, &v_heads, combined_mask.as_ref());
let concatenated = self.combine_heads_dynamic(&attention_output);
self.w_o.forward(&concatenated)
}
fn split_heads_dynamic(&self, x: &Tensor) -> Tensor {
x.reshape(vec![-1, -1, self.num_heads as i64, self.head_dim as i64])
.transpose(1, 2)
}
fn combine_heads_dynamic(&self, x: &Tensor) -> Tensor {
x.transpose(1, 2)
.reshape(vec![-1, -1, self.embed_dim as i64])
}
fn combine_masks(
&self,
attn_mask: Option<&Tensor>,
key_padding_mask: Option<&Tensor>,
) -> Option<Tensor> {
match (attn_mask, key_padding_mask) {
(None, None) => None,
(Some(m), None) => Some(m.clone()),
(None, Some(kpm)) => {
Some(self.expand_padding_mask(kpm))
}
(Some(am), Some(kpm)) => {
let expanded_kpm = self.expand_padding_mask(kpm);
Some(&am.clone() + &expanded_kpm)
}
}
}
fn expand_padding_mask(&self, mask: &Tensor) -> Tensor {
let one = Tensor::scalar(&self.context, 1.0);
let neg_inf = Tensor::scalar(&self.context, -1e9);
let inverted = &one - mask;
&inverted * &neg_inf
}
pub fn create_causal_mask(&self, seq_len: usize) -> Tensor {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
if j > i {
mask_data[i * seq_len + j] = -1e9;
}
}
}
let mask_arr = ArrayD::from_shape_vec(IxDyn(&[1, 1, seq_len, seq_len]), mask_data).unwrap();
Tensor::new_literal(&self.context, mask_arr, "causal_mask")
}
pub fn create_padding_mask_from_lengths(&self, lengths: &[usize], max_len: usize) -> Tensor {
let batch_size = lengths.len();
let mut mask_data = vec![0.0f32; batch_size * max_len];
for (b, &len) in lengths.iter().enumerate() {
for i in 0..len.min(max_len) {
mask_data[b * max_len + i] = 1.0;
}
}
let mask_arr = ArrayD::from_shape_vec(IxDyn(&[batch_size, max_len]), mask_data).unwrap();
Tensor::new_literal(&self.context, mask_arr, "padding_mask")
}
fn split_heads(&self, x: &Tensor) -> Tensor {
x.reshape(vec![1, 1, self.num_heads as i64, self.head_dim as i64])
.transpose(1, 2)
}
fn combine_heads(&self, x: &Tensor) -> Tensor {
x.transpose(1, 2).reshape(vec![1, self.embed_dim as i64])
}
}
impl Module for MultiHeadAttention {
fn forward(&self, inputs: &Tensor) -> Tensor {
let q = self.w_q.forward(inputs);
let k = self.w_k.forward(inputs);
let v = self.w_v.forward(inputs);
let q_heads = self.split_heads(&q);
let k_heads = self.split_heads(&k);
let v_heads = self.split_heads(&v);
let k_heads_transposed = k_heads.transpose(2, 3);
let scores = q_heads.dot(&k_heads_transposed);
let scale_tensor = Tensor::new_literal(&self.context, arr0(self.scale).into_dyn(), "scale");
let scores_scaled = &scores * &scale_tensor;
let attention_weights = scores_scaled.softmax();
let attention_output = attention_weights.dot(&v_heads);
let concatenated_output = self.combine_heads(&attention_output);
self.w_o.forward(&concatenated_output)
}
fn parameters(&self) -> Vec<Tensor> {
let mut params = Vec::new();
params.extend(self.w_q.parameters());
params.extend(self.w_k.parameters());
params.extend(self.w_v.parameters());
params.extend(self.w_o.parameters());
params
}
}
pub fn create_causal_mask(context: &Rc<RefCell<GraphContext>>, seq_len: usize) -> Tensor {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
if j > i {
mask_data[i * seq_len + j] = -1e9;
}
}
}
let mask_arr = ArrayD::from_shape_vec(IxDyn(&[1, 1, seq_len, seq_len]), mask_data).unwrap();
Tensor::new_literal(context, mask_arr, "causal_mask")
}
pub fn create_padding_mask_from_ids(
context: &Rc<RefCell<GraphContext>>,
lengths: &[usize],
max_len: usize,
) -> Tensor {
let batch_size = lengths.len();
let mut mask_data = vec![-1e9f32; batch_size * max_len];
for (b, &len) in lengths.iter().enumerate() {
for i in 0..len.min(max_len) {
mask_data[b * max_len + i] = 0.0;
}
}
let mask_arr = ArrayD::from_shape_vec(IxDyn(&[batch_size, 1, 1, max_len]), mask_data).unwrap();
Tensor::new_literal(context, mask_arr, "padding_mask")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::asg::{NodeType, Value};
#[test]
fn test_mha_creation() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let mha = MultiHeadAttention::new(&context, 64, 4, "mha");
assert_eq!(mha.embed_dim, 64);
assert_eq!(mha.num_heads, 4);
assert_eq!(mha.head_dim, 16);
}
#[test]
fn test_causal_mask() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let mask = create_causal_mask(&context, 4);
let graph = context.borrow();
let main_graph = graph.main_graph();
let node = main_graph.get_node(mask.node_id).unwrap();
if let NodeType::Literal(Value::Tensor(arr)) = &node.node_type {
assert_eq!(arr.shape(), &[1, 1, 4, 4]);
assert!(arr[[0, 0, 0, 1]] < -1e8);
assert!(arr[[0, 0, 0, 2]] < -1e8);
assert!(arr[[0, 0, 0, 3]] < -1e8);
assert_eq!(arr[[0, 0, 0, 0]], 0.0);
assert_eq!(arr[[0, 0, 1, 0]], 0.0);
assert_eq!(arr[[0, 0, 1, 1]], 0.0);
} else {
panic!("Expected Literal tensor");
}
}
#[test]
fn test_padding_mask_from_lengths() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let mha = MultiHeadAttention::new(&context, 64, 4, "mha");
let mask = mha.create_padding_mask_from_lengths(&[3, 2, 4], 4);
let graph = context.borrow();
let main_graph = graph.main_graph();
let node = main_graph.get_node(mask.node_id).unwrap();
if let NodeType::Literal(Value::Tensor(arr)) = &node.node_type {
assert_eq!(arr.shape(), &[3, 4]);
assert_eq!(arr[[0, 0]], 1.0);
assert_eq!(arr[[0, 2]], 1.0);
assert_eq!(arr[[0, 3]], 0.0);
assert_eq!(arr[[1, 1]], 1.0);
assert_eq!(arr[[1, 2]], 0.0);
} else {
panic!("Expected Literal tensor");
}
}
#[test]
fn test_mha_parameters() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let mha = MultiHeadAttention::new(&context, 64, 4, "mha");
let params = mha.parameters();
assert_eq!(params.len(), 8);
}
#[test]
fn test_mha_config() {
let config = MultiHeadAttentionConfig::new(256, 4)
.with_dropout(0.1)
.without_bias();
assert_eq!(config.embed_dim, 256);
assert_eq!(config.num_heads, 4);
assert_eq!(config.dropout, 0.1);
assert!(!config.bias);
}
}