use crate::error::{TokenizerError, TokenizerResult};
use crate::SignalTokenizer;
use scirs2_core::ndarray::{s, Array1, Array2};
use scirs2_core::random::{rngs::StdRng, Random};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransformerConfig {
pub input_dim: usize,
pub embed_dim: usize,
pub num_heads: usize,
pub num_encoder_layers: usize,
pub num_decoder_layers: usize,
pub feedforward_dim: usize,
pub dropout: f32,
pub max_seq_len: usize,
}
impl Default for TransformerConfig {
fn default() -> Self {
Self {
input_dim: 128,
embed_dim: 256,
num_heads: 8,
num_encoder_layers: 6,
num_decoder_layers: 6,
feedforward_dim: 1024,
dropout: 0.1,
max_seq_len: 512,
}
}
}
impl TransformerConfig {
pub fn validate(&self) -> TokenizerResult<()> {
if self.input_dim == 0 {
return Err(TokenizerError::invalid_input(
"input_dim must be positive",
"TransformerConfig::validate",
));
}
if self.embed_dim == 0 {
return Err(TokenizerError::invalid_input(
"embed_dim must be positive",
"TransformerConfig::validate",
));
}
if !self.embed_dim.is_multiple_of(self.num_heads) {
return Err(TokenizerError::invalid_input(
"embed_dim must be divisible by num_heads",
"TransformerConfig::validate",
));
}
if self.num_heads == 0 {
return Err(TokenizerError::invalid_input(
"num_heads must be positive",
"TransformerConfig::validate",
));
}
if !(0.0..=1.0).contains(&self.dropout) {
return Err(TokenizerError::invalid_input(
"dropout must be in range [0.0, 1.0]",
"TransformerConfig::validate",
));
}
if self.max_seq_len == 0 {
return Err(TokenizerError::invalid_input(
"max_seq_len must be positive",
"TransformerConfig::validate",
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MultiHeadAttention {
num_heads: usize,
head_dim: usize,
w_query: Array2<f32>,
w_key: Array2<f32>,
w_value: Array2<f32>,
w_out: Array2<f32>,
}
impl MultiHeadAttention {
pub fn new(embed_dim: usize, num_heads: usize) -> TokenizerResult<Self> {
if !embed_dim.is_multiple_of(num_heads) {
return Err(TokenizerError::invalid_input(
"embed_dim must be divisible by num_heads",
"MultiHeadAttention::new",
));
}
let head_dim = embed_dim / num_heads;
let mut rng = Random::seed(42);
let scale = (2.0 / (embed_dim + embed_dim) as f32).sqrt();
Ok(Self {
num_heads,
head_dim,
w_query: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
w_key: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
w_value: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
w_out: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
})
}
fn init_weights(rows: usize, cols: usize, scale: f32, rng: &mut Random<StdRng>) -> Array2<f32> {
let mut weights = Array2::zeros((rows, cols));
for val in weights.iter_mut() {
*val = (rng.gen_range(-1.0..1.0)) * scale;
}
weights
}
pub fn forward(&self, x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
let seq_len = x.nrows();
let embed_dim = x.ncols();
let query = x.dot(&self.w_query); let key = x.dot(&self.w_key); let value = x.dot(&self.w_value);
let scale = (self.head_dim as f32).sqrt();
let mut attention_output = Array2::zeros((seq_len, embed_dim));
for h in 0..self.num_heads {
let mut q_head = Array2::zeros((seq_len, self.head_dim));
let mut k_head = Array2::zeros((seq_len, self.head_dim));
let mut v_head = Array2::zeros((seq_len, self.head_dim));
let start_idx = h * self.head_dim;
for i in 0..seq_len {
for j in 0..self.head_dim {
q_head[[i, j]] = query[[i, start_idx + j]];
k_head[[i, j]] = key[[i, start_idx + j]];
v_head[[i, j]] = value[[i, start_idx + j]];
}
}
let scores = q_head.dot(&k_head.t()) / scale;
let attention_weights = Self::softmax(&scores)?;
let head_output = attention_weights.dot(&v_head);
for i in 0..seq_len {
for j in 0..self.head_dim {
attention_output[[i, start_idx + j]] = head_output[[i, j]];
}
}
}
Ok(attention_output.dot(&self.w_out))
}
fn softmax(x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
let mut result = x.clone();
for mut row in result.rows_mut() {
let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
for val in row.iter_mut() {
*val = (*val - max_val).exp();
}
let sum: f32 = row.iter().sum();
if sum > 0.0 {
for val in row.iter_mut() {
*val /= sum;
}
}
}
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct PositionalEncoding {
encodings: Array2<f32>,
}
impl PositionalEncoding {
pub fn new(max_seq_len: usize, embed_dim: usize) -> Self {
let mut encodings = Array2::zeros((max_seq_len, embed_dim));
for pos in 0..max_seq_len {
for i in 0..embed_dim {
let angle = pos as f32 / 10000.0_f32.powf(2.0 * (i / 2) as f32 / embed_dim as f32);
if i % 2 == 0 {
encodings[[pos, i]] = angle.sin();
} else {
encodings[[pos, i]] = angle.cos();
}
}
}
Self { encodings }
}
pub fn forward(&self, x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
let seq_len = x.nrows();
if seq_len > self.encodings.nrows() {
return Err(TokenizerError::encoding(
format!(
"Sequence length {} exceeds max_seq_len {}",
seq_len,
self.encodings.nrows()
),
"PositionalEncoding::forward",
));
}
let pos_enc = self.encodings.slice(s![0..seq_len, ..]);
Ok(x + &pos_enc)
}
}
#[derive(Debug, Clone)]
pub struct LayerNorm {
dim: usize,
eps: f32,
}
impl LayerNorm {
pub fn new(dim: usize, eps: f32) -> Self {
Self { dim, eps }
}
pub fn forward(&self, x: &Array2<f32>) -> Array2<f32> {
let mut result = x.clone();
for mut row in result.rows_mut() {
let mean = row.mean().unwrap_or(0.0);
let variance = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / self.dim as f32;
let std = (variance + self.eps).sqrt();
for val in row.iter_mut() {
*val = (*val - mean) / std;
}
}
result
}
}
#[derive(Debug, Clone)]
pub struct FeedForward {
w1: Array2<f32>,
w2: Array2<f32>,
}
impl FeedForward {
pub fn new(embed_dim: usize, hidden_dim: usize) -> Self {
let mut rng = Random::seed(43);
let scale1 = (2.0 / (embed_dim + hidden_dim) as f32).sqrt();
let scale2 = (2.0 / (hidden_dim + embed_dim) as f32).sqrt();
Self {
w1: Self::init_weights(embed_dim, hidden_dim, scale1, &mut rng),
w2: Self::init_weights(hidden_dim, embed_dim, scale2, &mut rng),
}
}
fn init_weights(rows: usize, cols: usize, scale: f32, rng: &mut Random<StdRng>) -> Array2<f32> {
let mut weights = Array2::zeros((rows, cols));
for val in weights.iter_mut() {
*val = (rng.gen_range(-1.0..1.0)) * scale;
}
weights
}
fn gelu(x: f32) -> f32 {
0.5 * x * (1.0 + ((2.0 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x.powi(3))).tanh())
}
pub fn forward(&self, x: &Array2<f32>) -> Array2<f32> {
let hidden = x.dot(&self.w1);
let activated = hidden.mapv(Self::gelu);
activated.dot(&self.w2)
}
}
#[derive(Debug, Clone)]
pub struct TransformerEncoderLayer {
attention: MultiHeadAttention,
ffn: FeedForward,
norm1: LayerNorm,
norm2: LayerNorm,
}
impl TransformerEncoderLayer {
pub fn new(
embed_dim: usize,
num_heads: usize,
feedforward_dim: usize,
) -> TokenizerResult<Self> {
Ok(Self {
attention: MultiHeadAttention::new(embed_dim, num_heads)?,
ffn: FeedForward::new(embed_dim, feedforward_dim),
norm1: LayerNorm::new(embed_dim, 1e-5),
norm2: LayerNorm::new(embed_dim, 1e-5),
})
}
pub fn forward(&self, x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
let attn_out = self.attention.forward(x)?;
let x = &(x + &attn_out);
let x_norm = self.norm1.forward(x);
let ffn_out = self.ffn.forward(&x_norm);
let out = &x_norm + &ffn_out;
Ok(self.norm2.forward(&out))
}
}
#[derive(Debug, Clone)]
pub struct TransformerTokenizer {
config: TransformerConfig,
input_proj: Array2<f32>,
output_proj: Array2<f32>,
pos_encoding: PositionalEncoding,
encoder_layers: Vec<TransformerEncoderLayer>,
decoder_layers: Vec<TransformerEncoderLayer>,
}
impl TransformerTokenizer {
pub fn new(config: TransformerConfig) -> TokenizerResult<Self> {
config.validate()?;
let mut rng = Random::seed(44);
let scale_in = (2.0 / (config.input_dim + config.embed_dim) as f32).sqrt();
let scale_out = (2.0 / (config.embed_dim + config.input_dim) as f32).sqrt();
let mut input_proj = Array2::zeros((config.input_dim, config.embed_dim));
let mut output_proj = Array2::zeros((config.embed_dim, config.input_dim));
for val in input_proj.iter_mut() {
*val = (rng.gen_range(-1.0..1.0)) * scale_in;
}
for val in output_proj.iter_mut() {
*val = (rng.gen_range(-1.0..1.0)) * scale_out;
}
let max_seq_len = config.max_seq_len;
let embed_dim = config.embed_dim;
let num_encoder_layers = config.num_encoder_layers;
let num_decoder_layers = config.num_decoder_layers;
let num_heads = config.num_heads;
let feedforward_dim = config.feedforward_dim;
let mut encoder_layers = Vec::new();
for _ in 0..num_encoder_layers {
encoder_layers.push(TransformerEncoderLayer::new(
embed_dim,
num_heads,
feedforward_dim,
)?);
}
let mut decoder_layers = Vec::new();
for _ in 0..num_decoder_layers {
decoder_layers.push(TransformerEncoderLayer::new(
embed_dim,
num_heads,
feedforward_dim,
)?);
}
Ok(Self {
config,
input_proj,
output_proj,
pos_encoding: PositionalEncoding::new(max_seq_len, embed_dim),
encoder_layers,
decoder_layers,
})
}
pub fn config(&self) -> &TransformerConfig {
&self.config
}
}
impl SignalTokenizer for TransformerTokenizer {
fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
let len = signal.len();
if len > self.config.max_seq_len * self.config.input_dim {
return Err(TokenizerError::encoding(
format!(
"Signal too long: {} > {}",
len,
self.config.max_seq_len * self.config.input_dim
),
"TransformerTokenizer::encode",
));
}
let seq_len = len.div_ceil(self.config.input_dim);
let mut padded = signal.to_vec();
padded.resize(seq_len * self.config.input_dim, 0.0);
let mut x = Array2::zeros((seq_len, self.config.input_dim));
for i in 0..seq_len {
for j in 0..self.config.input_dim {
x[[i, j]] = padded[i * self.config.input_dim + j];
}
}
let mut x = x.dot(&self.input_proj);
x = self.pos_encoding.forward(&x)?;
for layer in &self.encoder_layers {
x = layer.forward(&x)?;
}
let mut result = Vec::new();
for i in 0..x.nrows() {
for j in 0..x.ncols() {
result.push(x[[i, j]]);
}
}
Ok(Array1::from_vec(result))
}
fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
let total_len = tokens.len();
if !total_len.is_multiple_of(self.config.embed_dim) {
return Err(TokenizerError::decoding(
format!(
"Invalid token length: {} not divisible by {}",
total_len, self.config.embed_dim
),
"TransformerTokenizer::decode",
));
}
let seq_len = total_len / self.config.embed_dim;
let mut x = Array2::zeros((seq_len, self.config.embed_dim));
for i in 0..seq_len {
for j in 0..self.config.embed_dim {
x[[i, j]] = tokens[i * self.config.embed_dim + j];
}
}
for layer in &self.decoder_layers {
x = layer.forward(&x)?;
}
x = x.dot(&self.output_proj);
let mut result = Vec::new();
for i in 0..x.nrows() {
for j in 0..x.ncols() {
result.push(x[[i, j]]);
}
}
Ok(Array1::from_vec(result))
}
fn embed_dim(&self) -> usize {
self.config.embed_dim
}
fn vocab_size(&self) -> usize {
0 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transformer_config_validation() {
let config = TransformerConfig::default();
assert!(config.validate().is_ok());
let mut bad_config = config.clone();
bad_config.embed_dim = 0;
assert!(bad_config.validate().is_err());
let mut bad_config = config.clone();
bad_config.embed_dim = 100; assert!(bad_config.validate().is_err());
let mut bad_config = config.clone();
bad_config.dropout = 1.5;
assert!(bad_config.validate().is_err());
}
#[test]
fn test_multihead_attention_creation() {
let mha = MultiHeadAttention::new(256, 8);
assert!(mha.is_ok());
let bad_mha = MultiHeadAttention::new(256, 7); assert!(bad_mha.is_err());
}
#[test]
fn test_multihead_attention_forward() {
let mha = MultiHeadAttention::new(64, 4).unwrap();
let x = Array2::ones((10, 64)); let out = mha.forward(&x);
assert!(out.is_ok());
let out = out.unwrap();
assert_eq!(out.shape(), &[10, 64]);
}
#[test]
fn test_positional_encoding() {
let pe = PositionalEncoding::new(100, 64);
let x = Array2::zeros((50, 64));
let out = pe.forward(&x);
assert!(out.is_ok());
let out = out.unwrap();
assert_eq!(out.shape(), &[50, 64]);
}
#[test]
fn test_positional_encoding_seq_too_long() {
let pe = PositionalEncoding::new(10, 64);
let x = Array2::zeros((20, 64)); let out = pe.forward(&x);
assert!(out.is_err());
}
#[test]
fn test_layer_norm() {
let ln = LayerNorm::new(64, 1e-5);
let x = Array2::from_shape_fn((10, 64), |(i, j)| (i + j) as f32);
let out = ln.forward(&x);
assert_eq!(out.shape(), &[10, 64]);
for row in out.rows() {
let mean = row.mean().unwrap();
let var = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / 64.0;
assert!((mean.abs()) < 1e-5);
assert!((var - 1.0).abs() < 1e-4);
}
}
#[test]
fn test_feedforward() {
let ffn = FeedForward::new(64, 256);
let x = Array2::ones((10, 64));
let out = ffn.forward(&x);
assert_eq!(out.shape(), &[10, 64]);
}
#[test]
fn test_encoder_layer() {
let layer = TransformerEncoderLayer::new(64, 4, 256).unwrap();
let x = Array2::ones((10, 64));
let out = layer.forward(&x);
assert!(out.is_ok());
let out = out.unwrap();
assert_eq!(out.shape(), &[10, 64]);
}
#[test]
fn test_transformer_tokenizer_creation() {
let config = TransformerConfig {
input_dim: 32,
embed_dim: 64,
num_heads: 4,
num_encoder_layers: 2,
num_decoder_layers: 2,
feedforward_dim: 128,
dropout: 0.1,
max_seq_len: 100,
};
let tokenizer = TransformerTokenizer::new(config);
assert!(tokenizer.is_ok());
}
#[test]
fn test_transformer_encode_decode() {
let config = TransformerConfig {
input_dim: 16,
embed_dim: 32,
num_heads: 4,
num_encoder_layers: 1,
num_decoder_layers: 1,
feedforward_dim: 64,
dropout: 0.0,
max_seq_len: 10,
};
let tokenizer = TransformerTokenizer::new(config).unwrap();
let signal = Array1::linspace(0.0, 1.0, 64);
let encoded = tokenizer.encode(&signal);
assert!(encoded.is_ok());
let encoded = encoded.unwrap();
let decoded = tokenizer.decode(&encoded);
assert!(decoded.is_ok());
let decoded = decoded.unwrap();
assert!(decoded.len() >= signal.len());
}
#[test]
fn test_transformer_signal_too_long() {
let config = TransformerConfig {
input_dim: 16,
embed_dim: 32,
num_heads: 4,
num_encoder_layers: 1,
num_decoder_layers: 1,
feedforward_dim: 64,
dropout: 0.0,
max_seq_len: 2, };
let tokenizer = TransformerTokenizer::new(config).unwrap();
let signal = Array1::linspace(0.0, 1.0, 1000); let encoded = tokenizer.encode(&signal);
assert!(encoded.is_err());
}
#[test]
fn test_softmax() {
let x = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 1.0, 1.0, 1.0]).unwrap();
let result = MultiHeadAttention::softmax(&x).unwrap();
for row in result.rows() {
let sum: f32 = row.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
for &val in result.iter() {
assert!(val >= 0.0);
}
}
#[test]
fn test_gelu_activation() {
assert!((FeedForward::gelu(0.0)).abs() < 1e-5);
assert!(FeedForward::gelu(1.0) > FeedForward::gelu(0.5));
assert!(FeedForward::gelu(2.0) > FeedForward::gelu(1.0));
assert!(FeedForward::gelu(-1.0) < 0.0);
assert!(FeedForward::gelu(1.0) > 0.0);
}
}