use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
pub struct MultiheadAttention {
base: ModuleBase,
embed_dim: usize,
num_heads: usize,
dropout: f32,
bias: bool,
add_bias_kv: bool,
add_zero_attn: bool,
kdim: Option<usize>,
vdim: Option<usize>,
batch_first: bool,
}
impl MultiheadAttention {
pub fn new(embed_dim: usize, num_heads: usize) -> Self {
let mut base = ModuleBase::new();
assert!(
embed_dim % num_heads == 0,
"embed_dim must be divisible by num_heads"
);
let in_proj_weight = crate::init::xavier_uniform(&[3 * embed_dim, embed_dim])
.expect("Failed to create in_proj_weight");
let out_proj_weight = crate::init::xavier_uniform(&[embed_dim, embed_dim])
.expect("Failed to create out_proj_weight");
base.register_parameter("in_proj_weight".to_string(), Parameter::new(in_proj_weight));
base.register_parameter(
"out_proj.weight".to_string(),
Parameter::new(out_proj_weight),
);
let in_proj_bias =
zeros(&[3 * embed_dim]).expect("zeros tensor for in_proj_bias should succeed");
let out_proj_bias =
zeros(&[embed_dim]).expect("zeros tensor for out_proj_bias should succeed");
base.register_parameter("in_proj_bias".to_string(), Parameter::new(in_proj_bias));
base.register_parameter("out_proj.bias".to_string(), Parameter::new(out_proj_bias));
Self {
base,
embed_dim,
num_heads,
dropout: 0.0,
bias: true,
add_bias_kv: false,
add_zero_attn: false,
kdim: None,
vdim: None,
batch_first: false,
}
}
#[allow(clippy::too_many_arguments)]
pub fn with_config(
embed_dim: usize,
num_heads: usize,
dropout: f32,
bias: bool,
add_bias_kv: bool,
add_zero_attn: bool,
kdim: Option<usize>,
vdim: Option<usize>,
batch_first: bool,
) -> Self {
let mut attention = Self::new(embed_dim, num_heads);
attention.dropout = dropout;
attention.bias = bias;
attention.add_bias_kv = add_bias_kv;
attention.add_zero_attn = add_zero_attn;
attention.kdim = kdim;
attention.vdim = vdim;
attention.batch_first = batch_first;
attention
}
}
impl MultiheadAttention {
pub fn forward_with_kv(
&self,
query: &Tensor,
_key: &Tensor,
_value: &Tensor,
attn_mask: Option<&Tensor>,
is_causal: bool,
) -> Result<Tensor> {
let batch_size = query.shape().dims()[0];
let seq_len = query.shape().dims()[1];
let head_dim = self.embed_dim / self.num_heads;
let in_proj_weight = self.base.parameters["in_proj_weight"]
.tensor()
.read()
.clone();
let out_proj_weight = self.base.parameters["out_proj.weight"]
.tensor()
.read()
.clone();
let qkv = query.matmul(&in_proj_weight.transpose(0, 1)?)?;
let qkv = if self.bias {
let in_proj_bias = self.base.parameters["in_proj_bias"].tensor().read().clone();
qkv.add_op(&in_proj_bias)?
} else {
qkv
};
let q = qkv.narrow(2, 0, self.embed_dim)?;
let k = qkv.narrow(2, self.embed_dim as i64, self.embed_dim)?;
let v = qkv.narrow(2, (2 * self.embed_dim) as i64, self.embed_dim)?;
let q = q
.reshape(&[
batch_size as i32,
seq_len as i32,
self.num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?;
let k = k
.reshape(&[
batch_size as i32,
seq_len as i32,
self.num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?;
let v = v
.reshape(&[
batch_size as i32,
seq_len as i32,
self.num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?;
let (attended, _attn_weights) =
scaled_dot_product_attention(&q, &k, &v, attn_mask, self.dropout, is_causal)?;
let attended = attended.transpose(1, 2)?.contiguous()?.reshape(&[
batch_size as i32,
seq_len as i32,
self.embed_dim as i32,
])?;
let output = attended.matmul(&out_proj_weight.transpose(0, 1)?)?;
if self.bias {
let out_proj_bias = self.base.parameters["out_proj.bias"]
.tensor()
.read()
.clone();
output.add_op(&out_proj_bias)
} else {
Ok(output)
}
}
pub fn forward_cross_attention(
&self,
query: &Tensor,
key_value: &Tensor,
attn_mask: Option<&Tensor>,
) -> Result<Tensor> {
self.forward_with_kv(query, key_value, key_value, attn_mask, false)
}
}
impl Module for MultiheadAttention {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.forward_with_kv(input, input, input, None, false)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for MultiheadAttention {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MultiheadAttention")
.field("embed_dim", &self.embed_dim)
.field("num_heads", &self.num_heads)
.field("dropout", &self.dropout)
.finish()
}
}
pub fn scaled_dot_product_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
attn_mask: Option<&Tensor>,
dropout_p: f32,
is_causal: bool,
) -> Result<(Tensor, Tensor)> {
let d_k = query.shape().dims()[query.shape().ndim() - 1] as f32;
let scale = 1.0 / d_k.sqrt();
let key_transposed = key.transpose(-2, -1)?;
let mut scores = query.matmul(&key_transposed)?.mul_scalar(scale)?;
if is_causal {
let seq_len = scores.shape().dims()[scores.shape().ndim() - 1];
let causal_mask = create_causal_mask(seq_len)?;
let large_neg = causal_mask.mul_scalar(-1e9)?;
scores = scores.add_op(&large_neg)?;
}
if let Some(mask) = attn_mask {
let large_neg = mask.mul_scalar(-1e9)?;
scores = scores.add_op(&large_neg)?;
}
let attn_weights = scores.softmax(-1)?;
let attn_weights = if dropout_p > 0.0 && is_training() {
crate::functional::dropout(&attn_weights, dropout_p, is_training())?
} else {
attn_weights
};
let output = attn_weights.matmul(value)?;
Ok((output, attn_weights))
}
pub fn flash_attention(
query: &Tensor,
key: &Tensor,
value: &Tensor,
block_size: Option<usize>,
causal: bool,
) -> Result<Tensor> {
let block_size = block_size.unwrap_or(64);
let shape = query.shape();
let batch_size = shape.dims()[0];
let num_heads = shape.dims()[1];
let seq_len = shape.dims()[2];
let head_dim = shape.dims()[3];
let d_k = head_dim as f32;
let scale = 1.0 / d_k.sqrt();
let output_shape = [batch_size, num_heads, seq_len, head_dim];
let mut output = zeros::<f32>(&output_shape)?;
let mut _max_vals = full(&[batch_size, num_heads, seq_len], f32::NEG_INFINITY)?;
let mut _sum_exp = zeros::<f32>(&[batch_size, num_heads, seq_len])?;
let num_blocks = seq_len.div_ceil(block_size);
for i in 0..num_blocks {
let q_start = i * block_size;
let q_end = (q_start + block_size).min(seq_len);
let q_size = q_end - q_start;
let q_block = query.narrow(2, q_start as i64, q_size)?;
let mut o_block = zeros::<f32>(&[batch_size, num_heads, q_size, head_dim])?;
let mut max_block = full(&[batch_size, num_heads, q_size], f32::NEG_INFINITY)?;
let mut sum_block = zeros::<f32>(&[batch_size, num_heads, q_size])?;
for j in 0..num_blocks {
let k_start = j * block_size;
let k_end = (k_start + block_size).min(seq_len);
let k_size = k_end - k_start;
if causal && k_start >= q_end {
continue;
}
let k_block = key.narrow(2, k_start as i64, k_size)?;
let v_block = value.narrow(2, k_start as i64, k_size)?;
let mut scores = q_block
.matmul(&k_block.transpose(-2, -1)?)?
.mul_scalar(scale)?;
if causal {
let mask = create_block_causal_mask(q_start, q_end, k_start, k_end)?;
if mask.shape().dims().iter().product::<usize>() > 0 {
let large_neg = mask.mul_scalar(-1e9)?;
scores = scores.add_op(&large_neg)?;
}
}
let new_max = scores.max_dim(-1, true)?;
let old_max = max_block.unsqueeze(-1)?;
let updated_max = new_max.maximum(&old_max)?;
let old_exp_factor = old_max.sub(&updated_max)?.exp()?;
let new_exp_factor = scores.sub(&updated_max)?.exp()?;
let old_sum = sum_block.unsqueeze(-1)?;
let new_sum = new_exp_factor.sum_dim(&[-1], true)?;
let updated_sum = old_sum.mul_op(&old_exp_factor)?.add_op(&new_sum)?;
let weighted_values = new_exp_factor.matmul(&v_block)?;
o_block = o_block
.mul_op(&old_exp_factor.unsqueeze(-1)?)?
.add_op(&weighted_values)?;
max_block = updated_max.squeeze(-1)?;
sum_block = updated_sum.squeeze(-1)?;
}
let norm_factor = sum_block.unsqueeze(-1)?.reciprocal()?;
o_block = o_block.mul_op(&norm_factor)?;
if i == 0 {
output = o_block.clone();
_max_vals = max_block.clone();
_sum_exp = sum_block.clone();
}
}
Ok(output)
}
pub struct FlashMultiHeadAttention {
base: ModuleBase,
embed_dim: usize,
num_heads: usize,
dropout: f32,
bias: bool,
block_size: usize,
}
impl FlashMultiHeadAttention {
pub fn new(embed_dim: usize, num_heads: usize) -> Self {
let mut base = ModuleBase::new();
assert!(
embed_dim % num_heads == 0,
"embed_dim must be divisible by num_heads"
);
let qkv_weight = crate::init::xavier_uniform(&[3 * embed_dim, embed_dim])
.expect("Failed to create qkv_weight");
let out_proj_weight = crate::init::xavier_uniform(&[embed_dim, embed_dim])
.expect("Failed to create out_proj_weight");
base.register_parameter("qkv_weight".to_string(), Parameter::new(qkv_weight));
base.register_parameter(
"out_proj_weight".to_string(),
Parameter::new(out_proj_weight),
);
let qkv_bias = zeros(&[3 * embed_dim]).expect("zeros tensor for qkv_bias should succeed");
let out_proj_bias =
zeros(&[embed_dim]).expect("zeros tensor for out_proj_bias should succeed");
base.register_parameter("qkv_bias".to_string(), Parameter::new(qkv_bias));
base.register_parameter("out_proj_bias".to_string(), Parameter::new(out_proj_bias));
Self {
base,
embed_dim,
num_heads,
dropout: 0.0,
bias: true,
block_size: 64,
}
}
pub fn with_block_size(mut self, block_size: usize) -> Self {
self.block_size = block_size;
self
}
pub fn with_dropout(mut self, dropout: f32) -> Self {
self.dropout = dropout;
self
}
pub fn forward_flash(&self, input: &Tensor, causal: bool) -> Result<Tensor> {
let batch_size = input.shape().dims()[0];
let seq_len = input.shape().dims()[1];
let head_dim = self.embed_dim / self.num_heads;
let qkv_weight = self.base.parameters["qkv_weight"].tensor().read().clone();
let mut qkv = input.matmul(&qkv_weight.transpose(0, 1)?)?;
if self.bias {
let qkv_bias = self.base.parameters["qkv_bias"].tensor().read().clone();
qkv = qkv.add_op(&qkv_bias)?;
}
let chunk_size = self.embed_dim;
let q = qkv
.narrow(2, 0, chunk_size)?
.reshape(&[
batch_size as i32,
seq_len as i32,
self.num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?;
let k = qkv
.narrow(2, chunk_size as i64, chunk_size)?
.reshape(&[
batch_size as i32,
seq_len as i32,
self.num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?;
let v = qkv
.narrow(2, (2 * chunk_size) as i64, chunk_size)?
.reshape(&[
batch_size as i32,
seq_len as i32,
self.num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?;
let attended = flash_attention(&q, &k, &v, Some(self.block_size), causal)?;
let attended = attended.transpose(1, 2)?.contiguous()?.reshape(&[
batch_size as i32,
seq_len as i32,
self.embed_dim as i32,
])?;
let out_proj_weight = self.base.parameters["out_proj_weight"]
.tensor()
.read()
.clone();
let mut output = attended.matmul(&out_proj_weight.transpose(0, 1)?)?;
if self.bias {
let out_proj_bias = self.base.parameters["out_proj_bias"]
.tensor()
.read()
.clone();
output = output.add_op(&out_proj_bias)?;
}
if self.dropout > 0.0 && self.training() {
output = crate::functional::dropout(&output, self.dropout, self.training())?;
}
Ok(output)
}
}
impl Module for FlashMultiHeadAttention {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.forward_flash(input, false)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
fn create_block_causal_mask(
q_start: usize,
q_end: usize,
k_start: usize,
k_end: usize,
) -> Result<Tensor> {
let q_size = q_end - q_start;
let k_size = k_end - k_start;
let mut mask_data = vec![0.0f32; q_size * k_size];
for i in 0..q_size {
for j in 0..k_size {
let global_i = q_start + i;
let global_j = k_start + j;
if global_j > global_i {
mask_data[i * k_size + j] = 1.0;
}
}
}
Ok(
Tensor::from_data(mask_data, vec![q_size, k_size], DeviceType::Cpu)
.expect("tensor creation should succeed"),
)
}
pub struct FlexAttention {
base: ModuleBase,
embed_dim: usize,
num_heads: usize,
attention_pattern: AttentionPattern,
#[allow(dead_code)]
window_size: Option<usize>,
#[allow(dead_code)]
sparsity_factor: Option<usize>,
}
#[derive(Debug, Clone)]
pub enum AttentionPattern {
Full,
SlidingWindow(usize),
Sparse(usize),
BlockSparse {
block_size: usize,
num_blocks: usize,
},
RandomSparse { sparsity: f32 },
}
impl FlexAttention {
pub fn new(embed_dim: usize, num_heads: usize, attention_pattern: AttentionPattern) -> Self {
let mut base = ModuleBase::new();
assert!(
embed_dim % num_heads == 0,
"embed_dim must be divisible by num_heads"
);
let q_proj =
crate::init::xavier_uniform(&[embed_dim, embed_dim]).expect("Failed to create q_proj");
let k_proj =
crate::init::xavier_uniform(&[embed_dim, embed_dim]).expect("Failed to create k_proj");
let v_proj =
crate::init::xavier_uniform(&[embed_dim, embed_dim]).expect("Failed to create v_proj");
let out_proj = crate::init::xavier_uniform(&[embed_dim, embed_dim])
.expect("Failed to create out_proj");
base.register_parameter("q_proj".to_string(), Parameter::new(q_proj));
base.register_parameter("k_proj".to_string(), Parameter::new(k_proj));
base.register_parameter("v_proj".to_string(), Parameter::new(v_proj));
base.register_parameter("out_proj".to_string(), Parameter::new(out_proj));
Self {
base,
embed_dim,
num_heads,
attention_pattern,
window_size: None,
sparsity_factor: None,
}
}
fn apply_pattern_mask(&self, scores: &Tensor, seq_len: usize) -> Result<Tensor> {
match &self.attention_pattern {
AttentionPattern::Full => Ok(scores.clone()),
AttentionPattern::SlidingWindow(window) => {
let mask = create_sliding_window_mask(seq_len, *window)?;
let large_neg = mask.mul_scalar(-1e9)?;
scores.add_op(&large_neg)
}
AttentionPattern::Sparse(factor) => {
let mask = create_sparse_mask(seq_len, *factor)?;
let large_neg = mask.mul_scalar(-1e9)?;
scores.add_op(&large_neg)
}
AttentionPattern::BlockSparse {
block_size,
num_blocks,
} => {
let mask = create_block_sparse_mask(seq_len, *block_size, *num_blocks)?;
let large_neg = mask.mul_scalar(-1e9)?;
scores.add_op(&large_neg)
}
AttentionPattern::RandomSparse { sparsity } => {
let mask = create_random_sparse_mask(seq_len, *sparsity)?;
let large_neg = mask.mul_scalar(-1e9)?;
scores.add_op(&large_neg)
}
}
}
}
impl Module for FlexAttention {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let q_proj = self.base.parameters["q_proj"].tensor().read().clone();
let k_proj = self.base.parameters["k_proj"].tensor().read().clone();
let v_proj = self.base.parameters["v_proj"].tensor().read().clone();
let out_proj = self.base.parameters["out_proj"].tensor().read().clone();
let query = input.matmul(&q_proj)?;
let key = input.matmul(&k_proj)?;
let value = input.matmul(&v_proj)?;
let batch_size = input.shape().dims()[0];
let seq_len = input.shape().dims()[1];
let head_dim = self.embed_dim / self.num_heads;
let query = query
.reshape(&[
batch_size as i32,
seq_len as i32,
self.num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?;
let key = key
.reshape(&[
batch_size as i32,
seq_len as i32,
self.num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?;
let value = value
.reshape(&[
batch_size as i32,
seq_len as i32,
self.num_heads as i32,
head_dim as i32,
])?
.transpose(1, 2)?;
let d_k = head_dim as f32;
let scale = 1.0 / d_k.sqrt();
let scores = query.matmul(&key.transpose(-2, -1)?)?.mul_scalar(scale)?;
let masked_scores = self.apply_pattern_mask(&scores, seq_len)?;
let attn_weights = masked_scores.softmax(-1)?;
let attended = attn_weights.matmul(&value)?;
let attended = attended.transpose(1, 2)?.contiguous()?.reshape(&[
batch_size as i32,
seq_len as i32,
self.embed_dim as i32,
])?;
attended.matmul(&out_proj)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
fn create_causal_mask(seq_len: usize) -> Result<Tensor> {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in (i + 1)..seq_len {
mask_data[i * seq_len + j] = 1.0;
}
}
Tensor::from_data(mask_data, vec![seq_len, seq_len], DeviceType::Cpu)
}
fn create_sliding_window_mask(seq_len: usize, window_size: usize) -> Result<Tensor> {
let mut mask_data = vec![1.0f32; seq_len * seq_len];
for i in 0..seq_len {
let start = i.saturating_sub(window_size);
let end = (i + window_size + 1).min(seq_len);
for j in start..end {
mask_data[i * seq_len + j] = 0.0;
}
}
Tensor::from_data(mask_data, vec![seq_len, seq_len], DeviceType::Cpu)
}
fn create_sparse_mask(seq_len: usize, sparsity_factor: usize) -> Result<Tensor> {
let mut mask_data = vec![1.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in (0..seq_len).step_by(sparsity_factor) {
mask_data[i * seq_len + j] = 0.0;
}
}
Tensor::from_data(mask_data, vec![seq_len, seq_len], DeviceType::Cpu)
}
fn create_block_sparse_mask(
seq_len: usize,
block_size: usize,
num_blocks: usize,
) -> Result<Tensor> {
let mut mask_data = vec![1.0f32; seq_len * seq_len];
for block in 0..num_blocks {
let start_i = block * block_size;
let end_i = ((block + 1) * block_size).min(seq_len);
for block_j in 0..num_blocks {
let start_j = block_j * block_size;
let end_j = ((block_j + 1) * block_size).min(seq_len);
for i in start_i..end_i {
for j in start_j..end_j {
if i < seq_len && j < seq_len {
mask_data[i * seq_len + j] = 0.0;
}
}
}
}
}
Tensor::from_data(mask_data, vec![seq_len, seq_len], DeviceType::Cpu)
}
fn create_random_sparse_mask(seq_len: usize, sparsity: f32) -> Result<Tensor> {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len * seq_len {
if scirs2_core::random::quick::random_f32() > sparsity {
mask_data[i] = 1.0;
}
}
Tensor::from_data(mask_data, vec![seq_len, seq_len], DeviceType::Cpu)
}
fn is_training() -> bool {
false
}
pub struct AdditiveAttention {
base: ModuleBase,
hidden_size: usize,
encoder_size: usize,
attention_size: usize,
}
impl AdditiveAttention {
pub fn new(hidden_size: usize, encoder_size: usize, attention_size: usize) -> Self {
let mut base = ModuleBase::new();
let w_a = crate::init::xavier_uniform(&[attention_size, hidden_size])
.expect("Failed to create w_a");
let u_a = crate::init::xavier_uniform(&[attention_size, encoder_size])
.expect("Failed to create u_a");
let v = crate::init::xavier_uniform(&[attention_size, 1]).expect("Failed to create v");
base.register_parameter("w_a".to_string(), Parameter::new(w_a));
base.register_parameter("u_a".to_string(), Parameter::new(u_a));
base.register_parameter("v".to_string(), Parameter::new(v));
Self {
base,
hidden_size,
encoder_size,
attention_size,
}
}
pub fn forward_with_weights(
&self,
decoder_hidden: &Tensor,
encoder_outputs: &Tensor,
) -> Result<(Tensor, Tensor)> {
let w_a = self.base.parameters["w_a"].tensor().read().clone();
let u_a = self.base.parameters["u_a"].tensor().read().clone();
let v = self.base.parameters["v"].tensor().read().clone();
let seq_len = encoder_outputs.shape().dims()[0];
let _batch_size = encoder_outputs.shape().dims()[1];
let decoder_proj = decoder_hidden.matmul(&w_a.transpose(0, 1)?)?;
let mut encoder_projs = Vec::new();
for t in 0..seq_len {
let encoder_t = encoder_outputs.narrow(0, t as i64, 1)?.squeeze(0)?;
let proj_t = encoder_t.matmul(&u_a.transpose(0, 1)?)?;
encoder_projs.push(proj_t);
}
let mut attention_scores = Vec::new();
for t in 0..seq_len {
let combined = decoder_proj.add_op(&encoder_projs[t])?.tanh()?;
let score = combined.matmul(&v)?.squeeze(-1)?; attention_scores.push(score);
}
let scores_stacked = self.stack_attention_scores(&attention_scores)?;
let attention_weights = scores_stacked.softmax(-1)?;
let mut weighted_encoder = Vec::new();
for t in 0..seq_len {
let encoder_t = encoder_outputs.narrow(0, t as i64, 1)?.squeeze(0)?;
let weight_t = attention_weights.narrow(1, t as i64, 1)?.squeeze(1)?;
let weighted_t = encoder_t.mul_op(&weight_t.unsqueeze(1)?)?;
weighted_encoder.push(weighted_t);
}
let context = self.sum_weighted_encoder(&weighted_encoder)?;
Ok((context, attention_weights))
}
fn stack_attention_scores(&self, scores: &[Tensor]) -> Result<Tensor> {
let batch_size = scores[0].shape().dims()[0];
let seq_len = scores.len();
let mut stacked_data = Vec::with_capacity(batch_size * seq_len);
for b in 0..batch_size {
for s in scores {
let data = s.to_vec()?;
stacked_data.push(data[b]);
}
}
Tensor::from_vec(stacked_data, &[batch_size, seq_len])
}
fn sum_weighted_encoder(&self, weighted: &[Tensor]) -> Result<Tensor> {
let mut context = weighted[0].clone();
for w in &weighted[1..] {
context = context.add_op(w)?;
}
Ok(context)
}
}
impl Module for AdditiveAttention {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let hidden_size = self.hidden_size;
let decoder_hidden = input.narrow(1, 0, hidden_size)?;
let encoder_outputs =
input.narrow(1, hidden_size as i64, input.shape().dims()[1] - hidden_size)?;
let batch_size = input.shape().dims()[0];
let remaining_size = input.shape().dims()[1] - hidden_size;
let seq_len = remaining_size / self.encoder_size;
let encoder_reshaped = encoder_outputs
.reshape(&[batch_size as i32, seq_len as i32, self.encoder_size as i32])?
.transpose(0, 1)?;
let (context, _) = self.forward_with_weights(&decoder_hidden, &encoder_reshaped)?;
Ok(context)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct MultiplicativeAttention {
base: ModuleBase,
hidden_size: usize,
encoder_size: usize,
attention_type: LuongAttentionType,
}
#[derive(Debug, Clone)]
pub enum LuongAttentionType {
Dot,
General,
Concat,
}
impl MultiplicativeAttention {
pub fn new(
hidden_size: usize,
encoder_size: usize,
attention_type: LuongAttentionType,
) -> Self {
let mut base = ModuleBase::new();
match &attention_type {
LuongAttentionType::General => {
let w_a = crate::init::xavier_uniform(&[hidden_size, encoder_size])
.expect("Failed to initialize attention weights");
base.register_parameter("w_a".to_string(), Parameter::new(w_a));
}
LuongAttentionType::Concat => {
let w_a = crate::init::xavier_uniform(&[hidden_size, hidden_size + encoder_size])
.expect("Failed to initialize attention weights");
let v_a = crate::init::xavier_uniform(&[hidden_size, 1])
.expect("Failed to initialize attention vector");
base.register_parameter("w_a".to_string(), Parameter::new(w_a));
base.register_parameter("v_a".to_string(), Parameter::new(v_a));
}
LuongAttentionType::Dot => {
}
}
Self {
base,
hidden_size,
encoder_size,
attention_type,
}
}
pub fn forward_with_weights(
&self,
decoder_hidden: &Tensor,
encoder_outputs: &Tensor,
) -> Result<(Tensor, Tensor)> {
let seq_len = encoder_outputs.shape().dims()[0];
let attention_scores = match &self.attention_type {
LuongAttentionType::Dot => {
self.compute_dot_attention(decoder_hidden, encoder_outputs)?
}
LuongAttentionType::General => {
self.compute_general_attention(decoder_hidden, encoder_outputs)?
}
LuongAttentionType::Concat => {
self.compute_concat_attention(decoder_hidden, encoder_outputs)?
}
};
let attention_weights = attention_scores.softmax(-1)?;
let mut weighted_encoder = Vec::new();
for t in 0..seq_len {
let encoder_t = encoder_outputs.narrow(0, t as i64, 1)?.squeeze(0)?;
let weight_t = attention_weights.narrow(1, t as i64, 1)?.squeeze(1)?;
let weighted_t = encoder_t.mul_op(&weight_t.unsqueeze(1)?)?;
weighted_encoder.push(weighted_t);
}
let context = self.sum_weighted_states(&weighted_encoder)?;
Ok((context, attention_weights))
}
fn compute_dot_attention(
&self,
decoder_hidden: &Tensor,
encoder_outputs: &Tensor,
) -> Result<Tensor> {
let seq_len = encoder_outputs.shape().dims()[0];
let batch_size = decoder_hidden.shape().dims()[0];
let mut scores = Vec::new();
for t in 0..seq_len {
let encoder_t = encoder_outputs.narrow(0, t as i64, 1)?.squeeze(0)?;
let score = decoder_hidden.mul_op(&encoder_t)?.sum_dim(&[-1], false)?;
scores.push(score);
}
self.stack_scores(&scores, batch_size)
}
fn compute_general_attention(
&self,
decoder_hidden: &Tensor,
encoder_outputs: &Tensor,
) -> Result<Tensor> {
let w_a = self.base.parameters["w_a"].tensor().read().clone();
let seq_len = encoder_outputs.shape().dims()[0];
let batch_size = decoder_hidden.shape().dims()[0];
let transformed_decoder = decoder_hidden.matmul(&w_a)?;
let mut scores = Vec::new();
for t in 0..seq_len {
let encoder_t = encoder_outputs.narrow(0, t as i64, 1)?.squeeze(0)?;
let score = transformed_decoder
.mul_op(&encoder_t)?
.sum_dim(&[-1], false)?;
scores.push(score);
}
self.stack_scores(&scores, batch_size)
}
fn compute_concat_attention(
&self,
decoder_hidden: &Tensor,
encoder_outputs: &Tensor,
) -> Result<Tensor> {
let w_a = self.base.parameters["w_a"].tensor().read().clone();
let v_a = self.base.parameters["v_a"].tensor().read().clone();
let seq_len = encoder_outputs.shape().dims()[0];
let batch_size = decoder_hidden.shape().dims()[0];
let mut scores = Vec::new();
for t in 0..seq_len {
let encoder_t = encoder_outputs.narrow(0, t as i64, 1)?.squeeze(0)?;
let decoder_data = decoder_hidden.to_vec()?;
let encoder_data = encoder_t.to_vec()?;
let mut concat_data = Vec::with_capacity(decoder_data.len() + encoder_data.len());
concat_data.extend(decoder_data);
concat_data.extend(encoder_data);
let concat = Tensor::from_vec(
concat_data,
&[batch_size, self.hidden_size + self.encoder_size],
)?;
let transformed = concat.matmul(&w_a.transpose(0, 1)?)?.tanh()?;
let score = transformed.matmul(&v_a)?.squeeze(-1)?;
scores.push(score);
}
self.stack_scores(&scores, batch_size)
}
fn stack_scores(&self, scores: &[Tensor], batch_size: usize) -> Result<Tensor> {
let seq_len = scores.len();
let mut stacked_data = Vec::with_capacity(batch_size * seq_len);
for b in 0..batch_size {
for s in scores {
let data = s.to_vec()?;
stacked_data.push(data[b]);
}
}
Tensor::from_vec(stacked_data, &[batch_size, seq_len])
}
fn sum_weighted_states(&self, weighted: &[Tensor]) -> Result<Tensor> {
let mut context = weighted[0].clone();
for w in &weighted[1..] {
context = context.add_op(w)?;
}
Ok(context)
}
}
impl Module for MultiplicativeAttention {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let hidden_size = self.hidden_size;
let decoder_hidden = input.narrow(1, 0, hidden_size)?;
let encoder_outputs =
input.narrow(1, hidden_size as i64, input.shape().dims()[1] - hidden_size)?;
let batch_size = input.shape().dims()[0];
let remaining_size = input.shape().dims()[1] - hidden_size;
let seq_len = remaining_size / self.encoder_size;
let encoder_reshaped = encoder_outputs
.reshape(&[batch_size as i32, seq_len as i32, self.encoder_size as i32])?
.transpose(0, 1)?;
let (context, _) = self.forward_with_weights(&decoder_hidden, &encoder_reshaped)?;
Ok(context)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct LocationBasedAttention {
base: ModuleBase,
hidden_size: usize,
conv_filters: usize,
conv_kernel_size: usize,
}
impl LocationBasedAttention {
pub fn new(hidden_size: usize, conv_filters: usize, conv_kernel_size: usize) -> Self {
let mut base = ModuleBase::new();
let conv_weight = crate::init::xavier_uniform(&[conv_filters, 1, conv_kernel_size])
.expect("Failed to create conv_weight");
let location_proj = crate::init::xavier_uniform(&[hidden_size, conv_filters])
.expect("Failed to create location_proj");
let v_att = crate::init::xavier_uniform(&[hidden_size, 1]).expect("Failed to create v_att");
base.register_parameter("conv_weight".to_string(), Parameter::new(conv_weight));
base.register_parameter("location_proj".to_string(), Parameter::new(location_proj));
base.register_parameter("v_att".to_string(), Parameter::new(v_att));
Self {
base,
hidden_size,
conv_filters,
conv_kernel_size,
}
}
pub fn forward_with_location(
&self,
decoder_hidden: &Tensor,
encoder_outputs: &Tensor,
prev_attention: &Tensor, ) -> Result<(Tensor, Tensor)> {
let seq_len = encoder_outputs.shape().dims()[0];
let batch_size = decoder_hidden.shape().dims()[0];
let location_proj = self.base.parameters["location_proj"]
.tensor()
.read()
.clone();
let v_att = self.base.parameters["v_att"].tensor().read().clone();
let location_features = self.compute_location_features(prev_attention)?;
let location_proj_result = location_features.matmul(&location_proj.transpose(0, 1)?)?;
let mut attention_scores = Vec::new();
for t in 0..seq_len {
let encoder_t = encoder_outputs.narrow(0, t as i64, 1)?.squeeze(0)?;
let location_t = location_proj_result.narrow(1, t as i64, 1)?.squeeze(1)?;
let combined = decoder_hidden
.add_op(&encoder_t)?
.add_op(&location_t)?
.tanh()?;
let score = combined.matmul(&v_att)?.squeeze(-1)?;
attention_scores.push(score);
}
let scores_stacked = self.stack_location_scores(&attention_scores, batch_size)?;
let attention_weights = scores_stacked.softmax(-1)?;
let mut weighted_encoder = Vec::new();
for t in 0..seq_len {
let encoder_t = encoder_outputs.narrow(0, t as i64, 1)?.squeeze(0)?;
let weight_t = attention_weights.narrow(1, t as i64, 1)?.squeeze(1)?;
let weighted_t = encoder_t.mul_op(&weight_t.unsqueeze(1)?)?;
weighted_encoder.push(weighted_t);
}
let context = self.sum_location_weighted(&weighted_encoder)?;
Ok((context, attention_weights))
}
fn compute_location_features(&self, prev_attention: &Tensor) -> Result<Tensor> {
let expanded = prev_attention.unsqueeze(1)?;
let smoothed = expanded.clone(); Ok(smoothed)
}
fn stack_location_scores(&self, scores: &[Tensor], batch_size: usize) -> Result<Tensor> {
let seq_len = scores.len();
let mut stacked_data = Vec::with_capacity(batch_size * seq_len);
for b in 0..batch_size {
for s in scores {
let data = s.to_vec()?;
stacked_data.push(data[b]);
}
}
Tensor::from_vec(stacked_data, &[batch_size, seq_len])
}
fn sum_location_weighted(&self, weighted: &[Tensor]) -> Result<Tensor> {
let mut context = weighted[0].clone();
for w in &weighted[1..] {
context = context.add_op(w)?;
}
Ok(context)
}
}
impl Module for LocationBasedAttention {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let hidden_size = self.hidden_size;
let decoder_hidden = input.narrow(1, 0, hidden_size)?;
let remaining =
input.narrow(1, hidden_size as i64, input.shape().dims()[1] - hidden_size)?;
let batch_size = input.shape().dims()[0];
let encoder_part_size = remaining.shape().dims()[1] / 2;
let encoder_outputs = remaining.narrow(1, 0, encoder_part_size)?;
let prev_attention = remaining.narrow(1, encoder_part_size as i64, encoder_part_size)?;
let seq_len = encoder_part_size / self.hidden_size;
let encoder_reshaped = encoder_outputs
.reshape(&[batch_size as i32, seq_len as i32, self.hidden_size as i32])?
.transpose(0, 1)?;
let prev_att_reshaped = prev_attention.narrow(1, 0, seq_len)?;
let (context, _) =
self.forward_with_location(&decoder_hidden, &encoder_reshaped, &prev_att_reshaped)?;
Ok(context)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for AdditiveAttention {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AdditiveAttention")
.field("hidden_size", &self.hidden_size)
.field("encoder_size", &self.encoder_size)
.field("attention_size", &self.attention_size)
.finish()
}
}
impl std::fmt::Debug for MultiplicativeAttention {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MultiplicativeAttention")
.field("hidden_size", &self.hidden_size)
.field("encoder_size", &self.encoder_size)
.field("attention_type", &self.attention_type)
.finish()
}
}
impl std::fmt::Debug for LocationBasedAttention {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocationBasedAttention")
.field("hidden_size", &self.hidden_size)
.field("conv_filters", &self.conv_filters)
.field("conv_kernel_size", &self.conv_kernel_size)
.finish()
}
}
#[allow(dead_code)]
fn f32_mask_to_bool(mask: &Tensor) -> Result<Tensor> {
Ok(mask.clone())
}