Skip to main content

axonml_nn/layers/
transformer.rs

1//! Transformer Architecture - Encoder-Decoder Transformer
2//!
3//! Implements the full Transformer architecture following PyTorch's API:
4//! - `TransformerEncoderLayer` - Single encoder layer (self-attn + FFN)
5//! - `TransformerDecoderLayer` - Single decoder layer (self-attn + cross-attn + FFN)
6//! - `TransformerEncoder` - Stack of encoder layers
7//! - `TransformerDecoder` - Stack of decoder layers
8//! - `Seq2SeqTransformer` - Full encoder-decoder Transformer
9//!
10//! @version 0.1.0
11//! @author AutomataNexus Development Team
12
13use std::collections::HashMap;
14
15use axonml_autograd::Variable;
16use axonml_tensor::Tensor;
17
18use crate::layers::attention::MultiHeadAttention;
19use crate::layers::linear::Linear;
20use crate::layers::norm::LayerNorm;
21use crate::module::Module;
22use crate::parameter::Parameter;
23
24// =============================================================================
25// TransformerEncoderLayer
26// =============================================================================
27
28/// A single Transformer encoder layer.
29///
30/// Consists of multi-head self-attention followed by a position-wise
31/// feedforward network, each with residual connections and layer normalization.
32///
33/// # Shape
34/// - Input: (N, S, E) if batch_first (default)
35/// - Output: (N, S, E)
36///
37/// where N=batch, S=source seq len, E=d_model.
38pub struct TransformerEncoderLayer {
39    /// Self-attention.
40    self_attn: MultiHeadAttention,
41    /// Feedforward network — first linear.
42    linear1: Linear,
43    /// Feedforward network — second linear.
44    linear2: Linear,
45    /// Layer norm after self-attention.
46    norm1: LayerNorm,
47    /// Layer norm after feedforward.
48    norm2: LayerNorm,
49    /// Model dimension.
50    d_model: usize,
51}
52
53impl TransformerEncoderLayer {
54    /// Creates a new TransformerEncoderLayer.
55    ///
56    /// # Arguments
57    /// * `d_model` - Embedding dimension
58    /// * `nhead` - Number of attention heads
59    /// * `dim_feedforward` - Hidden dimension of feedforward network (default 2048)
60    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
61        Self {
62            self_attn: MultiHeadAttention::new(d_model, nhead),
63            linear1: Linear::new(d_model, dim_feedforward),
64            linear2: Linear::new(dim_feedforward, d_model),
65            norm1: LayerNorm::single(d_model),
66            norm2: LayerNorm::single(d_model),
67            d_model,
68        }
69    }
70
71    /// Forward pass with optional source mask.
72    ///
73    /// # Arguments
74    /// * `src` - Source sequence (N, S, E)
75    /// * `src_mask` - Optional attention mask
76    pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
77        // Self-attention sublayer: x + self_attn(x)
78        let attn_out = self.self_attn.attention(src, src, src, src_mask);
79        let x = src.add_var(&attn_out);
80        let x = self.norm1.forward(&x);
81
82        // Feedforward sublayer: x + FFN(x)
83        let ff_out = self.linear1.forward(&x).relu();
84        let ff_out = self.linear2.forward(&ff_out);
85        let x = x.add_var(&ff_out);
86        self.norm2.forward(&x)
87    }
88
89    /// Returns the model dimension.
90    pub fn d_model(&self) -> usize {
91        self.d_model
92    }
93}
94
95impl Module for TransformerEncoderLayer {
96    fn forward(&self, input: &Variable) -> Variable {
97        self.forward_with_mask(input, None)
98    }
99
100    fn parameters(&self) -> Vec<Parameter> {
101        let mut params = Vec::new();
102        params.extend(self.self_attn.parameters());
103        params.extend(self.linear1.parameters());
104        params.extend(self.linear2.parameters());
105        params.extend(self.norm1.parameters());
106        params.extend(self.norm2.parameters());
107        params
108    }
109
110    fn named_parameters(&self) -> HashMap<String, Parameter> {
111        let mut params = HashMap::new();
112        for (name, param) in self.self_attn.named_parameters() {
113            params.insert(format!("self_attn.{name}"), param);
114        }
115        for (name, param) in self.linear1.named_parameters() {
116            params.insert(format!("linear1.{name}"), param);
117        }
118        for (name, param) in self.linear2.named_parameters() {
119            params.insert(format!("linear2.{name}"), param);
120        }
121        for (name, param) in self.norm1.named_parameters() {
122            params.insert(format!("norm1.{name}"), param);
123        }
124        for (name, param) in self.norm2.named_parameters() {
125            params.insert(format!("norm2.{name}"), param);
126        }
127        params
128    }
129
130    fn name(&self) -> &'static str {
131        "TransformerEncoderLayer"
132    }
133}
134
135// =============================================================================
136// TransformerDecoderLayer
137// =============================================================================
138
139/// A single Transformer decoder layer.
140///
141/// Consists of:
142/// 1. Masked multi-head self-attention (causal)
143/// 2. Multi-head cross-attention over encoder output
144/// 3. Position-wise feedforward network
145///
146/// Each sublayer has residual connections and layer normalization.
147///
148/// # Shape
149/// - Target: (N, T, E)
150/// - Memory: (N, S, E)
151/// - Output: (N, T, E)
152pub struct TransformerDecoderLayer {
153    /// Masked self-attention (causal).
154    self_attn: MultiHeadAttention,
155    /// Cross-attention over encoder output.
156    cross_attn: MultiHeadAttention,
157    /// Feedforward network — first linear.
158    linear1: Linear,
159    /// Feedforward network — second linear.
160    linear2: Linear,
161    /// Layer norm after self-attention.
162    norm1: LayerNorm,
163    /// Layer norm after cross-attention.
164    norm2: LayerNorm,
165    /// Layer norm after feedforward.
166    norm3: LayerNorm,
167    /// Model dimension.
168    d_model: usize,
169}
170
171impl TransformerDecoderLayer {
172    /// Creates a new TransformerDecoderLayer.
173    ///
174    /// # Arguments
175    /// * `d_model` - Embedding dimension
176    /// * `nhead` - Number of attention heads
177    /// * `dim_feedforward` - Hidden dimension of feedforward network
178    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
179        Self {
180            self_attn: MultiHeadAttention::new(d_model, nhead),
181            cross_attn: MultiHeadAttention::new(d_model, nhead),
182            linear1: Linear::new(d_model, dim_feedforward),
183            linear2: Linear::new(dim_feedforward, d_model),
184            norm1: LayerNorm::single(d_model),
185            norm2: LayerNorm::single(d_model),
186            norm3: LayerNorm::single(d_model),
187            d_model,
188        }
189    }
190
191    /// Forward pass with encoder memory and optional masks.
192    ///
193    /// # Arguments
194    /// * `tgt` - Target sequence (N, T, E)
195    /// * `memory` - Encoder output (N, S, E)
196    /// * `tgt_mask` - Optional causal mask for self-attention
197    /// * `memory_mask` - Optional mask for cross-attention
198    pub fn forward_with_memory(
199        &self,
200        tgt: &Variable,
201        memory: &Variable,
202        tgt_mask: Option<&Variable>,
203        memory_mask: Option<&Variable>,
204    ) -> Variable {
205        // 1. Masked self-attention sublayer
206        let self_attn_out = self.self_attn.attention(tgt, tgt, tgt, tgt_mask);
207        let x = tgt.add_var(&self_attn_out);
208        let x = self.norm1.forward(&x);
209
210        // 2. Cross-attention sublayer (query=decoder, key/value=encoder)
211        let cross_attn_out = self.cross_attn.attention(&x, memory, memory, memory_mask);
212        let x = x.add_var(&cross_attn_out);
213        let x = self.norm2.forward(&x);
214
215        // 3. Feedforward sublayer
216        let ff_out = self.linear1.forward(&x).relu();
217        let ff_out = self.linear2.forward(&ff_out);
218        let x = x.add_var(&ff_out);
219        self.norm3.forward(&x)
220    }
221
222    /// Returns the model dimension.
223    pub fn d_model(&self) -> usize {
224        self.d_model
225    }
226}
227
228impl Module for TransformerDecoderLayer {
229    fn forward(&self, input: &Variable) -> Variable {
230        // Without memory, can only do self-attention pass.
231        // Use forward_with_memory() for full decoder behavior.
232        let self_attn_out = self.self_attn.attention(input, input, input, None);
233        let x = input.add_var(&self_attn_out);
234        let x = self.norm1.forward(&x);
235
236        // Skip cross-attention (no memory), go straight to FFN
237        let x_after_norm2 = self.norm2.forward(&x);
238        let ff_out = self.linear1.forward(&x_after_norm2).relu();
239        let ff_out = self.linear2.forward(&ff_out);
240        let x = x_after_norm2.add_var(&ff_out);
241        self.norm3.forward(&x)
242    }
243
244    fn parameters(&self) -> Vec<Parameter> {
245        let mut params = Vec::new();
246        params.extend(self.self_attn.parameters());
247        params.extend(self.cross_attn.parameters());
248        params.extend(self.linear1.parameters());
249        params.extend(self.linear2.parameters());
250        params.extend(self.norm1.parameters());
251        params.extend(self.norm2.parameters());
252        params.extend(self.norm3.parameters());
253        params
254    }
255
256    fn named_parameters(&self) -> HashMap<String, Parameter> {
257        let mut params = HashMap::new();
258        for (name, param) in self.self_attn.named_parameters() {
259            params.insert(format!("self_attn.{name}"), param);
260        }
261        for (name, param) in self.cross_attn.named_parameters() {
262            params.insert(format!("cross_attn.{name}"), param);
263        }
264        for (name, param) in self.linear1.named_parameters() {
265            params.insert(format!("linear1.{name}"), param);
266        }
267        for (name, param) in self.linear2.named_parameters() {
268            params.insert(format!("linear2.{name}"), param);
269        }
270        for (name, param) in self.norm1.named_parameters() {
271            params.insert(format!("norm1.{name}"), param);
272        }
273        for (name, param) in self.norm2.named_parameters() {
274            params.insert(format!("norm2.{name}"), param);
275        }
276        for (name, param) in self.norm3.named_parameters() {
277            params.insert(format!("norm3.{name}"), param);
278        }
279        params
280    }
281
282    fn name(&self) -> &'static str {
283        "TransformerDecoderLayer"
284    }
285}
286
287// =============================================================================
288// TransformerEncoder
289// =============================================================================
290
291/// Stack of N TransformerEncoderLayers.
292///
293/// # Shape
294/// - Input: (N, S, E)
295/// - Output: (N, S, E)
296pub struct TransformerEncoder {
297    /// Encoder layers.
298    layers: Vec<TransformerEncoderLayer>,
299    /// Optional final layer norm.
300    norm: Option<LayerNorm>,
301}
302
303impl TransformerEncoder {
304    /// Creates a TransformerEncoder with the given number of layers.
305    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
306        let layers = (0..num_layers)
307            .map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward))
308            .collect();
309
310        Self {
311            layers,
312            norm: Some(LayerNorm::single(d_model)),
313        }
314    }
315
316    /// Creates a TransformerEncoder without final layer norm.
317    pub fn without_norm(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
318        let layers = (0..num_layers)
319            .map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward))
320            .collect();
321
322        Self { layers, norm: None }
323    }
324
325    /// Forward pass with optional source mask.
326    pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
327        let mut x = src.clone();
328        for layer in &self.layers {
329            x = layer.forward_with_mask(&x, src_mask);
330        }
331        if let Some(ref norm) = self.norm {
332            x = norm.forward(&x);
333        }
334        x
335    }
336
337    /// Returns the number of layers.
338    pub fn num_layers(&self) -> usize {
339        self.layers.len()
340    }
341}
342
343impl Module for TransformerEncoder {
344    fn forward(&self, input: &Variable) -> Variable {
345        self.forward_with_mask(input, None)
346    }
347
348    fn parameters(&self) -> Vec<Parameter> {
349        let mut params: Vec<Parameter> = self.layers.iter()
350            .flat_map(|l| l.parameters())
351            .collect();
352        if let Some(ref norm) = self.norm {
353            params.extend(norm.parameters());
354        }
355        params
356    }
357
358    fn named_parameters(&self) -> HashMap<String, Parameter> {
359        let mut params = HashMap::new();
360        for (i, layer) in self.layers.iter().enumerate() {
361            for (name, param) in layer.named_parameters() {
362                params.insert(format!("layers.{i}.{name}"), param);
363            }
364        }
365        if let Some(ref norm) = self.norm {
366            for (name, param) in norm.named_parameters() {
367                params.insert(format!("norm.{name}"), param);
368            }
369        }
370        params
371    }
372
373    fn name(&self) -> &'static str {
374        "TransformerEncoder"
375    }
376}
377
378// =============================================================================
379// TransformerDecoder
380// =============================================================================
381
382/// Stack of N TransformerDecoderLayers.
383///
384/// # Shape
385/// - Target: (N, T, E)
386/// - Memory: (N, S, E)
387/// - Output: (N, T, E)
388pub struct TransformerDecoder {
389    /// Decoder layers.
390    layers: Vec<TransformerDecoderLayer>,
391    /// Optional final layer norm.
392    norm: Option<LayerNorm>,
393}
394
395impl TransformerDecoder {
396    /// Creates a TransformerDecoder with the given number of layers.
397    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
398        let layers = (0..num_layers)
399            .map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward))
400            .collect();
401
402        Self {
403            layers,
404            norm: Some(LayerNorm::single(d_model)),
405        }
406    }
407
408    /// Creates a TransformerDecoder without final layer norm.
409    pub fn without_norm(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
410        let layers = (0..num_layers)
411            .map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward))
412            .collect();
413
414        Self { layers, norm: None }
415    }
416
417    /// Forward pass with encoder memory and optional masks.
418    pub fn forward_with_memory(
419        &self,
420        tgt: &Variable,
421        memory: &Variable,
422        tgt_mask: Option<&Variable>,
423        memory_mask: Option<&Variable>,
424    ) -> Variable {
425        let mut x = tgt.clone();
426        for layer in &self.layers {
427            x = layer.forward_with_memory(&x, memory, tgt_mask, memory_mask);
428        }
429        if let Some(ref norm) = self.norm {
430            x = norm.forward(&x);
431        }
432        x
433    }
434
435    /// Returns the number of layers.
436    pub fn num_layers(&self) -> usize {
437        self.layers.len()
438    }
439}
440
441impl Module for TransformerDecoder {
442    fn forward(&self, input: &Variable) -> Variable {
443        // Without memory, runs self-attention only (for pretraining/testing)
444        let mut x = input.clone();
445        for layer in &self.layers {
446            x = layer.forward(&x);
447        }
448        if let Some(ref norm) = self.norm {
449            x = norm.forward(&x);
450        }
451        x
452    }
453
454    fn parameters(&self) -> Vec<Parameter> {
455        let mut params: Vec<Parameter> = self.layers.iter()
456            .flat_map(|l| l.parameters())
457            .collect();
458        if let Some(ref norm) = self.norm {
459            params.extend(norm.parameters());
460        }
461        params
462    }
463
464    fn named_parameters(&self) -> HashMap<String, Parameter> {
465        let mut params = HashMap::new();
466        for (i, layer) in self.layers.iter().enumerate() {
467            for (name, param) in layer.named_parameters() {
468                params.insert(format!("layers.{i}.{name}"), param);
469            }
470        }
471        if let Some(ref norm) = self.norm {
472            for (name, param) in norm.named_parameters() {
473                params.insert(format!("norm.{name}"), param);
474            }
475        }
476        params
477    }
478
479    fn name(&self) -> &'static str {
480        "TransformerDecoder"
481    }
482}
483
484// =============================================================================
485// Seq2SeqTransformer
486// =============================================================================
487
488/// Full Encoder-Decoder Transformer for sequence-to-sequence tasks.
489///
490/// Combines a TransformerEncoder and TransformerDecoder into a single module.
491/// Follows PyTorch's `nn.Transformer` API.
492///
493/// # Architecture
494/// ```text
495/// Source → [Encoder] → Memory
496///                         ↓
497/// Target → [Decoder] → Output
498/// ```
499///
500/// # Shape
501/// - Source: (N, S, E)
502/// - Target: (N, T, E)
503/// - Output: (N, T, E)
504pub struct Seq2SeqTransformer {
505    /// Encoder stack.
506    encoder: TransformerEncoder,
507    /// Decoder stack.
508    decoder: TransformerDecoder,
509    /// Model dimension.
510    d_model: usize,
511    /// Number of attention heads.
512    nhead: usize,
513}
514
515impl Seq2SeqTransformer {
516    /// Creates a new Seq2SeqTransformer.
517    ///
518    /// # Arguments
519    /// * `d_model` - Embedding/model dimension
520    /// * `nhead` - Number of attention heads
521    /// * `num_encoder_layers` - Number of encoder layers
522    /// * `num_decoder_layers` - Number of decoder layers
523    /// * `dim_feedforward` - Hidden dimension of feedforward networks
524    pub fn new(
525        d_model: usize,
526        nhead: usize,
527        num_encoder_layers: usize,
528        num_decoder_layers: usize,
529        dim_feedforward: usize,
530    ) -> Self {
531        Self {
532            encoder: TransformerEncoder::new(d_model, nhead, dim_feedforward, num_encoder_layers),
533            decoder: TransformerDecoder::new(d_model, nhead, dim_feedforward, num_decoder_layers),
534            d_model,
535            nhead,
536        }
537    }
538
539    /// Creates a Seq2SeqTransformer with default settings (6 layers, 2048 FFN).
540    pub fn default_config(d_model: usize, nhead: usize) -> Self {
541        Self::new(d_model, nhead, 6, 6, 2048)
542    }
543
544    /// Full forward pass: encode source, then decode target conditioned on encoder output.
545    ///
546    /// # Arguments
547    /// * `src` - Source sequence (N, S, E)
548    /// * `tgt` - Target sequence (N, T, E)
549    /// * `src_mask` - Optional mask for encoder self-attention
550    /// * `tgt_mask` - Optional causal mask for decoder self-attention
551    /// * `memory_mask` - Optional mask for decoder cross-attention
552    pub fn forward_seq2seq(
553        &self,
554        src: &Variable,
555        tgt: &Variable,
556        src_mask: Option<&Variable>,
557        tgt_mask: Option<&Variable>,
558        memory_mask: Option<&Variable>,
559    ) -> Variable {
560        let memory = self.encoder.forward_with_mask(src, src_mask);
561        self.decoder.forward_with_memory(tgt, &memory, tgt_mask, memory_mask)
562    }
563
564    /// Encode source sequence only (useful for inference).
565    pub fn encode(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
566        self.encoder.forward_with_mask(src, src_mask)
567    }
568
569    /// Decode target given pre-computed encoder memory (useful for inference).
570    pub fn decode(
571        &self,
572        tgt: &Variable,
573        memory: &Variable,
574        tgt_mask: Option<&Variable>,
575        memory_mask: Option<&Variable>,
576    ) -> Variable {
577        self.decoder.forward_with_memory(tgt, memory, tgt_mask, memory_mask)
578    }
579
580    /// Generates a causal (upper-triangular) mask for autoregressive decoding.
581    ///
582    /// Returns a mask of shape (seq_len, seq_len) where future positions are 0.0
583    /// and valid positions are 1.0.
584    pub fn generate_square_subsequent_mask(seq_len: usize) -> Variable {
585        let mut mask_data = vec![0.0f32; seq_len * seq_len];
586        for i in 0..seq_len {
587            for j in 0..=i {
588                mask_data[i * seq_len + j] = 1.0;
589            }
590        }
591        Variable::new(
592            Tensor::from_vec(mask_data, &[seq_len, seq_len]).unwrap(),
593            false,
594        )
595    }
596
597    /// Returns the model dimension.
598    pub fn d_model(&self) -> usize {
599        self.d_model
600    }
601
602    /// Returns the number of attention heads.
603    pub fn nhead(&self) -> usize {
604        self.nhead
605    }
606
607    /// Returns a reference to the encoder.
608    pub fn encoder(&self) -> &TransformerEncoder {
609        &self.encoder
610    }
611
612    /// Returns a reference to the decoder.
613    pub fn decoder(&self) -> &TransformerDecoder {
614        &self.decoder
615    }
616}
617
618impl Module for Seq2SeqTransformer {
619    fn forward(&self, input: &Variable) -> Variable {
620        // Single-input forward: encode only (use forward_seq2seq for full pipeline)
621        self.encoder.forward(input)
622    }
623
624    fn parameters(&self) -> Vec<Parameter> {
625        let mut params = self.encoder.parameters();
626        params.extend(self.decoder.parameters());
627        params
628    }
629
630    fn named_parameters(&self) -> HashMap<String, Parameter> {
631        let mut params = HashMap::new();
632        for (name, param) in self.encoder.named_parameters() {
633            params.insert(format!("encoder.{name}"), param);
634        }
635        for (name, param) in self.decoder.named_parameters() {
636            params.insert(format!("decoder.{name}"), param);
637        }
638        params
639    }
640
641    fn name(&self) -> &'static str {
642        "Seq2SeqTransformer"
643    }
644}
645
646// =============================================================================
647// Tests
648// =============================================================================
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653
654    #[test]
655    fn test_encoder_layer_creation() {
656        let layer = TransformerEncoderLayer::new(64, 4, 256);
657        assert_eq!(layer.d_model(), 64);
658    }
659
660    #[test]
661    fn test_encoder_layer_forward() {
662        let layer = TransformerEncoderLayer::new(64, 4, 256);
663        let input = Variable::new(
664            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
665            false,
666        );
667        let output = layer.forward(&input);
668        assert_eq!(output.shape(), vec![2, 10, 64]);
669    }
670
671    #[test]
672    fn test_decoder_layer_with_memory() {
673        let layer = TransformerDecoderLayer::new(64, 4, 256);
674        let tgt = Variable::new(
675            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
676            false,
677        );
678        let memory = Variable::new(
679            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
680            false,
681        );
682        let output = layer.forward_with_memory(&tgt, &memory, None, None);
683        assert_eq!(output.shape(), vec![2, 5, 64]);
684    }
685
686    #[test]
687    fn test_encoder_stack() {
688        let encoder = TransformerEncoder::new(64, 4, 256, 3);
689        assert_eq!(encoder.num_layers(), 3);
690
691        let input = Variable::new(
692            Tensor::from_vec(vec![0.1; 2 * 8 * 64], &[2, 8, 64]).unwrap(),
693            false,
694        );
695        let output = encoder.forward(&input);
696        assert_eq!(output.shape(), vec![2, 8, 64]);
697    }
698
699    #[test]
700    fn test_decoder_stack() {
701        let decoder = TransformerDecoder::new(64, 4, 256, 3);
702        assert_eq!(decoder.num_layers(), 3);
703
704        let tgt = Variable::new(
705            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
706            false,
707        );
708        let memory = Variable::new(
709            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
710            false,
711        );
712        let output = decoder.forward_with_memory(&tgt, &memory, None, None);
713        assert_eq!(output.shape(), vec![2, 5, 64]);
714    }
715
716    #[test]
717    fn test_seq2seq_transformer() {
718        let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
719        assert_eq!(transformer.d_model(), 64);
720        assert_eq!(transformer.nhead(), 4);
721
722        let src = Variable::new(
723            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
724            false,
725        );
726        let tgt = Variable::new(
727            Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
728            false,
729        );
730        let output = transformer.forward_seq2seq(&src, &tgt, None, None, None);
731        assert_eq!(output.shape(), vec![2, 5, 64]);
732    }
733
734    #[test]
735    fn test_seq2seq_encode_decode_separate() {
736        let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
737
738        let src = Variable::new(
739            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
740            false,
741        );
742        let tgt = Variable::new(
743            Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
744            false,
745        );
746
747        // Encode once, decode multiple times (autoregressive inference)
748        let memory = transformer.encode(&src, None);
749        assert_eq!(memory.shape(), vec![2, 10, 64]);
750
751        let output = transformer.decode(&tgt, &memory, None, None);
752        assert_eq!(output.shape(), vec![2, 5, 64]);
753    }
754
755    #[test]
756    fn test_causal_mask() {
757        let mask = Seq2SeqTransformer::generate_square_subsequent_mask(4);
758        let mask_data = mask.data().to_vec();
759        // Row 0: [1, 0, 0, 0]
760        // Row 1: [1, 1, 0, 0]
761        // Row 2: [1, 1, 1, 0]
762        // Row 3: [1, 1, 1, 1]
763        assert_eq!(mask_data[0], 1.0); // (0,0) = visible
764        assert_eq!(mask_data[1], 0.0); // (0,1) = masked
765        assert_eq!(mask_data[4], 1.0); // (1,0) = visible
766        assert_eq!(mask_data[5], 1.0); // (1,1) = visible
767        assert_eq!(mask_data[6], 0.0); // (1,2) = masked
768        assert_eq!(mask_data[15], 1.0); // (3,3) = visible
769    }
770
771    #[test]
772    fn test_default_config() {
773        let transformer = Seq2SeqTransformer::default_config(512, 8);
774        assert_eq!(transformer.encoder().num_layers(), 6);
775        assert_eq!(transformer.decoder().num_layers(), 6);
776    }
777
778    #[test]
779    fn test_parameter_count() {
780        let layer = TransformerEncoderLayer::new(64, 4, 256);
781        let params = layer.parameters();
782        // self_attn: 4 projections × (weight + bias) = 8
783        // linear1: weight + bias = 2
784        // linear2: weight + bias = 2
785        // norm1: weight + bias = 2
786        // norm2: weight + bias = 2
787        assert_eq!(params.len(), 16);
788    }
789
790    #[test]
791    fn test_decoder_parameter_count() {
792        let layer = TransformerDecoderLayer::new(64, 4, 256);
793        let params = layer.parameters();
794        // self_attn: 8, cross_attn: 8, linear1: 2, linear2: 2, norm1: 2, norm2: 2, norm3: 2
795        assert_eq!(params.len(), 26);
796    }
797
798    #[test]
799    fn test_named_parameters_hierarchy() {
800        let transformer = Seq2SeqTransformer::new(64, 4, 1, 1, 256);
801        let named = transformer.named_parameters();
802        // Verify hierarchical naming
803        assert!(named.contains_key("encoder.layers.0.self_attn.q_proj.weight"));
804        assert!(named.contains_key("decoder.layers.0.cross_attn.q_proj.weight"));
805        assert!(named.contains_key("encoder.norm.weight"));
806        assert!(named.contains_key("decoder.norm.weight"));
807    }
808
809    #[test]
810    fn test_seq2seq_with_causal_mask() {
811        let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
812        let src = Variable::new(
813            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
814            false,
815        );
816        let tgt = Variable::new(
817            Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
818            false,
819        );
820        let tgt_mask = Seq2SeqTransformer::generate_square_subsequent_mask(5);
821        let output = transformer.forward_seq2seq(&src, &tgt, None, Some(&tgt_mask), None);
822        assert_eq!(output.shape(), vec![2, 5, 64]);
823    }
824}