kizzasi_tokenizer/
transformer.rs

1//! Transformer-based signal tokenization using self-attention mechanisms.
2//!
3//! This module implements a modern transformer architecture for signal tokenization,
4//! featuring multi-head self-attention, positional encoding, and encoder-decoder structure.
5//! Inspired by "Attention Is All You Need" (Vaswani et al., 2017).
6//!
7//! # Architecture
8//!
9//! - Multi-head self-attention for capturing global dependencies
10//! - Positional encoding for sequence order information
11//! - Feed-forward networks for non-linear transformations
12//! - Layer normalization for training stability
13//! - Residual connections for gradient flow
14//!
15//! # Example
16//!
17//! ```
18//! use kizzasi_tokenizer::{TransformerTokenizer, TransformerConfig, SignalTokenizer};
19//! use scirs2_core::ndarray::Array1;
20//!
21//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
22//! let config = TransformerConfig {
23//!     input_dim: 128,
24//!     embed_dim: 256,
25//!     num_heads: 8,
26//!     num_encoder_layers: 4,
27//!     num_decoder_layers: 4,
28//!     feedforward_dim: 1024,
29//!     dropout: 0.1,
30//!     max_seq_len: 512,
31//! };
32//!
33//! let tokenizer = TransformerTokenizer::new(config)?;
34//! let signal = Array1::linspace(0.0, 1.0, 128);
35//! let tokens = tokenizer.encode(&signal)?;
36//! let reconstructed = tokenizer.decode(&tokens)?;
37//! # Ok(())
38//! # }
39//! ```
40
41use crate::error::{TokenizerError, TokenizerResult};
42use crate::SignalTokenizer;
43use scirs2_core::ndarray::{s, Array1, Array2};
44use scirs2_core::random::{rngs::StdRng, Random};
45use serde::{Deserialize, Serialize};
46
47/// Configuration for the Transformer tokenizer
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct TransformerConfig {
50    /// Input signal dimension
51    pub input_dim: usize,
52    /// Embedding dimension (must be divisible by num_heads)
53    pub embed_dim: usize,
54    /// Number of attention heads
55    pub num_heads: usize,
56    /// Number of encoder layers
57    pub num_encoder_layers: usize,
58    /// Number of decoder layers
59    pub num_decoder_layers: usize,
60    /// Feedforward network hidden dimension
61    pub feedforward_dim: usize,
62    /// Dropout probability (0.0 to 1.0)
63    pub dropout: f32,
64    /// Maximum sequence length for positional encoding
65    pub max_seq_len: usize,
66}
67
68impl Default for TransformerConfig {
69    fn default() -> Self {
70        Self {
71            input_dim: 128,
72            embed_dim: 256,
73            num_heads: 8,
74            num_encoder_layers: 6,
75            num_decoder_layers: 6,
76            feedforward_dim: 1024,
77            dropout: 0.1,
78            max_seq_len: 512,
79        }
80    }
81}
82
83impl TransformerConfig {
84    /// Validate the configuration
85    pub fn validate(&self) -> TokenizerResult<()> {
86        if self.input_dim == 0 {
87            return Err(TokenizerError::invalid_input(
88                "input_dim must be positive",
89                "TransformerConfig::validate",
90            ));
91        }
92        if self.embed_dim == 0 {
93            return Err(TokenizerError::invalid_input(
94                "embed_dim must be positive",
95                "TransformerConfig::validate",
96            ));
97        }
98        if !self.embed_dim.is_multiple_of(self.num_heads) {
99            return Err(TokenizerError::invalid_input(
100                "embed_dim must be divisible by num_heads",
101                "TransformerConfig::validate",
102            ));
103        }
104        if self.num_heads == 0 {
105            return Err(TokenizerError::invalid_input(
106                "num_heads must be positive",
107                "TransformerConfig::validate",
108            ));
109        }
110        if !(0.0..=1.0).contains(&self.dropout) {
111            return Err(TokenizerError::invalid_input(
112                "dropout must be in range [0.0, 1.0]",
113                "TransformerConfig::validate",
114            ));
115        }
116        if self.max_seq_len == 0 {
117            return Err(TokenizerError::invalid_input(
118                "max_seq_len must be positive",
119                "TransformerConfig::validate",
120            ));
121        }
122        Ok(())
123    }
124}
125
126/// Multi-head self-attention mechanism
127#[derive(Debug, Clone)]
128pub struct MultiHeadAttention {
129    /// Number of attention heads
130    num_heads: usize,
131    /// Dimension per head
132    head_dim: usize,
133    /// Query projection weights [embed_dim, embed_dim]
134    w_query: Array2<f32>,
135    /// Key projection weights [embed_dim, embed_dim]
136    w_key: Array2<f32>,
137    /// Value projection weights [embed_dim, embed_dim]
138    w_value: Array2<f32>,
139    /// Output projection weights [embed_dim, embed_dim]
140    w_out: Array2<f32>,
141}
142
143impl MultiHeadAttention {
144    /// Create a new multi-head attention layer
145    pub fn new(embed_dim: usize, num_heads: usize) -> TokenizerResult<Self> {
146        if !embed_dim.is_multiple_of(num_heads) {
147            return Err(TokenizerError::invalid_input(
148                "embed_dim must be divisible by num_heads",
149                "MultiHeadAttention::new",
150            ));
151        }
152
153        let head_dim = embed_dim / num_heads;
154        let mut rng = Random::seed(42);
155
156        // Xavier/Glorot initialization
157        let scale = (2.0 / (embed_dim + embed_dim) as f32).sqrt();
158
159        Ok(Self {
160            num_heads,
161            head_dim,
162            w_query: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
163            w_key: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
164            w_value: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
165            w_out: Self::init_weights(embed_dim, embed_dim, scale, &mut rng),
166        })
167    }
168
169    /// Initialize weights with Xavier/Glorot uniform distribution
170    fn init_weights(rows: usize, cols: usize, scale: f32, rng: &mut Random<StdRng>) -> Array2<f32> {
171        let mut weights = Array2::zeros((rows, cols));
172        for val in weights.iter_mut() {
173            *val = (rng.gen_range(-1.0..1.0)) * scale;
174        }
175        weights
176    }
177
178    /// Forward pass through multi-head attention
179    ///
180    /// # Arguments
181    ///
182    /// * `x` - Input tensor [seq_len, embed_dim]
183    ///
184    /// # Returns
185    ///
186    /// Output tensor [seq_len, embed_dim]
187    pub fn forward(&self, x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
188        let seq_len = x.nrows();
189        let embed_dim = x.ncols();
190
191        // Linear projections: Q, K, V = x @ W_q, x @ W_k, x @ W_v
192        let query = x.dot(&self.w_query); // [seq_len, embed_dim]
193        let key = x.dot(&self.w_key); // [seq_len, embed_dim]
194        let value = x.dot(&self.w_value); // [seq_len, embed_dim]
195
196        // Scaled dot-product attention for each head
197        let scale = (self.head_dim as f32).sqrt();
198        let mut attention_output = Array2::zeros((seq_len, embed_dim));
199
200        for h in 0..self.num_heads {
201            // Extract Q, K, V for this head
202            let mut q_head = Array2::zeros((seq_len, self.head_dim));
203            let mut k_head = Array2::zeros((seq_len, self.head_dim));
204            let mut v_head = Array2::zeros((seq_len, self.head_dim));
205
206            let start_idx = h * self.head_dim;
207            for i in 0..seq_len {
208                for j in 0..self.head_dim {
209                    q_head[[i, j]] = query[[i, start_idx + j]];
210                    k_head[[i, j]] = key[[i, start_idx + j]];
211                    v_head[[i, j]] = value[[i, start_idx + j]];
212                }
213            }
214
215            // Attention scores: Q @ K^T / sqrt(d_k)
216            let scores = q_head.dot(&k_head.t()) / scale; // [seq_len, seq_len]
217
218            // Softmax over the last dimension
219            let attention_weights = Self::softmax(&scores)?;
220
221            // Weighted sum: softmax(scores) @ V
222            let head_output = attention_weights.dot(&v_head); // [seq_len, head_dim]
223
224            // Copy to output tensor
225            for i in 0..seq_len {
226                for j in 0..self.head_dim {
227                    attention_output[[i, start_idx + j]] = head_output[[i, j]];
228                }
229            }
230        }
231
232        // Final linear projection
233        Ok(attention_output.dot(&self.w_out))
234    }
235
236    /// Apply softmax to each row
237    fn softmax(x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
238        let mut result = x.clone();
239        for mut row in result.rows_mut() {
240            // Subtract max for numerical stability
241            let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
242            for val in row.iter_mut() {
243                *val = (*val - max_val).exp();
244            }
245            let sum: f32 = row.iter().sum();
246            if sum > 0.0 {
247                for val in row.iter_mut() {
248                    *val /= sum;
249                }
250            }
251        }
252        Ok(result)
253    }
254}
255
256/// Positional encoding using sinusoidal functions
257#[derive(Debug, Clone)]
258pub struct PositionalEncoding {
259    /// Pre-computed positional encodings [max_seq_len, embed_dim]
260    encodings: Array2<f32>,
261}
262
263impl PositionalEncoding {
264    /// Create a new positional encoding
265    pub fn new(max_seq_len: usize, embed_dim: usize) -> Self {
266        let mut encodings = Array2::zeros((max_seq_len, embed_dim));
267
268        for pos in 0..max_seq_len {
269            for i in 0..embed_dim {
270                let angle = pos as f32 / 10000.0_f32.powf(2.0 * (i / 2) as f32 / embed_dim as f32);
271                if i % 2 == 0 {
272                    encodings[[pos, i]] = angle.sin();
273                } else {
274                    encodings[[pos, i]] = angle.cos();
275                }
276            }
277        }
278
279        Self { encodings }
280    }
281
282    /// Add positional encoding to input tensor
283    pub fn forward(&self, x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
284        let seq_len = x.nrows();
285        if seq_len > self.encodings.nrows() {
286            return Err(TokenizerError::encoding(
287                format!(
288                    "Sequence length {} exceeds max_seq_len {}",
289                    seq_len,
290                    self.encodings.nrows()
291                ),
292                "PositionalEncoding::forward",
293            ));
294        }
295
296        let pos_enc = self.encodings.slice(s![0..seq_len, ..]);
297        Ok(x + &pos_enc)
298    }
299}
300
301/// Layer normalization
302#[derive(Debug, Clone)]
303pub struct LayerNorm {
304    /// Feature dimension
305    dim: usize,
306    /// Small constant for numerical stability
307    eps: f32,
308}
309
310impl LayerNorm {
311    /// Create a new layer normalization
312    pub fn new(dim: usize, eps: f32) -> Self {
313        Self { dim, eps }
314    }
315
316    /// Apply layer normalization
317    pub fn forward(&self, x: &Array2<f32>) -> Array2<f32> {
318        let mut result = x.clone();
319        for mut row in result.rows_mut() {
320            let mean = row.mean().unwrap_or(0.0);
321            let variance = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / self.dim as f32;
322            let std = (variance + self.eps).sqrt();
323
324            for val in row.iter_mut() {
325                *val = (*val - mean) / std;
326            }
327        }
328        result
329    }
330}
331
332/// Feedforward network with GELU activation
333#[derive(Debug, Clone)]
334pub struct FeedForward {
335    /// First linear layer weights
336    w1: Array2<f32>,
337    /// Second linear layer weights
338    w2: Array2<f32>,
339}
340
341impl FeedForward {
342    /// Create a new feedforward network
343    pub fn new(embed_dim: usize, hidden_dim: usize) -> Self {
344        let mut rng = Random::seed(43);
345        let scale1 = (2.0 / (embed_dim + hidden_dim) as f32).sqrt();
346        let scale2 = (2.0 / (hidden_dim + embed_dim) as f32).sqrt();
347
348        Self {
349            w1: Self::init_weights(embed_dim, hidden_dim, scale1, &mut rng),
350            w2: Self::init_weights(hidden_dim, embed_dim, scale2, &mut rng),
351        }
352    }
353
354    /// Initialize weights
355    fn init_weights(rows: usize, cols: usize, scale: f32, rng: &mut Random<StdRng>) -> Array2<f32> {
356        let mut weights = Array2::zeros((rows, cols));
357        for val in weights.iter_mut() {
358            *val = (rng.gen_range(-1.0..1.0)) * scale;
359        }
360        weights
361    }
362
363    /// GELU activation function
364    fn gelu(x: f32) -> f32 {
365        0.5 * x * (1.0 + ((2.0 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x.powi(3))).tanh())
366    }
367
368    /// Forward pass
369    pub fn forward(&self, x: &Array2<f32>) -> Array2<f32> {
370        let hidden = x.dot(&self.w1);
371        let activated = hidden.mapv(Self::gelu);
372        activated.dot(&self.w2)
373    }
374}
375
376/// Transformer encoder layer
377#[derive(Debug, Clone)]
378pub struct TransformerEncoderLayer {
379    /// Multi-head attention
380    attention: MultiHeadAttention,
381    /// Feedforward network
382    ffn: FeedForward,
383    /// Layer normalization 1
384    norm1: LayerNorm,
385    /// Layer normalization 2
386    norm2: LayerNorm,
387}
388
389impl TransformerEncoderLayer {
390    /// Create a new encoder layer
391    pub fn new(
392        embed_dim: usize,
393        num_heads: usize,
394        feedforward_dim: usize,
395    ) -> TokenizerResult<Self> {
396        Ok(Self {
397            attention: MultiHeadAttention::new(embed_dim, num_heads)?,
398            ffn: FeedForward::new(embed_dim, feedforward_dim),
399            norm1: LayerNorm::new(embed_dim, 1e-5),
400            norm2: LayerNorm::new(embed_dim, 1e-5),
401        })
402    }
403
404    /// Forward pass with residual connections
405    pub fn forward(&self, x: &Array2<f32>) -> TokenizerResult<Array2<f32>> {
406        // Self-attention with residual
407        let attn_out = self.attention.forward(x)?;
408        let x = &(x + &attn_out);
409        let x_norm = self.norm1.forward(x);
410
411        // Feedforward with residual
412        let ffn_out = self.ffn.forward(&x_norm);
413        let out = &x_norm + &ffn_out;
414        Ok(self.norm2.forward(&out))
415    }
416}
417
418/// Transformer-based signal tokenizer
419#[derive(Debug, Clone)]
420pub struct TransformerTokenizer {
421    /// Configuration
422    config: TransformerConfig,
423    /// Input projection
424    input_proj: Array2<f32>,
425    /// Output projection
426    output_proj: Array2<f32>,
427    /// Positional encoding
428    pos_encoding: PositionalEncoding,
429    /// Encoder layers
430    encoder_layers: Vec<TransformerEncoderLayer>,
431    /// Decoder layers
432    decoder_layers: Vec<TransformerEncoderLayer>,
433}
434
435impl TransformerTokenizer {
436    /// Create a new transformer tokenizer
437    pub fn new(config: TransformerConfig) -> TokenizerResult<Self> {
438        config.validate()?;
439
440        let mut rng = Random::seed(44);
441        let scale_in = (2.0 / (config.input_dim + config.embed_dim) as f32).sqrt();
442        let scale_out = (2.0 / (config.embed_dim + config.input_dim) as f32).sqrt();
443
444        // Initialize projection layers
445        let mut input_proj = Array2::zeros((config.input_dim, config.embed_dim));
446        let mut output_proj = Array2::zeros((config.embed_dim, config.input_dim));
447
448        for val in input_proj.iter_mut() {
449            *val = (rng.gen_range(-1.0..1.0)) * scale_in;
450        }
451        for val in output_proj.iter_mut() {
452            *val = (rng.gen_range(-1.0..1.0)) * scale_out;
453        }
454
455        // Save these before moving config
456        let max_seq_len = config.max_seq_len;
457        let embed_dim = config.embed_dim;
458        let num_encoder_layers = config.num_encoder_layers;
459        let num_decoder_layers = config.num_decoder_layers;
460        let num_heads = config.num_heads;
461        let feedforward_dim = config.feedforward_dim;
462
463        // Create encoder layers
464        let mut encoder_layers = Vec::new();
465        for _ in 0..num_encoder_layers {
466            encoder_layers.push(TransformerEncoderLayer::new(
467                embed_dim,
468                num_heads,
469                feedforward_dim,
470            )?);
471        }
472
473        // Create decoder layers
474        let mut decoder_layers = Vec::new();
475        for _ in 0..num_decoder_layers {
476            decoder_layers.push(TransformerEncoderLayer::new(
477                embed_dim,
478                num_heads,
479                feedforward_dim,
480            )?);
481        }
482
483        Ok(Self {
484            config,
485            input_proj,
486            output_proj,
487            pos_encoding: PositionalEncoding::new(max_seq_len, embed_dim),
488            encoder_layers,
489            decoder_layers,
490        })
491    }
492
493    /// Get the configuration
494    pub fn config(&self) -> &TransformerConfig {
495        &self.config
496    }
497}
498
499impl SignalTokenizer for TransformerTokenizer {
500    fn encode(&self, signal: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
501        let len = signal.len();
502        if len > self.config.max_seq_len * self.config.input_dim {
503            return Err(TokenizerError::encoding(
504                format!(
505                    "Signal too long: {} > {}",
506                    len,
507                    self.config.max_seq_len * self.config.input_dim
508                ),
509                "TransformerTokenizer::encode",
510            ));
511        }
512
513        // Reshape to [seq_len, input_dim]
514        let seq_len = len.div_ceil(self.config.input_dim);
515        let mut padded = signal.to_vec();
516        padded.resize(seq_len * self.config.input_dim, 0.0);
517
518        let mut x = Array2::zeros((seq_len, self.config.input_dim));
519        for i in 0..seq_len {
520            for j in 0..self.config.input_dim {
521                x[[i, j]] = padded[i * self.config.input_dim + j];
522            }
523        }
524
525        // Project to embedding space
526        let mut x = x.dot(&self.input_proj); // [seq_len, embed_dim]
527
528        // Add positional encoding
529        x = self.pos_encoding.forward(&x)?;
530
531        // Encoder layers
532        for layer in &self.encoder_layers {
533            x = layer.forward(&x)?;
534        }
535
536        // Flatten to 1D
537        let mut result = Vec::new();
538        for i in 0..x.nrows() {
539            for j in 0..x.ncols() {
540                result.push(x[[i, j]]);
541            }
542        }
543
544        Ok(Array1::from_vec(result))
545    }
546
547    fn decode(&self, tokens: &Array1<f32>) -> TokenizerResult<Array1<f32>> {
548        let total_len = tokens.len();
549        if !total_len.is_multiple_of(self.config.embed_dim) {
550            return Err(TokenizerError::decoding(
551                format!(
552                    "Invalid token length: {} not divisible by {}",
553                    total_len, self.config.embed_dim
554                ),
555                "TransformerTokenizer::decode",
556            ));
557        }
558
559        let seq_len = total_len / self.config.embed_dim;
560
561        // Reshape to [seq_len, embed_dim]
562        let mut x = Array2::zeros((seq_len, self.config.embed_dim));
563        for i in 0..seq_len {
564            for j in 0..self.config.embed_dim {
565                x[[i, j]] = tokens[i * self.config.embed_dim + j];
566            }
567        }
568
569        // Decoder layers
570        for layer in &self.decoder_layers {
571            x = layer.forward(&x)?;
572        }
573
574        // Project back to input space
575        x = x.dot(&self.output_proj); // [seq_len, input_dim]
576
577        // Flatten to 1D
578        let mut result = Vec::new();
579        for i in 0..x.nrows() {
580            for j in 0..x.ncols() {
581                result.push(x[[i, j]]);
582            }
583        }
584
585        Ok(Array1::from_vec(result))
586    }
587
588    fn embed_dim(&self) -> usize {
589        self.config.embed_dim
590    }
591
592    fn vocab_size(&self) -> usize {
593        0 // Continuous tokenizer, no discrete vocabulary
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    #[test]
602    fn test_transformer_config_validation() {
603        let config = TransformerConfig::default();
604        assert!(config.validate().is_ok());
605
606        let mut bad_config = config.clone();
607        bad_config.embed_dim = 0;
608        assert!(bad_config.validate().is_err());
609
610        let mut bad_config = config.clone();
611        bad_config.embed_dim = 100; // Not divisible by num_heads (8)
612        assert!(bad_config.validate().is_err());
613
614        let mut bad_config = config.clone();
615        bad_config.dropout = 1.5;
616        assert!(bad_config.validate().is_err());
617    }
618
619    #[test]
620    fn test_multihead_attention_creation() {
621        let mha = MultiHeadAttention::new(256, 8);
622        assert!(mha.is_ok());
623
624        let bad_mha = MultiHeadAttention::new(256, 7); // Not divisible
625        assert!(bad_mha.is_err());
626    }
627
628    #[test]
629    fn test_multihead_attention_forward() {
630        let mha = MultiHeadAttention::new(64, 4).unwrap();
631        let x = Array2::ones((10, 64)); // [seq_len=10, embed_dim=64]
632        let out = mha.forward(&x);
633        assert!(out.is_ok());
634        let out = out.unwrap();
635        assert_eq!(out.shape(), &[10, 64]);
636    }
637
638    #[test]
639    fn test_positional_encoding() {
640        let pe = PositionalEncoding::new(100, 64);
641        let x = Array2::zeros((50, 64));
642        let out = pe.forward(&x);
643        assert!(out.is_ok());
644        let out = out.unwrap();
645        assert_eq!(out.shape(), &[50, 64]);
646    }
647
648    #[test]
649    fn test_positional_encoding_seq_too_long() {
650        let pe = PositionalEncoding::new(10, 64);
651        let x = Array2::zeros((20, 64)); // Longer than max_seq_len
652        let out = pe.forward(&x);
653        assert!(out.is_err());
654    }
655
656    #[test]
657    fn test_layer_norm() {
658        let ln = LayerNorm::new(64, 1e-5);
659        let x = Array2::from_shape_fn((10, 64), |(i, j)| (i + j) as f32);
660        let out = ln.forward(&x);
661        assert_eq!(out.shape(), &[10, 64]);
662
663        // Check that mean is approximately 0 and variance is approximately 1
664        for row in out.rows() {
665            let mean = row.mean().unwrap();
666            let var = row.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / 64.0;
667            assert!((mean.abs()) < 1e-5);
668            assert!((var - 1.0).abs() < 1e-4);
669        }
670    }
671
672    #[test]
673    fn test_feedforward() {
674        let ffn = FeedForward::new(64, 256);
675        let x = Array2::ones((10, 64));
676        let out = ffn.forward(&x);
677        assert_eq!(out.shape(), &[10, 64]);
678    }
679
680    #[test]
681    fn test_encoder_layer() {
682        let layer = TransformerEncoderLayer::new(64, 4, 256).unwrap();
683        let x = Array2::ones((10, 64));
684        let out = layer.forward(&x);
685        assert!(out.is_ok());
686        let out = out.unwrap();
687        assert_eq!(out.shape(), &[10, 64]);
688    }
689
690    #[test]
691    fn test_transformer_tokenizer_creation() {
692        let config = TransformerConfig {
693            input_dim: 32,
694            embed_dim: 64,
695            num_heads: 4,
696            num_encoder_layers: 2,
697            num_decoder_layers: 2,
698            feedforward_dim: 128,
699            dropout: 0.1,
700            max_seq_len: 100,
701        };
702        let tokenizer = TransformerTokenizer::new(config);
703        assert!(tokenizer.is_ok());
704    }
705
706    #[test]
707    fn test_transformer_encode_decode() {
708        let config = TransformerConfig {
709            input_dim: 16,
710            embed_dim: 32,
711            num_heads: 4,
712            num_encoder_layers: 1,
713            num_decoder_layers: 1,
714            feedforward_dim: 64,
715            dropout: 0.0,
716            max_seq_len: 10,
717        };
718        let tokenizer = TransformerTokenizer::new(config).unwrap();
719
720        let signal = Array1::linspace(0.0, 1.0, 64);
721        let encoded = tokenizer.encode(&signal);
722        assert!(encoded.is_ok());
723
724        let encoded = encoded.unwrap();
725        let decoded = tokenizer.decode(&encoded);
726        assert!(decoded.is_ok());
727        let decoded = decoded.unwrap();
728
729        // Should preserve length (with padding)
730        assert!(decoded.len() >= signal.len());
731    }
732
733    #[test]
734    fn test_transformer_signal_too_long() {
735        let config = TransformerConfig {
736            input_dim: 16,
737            embed_dim: 32,
738            num_heads: 4,
739            num_encoder_layers: 1,
740            num_decoder_layers: 1,
741            feedforward_dim: 64,
742            dropout: 0.0,
743            max_seq_len: 2, // Very small
744        };
745        let tokenizer = TransformerTokenizer::new(config).unwrap();
746
747        let signal = Array1::linspace(0.0, 1.0, 1000); // Too long
748        let encoded = tokenizer.encode(&signal);
749        assert!(encoded.is_err());
750    }
751
752    #[test]
753    fn test_softmax() {
754        let x = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 1.0, 1.0, 1.0]).unwrap();
755        let result = MultiHeadAttention::softmax(&x).unwrap();
756
757        // Check that each row sums to 1
758        for row in result.rows() {
759            let sum: f32 = row.iter().sum();
760            assert!((sum - 1.0).abs() < 1e-5);
761        }
762
763        // Check that all values are positive
764        for &val in result.iter() {
765            assert!(val >= 0.0);
766        }
767    }
768
769    #[test]
770    fn test_gelu_activation() {
771        // GELU(0) should be approximately 0
772        assert!((FeedForward::gelu(0.0)).abs() < 1e-5);
773
774        // GELU should be monotonic for positive values
775        assert!(FeedForward::gelu(1.0) > FeedForward::gelu(0.5));
776        assert!(FeedForward::gelu(2.0) > FeedForward::gelu(1.0));
777
778        // GELU should preserve sign but be smooth
779        assert!(FeedForward::gelu(-1.0) < 0.0);
780        assert!(FeedForward::gelu(1.0) > 0.0);
781    }
782}