Skip to main content

axonml_nn/layers/
transformer.rs

1//! Transformer Architecture - Encoder-Decoder Transformer
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/transformer.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::layers::attention::MultiHeadAttention;
23use crate::layers::linear::Linear;
24use crate::layers::norm::LayerNorm;
25use crate::module::Module;
26use crate::parameter::Parameter;
27
28// =============================================================================
29// TransformerEncoderLayer
30// =============================================================================
31
32/// A single Transformer encoder layer.
33///
34/// Consists of multi-head self-attention followed by a position-wise
35/// feedforward network, each with residual connections and layer normalization.
36///
37/// # Shape
38/// - Input: (N, S, E) if batch_first (default)
39/// - Output: (N, S, E)
40///
41/// where N=batch, S=source seq len, E=d_model.
42pub struct TransformerEncoderLayer {
43    /// Self-attention.
44    self_attn: MultiHeadAttention,
45    /// Feedforward network — first linear.
46    linear1: Linear,
47    /// Feedforward network — second linear.
48    linear2: Linear,
49    /// Layer norm after self-attention.
50    norm1: LayerNorm,
51    /// Layer norm after feedforward.
52    norm2: LayerNorm,
53    /// Model dimension.
54    d_model: usize,
55    /// Whether to use pre-norm (norm before sublayer) instead of post-norm.
56    pre_norm: bool,
57}
58
59impl TransformerEncoderLayer {
60    /// Creates a new TransformerEncoderLayer (post-norm, default).
61    ///
62    /// # Arguments
63    /// * `d_model` - Embedding dimension
64    /// * `nhead` - Number of attention heads
65    /// * `dim_feedforward` - Hidden dimension of feedforward network (default 2048)
66    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
67        Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
68    }
69
70    /// Creates a TransformerEncoderLayer with configurable norm ordering.
71    pub fn new_with_pre_norm(
72        d_model: usize,
73        nhead: usize,
74        dim_feedforward: usize,
75        pre_norm: bool,
76    ) -> Self {
77        Self {
78            self_attn: MultiHeadAttention::new(d_model, nhead),
79            linear1: Linear::new(d_model, dim_feedforward),
80            linear2: Linear::new(dim_feedforward, d_model),
81            norm1: LayerNorm::single(d_model),
82            norm2: LayerNorm::single(d_model),
83            d_model,
84            pre_norm,
85        }
86    }
87
88    /// Forward pass with optional source mask.
89    ///
90    /// # Arguments
91    /// * `src` - Source sequence (N, S, E)
92    /// * `src_mask` - Optional attention mask
93    pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
94        if self.pre_norm {
95            let normed = self.norm1.forward(src);
96            let attn_out = self
97                .self_attn
98                .attention(&normed, &normed, &normed, src_mask);
99            let x = src.add_var(&attn_out);
100
101            let normed = self.norm2.forward(&x);
102            let ff_out = self.linear1.forward(&normed).relu();
103            let ff_out = self.linear2.forward(&ff_out);
104            x.add_var(&ff_out)
105        } else {
106            let attn_out = self.self_attn.attention(src, src, src, src_mask);
107            let x = src.add_var(&attn_out);
108            let x = self.norm1.forward(&x);
109
110            let ff_out = self.linear1.forward(&x).relu();
111            let ff_out = self.linear2.forward(&ff_out);
112            let x = x.add_var(&ff_out);
113            self.norm2.forward(&x)
114        }
115    }
116
117    /// Returns the model dimension.
118    pub fn d_model(&self) -> usize {
119        self.d_model
120    }
121}
122
123impl Module for TransformerEncoderLayer {
124    fn forward(&self, input: &Variable) -> Variable {
125        self.forward_with_mask(input, None)
126    }
127
128    fn parameters(&self) -> Vec<Parameter> {
129        let mut params = Vec::new();
130        params.extend(self.self_attn.parameters());
131        params.extend(self.linear1.parameters());
132        params.extend(self.linear2.parameters());
133        params.extend(self.norm1.parameters());
134        params.extend(self.norm2.parameters());
135        params
136    }
137
138    fn named_parameters(&self) -> HashMap<String, Parameter> {
139        let mut params = HashMap::new();
140        for (name, param) in self.self_attn.named_parameters() {
141            params.insert(format!("self_attn.{name}"), param);
142        }
143        for (name, param) in self.linear1.named_parameters() {
144            params.insert(format!("linear1.{name}"), param);
145        }
146        for (name, param) in self.linear2.named_parameters() {
147            params.insert(format!("linear2.{name}"), param);
148        }
149        for (name, param) in self.norm1.named_parameters() {
150            params.insert(format!("norm1.{name}"), param);
151        }
152        for (name, param) in self.norm2.named_parameters() {
153            params.insert(format!("norm2.{name}"), param);
154        }
155        params
156    }
157
158    fn name(&self) -> &'static str {
159        "TransformerEncoderLayer"
160    }
161}
162
163// =============================================================================
164// TransformerDecoderLayer
165// =============================================================================
166
167/// A single Transformer decoder layer.
168///
169/// Consists of:
170/// 1. Masked multi-head self-attention (causal)
171/// 2. Multi-head cross-attention over encoder output
172/// 3. Position-wise feedforward network
173///
174/// Each sublayer has residual connections and layer normalization.
175///
176/// # Shape
177/// - Target: (N, T, E)
178/// - Memory: (N, S, E)
179/// - Output: (N, T, E)
180pub struct TransformerDecoderLayer {
181    /// Masked self-attention (causal).
182    self_attn: MultiHeadAttention,
183    /// Cross-attention over encoder output.
184    cross_attn: MultiHeadAttention,
185    /// Feedforward network — first linear.
186    linear1: Linear,
187    /// Feedforward network — second linear.
188    linear2: Linear,
189    /// Layer norm after self-attention.
190    norm1: LayerNorm,
191    /// Layer norm after cross-attention.
192    norm2: LayerNorm,
193    /// Layer norm after feedforward.
194    norm3: LayerNorm,
195    /// Model dimension.
196    d_model: usize,
197    /// Whether to use pre-norm (norm before sublayer) instead of post-norm.
198    pre_norm: bool,
199}
200
201impl TransformerDecoderLayer {
202    /// Creates a new TransformerDecoderLayer (post-norm, default).
203    ///
204    /// # Arguments
205    /// * `d_model` - Embedding dimension
206    /// * `nhead` - Number of attention heads
207    /// * `dim_feedforward` - Hidden dimension of feedforward network
208    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
209        Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
210    }
211
212    /// Creates a TransformerDecoderLayer with configurable norm ordering.
213    pub fn new_with_pre_norm(
214        d_model: usize,
215        nhead: usize,
216        dim_feedforward: usize,
217        pre_norm: bool,
218    ) -> Self {
219        Self {
220            self_attn: MultiHeadAttention::new(d_model, nhead),
221            cross_attn: MultiHeadAttention::new(d_model, nhead),
222            linear1: Linear::new(d_model, dim_feedforward),
223            linear2: Linear::new(dim_feedforward, d_model),
224            norm1: LayerNorm::single(d_model),
225            norm2: LayerNorm::single(d_model),
226            norm3: LayerNorm::single(d_model),
227            d_model,
228            pre_norm,
229        }
230    }
231
232    /// Forward pass with encoder memory and optional masks.
233    ///
234    /// # Arguments
235    /// * `tgt` - Target sequence (N, T, E)
236    /// * `memory` - Encoder output (N, S, E)
237    /// * `tgt_mask` - Optional causal mask for self-attention
238    /// * `memory_mask` - Optional mask for cross-attention
239    pub fn forward_with_memory(
240        &self,
241        tgt: &Variable,
242        memory: &Variable,
243        tgt_mask: Option<&Variable>,
244        memory_mask: Option<&Variable>,
245    ) -> Variable {
246        if self.pre_norm {
247            // Pre-norm: norm before each sublayer, inside residual branch
248            let normed = self.norm1.forward(tgt);
249            let self_attn_out = self
250                .self_attn
251                .attention(&normed, &normed, &normed, tgt_mask);
252            let x = tgt.add_var(&self_attn_out);
253
254            let normed = self.norm2.forward(&x);
255            let cross_attn_out = self
256                .cross_attn
257                .attention(&normed, memory, memory, memory_mask);
258            let x = x.add_var(&cross_attn_out);
259
260            let normed = self.norm3.forward(&x);
261            let ff_out = self.linear1.forward(&normed).relu();
262            let ff_out = self.linear2.forward(&ff_out);
263            x.add_var(&ff_out)
264        } else {
265            // Post-norm (original)
266            let self_attn_out = self.self_attn.attention(tgt, tgt, tgt, tgt_mask);
267            let x = tgt.add_var(&self_attn_out);
268            let x = self.norm1.forward(&x);
269
270            let cross_attn_out = self.cross_attn.attention(&x, memory, memory, memory_mask);
271            let x = x.add_var(&cross_attn_out);
272            let x = self.norm2.forward(&x);
273
274            let ff_out = self.linear1.forward(&x).relu();
275            let ff_out = self.linear2.forward(&ff_out);
276            let x = x.add_var(&ff_out);
277            self.norm3.forward(&x)
278        }
279    }
280
281    /// Returns the model dimension.
282    pub fn d_model(&self) -> usize {
283        self.d_model
284    }
285}
286
287impl Module for TransformerDecoderLayer {
288    fn forward(&self, input: &Variable) -> Variable {
289        // Without memory, can only do self-attention pass.
290        // Use forward_with_memory() for full decoder behavior.
291        if self.pre_norm {
292            let normed = self.norm1.forward(input);
293            let self_attn_out = self.self_attn.attention(&normed, &normed, &normed, None);
294            let x = input.add_var(&self_attn_out);
295
296            // Skip cross-attention (no memory), go straight to FFN
297            let normed = self.norm3.forward(&x);
298            let ff_out = self.linear1.forward(&normed).relu();
299            let ff_out = self.linear2.forward(&ff_out);
300            x.add_var(&ff_out)
301        } else {
302            let self_attn_out = self.self_attn.attention(input, input, input, None);
303            let x = input.add_var(&self_attn_out);
304            let x = self.norm1.forward(&x);
305
306            let x_after_norm2 = self.norm2.forward(&x);
307            let ff_out = self.linear1.forward(&x_after_norm2).relu();
308            let ff_out = self.linear2.forward(&ff_out);
309            let x = x_after_norm2.add_var(&ff_out);
310            self.norm3.forward(&x)
311        }
312    }
313
314    fn parameters(&self) -> Vec<Parameter> {
315        let mut params = Vec::new();
316        params.extend(self.self_attn.parameters());
317        params.extend(self.cross_attn.parameters());
318        params.extend(self.linear1.parameters());
319        params.extend(self.linear2.parameters());
320        params.extend(self.norm1.parameters());
321        params.extend(self.norm2.parameters());
322        params.extend(self.norm3.parameters());
323        params
324    }
325
326    fn named_parameters(&self) -> HashMap<String, Parameter> {
327        let mut params = HashMap::new();
328        for (name, param) in self.self_attn.named_parameters() {
329            params.insert(format!("self_attn.{name}"), param);
330        }
331        for (name, param) in self.cross_attn.named_parameters() {
332            params.insert(format!("cross_attn.{name}"), param);
333        }
334        for (name, param) in self.linear1.named_parameters() {
335            params.insert(format!("linear1.{name}"), param);
336        }
337        for (name, param) in self.linear2.named_parameters() {
338            params.insert(format!("linear2.{name}"), param);
339        }
340        for (name, param) in self.norm1.named_parameters() {
341            params.insert(format!("norm1.{name}"), param);
342        }
343        for (name, param) in self.norm2.named_parameters() {
344            params.insert(format!("norm2.{name}"), param);
345        }
346        for (name, param) in self.norm3.named_parameters() {
347            params.insert(format!("norm3.{name}"), param);
348        }
349        params
350    }
351
352    fn name(&self) -> &'static str {
353        "TransformerDecoderLayer"
354    }
355}
356
357// =============================================================================
358// TransformerEncoder
359// =============================================================================
360
361/// Stack of N TransformerEncoderLayers.
362///
363/// # Shape
364/// - Input: (N, S, E)
365/// - Output: (N, S, E)
366pub struct TransformerEncoder {
367    /// Encoder layers.
368    layers: Vec<TransformerEncoderLayer>,
369    /// Optional final layer norm.
370    norm: Option<LayerNorm>,
371}
372
373impl TransformerEncoder {
374    /// Creates a TransformerEncoder with the given number of layers (post-norm).
375    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
376        Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
377    }
378
379    /// Creates a TransformerEncoder with configurable norm ordering.
380    ///
381    /// With pre-norm, a final LayerNorm is always added after the last layer
382    /// (required since pre-norm layers don't normalize their output).
383    pub fn new_with_pre_norm(
384        d_model: usize,
385        nhead: usize,
386        dim_feedforward: usize,
387        num_layers: usize,
388        pre_norm: bool,
389    ) -> Self {
390        let layers = (0..num_layers)
391            .map(|_| {
392                TransformerEncoderLayer::new_with_pre_norm(
393                    d_model,
394                    nhead,
395                    dim_feedforward,
396                    pre_norm,
397                )
398            })
399            .collect();
400
401        Self {
402            layers,
403            norm: Some(LayerNorm::single(d_model)),
404        }
405    }
406
407    /// Creates a TransformerEncoder without final layer norm.
408    pub fn without_norm(
409        d_model: usize,
410        nhead: usize,
411        dim_feedforward: usize,
412        num_layers: usize,
413    ) -> Self {
414        let layers = (0..num_layers)
415            .map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward))
416            .collect();
417
418        Self { layers, norm: None }
419    }
420
421    /// Forward pass with optional source mask.
422    pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
423        let mut x = src.clone();
424        for layer in &self.layers {
425            x = layer.forward_with_mask(&x, src_mask);
426        }
427        if let Some(ref norm) = self.norm {
428            x = norm.forward(&x);
429        }
430        x
431    }
432
433    /// Returns the number of layers.
434    pub fn num_layers(&self) -> usize {
435        self.layers.len()
436    }
437}
438
439impl Module for TransformerEncoder {
440    fn forward(&self, input: &Variable) -> Variable {
441        self.forward_with_mask(input, None)
442    }
443
444    fn parameters(&self) -> Vec<Parameter> {
445        let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
446        if let Some(ref norm) = self.norm {
447            params.extend(norm.parameters());
448        }
449        params
450    }
451
452    fn named_parameters(&self) -> HashMap<String, Parameter> {
453        let mut params = HashMap::new();
454        for (i, layer) in self.layers.iter().enumerate() {
455            for (name, param) in layer.named_parameters() {
456                params.insert(format!("layers.{i}.{name}"), param);
457            }
458        }
459        if let Some(ref norm) = self.norm {
460            for (name, param) in norm.named_parameters() {
461                params.insert(format!("norm.{name}"), param);
462            }
463        }
464        params
465    }
466
467    fn name(&self) -> &'static str {
468        "TransformerEncoder"
469    }
470}
471
472// =============================================================================
473// TransformerDecoder
474// =============================================================================
475
476/// Stack of N TransformerDecoderLayers.
477///
478/// # Shape
479/// - Target: (N, T, E)
480/// - Memory: (N, S, E)
481/// - Output: (N, T, E)
482pub struct TransformerDecoder {
483    /// Decoder layers.
484    layers: Vec<TransformerDecoderLayer>,
485    /// Optional final layer norm.
486    norm: Option<LayerNorm>,
487}
488
489impl TransformerDecoder {
490    /// Creates a TransformerDecoder with the given number of layers (post-norm).
491    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
492        Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
493    }
494
495    /// Creates a TransformerDecoder with configurable norm ordering.
496    pub fn new_with_pre_norm(
497        d_model: usize,
498        nhead: usize,
499        dim_feedforward: usize,
500        num_layers: usize,
501        pre_norm: bool,
502    ) -> Self {
503        let layers = (0..num_layers)
504            .map(|_| {
505                TransformerDecoderLayer::new_with_pre_norm(
506                    d_model,
507                    nhead,
508                    dim_feedforward,
509                    pre_norm,
510                )
511            })
512            .collect();
513
514        Self {
515            layers,
516            norm: Some(LayerNorm::single(d_model)),
517        }
518    }
519
520    /// Creates a TransformerDecoder without final layer norm.
521    pub fn without_norm(
522        d_model: usize,
523        nhead: usize,
524        dim_feedforward: usize,
525        num_layers: usize,
526    ) -> Self {
527        let layers = (0..num_layers)
528            .map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward))
529            .collect();
530
531        Self { layers, norm: None }
532    }
533
534    /// Forward pass with encoder memory and optional masks.
535    pub fn forward_with_memory(
536        &self,
537        tgt: &Variable,
538        memory: &Variable,
539        tgt_mask: Option<&Variable>,
540        memory_mask: Option<&Variable>,
541    ) -> Variable {
542        let mut x = tgt.clone();
543        for layer in &self.layers {
544            x = layer.forward_with_memory(&x, memory, tgt_mask, memory_mask);
545        }
546        if let Some(ref norm) = self.norm {
547            x = norm.forward(&x);
548        }
549        x
550    }
551
552    /// Returns the number of layers.
553    pub fn num_layers(&self) -> usize {
554        self.layers.len()
555    }
556}
557
558impl Module for TransformerDecoder {
559    fn forward(&self, input: &Variable) -> Variable {
560        // Without memory, runs self-attention only (for pretraining/testing)
561        let mut x = input.clone();
562        for layer in &self.layers {
563            x = layer.forward(&x);
564        }
565        if let Some(ref norm) = self.norm {
566            x = norm.forward(&x);
567        }
568        x
569    }
570
571    fn parameters(&self) -> Vec<Parameter> {
572        let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
573        if let Some(ref norm) = self.norm {
574            params.extend(norm.parameters());
575        }
576        params
577    }
578
579    fn named_parameters(&self) -> HashMap<String, Parameter> {
580        let mut params = HashMap::new();
581        for (i, layer) in self.layers.iter().enumerate() {
582            for (name, param) in layer.named_parameters() {
583                params.insert(format!("layers.{i}.{name}"), param);
584            }
585        }
586        if let Some(ref norm) = self.norm {
587            for (name, param) in norm.named_parameters() {
588                params.insert(format!("norm.{name}"), param);
589            }
590        }
591        params
592    }
593
594    fn name(&self) -> &'static str {
595        "TransformerDecoder"
596    }
597}
598
599// =============================================================================
600// Seq2SeqTransformer
601// =============================================================================
602
603/// Full Encoder-Decoder Transformer for sequence-to-sequence tasks.
604///
605/// Combines a TransformerEncoder and TransformerDecoder into a single module.
606/// Follows PyTorch's `nn.Transformer` API.
607///
608/// # Architecture
609/// ```text
610/// Source → [Encoder] → Memory
611///                         ↓
612/// Target → [Decoder] → Output
613/// ```
614///
615/// # Shape
616/// - Source: (N, S, E)
617/// - Target: (N, T, E)
618/// - Output: (N, T, E)
619pub struct Seq2SeqTransformer {
620    /// Encoder stack.
621    encoder: TransformerEncoder,
622    /// Decoder stack.
623    decoder: TransformerDecoder,
624    /// Model dimension.
625    d_model: usize,
626    /// Number of attention heads.
627    nhead: usize,
628}
629
630impl Seq2SeqTransformer {
631    /// Creates a new Seq2SeqTransformer.
632    ///
633    /// # Arguments
634    /// * `d_model` - Embedding/model dimension
635    /// * `nhead` - Number of attention heads
636    /// * `num_encoder_layers` - Number of encoder layers
637    /// * `num_decoder_layers` - Number of decoder layers
638    /// * `dim_feedforward` - Hidden dimension of feedforward networks
639    pub fn new(
640        d_model: usize,
641        nhead: usize,
642        num_encoder_layers: usize,
643        num_decoder_layers: usize,
644        dim_feedforward: usize,
645    ) -> Self {
646        Self {
647            encoder: TransformerEncoder::new(d_model, nhead, dim_feedforward, num_encoder_layers),
648            decoder: TransformerDecoder::new(d_model, nhead, dim_feedforward, num_decoder_layers),
649            d_model,
650            nhead,
651        }
652    }
653
654    /// Creates a Seq2SeqTransformer with pre-norm ordering.
655    ///
656    /// Pre-norm applies LayerNorm before each sublayer (inside the residual
657    /// branch), which gives better gradient flow for small datasets.
658    pub fn new_pre_norm(
659        d_model: usize,
660        nhead: usize,
661        num_encoder_layers: usize,
662        num_decoder_layers: usize,
663        dim_feedforward: usize,
664    ) -> Self {
665        Self {
666            encoder: TransformerEncoder::new_with_pre_norm(
667                d_model,
668                nhead,
669                dim_feedforward,
670                num_encoder_layers,
671                true,
672            ),
673            decoder: TransformerDecoder::new_with_pre_norm(
674                d_model,
675                nhead,
676                dim_feedforward,
677                num_decoder_layers,
678                true,
679            ),
680            d_model,
681            nhead,
682        }
683    }
684
685    /// Creates a Seq2SeqTransformer with default settings (6 layers, 2048 FFN).
686    pub fn default_config(d_model: usize, nhead: usize) -> Self {
687        Self::new(d_model, nhead, 6, 6, 2048)
688    }
689
690    /// Full forward pass: encode source, then decode target conditioned on encoder output.
691    ///
692    /// # Arguments
693    /// * `src` - Source sequence (N, S, E)
694    /// * `tgt` - Target sequence (N, T, E)
695    /// * `src_mask` - Optional mask for encoder self-attention
696    /// * `tgt_mask` - Optional causal mask for decoder self-attention
697    /// * `memory_mask` - Optional mask for decoder cross-attention
698    pub fn forward_seq2seq(
699        &self,
700        src: &Variable,
701        tgt: &Variable,
702        src_mask: Option<&Variable>,
703        tgt_mask: Option<&Variable>,
704        memory_mask: Option<&Variable>,
705    ) -> Variable {
706        let memory = self.encoder.forward_with_mask(src, src_mask);
707        self.decoder
708            .forward_with_memory(tgt, &memory, tgt_mask, memory_mask)
709    }
710
711    /// Encode source sequence only (useful for inference).
712    pub fn encode(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
713        self.encoder.forward_with_mask(src, src_mask)
714    }
715
716    /// Decode target given pre-computed encoder memory (useful for inference).
717    pub fn decode(
718        &self,
719        tgt: &Variable,
720        memory: &Variable,
721        tgt_mask: Option<&Variable>,
722        memory_mask: Option<&Variable>,
723    ) -> Variable {
724        self.decoder
725            .forward_with_memory(tgt, memory, tgt_mask, memory_mask)
726    }
727
728    /// Generates a causal (upper-triangular) mask for autoregressive decoding.
729    ///
730    /// Returns a mask of shape (seq_len, seq_len) where future positions are 0.0
731    /// and valid positions are 1.0.
732    pub fn generate_square_subsequent_mask(seq_len: usize) -> Variable {
733        let mut mask_data = vec![0.0f32; seq_len * seq_len];
734        for i in 0..seq_len {
735            for j in 0..=i {
736                mask_data[i * seq_len + j] = 1.0;
737            }
738        }
739        Variable::new(
740            Tensor::from_vec(mask_data, &[seq_len, seq_len]).unwrap(),
741            false,
742        )
743    }
744
745    /// Returns the model dimension.
746    pub fn d_model(&self) -> usize {
747        self.d_model
748    }
749
750    /// Returns the number of attention heads.
751    pub fn nhead(&self) -> usize {
752        self.nhead
753    }
754
755    /// Returns a reference to the encoder.
756    pub fn encoder(&self) -> &TransformerEncoder {
757        &self.encoder
758    }
759
760    /// Returns a reference to the decoder.
761    pub fn decoder(&self) -> &TransformerDecoder {
762        &self.decoder
763    }
764}
765
766impl Module for Seq2SeqTransformer {
767    fn forward(&self, input: &Variable) -> Variable {
768        // Single-input forward: encode only (use forward_seq2seq for full pipeline)
769        self.encoder.forward(input)
770    }
771
772    fn parameters(&self) -> Vec<Parameter> {
773        let mut params = self.encoder.parameters();
774        params.extend(self.decoder.parameters());
775        params
776    }
777
778    fn named_parameters(&self) -> HashMap<String, Parameter> {
779        let mut params = HashMap::new();
780        for (name, param) in self.encoder.named_parameters() {
781            params.insert(format!("encoder.{name}"), param);
782        }
783        for (name, param) in self.decoder.named_parameters() {
784            params.insert(format!("decoder.{name}"), param);
785        }
786        params
787    }
788
789    fn name(&self) -> &'static str {
790        "Seq2SeqTransformer"
791    }
792}
793
794// =============================================================================
795// Tests
796// =============================================================================
797
798#[cfg(test)]
799mod tests {
800    use super::*;
801
802    #[test]
803    fn test_encoder_layer_creation() {
804        let layer = TransformerEncoderLayer::new(64, 4, 256);
805        assert_eq!(layer.d_model(), 64);
806    }
807
808    #[test]
809    fn test_encoder_layer_forward() {
810        let layer = TransformerEncoderLayer::new(64, 4, 256);
811        let input = Variable::new(
812            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
813            false,
814        );
815        let output = layer.forward(&input);
816        assert_eq!(output.shape(), vec![2, 10, 64]);
817    }
818
819    #[test]
820    fn test_decoder_layer_with_memory() {
821        let layer = TransformerDecoderLayer::new(64, 4, 256);
822        let tgt = Variable::new(
823            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
824            false,
825        );
826        let memory = Variable::new(
827            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
828            false,
829        );
830        let output = layer.forward_with_memory(&tgt, &memory, None, None);
831        assert_eq!(output.shape(), vec![2, 5, 64]);
832    }
833
834    #[test]
835    fn test_encoder_stack() {
836        let encoder = TransformerEncoder::new(64, 4, 256, 3);
837        assert_eq!(encoder.num_layers(), 3);
838
839        let input = Variable::new(
840            Tensor::from_vec(vec![0.1; 2 * 8 * 64], &[2, 8, 64]).unwrap(),
841            false,
842        );
843        let output = encoder.forward(&input);
844        assert_eq!(output.shape(), vec![2, 8, 64]);
845    }
846
847    #[test]
848    fn test_decoder_stack() {
849        let decoder = TransformerDecoder::new(64, 4, 256, 3);
850        assert_eq!(decoder.num_layers(), 3);
851
852        let tgt = Variable::new(
853            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
854            false,
855        );
856        let memory = Variable::new(
857            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
858            false,
859        );
860        let output = decoder.forward_with_memory(&tgt, &memory, None, None);
861        assert_eq!(output.shape(), vec![2, 5, 64]);
862    }
863
864    #[test]
865    fn test_seq2seq_transformer() {
866        let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
867        assert_eq!(transformer.d_model(), 64);
868        assert_eq!(transformer.nhead(), 4);
869
870        let src = Variable::new(
871            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
872            false,
873        );
874        let tgt = Variable::new(
875            Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
876            false,
877        );
878        let output = transformer.forward_seq2seq(&src, &tgt, None, None, None);
879        assert_eq!(output.shape(), vec![2, 5, 64]);
880    }
881
882    #[test]
883    fn test_seq2seq_encode_decode_separate() {
884        let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
885
886        let src = Variable::new(
887            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
888            false,
889        );
890        let tgt = Variable::new(
891            Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
892            false,
893        );
894
895        // Encode once, decode multiple times (autoregressive inference)
896        let memory = transformer.encode(&src, None);
897        assert_eq!(memory.shape(), vec![2, 10, 64]);
898
899        let output = transformer.decode(&tgt, &memory, None, None);
900        assert_eq!(output.shape(), vec![2, 5, 64]);
901    }
902
903    #[test]
904    fn test_causal_mask() {
905        let mask = Seq2SeqTransformer::generate_square_subsequent_mask(4);
906        let mask_data = mask.data().to_vec();
907        // Row 0: [1, 0, 0, 0]
908        // Row 1: [1, 1, 0, 0]
909        // Row 2: [1, 1, 1, 0]
910        // Row 3: [1, 1, 1, 1]
911        assert_eq!(mask_data[0], 1.0); // (0,0) = visible
912        assert_eq!(mask_data[1], 0.0); // (0,1) = masked
913        assert_eq!(mask_data[4], 1.0); // (1,0) = visible
914        assert_eq!(mask_data[5], 1.0); // (1,1) = visible
915        assert_eq!(mask_data[6], 0.0); // (1,2) = masked
916        assert_eq!(mask_data[15], 1.0); // (3,3) = visible
917    }
918
919    #[test]
920    fn test_default_config() {
921        let transformer = Seq2SeqTransformer::default_config(512, 8);
922        assert_eq!(transformer.encoder().num_layers(), 6);
923        assert_eq!(transformer.decoder().num_layers(), 6);
924    }
925
926    #[test]
927    fn test_parameter_count() {
928        let layer = TransformerEncoderLayer::new(64, 4, 256);
929        let params = layer.parameters();
930        // self_attn: 4 projections × (weight + bias) = 8
931        // linear1: weight + bias = 2
932        // linear2: weight + bias = 2
933        // norm1: weight + bias = 2
934        // norm2: weight + bias = 2
935        assert_eq!(params.len(), 16);
936    }
937
938    #[test]
939    fn test_decoder_parameter_count() {
940        let layer = TransformerDecoderLayer::new(64, 4, 256);
941        let params = layer.parameters();
942        // self_attn: 8, cross_attn: 8, linear1: 2, linear2: 2, norm1: 2, norm2: 2, norm3: 2
943        assert_eq!(params.len(), 26);
944    }
945
946    #[test]
947    fn test_named_parameters_hierarchy() {
948        let transformer = Seq2SeqTransformer::new(64, 4, 1, 1, 256);
949        let named = transformer.named_parameters();
950        // Verify hierarchical naming
951        assert!(named.contains_key("encoder.layers.0.self_attn.q_proj.weight"));
952        assert!(named.contains_key("decoder.layers.0.cross_attn.q_proj.weight"));
953        assert!(named.contains_key("encoder.norm.weight"));
954        assert!(named.contains_key("decoder.norm.weight"));
955    }
956
957    #[test]
958    fn test_seq2seq_with_causal_mask() {
959        let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
960        let src = Variable::new(
961            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
962            false,
963        );
964        let tgt = Variable::new(
965            Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
966            false,
967        );
968        let tgt_mask = Seq2SeqTransformer::generate_square_subsequent_mask(5);
969        let output = transformer.forward_seq2seq(&src, &tgt, None, Some(&tgt_mask), None);
970        assert_eq!(output.shape(), vec![2, 5, 64]);
971    }
972}