use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
use crate::layers::attention::MultiHeadAttention;
use crate::layers::linear::Linear;
use crate::layers::norm::LayerNorm;
use crate::module::Module;
use crate::parameter::Parameter;
pub struct TransformerEncoderLayer {
self_attn: MultiHeadAttention,
linear1: Linear,
linear2: Linear,
norm1: LayerNorm,
norm2: LayerNorm,
d_model: usize,
pre_norm: bool,
}
impl TransformerEncoderLayer {
pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
}
pub fn new_with_pre_norm(
d_model: usize,
nhead: usize,
dim_feedforward: usize,
pre_norm: bool,
) -> Self {
Self {
self_attn: MultiHeadAttention::new(d_model, nhead),
linear1: Linear::new(d_model, dim_feedforward),
linear2: Linear::new(dim_feedforward, d_model),
norm1: LayerNorm::single(d_model),
norm2: LayerNorm::single(d_model),
d_model,
pre_norm,
}
}
pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
if self.pre_norm {
let normed = self.norm1.forward(src);
let attn_out = self
.self_attn
.attention(&normed, &normed, &normed, src_mask);
let x = src.add_var(&attn_out);
let normed = self.norm2.forward(&x);
let ff_out = self.linear1.forward(&normed).relu();
let ff_out = self.linear2.forward(&ff_out);
x.add_var(&ff_out)
} else {
let attn_out = self.self_attn.attention(src, src, src, src_mask);
let x = src.add_var(&attn_out);
let x = self.norm1.forward(&x);
let ff_out = self.linear1.forward(&x).relu();
let ff_out = self.linear2.forward(&ff_out);
let x = x.add_var(&ff_out);
self.norm2.forward(&x)
}
}
pub fn d_model(&self) -> usize {
self.d_model
}
}
impl Module for TransformerEncoderLayer {
fn forward(&self, input: &Variable) -> Variable {
self.forward_with_mask(input, None)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters());
params.extend(self.linear1.parameters());
params.extend(self.linear2.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (name, param) in self.self_attn.named_parameters() {
params.insert(format!("self_attn.{name}"), param);
}
for (name, param) in self.linear1.named_parameters() {
params.insert(format!("linear1.{name}"), param);
}
for (name, param) in self.linear2.named_parameters() {
params.insert(format!("linear2.{name}"), param);
}
for (name, param) in self.norm1.named_parameters() {
params.insert(format!("norm1.{name}"), param);
}
for (name, param) in self.norm2.named_parameters() {
params.insert(format!("norm2.{name}"), param);
}
params
}
fn name(&self) -> &'static str {
"TransformerEncoderLayer"
}
}
pub struct TransformerDecoderLayer {
self_attn: MultiHeadAttention,
cross_attn: MultiHeadAttention,
linear1: Linear,
linear2: Linear,
norm1: LayerNorm,
norm2: LayerNorm,
norm3: LayerNorm,
d_model: usize,
pre_norm: bool,
}
impl TransformerDecoderLayer {
pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
}
pub fn new_with_pre_norm(
d_model: usize,
nhead: usize,
dim_feedforward: usize,
pre_norm: bool,
) -> Self {
Self {
self_attn: MultiHeadAttention::new(d_model, nhead),
cross_attn: MultiHeadAttention::new(d_model, nhead),
linear1: Linear::new(d_model, dim_feedforward),
linear2: Linear::new(dim_feedforward, d_model),
norm1: LayerNorm::single(d_model),
norm2: LayerNorm::single(d_model),
norm3: LayerNorm::single(d_model),
d_model,
pre_norm,
}
}
pub fn forward_with_memory(
&self,
tgt: &Variable,
memory: &Variable,
tgt_mask: Option<&Variable>,
memory_mask: Option<&Variable>,
) -> Variable {
if self.pre_norm {
let normed = self.norm1.forward(tgt);
let self_attn_out = self
.self_attn
.attention(&normed, &normed, &normed, tgt_mask);
let x = tgt.add_var(&self_attn_out);
let normed = self.norm2.forward(&x);
let cross_attn_out = self
.cross_attn
.attention(&normed, memory, memory, memory_mask);
let x = x.add_var(&cross_attn_out);
let normed = self.norm3.forward(&x);
let ff_out = self.linear1.forward(&normed).relu();
let ff_out = self.linear2.forward(&ff_out);
x.add_var(&ff_out)
} else {
let self_attn_out = self.self_attn.attention(tgt, tgt, tgt, tgt_mask);
let x = tgt.add_var(&self_attn_out);
let x = self.norm1.forward(&x);
let cross_attn_out = self.cross_attn.attention(&x, memory, memory, memory_mask);
let x = x.add_var(&cross_attn_out);
let x = self.norm2.forward(&x);
let ff_out = self.linear1.forward(&x).relu();
let ff_out = self.linear2.forward(&ff_out);
let x = x.add_var(&ff_out);
self.norm3.forward(&x)
}
}
pub fn d_model(&self) -> usize {
self.d_model
}
}
impl Module for TransformerDecoderLayer {
fn forward(&self, input: &Variable) -> Variable {
if self.pre_norm {
let normed = self.norm1.forward(input);
let self_attn_out = self.self_attn.attention(&normed, &normed, &normed, None);
let x = input.add_var(&self_attn_out);
let normed = self.norm3.forward(&x);
let ff_out = self.linear1.forward(&normed).relu();
let ff_out = self.linear2.forward(&ff_out);
x.add_var(&ff_out)
} else {
let self_attn_out = self.self_attn.attention(input, input, input, None);
let x = input.add_var(&self_attn_out);
let x = self.norm1.forward(&x);
let x_after_norm2 = self.norm2.forward(&x);
let ff_out = self.linear1.forward(&x_after_norm2).relu();
let ff_out = self.linear2.forward(&ff_out);
let x = x_after_norm2.add_var(&ff_out);
self.norm3.forward(&x)
}
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters());
params.extend(self.cross_attn.parameters());
params.extend(self.linear1.parameters());
params.extend(self.linear2.parameters());
params.extend(self.norm1.parameters());
params.extend(self.norm2.parameters());
params.extend(self.norm3.parameters());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (name, param) in self.self_attn.named_parameters() {
params.insert(format!("self_attn.{name}"), param);
}
for (name, param) in self.cross_attn.named_parameters() {
params.insert(format!("cross_attn.{name}"), param);
}
for (name, param) in self.linear1.named_parameters() {
params.insert(format!("linear1.{name}"), param);
}
for (name, param) in self.linear2.named_parameters() {
params.insert(format!("linear2.{name}"), param);
}
for (name, param) in self.norm1.named_parameters() {
params.insert(format!("norm1.{name}"), param);
}
for (name, param) in self.norm2.named_parameters() {
params.insert(format!("norm2.{name}"), param);
}
for (name, param) in self.norm3.named_parameters() {
params.insert(format!("norm3.{name}"), param);
}
params
}
fn name(&self) -> &'static str {
"TransformerDecoderLayer"
}
}
pub struct TransformerEncoder {
layers: Vec<TransformerEncoderLayer>,
norm: Option<LayerNorm>,
}
impl TransformerEncoder {
pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
}
pub fn new_with_pre_norm(
d_model: usize,
nhead: usize,
dim_feedforward: usize,
num_layers: usize,
pre_norm: bool,
) -> Self {
let layers = (0..num_layers)
.map(|_| {
TransformerEncoderLayer::new_with_pre_norm(
d_model,
nhead,
dim_feedforward,
pre_norm,
)
})
.collect();
Self {
layers,
norm: Some(LayerNorm::single(d_model)),
}
}
pub fn without_norm(
d_model: usize,
nhead: usize,
dim_feedforward: usize,
num_layers: usize,
) -> Self {
let layers = (0..num_layers)
.map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward))
.collect();
Self { layers, norm: None }
}
pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
let mut x = src.clone();
for layer in &self.layers {
x = layer.forward_with_mask(&x, src_mask);
}
if let Some(ref norm) = self.norm {
x = norm.forward(&x);
}
x
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
}
impl Module for TransformerEncoder {
fn forward(&self, input: &Variable) -> Variable {
self.forward_with_mask(input, None)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
if let Some(ref norm) = self.norm {
params.extend(norm.parameters());
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (i, layer) in self.layers.iter().enumerate() {
for (name, param) in layer.named_parameters() {
params.insert(format!("layers.{i}.{name}"), param);
}
}
if let Some(ref norm) = self.norm {
for (name, param) in norm.named_parameters() {
params.insert(format!("norm.{name}"), param);
}
}
params
}
fn name(&self) -> &'static str {
"TransformerEncoder"
}
}
pub struct TransformerDecoder {
layers: Vec<TransformerDecoderLayer>,
norm: Option<LayerNorm>,
}
impl TransformerDecoder {
pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
}
pub fn new_with_pre_norm(
d_model: usize,
nhead: usize,
dim_feedforward: usize,
num_layers: usize,
pre_norm: bool,
) -> Self {
let layers = (0..num_layers)
.map(|_| {
TransformerDecoderLayer::new_with_pre_norm(
d_model,
nhead,
dim_feedforward,
pre_norm,
)
})
.collect();
Self {
layers,
norm: Some(LayerNorm::single(d_model)),
}
}
pub fn without_norm(
d_model: usize,
nhead: usize,
dim_feedforward: usize,
num_layers: usize,
) -> Self {
let layers = (0..num_layers)
.map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward))
.collect();
Self { layers, norm: None }
}
pub fn forward_with_memory(
&self,
tgt: &Variable,
memory: &Variable,
tgt_mask: Option<&Variable>,
memory_mask: Option<&Variable>,
) -> Variable {
let mut x = tgt.clone();
for layer in &self.layers {
x = layer.forward_with_memory(&x, memory, tgt_mask, memory_mask);
}
if let Some(ref norm) = self.norm {
x = norm.forward(&x);
}
x
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
}
impl Module for TransformerDecoder {
fn forward(&self, input: &Variable) -> Variable {
let mut x = input.clone();
for layer in &self.layers {
x = layer.forward(&x);
}
if let Some(ref norm) = self.norm {
x = norm.forward(&x);
}
x
}
fn parameters(&self) -> Vec<Parameter> {
let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
if let Some(ref norm) = self.norm {
params.extend(norm.parameters());
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (i, layer) in self.layers.iter().enumerate() {
for (name, param) in layer.named_parameters() {
params.insert(format!("layers.{i}.{name}"), param);
}
}
if let Some(ref norm) = self.norm {
for (name, param) in norm.named_parameters() {
params.insert(format!("norm.{name}"), param);
}
}
params
}
fn name(&self) -> &'static str {
"TransformerDecoder"
}
}
pub struct Seq2SeqTransformer {
encoder: TransformerEncoder,
decoder: TransformerDecoder,
d_model: usize,
nhead: usize,
}
impl Seq2SeqTransformer {
pub fn new(
d_model: usize,
nhead: usize,
num_encoder_layers: usize,
num_decoder_layers: usize,
dim_feedforward: usize,
) -> Self {
Self {
encoder: TransformerEncoder::new(d_model, nhead, dim_feedforward, num_encoder_layers),
decoder: TransformerDecoder::new(d_model, nhead, dim_feedforward, num_decoder_layers),
d_model,
nhead,
}
}
pub fn new_pre_norm(
d_model: usize,
nhead: usize,
num_encoder_layers: usize,
num_decoder_layers: usize,
dim_feedforward: usize,
) -> Self {
Self {
encoder: TransformerEncoder::new_with_pre_norm(
d_model,
nhead,
dim_feedforward,
num_encoder_layers,
true,
),
decoder: TransformerDecoder::new_with_pre_norm(
d_model,
nhead,
dim_feedforward,
num_decoder_layers,
true,
),
d_model,
nhead,
}
}
pub fn default_config(d_model: usize, nhead: usize) -> Self {
Self::new(d_model, nhead, 6, 6, 2048)
}
pub fn forward_seq2seq(
&self,
src: &Variable,
tgt: &Variable,
src_mask: Option<&Variable>,
tgt_mask: Option<&Variable>,
memory_mask: Option<&Variable>,
) -> Variable {
let memory = self.encoder.forward_with_mask(src, src_mask);
self.decoder
.forward_with_memory(tgt, &memory, tgt_mask, memory_mask)
}
pub fn encode(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
self.encoder.forward_with_mask(src, src_mask)
}
pub fn decode(
&self,
tgt: &Variable,
memory: &Variable,
tgt_mask: Option<&Variable>,
memory_mask: Option<&Variable>,
) -> Variable {
self.decoder
.forward_with_memory(tgt, memory, tgt_mask, memory_mask)
}
pub fn generate_square_subsequent_mask(seq_len: usize) -> Variable {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..=i {
mask_data[i * seq_len + j] = 1.0;
}
}
Variable::new(
Tensor::from_vec(mask_data, &[seq_len, seq_len]).unwrap(),
false,
)
}
pub fn d_model(&self) -> usize {
self.d_model
}
pub fn nhead(&self) -> usize {
self.nhead
}
pub fn encoder(&self) -> &TransformerEncoder {
&self.encoder
}
pub fn decoder(&self) -> &TransformerDecoder {
&self.decoder
}
}
impl Module for Seq2SeqTransformer {
fn forward(&self, input: &Variable) -> Variable {
self.encoder.forward(input)
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = self.encoder.parameters();
params.extend(self.decoder.parameters());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (name, param) in self.encoder.named_parameters() {
params.insert(format!("encoder.{name}"), param);
}
for (name, param) in self.decoder.named_parameters() {
params.insert(format!("decoder.{name}"), param);
}
params
}
fn name(&self) -> &'static str {
"Seq2SeqTransformer"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encoder_layer_creation() {
let layer = TransformerEncoderLayer::new(64, 4, 256);
assert_eq!(layer.d_model(), 64);
}
#[test]
fn test_encoder_layer_forward() {
let layer = TransformerEncoderLayer::new(64, 4, 256);
let input = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
false,
);
let output = layer.forward(&input);
assert_eq!(output.shape(), vec![2, 10, 64]);
}
#[test]
fn test_decoder_layer_with_memory() {
let layer = TransformerDecoderLayer::new(64, 4, 256);
let tgt = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
false,
);
let memory = Variable::new(
Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
false,
);
let output = layer.forward_with_memory(&tgt, &memory, None, None);
assert_eq!(output.shape(), vec![2, 5, 64]);
}
#[test]
fn test_encoder_stack() {
let encoder = TransformerEncoder::new(64, 4, 256, 3);
assert_eq!(encoder.num_layers(), 3);
let input = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 8 * 64], &[2, 8, 64]).unwrap(),
false,
);
let output = encoder.forward(&input);
assert_eq!(output.shape(), vec![2, 8, 64]);
}
#[test]
fn test_decoder_stack() {
let decoder = TransformerDecoder::new(64, 4, 256, 3);
assert_eq!(decoder.num_layers(), 3);
let tgt = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
false,
);
let memory = Variable::new(
Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
false,
);
let output = decoder.forward_with_memory(&tgt, &memory, None, None);
assert_eq!(output.shape(), vec![2, 5, 64]);
}
#[test]
fn test_seq2seq_transformer() {
let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
assert_eq!(transformer.d_model(), 64);
assert_eq!(transformer.nhead(), 4);
let src = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
false,
);
let tgt = Variable::new(
Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
false,
);
let output = transformer.forward_seq2seq(&src, &tgt, None, None, None);
assert_eq!(output.shape(), vec![2, 5, 64]);
}
#[test]
fn test_seq2seq_encode_decode_separate() {
let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
let src = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
false,
);
let tgt = Variable::new(
Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
false,
);
let memory = transformer.encode(&src, None);
assert_eq!(memory.shape(), vec![2, 10, 64]);
let output = transformer.decode(&tgt, &memory, None, None);
assert_eq!(output.shape(), vec![2, 5, 64]);
}
#[test]
fn test_causal_mask() {
let mask = Seq2SeqTransformer::generate_square_subsequent_mask(4);
let mask_data = mask.data().to_vec();
assert_eq!(mask_data[0], 1.0); assert_eq!(mask_data[1], 0.0); assert_eq!(mask_data[4], 1.0); assert_eq!(mask_data[5], 1.0); assert_eq!(mask_data[6], 0.0); assert_eq!(mask_data[15], 1.0); }
#[test]
fn test_default_config() {
let transformer = Seq2SeqTransformer::default_config(512, 8);
assert_eq!(transformer.encoder().num_layers(), 6);
assert_eq!(transformer.decoder().num_layers(), 6);
}
#[test]
fn test_parameter_count() {
let layer = TransformerEncoderLayer::new(64, 4, 256);
let params = layer.parameters();
assert_eq!(params.len(), 16);
}
#[test]
fn test_decoder_parameter_count() {
let layer = TransformerDecoderLayer::new(64, 4, 256);
let params = layer.parameters();
assert_eq!(params.len(), 26);
}
#[test]
fn test_named_parameters_hierarchy() {
let transformer = Seq2SeqTransformer::new(64, 4, 1, 1, 256);
let named = transformer.named_parameters();
assert!(named.contains_key("encoder.layers.0.self_attn.q_proj.weight"));
assert!(named.contains_key("decoder.layers.0.cross_attn.q_proj.weight"));
assert!(named.contains_key("encoder.norm.weight"));
assert!(named.contains_key("decoder.norm.weight"));
}
#[test]
fn test_seq2seq_with_causal_mask() {
let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
let src = Variable::new(
Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
false,
);
let tgt = Variable::new(
Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
false,
);
let tgt_mask = Seq2SeqTransformer::generate_square_subsequent_mask(5);
let output = transformer.forward_seq2seq(&src, &tgt, None, Some(&tgt_mask), None);
assert_eq!(output.shape(), vec![2, 5, 64]);
}
}