use crate::error::{NeuralError, Result};
use crate::layers::Layer;
use crate::transformer::{TransformerDecoder, TransformerEncoder};
use crate::utils::{PositionalEncoding, PositionalEncodingFactory, PositionalEncodingType};
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::{Rng, RngExt};
use scirs2_core::simd_ops::SimdUnifiedOps;
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone)]
pub struct TransformerConfig {
pub d_model: usize,
pub n_encoder_layers: usize,
pub n_decoder_layers: usize,
pub n_heads: usize,
pub d_ff: usize,
pub max_seq_len: usize,
pub dropout: f64,
pub pos_encoding_type: PositionalEncodingType,
pub epsilon: f64,
}
impl Default for TransformerConfig {
fn default() -> Self {
Self {
d_model: 512,
n_encoder_layers: 6,
n_decoder_layers: 6,
n_heads: 8,
d_ff: 2048,
max_seq_len: 512,
dropout: 0.1,
pos_encoding_type: PositionalEncodingType::Sinusoidal,
epsilon: 1e-5,
}
}
}
pub struct Transformer<F: Float + Debug + Send + Sync + SimdUnifiedOps + NumAssign> {
encoder: TransformerEncoder<F>,
decoder: TransformerDecoder<F>,
pos_encoding: Box<dyn PositionalEncoding<F> + Send + Sync>,
config: TransformerConfig,
encoder_output_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + SimdUnifiedOps + NumAssign> Clone
for Transformer<F>
{
fn clone(&self) -> Self {
Self {
encoder: self.encoder.clone(),
decoder: self.decoder.clone(),
pos_encoding: self.pos_encoding.clone_box(),
config: self.config.clone(),
encoder_output_cache: Arc::new(RwLock::new(None)),
}
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + SimdUnifiedOps + NumAssign>
Transformer<F>
{
pub fn new<R: Rng>(config: TransformerConfig, rng: &mut R) -> Result<Self> {
let encoder = TransformerEncoder::new(
config.d_model,
config.n_encoder_layers,
config.n_heads,
config.d_ff,
config.dropout,
config.epsilon,
rng,
)?;
let decoder = TransformerDecoder::new(
config.d_model,
config.n_decoder_layers,
config.n_heads,
config.d_ff,
config.dropout,
config.epsilon,
rng,
)?;
let pos_encoding = PositionalEncodingFactory::create(
config.pos_encoding_type,
config.d_model,
config.max_seq_len,
rng,
);
Ok(Self {
encoder,
decoder,
pos_encoding,
config,
encoder_output_cache: Arc::new(RwLock::new(None)),
})
}
pub fn forward_train(
&self,
src: &Array<F, IxDyn>,
tgt: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
if src.ndim() < 3 {
return Err(NeuralError::InferenceError(
"Source must have at least 3 dimensions [batch, src_len, features]".to_string(),
));
}
let src_shape = src.shape();
let src_feat_dim = src_shape[src.ndim() - 1];
if src_feat_dim != self.config.d_model {
return Err(NeuralError::InferenceError(format!(
"Last dimension of source ({}) must match d_model ({})",
src_feat_dim, self.config.d_model
)));
}
if tgt.ndim() < 3 {
return Err(NeuralError::InferenceError(
"Target must have at least 3 dimensions [batch, tgt_len, features]".to_string(),
));
}
let tgt_shape = tgt.shape();
let tgt_feat_dim = tgt_shape[tgt.ndim() - 1];
if tgt_feat_dim != self.config.d_model {
return Err(NeuralError::InferenceError(format!(
"Last dimension of target ({}) must match d_model ({})",
tgt_feat_dim, self.config.d_model
)));
}
let src_pos = self.pos_encoding.forward(src)?;
let tgt_pos = self.pos_encoding.forward(tgt)?;
let encoder_output = self.encoder.forward(&src_pos)?;
*self.encoder_output_cache.write().expect("Operation failed") =
Some(encoder_output.clone());
let decoder_output = self
.decoder
.forward_with_encoder(&tgt_pos, &encoder_output)?;
Ok(decoder_output)
}
pub fn forward_inference(
&self,
src: &Array<F, IxDyn>,
tgt: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
if src.ndim() < 3 {
return Err(NeuralError::InferenceError(
"Source must have at least 3 dimensions [batch, src_len, features]".to_string(),
));
}
let src_shape = src.shape();
let src_feat_dim = src_shape[src.ndim() - 1];
if src_feat_dim != self.config.d_model {
return Err(NeuralError::InferenceError(format!(
"Last dimension of source ({}) must match d_model ({})",
src_feat_dim, self.config.d_model
)));
}
if tgt.ndim() < 3 {
return Err(NeuralError::InferenceError(
"Target must have at least 3 dimensions [batch, tgt_len, features]".to_string(),
));
}
let tgt_shape = tgt.shape();
let tgt_feat_dim = tgt_shape[tgt.ndim() - 1];
if tgt_feat_dim != self.config.d_model {
return Err(NeuralError::InferenceError(format!(
"Last dimension of target ({}) must match d_model ({})",
tgt_feat_dim, self.config.d_model
)));
}
let src_pos = self.pos_encoding.forward(src)?;
let tgt_pos = self.pos_encoding.forward(tgt)?;
let encoder_output = self.encoder.forward(&src_pos)?;
*self.encoder_output_cache.write().expect("Operation failed") =
Some(encoder_output.clone());
let decoder_output = self
.decoder
.forward_with_encoder(&tgt_pos, &encoder_output)?;
Ok(decoder_output)
}
pub fn config(&self) -> &TransformerConfig {
&self.config
}
pub fn encoder(&self) -> &TransformerEncoder<F> {
&self.encoder
}
pub fn decoder(&self) -> &TransformerDecoder<F> {
&self.decoder
}
pub fn encoder_mut(&mut self) -> &mut TransformerEncoder<F> {
&mut self.encoder
}
pub fn decoder_mut(&mut self) -> &mut TransformerDecoder<F> {
&mut self.decoder
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + SimdUnifiedOps + NumAssign> Layer<F>
for Transformer<F>
{
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
if input.ndim() < 3 {
return Err(NeuralError::InferenceError(
"Input must have at least 3 dimensions [batch, seq_len, features]".to_string(),
));
}
let input_shape = input.shape();
let feat_dim = input_shape[input.ndim() - 1];
if feat_dim != self.config.d_model {
return Err(NeuralError::InferenceError(format!(
"Last dimension of input ({}) must match d_model ({})",
feat_dim, self.config.d_model
)));
}
let input_pos = self.pos_encoding.forward(input)?;
let encoder_output = self.encoder.forward(&input_pos)?;
Ok(encoder_output)
}
fn backward(
&self,
input: &Array<F, IxDyn>,
_grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let grad_input = Array::zeros(input.dim());
Ok(grad_input)
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.encoder.update(learning_rate)?;
self.decoder.update(learning_rate)?;
self.pos_encoding.update(learning_rate)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array3;
#[test]
fn test_transformer_config_default() {
let config = TransformerConfig::default();
assert_eq!(config.d_model, 512);
assert_eq!(config.n_encoder_layers, 6);
assert_eq!(config.n_decoder_layers, 6);
assert_eq!(config.n_heads, 8);
assert_eq!(config.d_ff, 2048);
assert_eq!(config.max_seq_len, 512);
}
#[test]
fn test_transformer_creation() {
let mut rng = scirs2_core::random::rng();
let config = TransformerConfig {
d_model: 64,
n_encoder_layers: 2,
n_decoder_layers: 2,
n_heads: 4,
d_ff: 128,
max_seq_len: 100,
dropout: 0.1,
pos_encoding_type: PositionalEncodingType::Sinusoidal,
epsilon: 1e-5,
};
let result = Transformer::<f64>::new(config, &mut rng);
assert!(result.is_ok());
}
#[test]
fn test_transformer_train() {
let mut rng = scirs2_core::random::rng();
let config = TransformerConfig {
d_model: 64,
n_encoder_layers: 2,
n_decoder_layers: 2,
n_heads: 4,
d_ff: 128,
max_seq_len: 100,
dropout: 0.1,
pos_encoding_type: PositionalEncodingType::Sinusoidal,
epsilon: 1e-5,
};
let transformer = Transformer::<f64>::new(config, &mut rng).expect("Operation failed");
let batch_size = 2;
let src_len = 8;
let tgt_len = 6;
let d_model = 64;
let src = Array3::<f64>::from_elem((batch_size, src_len, d_model), 0.1).into_dyn();
let tgt = Array3::<f64>::from_elem((batch_size, tgt_len, d_model), 0.1).into_dyn();
let output = transformer
.forward_train(&src, &tgt)
.expect("Operation failed");
assert_eq!(output.shape(), tgt.shape());
}
#[test]
fn test_transformer_inference() {
let mut rng = scirs2_core::random::rng();
let config = TransformerConfig {
d_model: 64,
n_encoder_layers: 1,
n_decoder_layers: 1,
n_heads: 4,
d_ff: 128,
max_seq_len: 100,
dropout: 0.1,
pos_encoding_type: PositionalEncodingType::Sinusoidal,
epsilon: 1e-5,
};
let transformer = Transformer::<f64>::new(config, &mut rng).expect("Operation failed");
let batch_size = 2;
let src_len = 8;
let tgt_len = 1; let d_model = 64;
let src = Array3::<f64>::from_elem((batch_size, src_len, d_model), 0.1).into_dyn();
let tgt = Array3::<f64>::from_elem((batch_size, tgt_len, d_model), 0.1).into_dyn();
let output = transformer
.forward_inference(&src, &tgt)
.expect("Operation failed");
assert_eq!(output.shape(), tgt.shape());
}
#[test]
fn test_encoder_only() {
let mut rng = scirs2_core::random::rng();
let config = TransformerConfig {
d_model: 64,
n_encoder_layers: 1,
n_decoder_layers: 1,
n_heads: 4,
d_ff: 128,
max_seq_len: 100,
dropout: 0.1,
pos_encoding_type: PositionalEncodingType::Sinusoidal,
epsilon: 1e-5,
};
let transformer = Transformer::<f64>::new(config, &mut rng).expect("Operation failed");
let batch_size = 2;
let src_len = 8;
let d_model = 64;
let src = Array3::<f64>::from_elem((batch_size, src_len, d_model), 0.1).into_dyn();
let output = transformer.forward(&src).expect("Operation failed");
assert_eq!(output.shape(), src.shape());
}
#[test]
fn test_transformer_clone() {
let mut rng = scirs2_core::random::rng();
let config = TransformerConfig {
d_model: 32,
n_encoder_layers: 1,
n_decoder_layers: 1,
n_heads: 2,
d_ff: 64,
max_seq_len: 50,
dropout: 0.1,
pos_encoding_type: PositionalEncodingType::Sinusoidal,
epsilon: 1e-5,
};
let transformer = Transformer::<f64>::new(config, &mut rng).expect("Operation failed");
let transformer_clone = transformer.clone();
let src = Array3::<f64>::from_elem((1, 4, 32), 0.1).into_dyn();
let output1 = transformer.forward(&src).expect("Operation failed");
let output2 = transformer_clone.forward(&src).expect("Operation failed");
assert_eq!(output1.shape(), output2.shape());
}
#[test]
fn test_transformer_invalid_input() {
let mut rng = scirs2_core::random::rng();
let config = TransformerConfig {
d_model: 64,
n_encoder_layers: 1,
n_decoder_layers: 1,
n_heads: 4,
d_ff: 128,
max_seq_len: 100,
dropout: 0.1,
pos_encoding_type: PositionalEncodingType::Sinusoidal,
epsilon: 1e-5,
};
let transformer = Transformer::<f64>::new(config, &mut rng).expect("Operation failed");
let wrong_input = scirs2_core::ndarray::Array2::<f64>::from_elem((4, 64), 0.1).into_dyn();
let result = transformer.forward(&wrong_input);
assert!(result.is_err());
let wrong_dim_input = Array3::<f64>::from_elem((2, 4, 32), 0.1).into_dyn(); let result = transformer.forward(&wrong_dim_input);
assert!(result.is_err());
}
#[test]
fn test_transformer_accessors() {
let mut rng = scirs2_core::random::rng();
let config = TransformerConfig {
d_model: 64,
n_encoder_layers: 2,
n_decoder_layers: 3,
n_heads: 4,
d_ff: 128,
max_seq_len: 100,
dropout: 0.1,
pos_encoding_type: PositionalEncodingType::Sinusoidal,
epsilon: 1e-5,
};
let transformer =
Transformer::<f64>::new(config.clone(), &mut rng).expect("Operation failed");
assert_eq!(transformer.config().d_model, 64);
assert_eq!(transformer.config().n_encoder_layers, 2);
assert_eq!(transformer.config().n_decoder_layers, 3);
assert_eq!(transformer.encoder().num_layers(), 2);
assert_eq!(transformer.decoder().num_layers(), 3);
}
}