use crate::autograd::Variable;
use crate::error::{RusTorchError, RusTorchResult};
use crate::nn::attention::MultiheadAttention;
use crate::nn::{Dropout, LayerNorm, Linear, Module};
use crate::tensor::Tensor;
use ndarray::ScalarOperand;
use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug)]
pub struct TransformerEncoderLayer<
T: Float + Send + Sync + ndarray::ScalarOperand + num_traits::FromPrimitive + Sum,
> {
self_attention: MultiheadAttention<T>,
ff_linear1: Linear<T>,
ff_linear2: Linear<T>,
norm1: LayerNorm<T>,
norm2: LayerNorm<T>,
dropout1: Dropout<T>,
dropout2: Dropout<T>,
d_model: usize,
d_ff: usize,
}
impl<T> TransformerEncoderLayer<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ std::fmt::Display,
{
pub fn new(
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout: Option<T>,
) -> RusTorchResult<Self> {
if d_model == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "TransformerEncoderLayer::new".to_string(),
message: "d_model must be greater than 0".to_string(),
});
}
if num_heads == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "TransformerEncoderLayer::new".to_string(),
message: "num_heads must be greater than 0".to_string(),
});
}
if d_ff == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "TransformerEncoderLayer::new".to_string(),
message: "d_ff must be greater than 0".to_string(),
});
}
if d_model % num_heads != 0 {
return Err(RusTorchError::InvalidParameters {
operation: "TransformerEncoderLayer::new".to_string(),
message: format!(
"d_model ({}) must be divisible by num_heads ({})",
d_model, num_heads
),
});
}
let dropout_p = dropout.unwrap_or_else(|| T::from(0.1).unwrap());
let self_attention = MultiheadAttention::new(
d_model,
num_heads,
Some(dropout_p),
Some(true),
None,
None,
Some(false),
);
let ff_linear1 = Linear::new(d_model, d_ff);
let ff_linear2 = Linear::new(d_ff, d_model);
let norm1 = LayerNorm::new(vec![d_model], None, None);
let norm2 = LayerNorm::new(vec![d_model], None, None);
let dropout1 = Dropout::new(dropout_p, false);
let dropout2 = Dropout::new(dropout_p, false);
Ok(TransformerEncoderLayer {
self_attention,
ff_linear1,
ff_linear2,
norm1,
norm2,
dropout1,
dropout2,
d_model,
d_ff,
})
}
pub fn forward(&self, input: &Variable<T>, mask: Option<&Variable<T>>) -> Variable<T> {
let (attn_output, _) =
self.self_attention
.forward(input, input, input, mask, Some(false), None, Some(true));
let attn_output = self.dropout1.forward(&attn_output);
let attn_residual = self.add_tensors(input, &attn_output);
let norm1_output = self.norm1.forward(&attn_residual);
let ff_output = self.ff_linear1.forward(&norm1_output);
let ff_output = self.apply_relu(&ff_output); let ff_output = self.ff_linear2.forward(&ff_output);
let ff_output = self.dropout2.forward(&ff_output);
let ff_residual = self.add_tensors(&norm1_output, &ff_output);
let norm2_output = self.norm2.forward(&ff_residual);
norm2_output
}
fn add_tensors(&self, a: &Variable<T>, b: &Variable<T>) -> Variable<T> {
let a_binding = a.data();
let a_data = a_binding.read().unwrap();
let b_binding = b.data();
let b_data = b_binding.read().unwrap();
let a_array = a_data.as_array();
let b_array = b_data.as_array();
if a_data.shape() != b_data.shape() {
panic!(
"Cannot add tensors with different shapes: {:?} vs {:?}",
a_data.shape(),
b_data.shape()
);
}
let mut result_data = Vec::with_capacity(a_array.len());
if let (Some(a_slice), Some(b_slice)) = (a_array.as_slice(), b_array.as_slice()) {
for (&a_val, &b_val) in a_slice.iter().zip(b_slice.iter()) {
result_data.push(a_val + b_val);
}
} else {
for _i in 0..a_array.len() {
result_data.push(T::zero()); }
}
Variable::new(
Tensor::from_vec(result_data, a_data.shape().to_vec()),
a.requires_grad() || b.requires_grad(),
)
}
fn apply_relu(&self, input: &Variable<T>) -> Variable<T> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let input_array = input_data.as_array();
let mut output_data = Vec::with_capacity(input_array.len());
if let Some(input_slice) = input_array.as_slice() {
for &val in input_slice {
output_data.push(if val > T::zero() { val } else { T::zero() });
}
} else {
for _ in 0..input_array.len() {
output_data.push(T::zero());
}
}
Variable::new(
Tensor::from_vec(output_data, input_data.shape().to_vec()),
input.requires_grad(),
)
}
pub fn d_model(&self) -> usize {
self.d_model
}
pub fn d_ff(&self) -> usize {
self.d_ff
}
pub fn parameters(&self) -> Vec<Variable<T>> {
let mut params = Vec::new();
params.extend(self.self_attention.parameters());
params.extend(self.ff_linear1.parameters());
params.extend(self.ff_linear2.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params
}
}
impl<T> Module<T> for TransformerEncoderLayer<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ std::fmt::Display,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input, None)
}
fn parameters(&self) -> Vec<Variable<T>> {
self.parameters()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug)]
pub struct TransformerEncoder<
T: Float + Send + Sync + ndarray::ScalarOperand + num_traits::FromPrimitive + Sum,
> {
layers: Vec<TransformerEncoderLayer<T>>,
num_layers: usize,
d_model: usize,
}
impl<T> TransformerEncoder<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ std::fmt::Display,
{
pub fn new(
num_layers: usize,
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout: Option<T>,
) -> RusTorchResult<Self> {
if num_layers == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "TransformerEncoder::new".to_string(),
message: "num_layers must be greater than 0".to_string(),
});
}
let mut layers = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
layers.push(TransformerEncoderLayer::new(
d_model, num_heads, d_ff, dropout,
)?);
}
Ok(TransformerEncoder {
layers,
num_layers,
d_model,
})
}
pub fn forward(&self, input: &Variable<T>, mask: Option<&Variable<T>>) -> Variable<T> {
let mut x = input.clone();
for layer in &self.layers {
x = layer.forward(&x, mask);
}
x
}
pub fn num_layers(&self) -> usize {
self.num_layers
}
pub fn d_model(&self) -> usize {
self.d_model
}
pub fn parameters(&self) -> Vec<Variable<T>> {
let mut params = Vec::new();
for layer in &self.layers {
params.extend(layer.parameters());
}
params
}
}
impl<T> Module<T> for TransformerEncoder<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ std::fmt::Display,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input, None)
}
fn parameters(&self) -> Vec<Variable<T>> {
self.parameters()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug)]
pub struct TransformerDecoderLayer<
T: Float + Send + Sync + ndarray::ScalarOperand + num_traits::FromPrimitive + Sum,
> {
self_attention: MultiheadAttention<T>,
cross_attention: MultiheadAttention<T>,
ff_linear1: Linear<T>,
ff_linear2: Linear<T>,
norm1: LayerNorm<T>,
norm2: LayerNorm<T>,
norm3: LayerNorm<T>,
dropout1: Dropout<T>,
dropout2: Dropout<T>,
dropout3: Dropout<T>,
d_model: usize,
d_ff: usize,
}
impl<T> TransformerDecoderLayer<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ std::fmt::Display,
{
pub fn new(
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout: Option<T>,
) -> RusTorchResult<Self> {
if d_model == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "TransformerEncoderLayer::new".to_string(),
message: "d_model must be greater than 0".to_string(),
});
}
if num_heads == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "TransformerEncoderLayer::new".to_string(),
message: "num_heads must be greater than 0".to_string(),
});
}
if d_ff == 0 {
return Err(RusTorchError::InvalidParameters {
operation: "TransformerEncoderLayer::new".to_string(),
message: "d_ff must be greater than 0".to_string(),
});
}
if d_model % num_heads != 0 {
return Err(RusTorchError::InvalidParameters {
operation: "TransformerEncoderLayer::new".to_string(),
message: format!(
"d_model ({}) must be divisible by num_heads ({})",
d_model, num_heads
),
});
}
let dropout_p = dropout.unwrap_or_else(|| T::from(0.1).unwrap());
let self_attention = MultiheadAttention::new(
d_model,
num_heads,
Some(dropout_p),
Some(true),
None,
None,
Some(false),
);
let cross_attention = MultiheadAttention::new(
d_model,
num_heads,
Some(dropout_p),
Some(true),
None,
None,
Some(false),
);
let ff_linear1 = Linear::new(d_model, d_ff);
let ff_linear2 = Linear::new(d_ff, d_model);
let norm1 = LayerNorm::new(vec![d_model], None, None);
let norm2 = LayerNorm::new(vec![d_model], None, None);
let norm3 = LayerNorm::new(vec![d_model], None, None);
let dropout1 = Dropout::new(dropout_p, false);
let dropout2 = Dropout::new(dropout_p, false);
let dropout3 = Dropout::new(dropout_p, false);
Ok(TransformerDecoderLayer {
self_attention,
cross_attention,
ff_linear1,
ff_linear2,
norm1,
norm2,
norm3,
dropout1,
dropout2,
dropout3,
d_model,
d_ff,
})
}
pub fn forward(
&self,
target: &Variable<T>,
memory: &Variable<T>,
target_mask: Option<&Variable<T>>,
memory_mask: Option<&Variable<T>>,
) -> RusTorchResult<Variable<T>> {
let self_attn_output = self
.self_attention
.forward(
target,
target,
target,
target_mask,
Some(false),
None,
Some(true),
)
.0;
let self_attn_output = self.dropout1.forward(&self_attn_output);
let self_attn_residual = self.add_tensors(target, &self_attn_output);
let norm1_output = self.norm1.forward(&self_attn_residual);
let cross_attn_output = self
.cross_attention
.forward(
&norm1_output,
memory,
memory,
memory_mask,
Some(false),
None,
Some(true),
)
.0;
let cross_attn_output = self.dropout2.forward(&cross_attn_output);
let cross_attn_residual = self.add_tensors(&norm1_output, &cross_attn_output);
let norm2_output = self.norm2.forward(&cross_attn_residual);
let ff_output = self.ff_linear1.forward(&norm2_output);
let ff_output = self.apply_relu(&ff_output);
let ff_output = self.ff_linear2.forward(&ff_output);
let ff_output = self.dropout3.forward(&ff_output);
let ff_residual = self.add_tensors(&norm2_output, &ff_output);
let norm3_output = self.norm3.forward(&ff_residual);
Ok(norm3_output)
}
fn add_tensors(&self, a: &Variable<T>, b: &Variable<T>) -> Variable<T> {
let a_binding = a.data();
let a_data = a_binding.read().unwrap();
let b_binding = b.data();
let b_data = b_binding.read().unwrap();
let a_array = a_data.as_array();
let b_array = b_data.as_array();
if a_data.shape() != b_data.shape() {
panic!(
"Cannot add tensors with different shapes: {:?} vs {:?}",
a_data.shape(),
b_data.shape()
);
}
let mut result_data = Vec::with_capacity(a_array.len());
if let (Some(a_slice), Some(b_slice)) = (a_array.as_slice(), b_array.as_slice()) {
for (&a_val, &b_val) in a_slice.iter().zip(b_slice.iter()) {
result_data.push(a_val + b_val);
}
} else {
for _i in 0..a_array.len() {
result_data.push(T::zero()); }
}
Variable::new(
Tensor::from_vec(result_data, a_data.shape().to_vec()),
a.requires_grad() || b.requires_grad(),
)
}
fn apply_relu(&self, input: &Variable<T>) -> Variable<T> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let input_array = input_data.as_array();
let mut output_data = Vec::with_capacity(input_array.len());
if let Some(input_slice) = input_array.as_slice() {
for &val in input_slice {
output_data.push(if val > T::zero() { val } else { T::zero() });
}
} else {
for _ in 0..input_array.len() {
output_data.push(T::zero());
}
}
Variable::new(
Tensor::from_vec(output_data, input_data.shape().to_vec()),
input.requires_grad(),
)
}
pub fn d_model(&self) -> usize {
self.d_model
}
pub fn d_ff(&self) -> usize {
self.d_ff
}
pub fn parameters(&self) -> Vec<Variable<T>> {
let mut params = Vec::new();
params.extend(self.self_attention.parameters());
params.extend(self.cross_attention.parameters());
params.extend(self.ff_linear1.parameters());
params.extend(self.ff_linear2.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params.extend(self.norm3.parameters());
params
}
}
impl<T> Module<T> for TransformerDecoderLayer<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ std::fmt::Display,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input, input, None, None).unwrap()
}
fn parameters(&self) -> Vec<Variable<T>> {
self.parameters()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug)]
pub struct Transformer<
T: Float + Send + Sync + ndarray::ScalarOperand + num_traits::FromPrimitive + Sum,
> {
encoder: TransformerEncoder<T>,
decoder_layers: Vec<TransformerDecoderLayer<T>>,
num_decoder_layers: usize,
d_model: usize,
}
impl<T> Transformer<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ std::fmt::Display,
{
pub fn new(
num_encoder_layers: usize,
num_decoder_layers: usize,
d_model: usize,
num_heads: usize,
d_ff: usize,
dropout: Option<T>,
) -> RusTorchResult<Self> {
let encoder =
TransformerEncoder::new(num_encoder_layers, d_model, num_heads, d_ff, dropout)?;
let mut decoder_layers = Vec::with_capacity(num_decoder_layers);
for _ in 0..num_decoder_layers {
decoder_layers.push(TransformerDecoderLayer::new(
d_model, num_heads, d_ff, dropout,
)?);
}
Ok(Transformer {
encoder,
decoder_layers,
num_decoder_layers,
d_model,
})
}
pub fn forward(
&self,
src: &Variable<T>,
tgt: &Variable<T>,
src_mask: Option<&Variable<T>>,
tgt_mask: Option<&Variable<T>>,
memory_mask: Option<&Variable<T>>,
) -> Variable<T> {
let memory = self.encoder.forward(src, src_mask);
let mut x = tgt.clone();
for layer in &self.decoder_layers {
x = layer.forward(&x, &memory, tgt_mask, memory_mask).unwrap();
}
x
}
pub fn encode(&self, src: &Variable<T>, src_mask: Option<&Variable<T>>) -> Variable<T> {
self.encoder.forward(src, src_mask)
}
pub fn d_model(&self) -> usize {
self.d_model
}
pub fn num_encoder_layers(&self) -> usize {
self.encoder.num_layers()
}
pub fn num_decoder_layers(&self) -> usize {
self.num_decoder_layers
}
pub fn parameters(&self) -> Vec<Variable<T>> {
let mut params = Vec::new();
params.extend(self.encoder.parameters());
for layer in &self.decoder_layers {
params.extend(layer.parameters());
}
params
}
}
impl<T> Module<T> for Transformer<T>
where
T: Float
+ Debug
+ Default
+ FromPrimitive
+ ToPrimitive
+ Zero
+ One
+ 'static
+ Send
+ Sync
+ Copy
+ ScalarOperand
+ Sum
+ std::fmt::Display,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
self.forward(input, input, None, None, None)
}
fn parameters(&self) -> Vec<Variable<T>> {
self.parameters()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transformer_encoder_layer_creation() {
let layer = TransformerEncoderLayer::<f32>::new(512, 8, 2048, None).unwrap();
assert_eq!(layer.d_model(), 512);
assert_eq!(layer.d_ff(), 2048);
let params = layer.parameters();
assert!(params.len() > 0); }
#[test]
fn test_transformer_encoder_creation() {
let encoder = TransformerEncoder::<f32>::new(6, 512, 8, 2048, None).unwrap();
assert_eq!(encoder.num_layers(), 6);
assert_eq!(encoder.d_model(), 512);
let params = encoder.parameters();
assert!(params.len() > 0);
}
#[test]
fn test_transformer_creation() {
let transformer = Transformer::<f32>::new(6, 6, 512, 8, 2048, None).unwrap();
assert_eq!(transformer.num_encoder_layers(), 6);
assert_eq!(transformer.num_decoder_layers(), 6);
assert_eq!(transformer.d_model(), 512);
let params = transformer.parameters();
assert!(params.len() > 0);
}
}