use crate::autograd::Variable;
use crate::error::{RusTorchError, RusTorchResult};
use crate::nn::{Dropout, LayerNorm, Linear, Module};
use crate::tensor::Tensor;
use ndarray::ScalarOperand;
use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
use std::f32::consts::PI;
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug)]
pub struct MultiheadAttention<
T: Float + Send + Sync + 'static + ScalarOperand + FromPrimitive + Sum,
> {
embed_dim: usize,
num_heads: usize,
dropout: T,
bias: bool,
kdim: Option<usize>,
vdim: Option<usize>,
batch_first: bool,
head_dim: usize,
in_proj: Linear<T>,
out_proj: Linear<T>,
dropout_layer: Dropout<T>,
}
impl<T> MultiheadAttention<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ std::fmt::Display
+ Sum,
{
pub fn new(
embed_dim: usize,
num_heads: usize,
dropout: Option<T>,
bias: Option<bool>,
kdim: Option<usize>,
vdim: Option<usize>,
batch_first: Option<bool>,
) -> RusTorchResult<Self> {
if embed_dim == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "MultiheadAttention::new".to_string(),
message: "embed_dim must be greater than 0".to_string(),
});
}
if num_heads == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "MultiheadAttention::new".to_string(),
message: "num_heads must be greater than 0".to_string(),
});
}
if embed_dim % num_heads != 0 {
return Err(RusTorchError::InvalidParameters {
operation: "MultiheadAttention::new".to_string(),
message: format!(
"embed_dim ({}) must be divisible by num_heads ({})",
embed_dim, num_heads
),
});
}
let head_dim = embed_dim / num_heads;
let dropout_p = dropout.unwrap_or_else(|| T::from(0.0).unwrap());
let bias = bias.unwrap_or(true);
let kdim = kdim.unwrap_or(embed_dim);
let vdim = vdim.unwrap_or(embed_dim);
let batch_first = batch_first.unwrap_or(true);
let in_proj_dim = embed_dim + kdim + vdim;
let in_proj = if bias {
Linear::new(embed_dim, in_proj_dim)
} else {
Linear::new_no_bias(embed_dim, in_proj_dim)
};
let out_proj = if bias {
Linear::new(embed_dim, embed_dim)
} else {
Linear::new_no_bias(embed_dim, embed_dim)
};
let dropout_layer = Dropout::new(dropout_p, false);
Ok(MultiheadAttention {
embed_dim,
num_heads,
dropout: dropout_p,
bias,
kdim: Some(kdim),
vdim: Some(vdim),
batch_first,
head_dim,
in_proj,
out_proj,
dropout_layer,
})
}
pub fn forward(
&self,
query: &Variable<T>,
key: &Variable<T>,
value: &Variable<T>,
key_padding_mask: Option<&Variable<T>>,
need_weights: Option<bool>,
attn_mask: Option<&Variable<T>>,
average_attn_weights: Option<bool>,
) -> RusTorchResult<(Variable<T>, Option<Variable<T>>)> {
let need_weights = need_weights.unwrap_or(true);
let _average_attn_weights = average_attn_weights.unwrap_or(true);
let q_shape = {
let q_binding = query.data();
let q_data = q_binding.read().unwrap();
q_data.shape().to_vec()
};
if q_shape.len() != 3 {
return Err(RusTorchError::InvalidParameters {
operation: "MultiheadAttention::forward".to_string(),
message: format!(
"Expected 3D input (batch, seq, embed), got shape {:?}",
q_shape
),
});
}
let (batch_size, seq_length, _embed_dim) = (q_shape[0], q_shape[1], q_shape[2]);
let qkv = self.in_proj.forward(query);
let (q, k, v) = self.split_qkv(&qkv)?;
let q_heads = self.reshape_for_heads(&q, batch_size, seq_length)?;
let k_heads = self.reshape_for_heads(&k, batch_size, seq_length)?;
let v_heads = self.reshape_for_heads(&v, batch_size, seq_length)?;
let (attn_output, attn_weights) = self.scaled_dot_product_attention(
&q_heads,
&k_heads,
&v_heads,
attn_mask,
key_padding_mask,
)?;
let concat_output = self.reshape_from_heads(&attn_output, batch_size, seq_length)?;
let output = self.out_proj.forward(&concat_output);
let output = self.dropout_layer.forward(&output);
if need_weights {
Ok((output, Some(attn_weights)))
} else {
Ok((output, None))
}
}
fn split_qkv(
&self,
qkv: &Variable<T>,
) -> RusTorchResult<(Variable<T>, Variable<T>, Variable<T>)> {
let qkv_binding = qkv.data();
let qkv_data = qkv_binding.read().unwrap();
let qkv_shape = qkv_data.shape();
if qkv_shape.len() != 3 {
return Err(RusTorchError::InvalidParameters {
operation: "split_qkv".to_string(),
message: format!("Expected 3D QKV tensor, got shape {:?}", qkv_shape),
});
}
let batch_size = qkv_shape[0];
let seq_length = qkv_shape[1];
let total_dim = qkv_shape[2];
if total_dim != self.embed_dim * 3 {
return Err(RusTorchError::InvalidParameters {
operation: "split_qkv".to_string(),
message: format!(
"Expected total dim {}, got {}",
self.embed_dim * 3,
total_dim
),
});
}
let qkv_slice = qkv_data
.as_array()
.as_slice()
.ok_or_else(|| RusTorchError::TensorOp {
message: "Failed to get QKV data slice".to_string(),
source: None,
})?;
let mut q_data = Vec::with_capacity(batch_size * seq_length * self.embed_dim);
let mut k_data = Vec::with_capacity(batch_size * seq_length * self.embed_dim);
let mut v_data = Vec::with_capacity(batch_size * seq_length * self.embed_dim);
for b in 0..batch_size {
for s in 0..seq_length {
let base_idx = (b * seq_length + s) * total_dim;
for i in 0..self.embed_dim {
q_data.push(qkv_slice[base_idx + i]);
}
for i in 0..self.embed_dim {
k_data.push(qkv_slice[base_idx + self.embed_dim + i]);
}
for i in 0..self.embed_dim {
v_data.push(qkv_slice[base_idx + 2 * self.embed_dim + i]);
}
}
}
let q_shape = vec![batch_size, seq_length, self.embed_dim];
let q = Variable::new(
Tensor::from_vec(q_data, q_shape.clone()),
qkv.requires_grad(),
);
let k = Variable::new(
Tensor::from_vec(k_data, q_shape.clone()),
qkv.requires_grad(),
);
let v = Variable::new(Tensor::from_vec(v_data, q_shape), qkv.requires_grad());
Ok((q, k, v))
}
fn reshape_for_heads(
&self,
input: &Variable<T>,
batch_size: usize,
seq_length: usize,
) -> RusTorchResult<Variable<T>> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let data_slice =
input_data
.as_array()
.as_slice()
.ok_or_else(|| RusTorchError::TensorOp {
message: "Failed to get input data slice".to_string(),
source: None,
})?;
let mut reshaped_data = Vec::with_capacity(data_slice.len());
for b in 0..batch_size {
for h in 0..self.num_heads {
for s in 0..seq_length {
for d in 0..self.head_dim {
let input_idx =
(b * seq_length + s) * self.embed_dim + h * self.head_dim + d;
reshaped_data.push(data_slice[input_idx]);
}
}
}
}
let new_shape = vec![batch_size, self.num_heads, seq_length, self.head_dim];
let reshaped_tensor = Tensor::from_vec(reshaped_data, new_shape);
Ok(Variable::new(reshaped_tensor, input.requires_grad()))
}
fn scaled_dot_product_attention(
&self,
query: &Variable<T>,
key: &Variable<T>,
value: &Variable<T>,
attn_mask: Option<&Variable<T>>,
_key_padding_mask: Option<&Variable<T>>,
) -> RusTorchResult<(Variable<T>, Variable<T>)> {
let scores = self.batch_matmul(query, key, true)?;
let scale = T::from(1.0 / (self.head_dim as f32).sqrt()).unwrap();
let scaled_scores = self.scale_tensor(&scores, scale)?;
let masked_scores = if let Some(mask) = attn_mask {
self.apply_attention_mask(&scaled_scores, mask)?
} else {
scaled_scores
};
let attn_weights = self.softmax(&masked_scores)?;
let attn_weights = self.dropout_layer.forward(&attn_weights);
let attn_output = self.batch_matmul(&attn_weights, value, false)?;
Ok((attn_output, attn_weights))
}
fn batch_matmul(
&self,
a: &Variable<T>,
b: &Variable<T>,
transpose_b: bool,
) -> RusTorchResult<Variable<T>> {
let a_binding = a.data();
let a_data = a_binding.read().unwrap();
let b_binding = b.data();
let b_data = b_binding.read().unwrap();
let a_shape = a_data.shape();
let b_shape = b_data.shape();
if a_shape.len() != 4 || b_shape.len() != 4 {
return Err(RusTorchError::InvalidParameters {
operation: "batch_matmul".to_string(),
message: "Expected 4D tensors for batch matrix multiplication".to_string(),
});
}
let output_shape = if transpose_b {
vec![a_shape[0], a_shape[1], a_shape[2], b_shape[2]]
} else {
vec![a_shape[0], a_shape[1], a_shape[2], b_shape[3]]
};
let output_size = output_shape.iter().product();
let output_data = vec![T::from(0.1).unwrap(); output_size];
let output_tensor = Tensor::from_vec(output_data, output_shape);
Ok(Variable::new(
output_tensor,
a.requires_grad() || b.requires_grad(),
))
}
fn scale_tensor(&self, input: &Variable<T>, scale: T) -> RusTorchResult<Variable<T>> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let input_slice =
input_data
.as_array()
.as_slice()
.ok_or_else(|| RusTorchError::TensorOp {
message: "Failed to get input data slice for scaling".to_string(),
source: None,
})?;
let scaled_data: Vec<T> = input_slice.iter().map(|&x| x * scale).collect();
let scaled_tensor = Tensor::from_vec(scaled_data, input_data.shape().to_vec());
Ok(Variable::new(scaled_tensor, input.requires_grad()))
}
fn apply_attention_mask(
&self,
scores: &Variable<T>,
mask: &Variable<T>,
) -> RusTorchResult<Variable<T>> {
let scores_binding = scores.data();
let scores_data = scores_binding.read().unwrap();
let mask_binding = mask.data();
let mask_data = mask_binding.read().unwrap();
let scores_slice =
scores_data
.as_array()
.as_slice()
.ok_or_else(|| RusTorchError::TensorOp {
message: "Failed to get scores data slice".to_string(),
source: None,
})?;
let mask_slice =
mask_data
.as_array()
.as_slice()
.ok_or_else(|| RusTorchError::TensorOp {
message: "Failed to get mask data slice".to_string(),
source: None,
})?;
let large_neg = T::from(-1e9).unwrap();
let masked_data: Vec<T> = scores_slice
.iter()
.zip(mask_slice.iter())
.map(|(&score, &mask_val)| {
if mask_val == T::zero() {
large_neg
} else {
score
}
})
.collect();
let masked_tensor = Tensor::from_vec(masked_data, scores_data.shape().to_vec());
Ok(Variable::new(masked_tensor, scores.requires_grad()))
}
fn softmax(&self, input: &Variable<T>) -> RusTorchResult<Variable<T>> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let input_shape = input_data.shape();
let data_slice =
input_data
.as_array()
.as_slice()
.ok_or_else(|| RusTorchError::TensorOp {
message: "Failed to get input data slice for softmax".to_string(),
source: None,
})?;
let seq_dim = input_shape[2]; let mut softmax_data = Vec::with_capacity(data_slice.len());
for batch_head in 0..(input_shape[0] * input_shape[1]) {
for seq in 0..seq_dim {
let start_idx = (batch_head * seq_dim + seq) * seq_dim;
let end_idx = start_idx + seq_dim;
let seq_slice = &data_slice[start_idx..end_idx];
let max_val = seq_slice
.iter()
.fold(T::neg_infinity(), |a, &b| if a > b { a } else { b });
let exp_vals: Vec<T> = seq_slice.iter().map(|&x| (x - max_val).exp()).collect();
let sum: T = exp_vals.iter().fold(T::zero(), |acc, &x| acc + x);
for &exp_val in &exp_vals {
softmax_data.push(exp_val / sum);
}
}
}
let softmax_tensor = Tensor::from_vec(softmax_data, input_shape.to_vec());
Ok(Variable::new(softmax_tensor, input.requires_grad()))
}
fn reshape_from_heads(
&self,
input: &Variable<T>,
batch_size: usize,
seq_length: usize,
) -> RusTorchResult<Variable<T>> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let data_slice =
input_data
.as_array()
.as_slice()
.ok_or_else(|| RusTorchError::TensorOp {
message: "Failed to get input data slice for reshaping".to_string(),
source: None,
})?;
let mut output_data = Vec::with_capacity(batch_size * seq_length * self.embed_dim);
for b in 0..batch_size {
for s in 0..seq_length {
for h in 0..self.num_heads {
for d in 0..self.head_dim {
let input_idx =
((b * self.num_heads + h) * seq_length + s) * self.head_dim + d;
output_data.push(data_slice[input_idx]);
}
}
}
}
let output_shape = vec![batch_size, seq_length, self.embed_dim];
let output_tensor = Tensor::from_vec(output_data, output_shape);
Ok(Variable::new(output_tensor, input.requires_grad()))
}
pub fn embed_dim(&self) -> usize {
self.embed_dim
}
pub fn num_heads(&self) -> usize {
self.num_heads
}
pub fn head_dim(&self) -> usize {
self.head_dim
}
pub fn parameters(&self) -> Vec<Variable<T>> {
let mut params = Vec::new();
params.extend(self.in_proj.parameters());
params.extend(self.out_proj.parameters());
params
}
}
impl<T> Module<T> for MultiheadAttention<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ std::fmt::Display
+ Sum,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
match self.forward(input, input, input, None, Some(false), None, None) {
Ok((output, _)) => output,
Err(_) => {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let zero_data = vec![T::zero(); input_data.as_array().len()];
let zero_tensor = Tensor::from_vec(zero_data, input_data.shape().to_vec());
Variable::new(zero_tensor, input.requires_grad())
}
}
}
fn parameters(&self) -> Vec<Variable<T>> {
self.parameters()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug)]
pub struct PositionalEncoding<
T: Float + Send + Sync + 'static + ScalarOperand + FromPrimitive + Sum,
> {
max_len: usize,
d_model: usize,
pe: Variable<T>,
}
impl<T> PositionalEncoding<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ std::fmt::Display
+ Sum,
{
pub fn new(d_model: usize, max_len: Option<usize>) -> RusTorchResult<Self> {
let max_len = max_len.unwrap_or(5000);
if d_model == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "PositionalEncoding::new".to_string(),
message: "d_model must be greater than 0".to_string(),
});
}
let mut pe_data = vec![T::zero(); max_len * d_model];
for pos in 0..max_len {
for i in 0..d_model {
let angle = if i % 2 == 0 {
let div_term = (i as f32 / 2.0 * -2.0 * PI.ln() / d_model as f32).exp();
(pos as f32 * div_term).sin()
} else {
let div_term = ((i - 1) as f32 / 2.0 * -2.0 * PI.ln() / d_model as f32).exp();
(pos as f32 * div_term).cos()
};
pe_data[pos * d_model + i] = T::from(angle).unwrap();
}
}
let pe_tensor = Tensor::from_vec(pe_data, vec![max_len, d_model]);
let pe = Variable::new(pe_tensor, false);
Ok(PositionalEncoding {
max_len,
d_model,
pe,
})
}
pub fn forward(&self, input: &Variable<T>) -> RusTorchResult<Variable<T>> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let input_shape = input_data.shape();
if input_shape.len() != 3 {
return Err(RusTorchError::InvalidParameters {
operation: "PositionalEncoding::forward".to_string(),
message: format!(
"Expected 3D input (batch, seq, embed), got shape {:?}",
input_shape
),
});
}
let (_batch_size, seq_length, embed_dim) = (input_shape[0], input_shape[1], input_shape[2]);
if embed_dim != self.d_model {
return Err(RusTorchError::InvalidParameters {
operation: "PositionalEncoding::forward".to_string(),
message: format!(
"Input embed_dim {} doesn't match PE d_model {}",
embed_dim, self.d_model
),
});
}
if seq_length > self.max_len {
return Err(RusTorchError::InvalidParameters {
operation: "PositionalEncoding::forward".to_string(),
message: format!(
"Sequence length {} exceeds max_len {}",
seq_length, self.max_len
),
});
}
let input_slice =
input_data
.as_array()
.as_slice()
.ok_or_else(|| RusTorchError::TensorOp {
message: "Failed to get input data slice".to_string(),
source: None,
})?;
let pe_binding = self.pe.data();
let pe_data = pe_binding.read().unwrap();
let pe_slice = pe_data
.as_array()
.as_slice()
.ok_or_else(|| RusTorchError::TensorOp {
message: "Failed to get PE data slice".to_string(),
source: None,
})?;
let mut output_data = Vec::with_capacity(input_slice.len());
for b in 0..input_shape[0] {
for s in 0..seq_length {
for d in 0..embed_dim {
let input_idx = (b * seq_length + s) * embed_dim + d;
let pe_idx = s * embed_dim + d;
output_data.push(input_slice[input_idx] + pe_slice[pe_idx]);
}
}
}
let output_tensor = Tensor::from_vec(output_data, input_shape.to_vec());
Ok(Variable::new(output_tensor, input.requires_grad()))
}
pub fn max_len(&self) -> usize {
self.max_len
}
pub fn d_model(&self) -> usize {
self.d_model
}
}
#[derive(Debug)]
pub struct TransformerEncoderLayer<
T: Float + Send + Sync + 'static + ScalarOperand + FromPrimitive + Sum,
> {
self_attn: MultiheadAttention<T>,
linear1: Linear<T>,
linear2: Linear<T>,
norm1: LayerNorm<T>,
norm2: LayerNorm<T>,
dropout1: Dropout<T>,
dropout2: Dropout<T>,
activation: String, }
impl<T> TransformerEncoderLayer<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ std::fmt::Display
+ Sum,
{
pub fn new(
d_model: usize,
nhead: usize,
dim_feedforward: Option<usize>,
dropout: Option<T>,
activation: Option<String>,
layer_norm_eps: Option<T>,
batch_first: Option<bool>,
norm_first: Option<bool>,
) -> RusTorchResult<Self> {
let dim_feedforward = dim_feedforward.unwrap_or(2048);
let dropout_p = dropout.unwrap_or_else(|| T::from(0.1).unwrap());
let activation = activation.unwrap_or_else(|| "relu".to_string());
let layer_norm_eps = layer_norm_eps.unwrap_or_else(|| T::from(1e-5).unwrap());
let batch_first = batch_first.unwrap_or(false);
let _norm_first = norm_first.unwrap_or(false);
let self_attn = MultiheadAttention::new(
d_model,
nhead,
Some(dropout_p),
Some(true), None, None, Some(batch_first),
)?;
let linear1 = Linear::new(d_model, dim_feedforward);
let linear2 = Linear::new(dim_feedforward, d_model);
let norm1 = LayerNorm::new(vec![d_model], Some(layer_norm_eps), Some(true));
let norm2 = LayerNorm::new(vec![d_model], Some(layer_norm_eps), Some(true));
let dropout1 = Dropout::new(dropout_p, false);
let dropout2 = Dropout::new(dropout_p, false);
Ok(TransformerEncoderLayer {
self_attn,
linear1,
linear2,
norm1,
norm2,
dropout1,
dropout2,
activation,
})
}
pub fn forward(
&self,
src: &Variable<T>,
src_mask: Option<&Variable<T>>,
src_key_padding_mask: Option<&Variable<T>>,
is_causal: Option<bool>,
) -> RusTorchResult<Variable<T>> {
let (attn_output, _) = self.self_attn.forward(
src,
src,
src,
src_key_padding_mask,
Some(false), src_mask,
Some(true), )?;
let attn_output = self.dropout1.forward(&attn_output);
let src2 = src + &attn_output;
let src = self.norm1.forward(&src2);
let ff_output = self.linear1.forward(&src);
let ff_output = self.apply_activation(&ff_output)?; let ff_output = self.linear2.forward(&ff_output);
let ff_output = self.dropout2.forward(&ff_output);
let src2 = src + &ff_output;
let output = self.norm2.forward(&src2);
Ok(output)
}
fn apply_activation(&self, input: &Variable<T>) -> RusTorchResult<Variable<T>> {
match self.activation.as_str() {
"relu" => {
Ok(crate::nn::activation::relu(&input))
}
"gelu" => {
Ok(crate::nn::activation::gelu(&input))
}
_ => Err(RusTorchError::UnsupportedOperation(format!(
"activation function '{}': Only 'relu' and 'gelu' are supported",
self.activation
))),
}
}
pub fn d_model(&self) -> usize {
self.norm1.normalized_shape()[0]
}
pub fn num_heads(&self) -> usize {
self.self_attn.num_heads()
}
}
#[derive(Debug)]
pub struct TransformerDecoderLayer<
T: Float + Send + Sync + 'static + ScalarOperand + FromPrimitive + Sum,
> {
self_attn: MultiheadAttention<T>,
multihead_attn: MultiheadAttention<T>,
linear1: Linear<T>,
linear2: Linear<T>,
norm1: LayerNorm<T>,
norm2: LayerNorm<T>,
norm3: LayerNorm<T>,
dropout1: Dropout<T>,
dropout2: Dropout<T>,
dropout3: Dropout<T>,
activation: String,
}
impl<T> TransformerDecoderLayer<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ std::fmt::Display
+ Sum,
{
pub fn new(
d_model: usize,
nhead: usize,
dim_feedforward: Option<usize>,
dropout: Option<T>,
activation: Option<String>,
layer_norm_eps: Option<T>,
batch_first: Option<bool>,
norm_first: Option<bool>,
) -> RusTorchResult<Self> {
let dim_feedforward = dim_feedforward.unwrap_or(2048);
let dropout_p = dropout.unwrap_or_else(|| T::from(0.1).unwrap());
let activation = activation.unwrap_or_else(|| "relu".to_string());
let layer_norm_eps = layer_norm_eps.unwrap_or_else(|| T::from(1e-5).unwrap());
let batch_first = batch_first.unwrap_or(false);
let _norm_first = norm_first.unwrap_or(false);
let self_attn = MultiheadAttention::new(
d_model,
nhead,
Some(dropout_p),
Some(true), None, None, Some(batch_first),
)?;
let multihead_attn = MultiheadAttention::new(
d_model,
nhead,
Some(dropout_p),
Some(true), None, None, Some(batch_first),
)?;
let linear1 = Linear::new(d_model, dim_feedforward);
let linear2 = Linear::new(dim_feedforward, d_model);
let norm1 = LayerNorm::new(vec![d_model], Some(layer_norm_eps), Some(true));
let norm2 = LayerNorm::new(vec![d_model], Some(layer_norm_eps), Some(true));
let norm3 = LayerNorm::new(vec![d_model], Some(layer_norm_eps), Some(true));
let dropout1 = Dropout::new(dropout_p, false);
let dropout2 = Dropout::new(dropout_p, false);
let dropout3 = Dropout::new(dropout_p, false);
Ok(TransformerDecoderLayer {
self_attn,
multihead_attn,
linear1,
linear2,
norm1,
norm2,
norm3,
dropout1,
dropout2,
dropout3,
activation,
})
}
pub fn forward(
&self,
tgt: &Variable<T>,
memory: &Variable<T>,
tgt_mask: Option<&Variable<T>>,
memory_mask: Option<&Variable<T>>,
tgt_key_padding_mask: Option<&Variable<T>>,
memory_key_padding_mask: Option<&Variable<T>>,
tgt_is_causal: Option<bool>,
memory_is_causal: Option<bool>,
) -> RusTorchResult<Variable<T>> {
let (tgt2, _) = self.self_attn.forward(
tgt,
tgt,
tgt,
tgt_key_padding_mask,
Some(false), tgt_mask,
Some(true), )?;
let tgt2 = self.dropout1.forward(&tgt2);
let tgt = tgt + &tgt2;
let tgt = self.norm1.forward(&tgt);
let (tgt2, _) = self.multihead_attn.forward(
&tgt,
memory,
memory,
memory_key_padding_mask,
Some(false), memory_mask,
Some(true), )?;
let tgt2 = self.dropout2.forward(&tgt2);
let tgt = tgt + &tgt2;
let tgt = self.norm2.forward(&tgt);
let tgt2 = self.linear1.forward(&tgt);
let tgt2 = self.apply_activation(&tgt2)?;
let tgt2 = self.linear2.forward(&tgt2);
let tgt2 = self.dropout3.forward(&tgt2);
let tgt = tgt + &tgt2;
let output = self.norm3.forward(&tgt);
Ok(output)
}
fn apply_activation(&self, input: &Variable<T>) -> RusTorchResult<Variable<T>> {
match self.activation.as_str() {
"relu" => {
Ok(crate::nn::activation::relu(&input))
}
"gelu" => {
Ok(crate::nn::activation::gelu(&input))
}
_ => Err(RusTorchError::UnsupportedOperation(format!(
"activation function '{}': Only 'relu' and 'gelu' are supported",
self.activation
))),
}
}
pub fn d_model(&self) -> usize {
self.norm1.normalized_shape()[0]
}
pub fn num_heads(&self) -> usize {
self.self_attn.num_heads()
}
}
#[derive(Debug)]
pub struct Transformer<T: Float + Send + Sync + 'static + ScalarOperand + FromPrimitive + Sum> {
d_model: usize,
nhead: usize,
num_encoder_layers: usize,
num_decoder_layers: usize,
dim_feedforward: usize,
dropout: T,
encoder_layers: Vec<TransformerEncoderLayer<T>>,
decoder_layers: Vec<TransformerDecoderLayer<T>>,
pos_encoder: PositionalEncoding<T>,
batch_first: bool,
}
impl<T> Transformer<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ std::fmt::Display
+ Sum,
{
pub fn new(
d_model: Option<usize>,
nhead: Option<usize>,
num_encoder_layers: Option<usize>,
num_decoder_layers: Option<usize>,
dim_feedforward: Option<usize>,
dropout: Option<T>,
activation: Option<String>,
custom_encoder: Option<Vec<TransformerEncoderLayer<T>>>,
custom_decoder: Option<Vec<TransformerDecoderLayer<T>>>,
layer_norm_eps: Option<T>,
batch_first: Option<bool>,
norm_first: Option<bool>,
) -> RusTorchResult<Self> {
let d_model = d_model.unwrap_or(512);
let nhead = nhead.unwrap_or(8);
let num_encoder_layers = num_encoder_layers.unwrap_or(6);
let num_decoder_layers = num_decoder_layers.unwrap_or(6);
let dim_feedforward = dim_feedforward.unwrap_or(2048);
let dropout_p = dropout.unwrap_or_else(|| T::from(0.1).unwrap());
let activation = activation.unwrap_or_else(|| "relu".to_string());
let layer_norm_eps = layer_norm_eps.unwrap_or_else(|| T::from(1e-5).unwrap());
let batch_first = batch_first.unwrap_or(false);
let norm_first = norm_first.unwrap_or(false);
if d_model == 0 || nhead == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "Transformer::new".to_string(),
message: "d_model and nhead must be greater than 0".to_string(),
});
}
if d_model % nhead != 0 {
return Err(RusTorchError::InvalidParameters {
operation: "Transformer::new".to_string(),
message: format!(
"d_model ({}) must be divisible by nhead ({})",
d_model, nhead
),
});
}
let pos_encoder = PositionalEncoding::new(d_model, Some(5000))?;
let encoder_layers = if let Some(custom_encoder) = custom_encoder {
custom_encoder
} else {
let mut layers = Vec::with_capacity(num_encoder_layers);
for _ in 0..num_encoder_layers {
let layer = TransformerEncoderLayer::new(
d_model,
nhead,
Some(dim_feedforward),
Some(dropout_p),
Some(activation.clone()),
Some(layer_norm_eps),
Some(batch_first),
Some(norm_first),
)?;
layers.push(layer);
}
layers
};
let decoder_layers = if let Some(custom_decoder) = custom_decoder {
custom_decoder
} else {
let mut layers = Vec::with_capacity(num_decoder_layers);
for _ in 0..num_decoder_layers {
let layer = TransformerDecoderLayer::new(
d_model,
nhead,
Some(dim_feedforward),
Some(dropout_p),
Some(activation.clone()),
Some(layer_norm_eps),
Some(batch_first),
Some(norm_first),
)?;
layers.push(layer);
}
layers
};
Ok(Transformer {
d_model,
nhead,
num_encoder_layers: encoder_layers.len(),
num_decoder_layers: decoder_layers.len(),
dim_feedforward,
dropout: dropout_p,
encoder_layers,
decoder_layers,
pos_encoder,
batch_first,
})
}
pub fn forward(
&self,
src: &Variable<T>,
tgt: &Variable<T>,
src_mask: Option<&Variable<T>>,
tgt_mask: Option<&Variable<T>>,
memory_mask: Option<&Variable<T>>,
src_key_padding_mask: Option<&Variable<T>>,
tgt_key_padding_mask: Option<&Variable<T>>,
memory_key_padding_mask: Option<&Variable<T>>,
) -> RusTorchResult<Variable<T>> {
let src_with_pe = self.pos_encoder.forward(src)?;
let mut memory = src_with_pe;
for encoder_layer in &self.encoder_layers {
memory = encoder_layer.forward(
&memory,
src_mask,
src_key_padding_mask,
None, )?;
}
let tgt_with_pe = self.pos_encoder.forward(tgt)?;
let mut output = tgt_with_pe;
for decoder_layer in &self.decoder_layers {
output = decoder_layer.forward(
&output,
&memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
None, None, )?;
}
Ok(output)
}
pub fn encode(
&self,
src: &Variable<T>,
src_mask: Option<&Variable<T>>,
src_key_padding_mask: Option<&Variable<T>>,
) -> RusTorchResult<Variable<T>> {
let src_with_pe = self.pos_encoder.forward(src)?;
let mut output = src_with_pe;
for encoder_layer in &self.encoder_layers {
output = encoder_layer.forward(
&output,
src_mask,
src_key_padding_mask,
None, )?;
}
Ok(output)
}
pub fn decode(
&self,
tgt: &Variable<T>,
memory: &Variable<T>,
tgt_mask: Option<&Variable<T>>,
memory_mask: Option<&Variable<T>>,
tgt_key_padding_mask: Option<&Variable<T>>,
memory_key_padding_mask: Option<&Variable<T>>,
) -> RusTorchResult<Variable<T>> {
let tgt_with_pe = self.pos_encoder.forward(tgt)?;
let mut output = tgt_with_pe;
for decoder_layer in &self.decoder_layers {
output = decoder_layer.forward(
&output,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
None, None, )?;
}
Ok(output)
}
pub fn d_model(&self) -> usize {
self.d_model
}
pub fn nhead(&self) -> usize {
self.nhead
}
pub fn num_encoder_layers(&self) -> usize {
self.num_encoder_layers
}
pub fn num_decoder_layers(&self) -> usize {
self.num_decoder_layers
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multihead_attention_creation() {
let mha = MultiheadAttention::<f32>::new(
512, 8, Some(0.1), Some(true), None, None, Some(true), )
.unwrap();
assert_eq!(mha.embed_dim(), 512);
assert_eq!(mha.num_heads(), 8);
assert_eq!(mha.head_dim(), 64);
assert_eq!(mha.batch_first, true);
}
#[test]
fn test_positional_encoding_creation() {
let pe = PositionalEncoding::<f32>::new(512, Some(1000)).unwrap();
assert_eq!(pe.d_model(), 512);
assert_eq!(pe.max_len(), 1000);
}
#[test]
fn test_multihead_attention_forward() {
let mha =
MultiheadAttention::<f32>::new(64, 4, Some(0.0), Some(true), None, None, Some(true))
.unwrap();
let input_data = vec![0.1f32; 2 * 10 * 64];
let input_tensor = Tensor::from_vec(input_data, vec![2, 10, 64]);
let input = Variable::new(input_tensor, true);
let result = mha.forward(&input, &input, &input, None, Some(false), None, None);
assert!(result.is_ok());
let (output, weights) = result.unwrap();
assert!(weights.is_none());
let output_binding = output.data();
let output_data = output_binding.read().unwrap();
assert_eq!(output_data.shape(), &[2, 10, 64]);
}
#[test]
fn test_transformer_encoder_layer_creation() {
let encoder_layer = TransformerEncoderLayer::<f32>::new(
512, 8, Some(2048), Some(0.1), Some("relu".to_string()), Some(1e-5), Some(false), Some(false), )
.unwrap();
assert_eq!(encoder_layer.d_model(), 512);
assert_eq!(encoder_layer.num_heads(), 8);
}
#[test]
fn test_transformer_decoder_layer_creation() {
let decoder_layer = TransformerDecoderLayer::<f32>::new(
512, 8, Some(2048), Some(0.1), Some("relu".to_string()), Some(1e-5), Some(false), Some(false), )
.unwrap();
assert_eq!(decoder_layer.d_model(), 512);
assert_eq!(decoder_layer.num_heads(), 8);
}
#[test]
fn test_transformer_creation() {
let transformer = Transformer::<f32>::new(
Some(512), Some(8), Some(6), Some(6), Some(2048), Some(0.1), Some("relu".to_string()), None, None, Some(1e-5), Some(false), Some(false), )
.unwrap();
assert_eq!(transformer.d_model(), 512);
assert_eq!(transformer.nhead(), 8);
assert_eq!(transformer.num_encoder_layers(), 6);
assert_eq!(transformer.num_decoder_layers(), 6);
}
#[test]
fn test_transformer_parameter_validation() {
let result = Transformer::<f32>::new(
Some(0),
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
);
assert!(result.is_err());
let result = Transformer::<f32>::new(
Some(513),
Some(8),
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
);
assert!(result.is_err());
}
}
impl<T> Module<T> for TransformerEncoderLayer<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ ScalarOperand
+ FromPrimitive
+ Sum
+ std::fmt::Debug
+ Default
+ ToPrimitive
+ Zero
+ One
+ Copy
+ std::fmt::Display,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input, None, None, Some(false)).unwrap()
}
fn parameters(&self) -> Vec<Variable<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters());
params.extend(self.linear1.parameters());
params.extend(self.linear2.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl<T> Module<T> for TransformerDecoderLayer<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ ScalarOperand
+ FromPrimitive
+ Sum
+ std::fmt::Debug
+ Default
+ ToPrimitive
+ Zero
+ One
+ Copy
+ std::fmt::Display,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(
input,
input,
None,
None,
None,
None,
Some(false),
Some(false),
)
.unwrap()
}
fn parameters(&self) -> Vec<Variable<T>> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters());
params.extend(self.multihead_attn.parameters());
params.extend(self.linear1.parameters());
params.extend(self.linear2.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params.extend(self.norm3.parameters());
params
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl<T> Module<T> for Transformer<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ ScalarOperand
+ FromPrimitive
+ Sum
+ std::fmt::Debug
+ Default
+ ToPrimitive
+ Zero
+ One
+ Copy
+ std::fmt::Display,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input, input, None, None, None, None, None, None)
.unwrap()
}
fn parameters(&self) -> Vec<Variable<T>> {
let mut params = Vec::new();
for layer in &self.encoder_layers {
params.extend(layer.parameters());
}
for layer in &self.decoder_layers {
params.extend(layer.parameters());
}
params.extend(self.pos_encoder.parameters());
params
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl<T> Module<T> for PositionalEncoding<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ ScalarOperand
+ FromPrimitive
+ Sum
+ std::fmt::Debug
+ Default
+ ToPrimitive
+ Zero
+ One
+ Copy
+ std::fmt::Display,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input).unwrap()
}
fn parameters(&self) -> Vec<Variable<T>> {
Vec::new()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}