use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
pub struct TransformerEncoderLayer {
base: ModuleBase,
d_model: usize,
nhead: usize,
dim_feedforward: usize,
dropout: f32,
activation: String,
layer_norm_eps: f32,
batch_first: bool,
norm_first: bool,
}
impl TransformerEncoderLayer {
pub fn new(
d_model: usize,
nhead: usize,
dim_feedforward: Option<usize>,
dropout: Option<f32>,
activation: Option<String>,
layer_norm_eps: Option<f32>,
batch_first: Option<bool>,
norm_first: Option<bool>,
) -> Result<Self> {
let mut base = ModuleBase::new();
let dim_feedforward = dim_feedforward.unwrap_or(2048);
let dropout = dropout.unwrap_or(0.1);
let activation = activation.unwrap_or_else(|| "relu".to_string());
let layer_norm_eps = layer_norm_eps.unwrap_or(1e-5);
let batch_first = batch_first.unwrap_or(false);
let norm_first = norm_first.unwrap_or(false);
let self_attn_weight = crate::init::xavier_uniform(&[3 * d_model, d_model])?;
let self_attn_out_weight = crate::init::xavier_uniform(&[d_model, d_model])?;
base.register_parameter(
"self_attn.in_proj_weight".to_string(),
Parameter::new(self_attn_weight),
);
base.register_parameter(
"self_attn.out_proj.weight".to_string(),
Parameter::new(self_attn_out_weight),
);
let self_attn_bias = zeros(&[3 * d_model])?;
let self_attn_out_bias = zeros(&[d_model])?;
base.register_parameter(
"self_attn.in_proj_bias".to_string(),
Parameter::new(self_attn_bias),
);
base.register_parameter(
"self_attn.out_proj.bias".to_string(),
Parameter::new(self_attn_out_bias),
);
let linear1_weight = crate::init::xavier_uniform(&[dim_feedforward, d_model])?;
let linear2_weight = crate::init::xavier_uniform(&[d_model, dim_feedforward])?;
base.register_parameter("linear1.weight".to_string(), Parameter::new(linear1_weight));
base.register_parameter("linear2.weight".to_string(), Parameter::new(linear2_weight));
let linear1_bias = zeros(&[dim_feedforward])?;
let linear2_bias = zeros(&[d_model])?;
base.register_parameter("linear1.bias".to_string(), Parameter::new(linear1_bias));
base.register_parameter("linear2.bias".to_string(), Parameter::new(linear2_bias));
let norm1_weight = ones(&[d_model])?;
let norm1_bias = zeros(&[d_model])?;
let norm2_weight = ones(&[d_model])?;
let norm2_bias = zeros(&[d_model])?;
base.register_parameter("norm1.weight".to_string(), Parameter::new(norm1_weight));
base.register_parameter("norm1.bias".to_string(), Parameter::new(norm1_bias));
base.register_parameter("norm2.weight".to_string(), Parameter::new(norm2_weight));
base.register_parameter("norm2.bias".to_string(), Parameter::new(norm2_bias));
Ok(Self {
base,
d_model,
nhead,
dim_feedforward,
dropout,
activation,
layer_norm_eps,
batch_first,
norm_first,
})
}
fn self_attention(&self, src: &Tensor, src_mask: Option<&Tensor>) -> Result<Tensor> {
let in_proj_weight = self.base.parameters["self_attn.in_proj_weight"]
.tensor()
.read()
.clone();
let out_proj_weight = self.base.parameters["self_attn.out_proj.weight"]
.tensor()
.read()
.clone();
let in_proj_bias = self.base.parameters["self_attn.in_proj_bias"]
.tensor()
.read()
.clone();
let out_proj_bias = self.base.parameters["self_attn.out_proj.bias"]
.tensor()
.read()
.clone();
let qkv = src
.matmul(&in_proj_weight.transpose(0, 1)?)?
.add_op(&in_proj_bias)?;
let src_shape_binding = src.shape();
let src_shape = src_shape_binding.dims();
let batch_size = if self.batch_first {
src_shape[0]
} else {
src_shape[1]
};
let seq_len = if self.batch_first {
src_shape[1]
} else {
src_shape[0]
};
let chunk_size = self.d_model;
let q = qkv.narrow((qkv.shape().ndim() - 1) as i32, 0, chunk_size)?;
let k = qkv.narrow(
(qkv.shape().ndim() - 1) as i32,
chunk_size as i64,
chunk_size,
)?;
let v = qkv.narrow(
(qkv.shape().ndim() - 1) as i32,
(2 * chunk_size) as i64,
chunk_size,
)?;
let head_dim = self.d_model / self.nhead;
let q = if self.batch_first {
q.reshape(&[
batch_size as i32,
seq_len as i32,
self.nhead as i32,
head_dim as i32,
])?
.transpose(1, 2)?
} else {
q.reshape(&[
seq_len as i32,
batch_size as i32,
self.nhead as i32,
head_dim as i32,
])?
.transpose(0, 2)?
.transpose(1, 3)?
};
let k = if self.batch_first {
k.reshape(&[
batch_size as i32,
seq_len as i32,
self.nhead as i32,
head_dim as i32,
])?
.transpose(1, 2)?
} else {
k.reshape(&[
seq_len as i32,
batch_size as i32,
self.nhead as i32,
head_dim as i32,
])?
.transpose(0, 2)?
.transpose(1, 3)?
};
let v = if self.batch_first {
v.reshape(&[
batch_size as i32,
seq_len as i32,
self.nhead as i32,
head_dim as i32,
])?
.transpose(1, 2)?
} else {
v.reshape(&[
seq_len as i32,
batch_size as i32,
self.nhead as i32,
head_dim as i32,
])?
.transpose(0, 2)?
.transpose(1, 3)?
};
let d_k = head_dim as f32;
let scale = 1.0 / d_k.sqrt();
let scores = q.matmul(&k.transpose(-2, -1)?)?.mul_scalar(scale)?;
let scores = if let Some(mask) = src_mask {
let large_neg = mask.mul_scalar(-1e9)?;
scores.add_op(&large_neg)?
} else {
scores
};
let attn_weights = scores.softmax(-1)?;
let attn_weights = if self.dropout > 0.0 && self.training() {
crate::functional::dropout(&attn_weights, self.dropout, self.training())?
} else {
attn_weights
};
let attn_output = attn_weights.matmul(&v)?;
let attn_output = if self.batch_first {
attn_output.transpose(1, 2)?.contiguous()?.reshape(&[
batch_size as i32,
seq_len as i32,
self.d_model as i32,
])?
} else {
attn_output
.transpose(0, 2)?
.transpose(1, 3)?
.contiguous()?
.reshape(&[seq_len as i32, batch_size as i32, self.d_model as i32])?
};
attn_output
.matmul(&out_proj_weight.transpose(0, 1)?)?
.add_op(&out_proj_bias)
}
fn feed_forward(&self, src: &Tensor) -> Result<Tensor> {
let linear1_weight = self.base.parameters["linear1.weight"]
.tensor()
.read()
.clone();
let linear1_bias = self.base.parameters["linear1.bias"].tensor().read().clone();
let linear2_weight = self.base.parameters["linear2.weight"]
.tensor()
.read()
.clone();
let linear2_bias = self.base.parameters["linear2.bias"].tensor().read().clone();
let x = src
.matmul(&linear1_weight.transpose(0, 1)?)?
.add_op(&linear1_bias)?;
let x = match self.activation.as_str() {
"relu" => x.relu()?,
"gelu" => x.gelu()?,
"silu" | "swish" => {
let sigmoid_x = x.sigmoid()?;
x.mul_op(&sigmoid_x)?
}
_ => x.relu()?, };
let x = if self.dropout > 0.0 && self.training() {
crate::functional::dropout(&x, self.dropout, self.training())?
} else {
x
};
x.matmul(&linear2_weight.transpose(0, 1)?)?
.add_op(&linear2_bias)
}
fn layer_norm(&self, x: &Tensor, norm_type: usize) -> Result<Tensor> {
let weight_key = format!("norm{}.weight", norm_type);
let bias_key = format!("norm{}.bias", norm_type);
let weight = self.base.parameters[&weight_key].tensor().read().clone();
let bias = self.base.parameters[&bias_key].tensor().read().clone();
let last_dim = x.shape().ndim() - 1;
let mean = x.mean(Some(&[last_dim]), true)?;
let var = x.var(
Some(&[last_dim]),
true,
torsh_tensor::stats::StatMode::Sample,
)?;
let normalized = x
.sub(&mean)?
.div(&(var.add_scalar(self.layer_norm_eps)?.sqrt()?))?;
normalized.mul_op(&weight)?.add_op(&bias)
}
}
impl Module for TransformerEncoderLayer {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.forward_with_mask(input, None)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl TransformerEncoderLayer {
pub fn forward_with_mask(&self, src: &Tensor, src_mask: Option<&Tensor>) -> Result<Tensor> {
let x = src;
if self.norm_first {
let norm_x = self.layer_norm(x, 1)?;
let attn_output = self.self_attention(&norm_x, src_mask)?;
let x = x.add_op(&attn_output)?;
let norm_x = self.layer_norm(&x, 2)?;
let ff_output = self.feed_forward(&norm_x)?;
x.add_op(&ff_output) } else {
let attn_output = self.self_attention(x, src_mask)?;
let x = x.add_op(&attn_output)?; let x = self.layer_norm(&x, 1)?;
let ff_output = self.feed_forward(&x)?;
let x = x.add_op(&ff_output)?; self.layer_norm(&x, 2)
}
}
}
pub struct TransformerEncoder {
base: ModuleBase,
layers: Vec<TransformerEncoderLayer>,
num_layers: usize,
#[allow(dead_code)]
norm: Option<()>, }
impl TransformerEncoder {
pub fn new(encoder_layer: TransformerEncoderLayer, num_layers: usize) -> Result<Self> {
let mut base = ModuleBase::new();
let mut layers = Vec::with_capacity(num_layers);
for i in 0..num_layers {
let layer = TransformerEncoderLayer::new(
encoder_layer.d_model,
encoder_layer.nhead,
Some(encoder_layer.dim_feedforward),
Some(encoder_layer.dropout),
Some(encoder_layer.activation.clone()),
Some(encoder_layer.layer_norm_eps),
Some(encoder_layer.batch_first),
Some(encoder_layer.norm_first),
)?;
for (name, param) in layer.parameters() {
let layer_name = format!("layers.{}.{}", i, name);
base.register_parameter(layer_name, param);
}
layers.push(layer);
}
Ok(Self {
base,
layers,
num_layers,
norm: None,
})
}
pub fn forward_with_mask(&self, src: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let mut output = src.clone();
for layer in &self.layers {
output = layer.forward_with_mask(&output, mask)?;
}
Ok(output)
}
}
impl Module for TransformerEncoder {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.forward_with_mask(input, None)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
for layer in &mut self.layers {
layer.train();
}
}
fn eval(&mut self) {
self.base.set_training(false);
for layer in &mut self.layers {
layer.eval();
}
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
for layer in &mut self.layers {
layer.set_training(training);
}
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)?;
for layer in &mut self.layers {
layer.to_device(device)?;
}
Ok(())
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct Transformer {
base: ModuleBase,
d_model: usize,
nhead: usize,
num_encoder_layers: usize,
dim_feedforward: usize,
dropout: f32,
#[allow(dead_code)]
max_seq_length: usize,
vocab_size: Option<usize>,
}
impl Transformer {
pub fn new(
d_model: usize,
nhead: usize,
num_encoder_layers: usize,
num_decoder_layers: Option<usize>,
dim_feedforward: Option<usize>,
dropout: Option<f32>,
max_seq_length: Option<usize>,
vocab_size: Option<usize>,
) -> Result<Self> {
let mut base = ModuleBase::new();
let dim_feedforward = dim_feedforward.unwrap_or(2048);
let dropout = dropout.unwrap_or(0.1);
let max_seq_length = max_seq_length.unwrap_or(1024);
let _num_decoder_layers = num_decoder_layers.unwrap_or(0);
let pos_encoding = create_positional_encoding(max_seq_length, d_model)?;
base.register_parameter("pos_encoding".to_string(), Parameter::new(pos_encoding));
if let Some(vocab_size) = vocab_size {
let embedding_weight = crate::init::xavier_uniform(&[vocab_size, d_model])?;
base.register_parameter(
"embedding.weight".to_string(),
Parameter::new(embedding_weight),
);
}
Ok(Self {
base,
d_model,
nhead,
num_encoder_layers,
dim_feedforward,
dropout,
max_seq_length,
vocab_size,
})
}
pub fn create_encoder(&self) -> Result<TransformerEncoder> {
let encoder_layer = TransformerEncoderLayer::new(
self.d_model,
self.nhead,
Some(self.dim_feedforward),
Some(self.dropout),
None, None, Some(true), Some(false), )?;
TransformerEncoder::new(encoder_layer, self.num_encoder_layers)
}
pub fn add_positional_encoding(&self, x: &Tensor) -> Result<Tensor> {
let pos_encoding = self.base.parameters["pos_encoding"].tensor().read().clone();
let seq_len = x.shape().dims()[1];
let pos_slice = pos_encoding.narrow(0, 0, seq_len)?;
x.add_op(&pos_slice.unsqueeze(0)?)
}
pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
if self.vocab_size.is_none() {
return Err(torsh_core::TorshError::InvalidArgument(
"Transformer was not initialized with vocab_size".to_string(),
));
}
let embedding_weight = self.base.parameters["embedding.weight"]
.tensor()
.read()
.clone();
input_ids.matmul(&embedding_weight)
}
}
impl Module for Transformer {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let encoder = self.create_encoder()?;
encoder.forward(input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
fn create_positional_encoding(max_len: usize, d_model: usize) -> Result<Tensor> {
let mut pos_encoding = vec![0.0f32; max_len * d_model];
for pos in 0..max_len {
for i in (0..d_model).step_by(2) {
let angle = pos as f32 / 10000.0_f32.powf(i as f32 / d_model as f32);
pos_encoding[pos * d_model + i] = angle.sin();
if i + 1 < d_model {
pos_encoding[pos * d_model + i + 1] = angle.cos();
}
}
}
Tensor::from_vec(pos_encoding, &[max_len, d_model])
}
impl std::fmt::Debug for TransformerEncoderLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TransformerEncoderLayer")
.field("d_model", &self.d_model)
.field("nhead", &self.nhead)
.field("dim_feedforward", &self.dim_feedforward)
.field("dropout", &self.dropout)
.field("activation", &self.activation)
.finish()
}
}
impl std::fmt::Debug for TransformerEncoder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TransformerEncoder")
.field("num_layers", &self.num_layers)
.finish()
}
}
impl std::fmt::Debug for Transformer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Transformer")
.field("d_model", &self.d_model)
.field("nhead", &self.nhead)
.field("num_encoder_layers", &self.num_encoder_layers)
.field("dim_feedforward", &self.dim_feedforward)
.field("vocab_size", &self.vocab_size)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::zeros;
#[test]
fn test_transformer_encoder_layer_new() -> Result<()> {
let layer = TransformerEncoderLayer::new(
512, 8, Some(2048),
Some(0.1),
Some("relu".to_string()),
Some(1e-5),
Some(false),
Some(false),
)?;
assert_eq!(layer.d_model, 512);
assert_eq!(layer.nhead, 8);
assert_eq!(layer.dim_feedforward, 2048);
assert_eq!(layer.dropout, 0.1);
assert_eq!(layer.activation, "relu");
Ok(())
}
#[test]
fn test_transformer_encoder_layer_defaults() -> Result<()> {
let layer = TransformerEncoderLayer::new(256, 4, None, None, None, None, None, None)?;
assert_eq!(layer.d_model, 256);
assert_eq!(layer.nhead, 4);
assert_eq!(layer.dim_feedforward, 2048); assert_eq!(layer.dropout, 0.1); assert_eq!(layer.activation, "relu"); Ok(())
}
#[test]
#[ignore] fn test_transformer_encoder_layer_forward() -> Result<()> {
let layer = TransformerEncoderLayer::new(
64,
4,
Some(256),
Some(0.0),
None,
None,
Some(false),
None,
)?;
let input = zeros(&[10, 2, 64])?;
let output = layer.forward(&input)?;
let output_shape = output.shape();
assert_eq!(output_shape.dims(), &[10, 2, 64]);
Ok(())
}
#[test]
#[ignore] fn test_transformer_encoder_layer_forward_batch_first() -> Result<()> {
let layer = TransformerEncoderLayer::new(
64,
4,
Some(256),
Some(0.0),
None,
None,
Some(true),
None,
)?;
let input = zeros(&[2, 10, 64])?;
let output = layer.forward(&input)?;
let output_shape = output.shape();
assert_eq!(output_shape.dims(), &[2, 10, 64]);
Ok(())
}
#[test]
#[ignore] fn test_transformer_encoder_layer_forward_with_mask() -> Result<()> {
let layer = TransformerEncoderLayer::new(
64,
4,
Some(256),
Some(0.0),
None,
None,
Some(false),
None,
)?;
let input = zeros(&[10, 2, 64])?;
let mask = zeros(&[10, 10])?;
let output = layer.forward_with_mask(&input, Some(&mask))?;
let output_shape = output.shape();
assert_eq!(output_shape.dims(), &[10, 2, 64]);
Ok(())
}
#[test]
#[ignore] fn test_transformer_encoder_layer_norm_first() -> Result<()> {
let layer = TransformerEncoderLayer::new(
64,
4,
Some(256),
Some(0.0),
None,
None,
Some(false),
Some(true), )?;
let input = zeros(&[10, 2, 64])?;
let output = layer.forward(&input)?;
assert_eq!(output.shape().dims(), &[10, 2, 64]);
Ok(())
}
#[test]
fn test_transformer_encoder_layer_parameters() -> Result<()> {
let layer = TransformerEncoderLayer::new(128, 4, Some(512), None, None, None, None, None)?;
let params = layer.parameters();
assert_eq!(params.len(), 12);
assert!(params.contains_key("self_attn.in_proj_weight"));
assert!(params.contains_key("linear1.weight"));
assert!(params.contains_key("norm1.weight"));
Ok(())
}
#[test]
fn test_transformer_encoder_layer_training_mode() -> Result<()> {
let mut layer =
TransformerEncoderLayer::new(64, 4, Some(256), None, None, None, None, None)?;
assert!(layer.training());
layer.eval();
assert!(!layer.training());
layer.train();
assert!(layer.training());
Ok(())
}
#[test]
fn test_transformer_encoder_layer_activation_variants() -> Result<()> {
let layer_relu = TransformerEncoderLayer::new(
64,
4,
Some(256),
Some(0.0),
Some("relu".to_string()),
None,
None,
None,
)?;
assert_eq!(layer_relu.activation, "relu");
let layer_gelu = TransformerEncoderLayer::new(
64,
4,
Some(256),
Some(0.0),
Some("gelu".to_string()),
None,
None,
None,
)?;
assert_eq!(layer_gelu.activation, "gelu");
let layer_silu = TransformerEncoderLayer::new(
64,
4,
Some(256),
Some(0.0),
Some("silu".to_string()),
None,
None,
None,
)?;
assert_eq!(layer_silu.activation, "silu");
Ok(())
}
#[test]
fn test_transformer_encoder_new() -> Result<()> {
let encoder_layer =
TransformerEncoderLayer::new(128, 8, Some(512), Some(0.1), None, None, None, None)?;
let encoder = TransformerEncoder::new(encoder_layer, 6)?;
assert_eq!(encoder.num_layers, 6);
assert_eq!(encoder.layers.len(), 6);
Ok(())
}
#[test]
#[ignore] fn test_transformer_encoder_forward() -> Result<()> {
let encoder_layer = TransformerEncoderLayer::new(
64,
4,
Some(256),
Some(0.0),
None,
None,
Some(false),
None,
)?;
let encoder = TransformerEncoder::new(encoder_layer, 3)?;
let input = zeros(&[10, 2, 64])?;
let output = encoder.forward(&input)?;
let output_shape = output.shape();
assert_eq!(output_shape.dims(), &[10, 2, 64]);
Ok(())
}
#[test]
#[ignore] fn test_transformer_encoder_forward_with_mask() -> Result<()> {
let encoder_layer = TransformerEncoderLayer::new(
64,
4,
Some(256),
Some(0.0),
None,
None,
Some(false),
None,
)?;
let encoder = TransformerEncoder::new(encoder_layer, 2)?;
let input = zeros(&[10, 2, 64])?;
let mask = zeros(&[10, 10])?;
let output = encoder.forward_with_mask(&input, Some(&mask))?;
let output_shape = output.shape();
assert_eq!(output_shape.dims(), &[10, 2, 64]);
Ok(())
}
#[test]
fn test_transformer_encoder_parameters() -> Result<()> {
let encoder_layer =
TransformerEncoderLayer::new(64, 4, Some(256), None, None, None, None, None)?;
let encoder = TransformerEncoder::new(encoder_layer, 3)?;
let params = encoder.parameters();
assert_eq!(params.len(), 36); assert!(params.contains_key("layers.0.self_attn.in_proj_weight"));
assert!(params.contains_key("layers.1.linear1.weight"));
assert!(params.contains_key("layers.2.norm2.bias"));
Ok(())
}
#[test]
fn test_transformer_encoder_training_mode() -> Result<()> {
let encoder_layer =
TransformerEncoderLayer::new(64, 4, Some(256), None, None, None, None, None)?;
let mut encoder = TransformerEncoder::new(encoder_layer, 2)?;
assert!(encoder.training());
for layer in &encoder.layers {
assert!(layer.training());
}
encoder.eval();
assert!(!encoder.training());
for layer in &encoder.layers {
assert!(!layer.training());
}
encoder.train();
assert!(encoder.training());
for layer in &encoder.layers {
assert!(layer.training());
}
Ok(())
}
#[test]
fn test_transformer_new() -> Result<()> {
let transformer = Transformer::new(
512, 8, 6, Some(6), Some(2048), Some(0.1), Some(1024), Some(10000), )?;
assert_eq!(transformer.d_model, 512);
assert_eq!(transformer.nhead, 8);
assert_eq!(transformer.num_encoder_layers, 6);
assert_eq!(transformer.dim_feedforward, 2048);
assert_eq!(transformer.dropout, 0.1);
assert_eq!(transformer.vocab_size, Some(10000));
Ok(())
}
#[test]
fn test_transformer_defaults() -> Result<()> {
let transformer = Transformer::new(256, 4, 3, None, None, None, None, None)?;
assert_eq!(transformer.d_model, 256);
assert_eq!(transformer.nhead, 4);
assert_eq!(transformer.num_encoder_layers, 3);
assert_eq!(transformer.dim_feedforward, 2048); assert_eq!(transformer.dropout, 0.1); assert_eq!(transformer.vocab_size, None);
Ok(())
}
#[test]
fn test_transformer_create_encoder() -> Result<()> {
let transformer = Transformer::new(128, 8, 4, None, Some(512), Some(0.1), None, None)?;
let encoder = transformer.create_encoder()?;
assert_eq!(encoder.num_layers, 4);
Ok(())
}
#[test]
#[ignore] fn test_transformer_forward() -> Result<()> {
let transformer = Transformer::new(64, 4, 2, None, Some(256), Some(0.0), None, None)?;
let input = zeros(&[2, 10, 64])?;
let output = transformer.forward(&input)?;
let output_shape = output.shape();
assert_eq!(output_shape.dims(), &[2, 10, 64]);
Ok(())
}
#[test]
fn test_transformer_add_positional_encoding() -> Result<()> {
let transformer = Transformer::new(128, 8, 2, None, None, None, Some(100), None)?;
let input = zeros(&[2, 10, 128])?;
let output = transformer.add_positional_encoding(&input)?;
let output_shape = output.shape();
assert_eq!(output_shape.dims(), &[2, 10, 128]);
Ok(())
}
#[test]
fn test_transformer_embed_tokens_without_vocab() -> Result<()> {
let transformer = Transformer::new(128, 8, 2, None, None, None, None, None)?;
let input_ids = zeros(&[2, 10])?;
let result = transformer.embed_tokens(&input_ids);
assert!(result.is_err()); Ok(())
}
#[test]
fn test_transformer_with_vocab_size() -> Result<()> {
let transformer = Transformer::new(128, 8, 2, None, None, None, None, Some(10000))?;
let params = transformer.parameters();
assert!(params.contains_key("embedding.weight"));
assert!(params.contains_key("pos_encoding"));
Ok(())
}
#[test]
fn test_transformer_parameters() -> Result<()> {
let transformer = Transformer::new(64, 4, 2, None, None, None, Some(100), Some(1000))?;
let params = transformer.parameters();
assert_eq!(params.len(), 2);
assert!(params.contains_key("embedding.weight"));
assert!(params.contains_key("pos_encoding"));
Ok(())
}
#[test]
fn test_transformer_training_mode() -> Result<()> {
let mut transformer = Transformer::new(64, 4, 2, None, None, None, None, None)?;
assert!(transformer.training());
transformer.eval();
assert!(!transformer.training());
transformer.train();
assert!(transformer.training());
Ok(())
}
#[test]
fn test_create_positional_encoding() -> Result<()> {
let max_len = 100;
let d_model = 128;
let pos_encoding = create_positional_encoding(max_len, d_model)?;
let shape = pos_encoding.shape();
assert_eq!(shape.dims(), &[max_len, d_model]);
Ok(())
}
#[test]
fn test_create_positional_encoding_odd_d_model() -> Result<()> {
let max_len = 50;
let d_model = 127;
let pos_encoding = create_positional_encoding(max_len, d_model)?;
let shape = pos_encoding.shape();
assert_eq!(shape.dims(), &[max_len, d_model]);
Ok(())
}
#[test]
fn test_create_positional_encoding_small() -> Result<()> {
let max_len = 10;
let d_model = 8;
let pos_encoding = create_positional_encoding(max_len, d_model)?;
let shape = pos_encoding.shape();
assert_eq!(shape.dims(), &[max_len, d_model]);
let data = pos_encoding.to_vec()?;
for val in data {
assert!(val.is_finite());
}
Ok(())
}
#[test]
fn test_module_named_parameters() -> Result<()> {
let layer = TransformerEncoderLayer::new(128, 8, Some(512), None, None, None, None, None)?;
let named_params = layer.named_parameters();
assert_eq!(named_params.len(), 12);
assert!(named_params.contains_key("self_attn.in_proj_weight"));
assert!(named_params.contains_key("self_attn.out_proj.weight"));
assert!(named_params.contains_key("linear1.weight"));
assert!(named_params.contains_key("linear2.weight"));
Ok(())
}
#[test]
fn test_module_to_device() -> Result<()> {
let mut layer =
TransformerEncoderLayer::new(64, 4, Some(256), None, None, None, None, None)?;
layer.to_device(DeviceType::Cpu)?;
Ok(())
}
#[test]
fn test_encoder_to_device_propagates() -> Result<()> {
let encoder_layer =
TransformerEncoderLayer::new(64, 4, Some(256), None, None, None, None, None)?;
let mut encoder = TransformerEncoder::new(encoder_layer, 3)?;
encoder.to_device(DeviceType::Cpu)?;
Ok(())
}
}