1use crate::error::{TokenizerError, TokenizerResult};
42use crate::SignalTokenizer;
43use scirs2_core::ndarray::{s, Array1, Array2};
44use scirs2_core::random::{rngs::StdRng, Random};
45use serde::{Deserialize, Serialize};
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct TransformerConfig {
50 pub input_dim: usize,
52 pub embed_dim: usize,
54 pub num_heads: usize,
56 pub num_encoder_layers: usize,
58 pub num_decoder_layers: usize,
60 pub feedforward_dim: usize,
62 pub dropout: f32,
64 pub max_seq_len: usize,
66}
67
68impl Default for TransformerConfig {
69 fn default() -> Self {
70 Self {
71 input_dim: 128,
72 embed_dim: 256,
73 num_heads: 8,
74 num_encoder_layers: 6,
75 num_decoder_layers: 6,
76 feedforward_dim: 1024,
77 dropout: 0.1,
78 max_seq_len: 512,
79 }
80 }
81}
82
83impl TransformerConfig {
84 pub fn validate(&self) -> TokenizerResult<()> {
86 if self.input_dim == 0 {
87 return Err(TokenizerError::invalid_input(
88 "input_dim must be positive",
89 "TransformerConfig::validate",
90 ));
91 }
92 if self.embed_dim == 0 {
93 return Err(TokenizerError::invalid_input(
94 "embed_dim must be positive",
95 "TransformerConfig::validate",
96 ));
97 }
98 if !self.embed_dim.is_multiple_of(self.num_heads) {
99 return Err(TokenizerError::invalid_input(
100 "embed_dim must be divisible by num_heads",
101 "TransformerConfig::validate",
102 ));
103 }
104 if self.num_heads == 0 {
105 return Err(TokenizerError::invalid_input(
106 "num_heads must be positive",
107 "TransformerConfig::validate",
108 ));
109 }
110 if !(0.0..=1.0).contains(&self.dropout) {
111 return Err(TokenizerError::invalid_input(
112 "dropout must be in range [0.0, 1.0]",
113 "TransformerConfig::validate",
114 ));
115 }
116 if self.max_seq_len == 0 {
117 return Err(TokenizerError::invalid_input(
118 "max_seq_len must be positive",
119 "TransformerConfig::validate",
120 ));
121 }
122 Ok(())
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct MultiHeadAttention {
129 num_heads: usize,
131 head_dim: usize,
133 w_query: Array2<f32>,
135 w_key: Array2<f32>,
137 w_value: Array2<f32>,
139 w_out: Array2<f32>,
141}
142
143impl MultiHeadAttention {
144 pub fn new(embed_dim: usize, num_heads: usize) -> TokenizerResult<Self> {
146 if !embed_dim.is_multiple_of(num_heads) {
147 return Err(TokenizerError::invalid_input(
148 "embed_dim must be divisible by num_heads",
149 "MultiHeadAttention::new",
150 ));
151 }
152
153 let head_dim = embed_dim / num_heads;
154 let mut rng = Random::seed(42);
155
156 let scale = (2.0 / (embed_dim + embed_dim) as f32).sqrt();
158
159 Ok(Self {
160 num_heads,
161 head_dim,
162 w_query: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
163 w_key: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
164 w_value: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
165 w_out: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
166 })
167 }
168
169 fn init_weights(rows: usize, cols: usize, scale: f32, rng: &mut Random<StdRng>) -> Array2<f32> {
171 let mut weights = Array2::zeros((rows, cols));
172 for val in weights.iter_mut() {
173 *val = (rng.gen_range(-1.0..1.0)) * scale;
174 }
175 weights
176 }
177
178 pub fn forward(&self, x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
188 let seq_len = x.nrows();
189 let embed_dim = x.ncols();
190
191 let query = x.dot(&self.w_query); let key = x.dot(&self.w_key); let value = x.dot(&self.w_value); let scale = (self.head_dim as f32).sqrt();
198 let mut attention_output = Array2::zeros((seq_len, embed_dim));
199
200 for h in 0..self.num_heads {
201 let mut q_head = Array2::zeros((seq_len, self.head_dim));
203 let mut k_head = Array2::zeros((seq_len, self.head_dim));
204 let mut v_head = Array2::zeros((seq_len, self.head_dim));
205
206 let start_idx = h * self.head_dim;
207 for i in 0..seq_len {
208 for j in 0..self.head_dim {
209 q_head[[i, j]] = query[[i, start_idx + j]];
210 k_head[[i, j]] = key[[i, start_idx + j]];
211 v_head[[i, j]] = value[[i, start_idx + j]];
212 }
213 }
214
215 let scores = q_head.dot(&k_head.t()) / scale; let attention_weights = Self::softmax(&scores)?;
220
221 let head_output = attention_weights.dot(&v_head); for i in 0..seq_len {
226 for j in 0..self.head_dim {
227 attention_output[[i, start_idx + j]] = head_output[[i, j]];
228 }
229 }
230 }
231
232 Ok(attention_output.dot(&self.w_out))
234 }
235
236 fn softmax(x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
238 let mut result = x.clone();
239 for mut row in result.rows_mut() {
240 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
242 for val in row.iter_mut() {
243 *val = (*val - max_val).exp();
244 }
245 let sum: f32 = row.iter().sum();
246 if sum > 0.0 {
247 for val in row.iter_mut() {
248 *val /= sum;
249 }
250 }
251 }
252 Ok(result)
253 }
254}
255
256#[derive(Debug, Clone)]
258pub struct PositionalEncoding {
259 encodings: Array2<f32>,
261}
262
263impl PositionalEncoding {
264 pub fn new(max_seq_len: usize, embed_dim: usize) -> Self {
266 let mut encodings = Array2::zeros((max_seq_len, embed_dim));
267
268 for pos in 0..max_seq_len {
269 for i in 0..embed_dim {
270 let angle = pos as f32 / 10000.0_f32.powf(2.0 * (i / 2) as f32 / embed_dim as f32);
271 if i % 2 == 0 {
272 encodings[[pos, i]] = angle.sin();
273 } else {
274 encodings[[pos, i]] = angle.cos();
275 }
276 }
277 }
278
279 Self { encodings }
280 }
281
282 pub fn forward(&self, x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
284 let seq_len = x.nrows();
285 if seq_len > self.encodings.nrows() {
286 return Err(TokenizerError::encoding(
287 format!(
288 "Sequence length {} exceeds max_seq_len {}",
289 seq_len,
290 self.encodings.nrows()
291 ),
292 "PositionalEncoding::forward",
293 ));
294 }
295
296 let pos_enc = self.encodings.slice(s![0..seq_len, ..]);
297 Ok(x + &pos_enc)
298 }
299}
300
301#[derive(Debug, Clone)]
303pub struct LayerNorm {
304 dim: usize,
306 eps: f32,
308}
309
310impl LayerNorm {
311 pub fn new(dim: usize, eps: f32) -> Self {
313 Self { dim, eps }
314 }
315
316 pub fn forward(&self, x: &Array2<f32>) -> Array2<f32> {
318 let mut result = x.clone();
319 for mut row in result.rows_mut() {
320 let mean = row.mean().unwrap_or(0.0);
321 let variance = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / self.dim as f32;
322 let std = (variance + self.eps).sqrt();
323
324 for val in row.iter_mut() {
325 *val = (*val - mean) / std;
326 }
327 }
328 result
329 }
330}
331
332#[derive(Debug, Clone)]
334pub struct FeedForward {
335 w1: Array2<f32>,
337 w2: Array2<f32>,
339}
340
341impl FeedForward {
342 pub fn new(embed_dim: usize, hidden_dim: usize) -> Self {
344 let mut rng = Random::seed(43);
345 let scale1 = (2.0 / (embed_dim + hidden_dim) as f32).sqrt();
346 let scale2 = (2.0 / (hidden_dim + embed_dim) as f32).sqrt();
347
348 Self {
349 w1: Self::init_weights(embed_dim, hidden_dim, scale1, &mut rng),
350 w2: Self::init_weights(hidden_dim, embed_dim, scale2, &mut rng),
351 }
352 }
353
354 fn init_weights(rows: usize, cols: usize, scale: f32, rng: &mut Random<StdRng>) -> Array2<f32> {
356 let mut weights = Array2::zeros((rows, cols));
357 for val in weights.iter_mut() {
358 *val = (rng.gen_range(-1.0..1.0)) * scale;
359 }
360 weights
361 }
362
363 fn gelu(x: f32) -> f32 {
365 0.5 * x * (1.0 + ((2.0 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x.powi(3))).tanh())
366 }
367
368 pub fn forward(&self, x: &Array2<f32>) -> Array2<f32> {
370 let hidden = x.dot(&self.w1);
371 let activated = hidden.mapv(Self::gelu);
372 activated.dot(&self.w2)
373 }
374}
375
376#[derive(Debug, Clone)]
378pub struct TransformerEncoderLayer {
379 attention: MultiHeadAttention,
381 ffn: FeedForward,
383 norm1: LayerNorm,
385 norm2: LayerNorm,
387}
388
389impl TransformerEncoderLayer {
390 pub fn new(
392 embed_dim: usize,
393 num_heads: usize,
394 feedforward_dim: usize,
395 ) -> TokenizerResult<Self> {
396 Ok(Self {
397 attention: MultiHeadAttention::new(embed_dim, num_heads)?,
398 ffn: FeedForward::new(embed_dim, feedforward_dim),
399 norm1: LayerNorm::new(embed_dim, 1e-5),
400 norm2: LayerNorm::new(embed_dim, 1e-5),
401 })
402 }
403
404 pub fn forward(&self, x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
406 let attn_out = self.attention.forward(x)?;
408 let x = &(x + &attn_out);
409 let x_norm = self.norm1.forward(x);
410
411 let ffn_out = self.ffn.forward(&x_norm);
413 let out = &x_norm + &ffn_out;
414 Ok(self.norm2.forward(&out))
415 }
416}
417
418#[derive(Debug, Clone)]
420pub struct TransformerTokenizer {
421 config: TransformerConfig,
423 input_proj: Array2<f32>,
425 output_proj: Array2<f32>,
427 pos_encoding: PositionalEncoding,
429 encoder_layers: Vec<TransformerEncoderLayer>,
431 decoder_layers: Vec<TransformerEncoderLayer>,
433}
434
435impl TransformerTokenizer {
436 pub fn new(config: TransformerConfig) -> TokenizerResult<Self> {
438 config.validate()?;
439
440 let mut rng = Random::seed(44);
441 let scale_in = (2.0 / (config.input_dim + config.embed_dim) as f32).sqrt();
442 let scale_out = (2.0 / (config.embed_dim + config.input_dim) as f32).sqrt();
443
444 let mut input_proj = Array2::zeros((config.input_dim, config.embed_dim));
446 let mut output_proj = Array2::zeros((config.embed_dim, config.input_dim));
447
448 for val in input_proj.iter_mut() {
449 *val = (rng.gen_range(-1.0..1.0)) * scale_in;
450 }
451 for val in output_proj.iter_mut() {
452 *val = (rng.gen_range(-1.0..1.0)) * scale_out;
453 }
454
455 let max_seq_len = config.max_seq_len;
457 let embed_dim = config.embed_dim;
458 let num_encoder_layers = config.num_encoder_layers;
459 let num_decoder_layers = config.num_decoder_layers;
460 let num_heads = config.num_heads;
461 let feedforward_dim = config.feedforward_dim;
462
463 let mut encoder_layers = Vec::new();
465 for _ in 0..num_encoder_layers {
466 encoder_layers.push(TransformerEncoderLayer::new(
467 embed_dim,
468 num_heads,
469 feedforward_dim,
470 )?);
471 }
472
473 let mut decoder_layers = Vec::new();
475 for _ in 0..num_decoder_layers {
476 decoder_layers.push(TransformerEncoderLayer::new(
477 embed_dim,
478 num_heads,
479 feedforward_dim,
480 )?);
481 }
482
483 Ok(Self {
484 config,
485 input_proj,
486 output_proj,
487 pos_encoding: PositionalEncoding::new(max_seq_len, embed_dim),
488 encoder_layers,
489 decoder_layers,
490 })
491 }
492
493 pub fn config(&self) -> &TransformerConfig {
495 &self.config
496 }
497}
498
499impl SignalTokenizer for TransformerTokenizer {
500 fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
501 let len = signal.len();
502 if len > self.config.max_seq_len * self.config.input_dim {
503 return Err(TokenizerError::encoding(
504 format!(
505 "Signal too long: {} > {}",
506 len,
507 self.config.max_seq_len * self.config.input_dim
508 ),
509 "TransformerTokenizer::encode",
510 ));
511 }
512
513 let seq_len = len.div_ceil(self.config.input_dim);
515 let mut padded = signal.to_vec();
516 padded.resize(seq_len * self.config.input_dim, 0.0);
517
518 let mut x = Array2::zeros((seq_len, self.config.input_dim));
519 for i in 0..seq_len {
520 for j in 0..self.config.input_dim {
521 x[[i, j]] = padded[i * self.config.input_dim + j];
522 }
523 }
524
525 let mut x = x.dot(&self.input_proj); x = self.pos_encoding.forward(&x)?;
530
531 for layer in &self.encoder_layers {
533 x = layer.forward(&x)?;
534 }
535
536 let mut result = Vec::new();
538 for i in 0..x.nrows() {
539 for j in 0..x.ncols() {
540 result.push(x[[i, j]]);
541 }
542 }
543
544 Ok(Array1::from_vec(result))
545 }
546
547 fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
548 let total_len = tokens.len();
549 if !total_len.is_multiple_of(self.config.embed_dim) {
550 return Err(TokenizerError::decoding(
551 format!(
552 "Invalid token length: {} not divisible by {}",
553 total_len, self.config.embed_dim
554 ),
555 "TransformerTokenizer::decode",
556 ));
557 }
558
559 let seq_len = total_len / self.config.embed_dim;
560
561 let mut x = Array2::zeros((seq_len, self.config.embed_dim));
563 for i in 0..seq_len {
564 for j in 0..self.config.embed_dim {
565 x[[i, j]] = tokens[i * self.config.embed_dim + j];
566 }
567 }
568
569 for layer in &self.decoder_layers {
571 x = layer.forward(&x)?;
572 }
573
574 x = x.dot(&self.output_proj); let mut result = Vec::new();
579 for i in 0..x.nrows() {
580 for j in 0..x.ncols() {
581 result.push(x[[i, j]]);
582 }
583 }
584
585 Ok(Array1::from_vec(result))
586 }
587
588 fn embed_dim(&self) -> usize {
589 self.config.embed_dim
590 }
591
592 fn vocab_size(&self) -> usize {
593 0 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600
601 #[test]
602 fn test_transformer_config_validation() {
603 let config = TransformerConfig::default();
604 assert!(config.validate().is_ok());
605
606 let mut bad_config = config.clone();
607 bad_config.embed_dim = 0;
608 assert!(bad_config.validate().is_err());
609
610 let mut bad_config = config.clone();
611 bad_config.embed_dim = 100; assert!(bad_config.validate().is_err());
613
614 let mut bad_config = config.clone();
615 bad_config.dropout = 1.5;
616 assert!(bad_config.validate().is_err());
617 }
618
619 #[test]
620 fn test_multihead_attention_creation() {
621 let mha = MultiHeadAttention::new(256, 8);
622 assert!(mha.is_ok());
623
624 let bad_mha = MultiHeadAttention::new(256, 7); assert!(bad_mha.is_err());
626 }
627
628 #[test]
629 fn test_multihead_attention_forward() {
630 let mha = MultiHeadAttention::new(64, 4).unwrap();
631 let x = Array2::ones((10, 64)); let out = mha.forward(&x);
633 assert!(out.is_ok());
634 let out = out.unwrap();
635 assert_eq!(out.shape(), &[10, 64]);
636 }
637
638 #[test]
639 fn test_positional_encoding() {
640 let pe = PositionalEncoding::new(100, 64);
641 let x = Array2::zeros((50, 64));
642 let out = pe.forward(&x);
643 assert!(out.is_ok());
644 let out = out.unwrap();
645 assert_eq!(out.shape(), &[50, 64]);
646 }
647
648 #[test]
649 fn test_positional_encoding_seq_too_long() {
650 let pe = PositionalEncoding::new(10, 64);
651 let x = Array2::zeros((20, 64)); let out = pe.forward(&x);
653 assert!(out.is_err());
654 }
655
656 #[test]
657 fn test_layer_norm() {
658 let ln = LayerNorm::new(64, 1e-5);
659 let x = Array2::from_shape_fn((10, 64), |(i, j)| (i + j) as f32);
660 let out = ln.forward(&x);
661 assert_eq!(out.shape(), &[10, 64]);
662
663 for row in out.rows() {
665 let mean = row.mean().unwrap();
666 let var = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / 64.0;
667 assert!((mean.abs()) < 1e-5);
668 assert!((var - 1.0).abs() < 1e-4);
669 }
670 }
671
672 #[test]
673 fn test_feedforward() {
674 let ffn = FeedForward::new(64, 256);
675 let x = Array2::ones((10, 64));
676 let out = ffn.forward(&x);
677 assert_eq!(out.shape(), &[10, 64]);
678 }
679
680 #[test]
681 fn test_encoder_layer() {
682 let layer = TransformerEncoderLayer::new(64, 4, 256).unwrap();
683 let x = Array2::ones((10, 64));
684 let out = layer.forward(&x);
685 assert!(out.is_ok());
686 let out = out.unwrap();
687 assert_eq!(out.shape(), &[10, 64]);
688 }
689
690 #[test]
691 fn test_transformer_tokenizer_creation() {
692 let config = TransformerConfig {
693 input_dim: 32,
694 embed_dim: 64,
695 num_heads: 4,
696 num_encoder_layers: 2,
697 num_decoder_layers: 2,
698 feedforward_dim: 128,
699 dropout: 0.1,
700 max_seq_len: 100,
701 };
702 let tokenizer = TransformerTokenizer::new(config);
703 assert!(tokenizer.is_ok());
704 }
705
706 #[test]
707 fn test_transformer_encode_decode() {
708 let config = TransformerConfig {
709 input_dim: 16,
710 embed_dim: 32,
711 num_heads: 4,
712 num_encoder_layers: 1,
713 num_decoder_layers: 1,
714 feedforward_dim: 64,
715 dropout: 0.0,
716 max_seq_len: 10,
717 };
718 let tokenizer = TransformerTokenizer::new(config).unwrap();
719
720 let signal = Array1::linspace(0.0, 1.0, 64);
721 let encoded = tokenizer.encode(&signal);
722 assert!(encoded.is_ok());
723
724 let encoded = encoded.unwrap();
725 let decoded = tokenizer.decode(&encoded);
726 assert!(decoded.is_ok());
727 let decoded = decoded.unwrap();
728
729 assert!(decoded.len() >= signal.len());
731 }
732
733 #[test]
734 fn test_transformer_signal_too_long() {
735 let config = TransformerConfig {
736 input_dim: 16,
737 embed_dim: 32,
738 num_heads: 4,
739 num_encoder_layers: 1,
740 num_decoder_layers: 1,
741 feedforward_dim: 64,
742 dropout: 0.0,
743 max_seq_len: 2, };
745 let tokenizer = TransformerTokenizer::new(config).unwrap();
746
747 let signal = Array1::linspace(0.0, 1.0, 1000); let encoded = tokenizer.encode(&signal);
749 assert!(encoded.is_err());
750 }
751
752 #[test]
753 fn test_softmax() {
754 let x = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 1.0, 1.0, 1.0]).unwrap();
755 let result = MultiHeadAttention::softmax(&x).unwrap();
756
757 for row in result.rows() {
759 let sum: f32 = row.iter().sum();
760 assert!((sum - 1.0).abs() < 1e-5);
761 }
762
763 for &val in result.iter() {
765 assert!(val >= 0.0);
766 }
767 }
768
769 #[test]
770 fn test_gelu_activation() {
771 assert!((FeedForward::gelu(0.0)).abs() < 1e-5);
773
774 assert!(FeedForward::gelu(1.0) > FeedForward::gelu(0.5));
776 assert!(FeedForward::gelu(2.0) > FeedForward::gelu(1.0));
777
778 assert!(FeedForward::gelu(-1.0) < 0.0);
780 assert!(FeedForward::gelu(1.0) > 0.0);
781 }
782}