Skip to main content

axonml_nn/layers/
transformer.rs

1//! Transformer encoder/decoder stacks and the full `Seq2SeqTransformer`.
2//!
3//! 975 lines. `TransformerEncoderLayer` (self-attention + FFN + LayerNorm +
4//! residual), `TransformerDecoderLayer` (self-attention + cross-attention +
5//! FFN + LayerNorm + residual), `TransformerEncoder` (N encoder layers),
6//! `TransformerDecoder` (N decoder layers), `Seq2SeqTransformer` (source
7//! embedding + target embedding + encoder + decoder + output projection,
8//! with causal masking for the decoder). All implement `Module`.
9//!
10//! # File
11//! `crates/axonml-nn/src/layers/transformer.rs`
12//!
13//! # Author
14//! Andrew Jewell Sr. — AutomataNexus LLC
15//! ORCID: 0009-0005-2158-7060
16//!
17//! # Updated
18//! April 14, 2026 11:15 PM EST
19//!
20//! # Disclaimer
21//! Use at own risk. This software is provided "as is", without warranty of any
22//! kind, express or implied. The author and AutomataNexus shall not be held
23//! liable for any damages arising from the use of this software.
24
25use std::collections::HashMap;
26
27use axonml_autograd::Variable;
28use axonml_tensor::Tensor;
29
30use crate::layers::attention::MultiHeadAttention;
31use crate::layers::linear::Linear;
32use crate::layers::norm::LayerNorm;
33use crate::module::Module;
34use crate::parameter::Parameter;
35
36// =============================================================================
37// TransformerEncoderLayer
38// =============================================================================
39
40/// A single Transformer encoder layer.
41///
42/// Consists of multi-head self-attention followed by a position-wise
43/// feedforward network, each with residual connections and layer normalization.
44///
45/// # Shape
46/// - Input: (N, S, E) if batch_first (default)
47/// - Output: (N, S, E)
48///
49/// where N=batch, S=source seq len, E=d_model.
50pub struct TransformerEncoderLayer {
51    /// Self-attention.
52    self_attn: MultiHeadAttention,
53    /// Feedforward network — first linear.
54    linear1: Linear,
55    /// Feedforward network — second linear.
56    linear2: Linear,
57    /// Layer norm after self-attention.
58    norm1: LayerNorm,
59    /// Layer norm after feedforward.
60    norm2: LayerNorm,
61    /// Model dimension.
62    d_model: usize,
63    /// Whether to use pre-norm (norm before sublayer) instead of post-norm.
64    pre_norm: bool,
65}
66
67impl TransformerEncoderLayer {
68    /// Creates a new TransformerEncoderLayer (post-norm, default).
69    ///
70    /// # Arguments
71    /// * `d_model` - Embedding dimension
72    /// * `nhead` - Number of attention heads
73    /// * `dim_feedforward` - Hidden dimension of feedforward network (default 2048)
74    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
75        Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
76    }
77
78    /// Creates a TransformerEncoderLayer with configurable norm ordering.
79    pub fn new_with_pre_norm(
80        d_model: usize,
81        nhead: usize,
82        dim_feedforward: usize,
83        pre_norm: bool,
84    ) -> Self {
85        Self {
86            self_attn: MultiHeadAttention::new(d_model, nhead),
87            linear1: Linear::new(d_model, dim_feedforward),
88            linear2: Linear::new(dim_feedforward, d_model),
89            norm1: LayerNorm::single(d_model),
90            norm2: LayerNorm::single(d_model),
91            d_model,
92            pre_norm,
93        }
94    }
95
96    /// Forward pass with optional source mask.
97    ///
98    /// # Arguments
99    /// * `src` - Source sequence (N, S, E)
100    /// * `src_mask` - Optional attention mask
101    pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
102        if self.pre_norm {
103            let normed = self.norm1.forward(src);
104            let attn_out = self
105                .self_attn
106                .attention(&normed, &normed, &normed, src_mask);
107            let x = src.add_var(&attn_out);
108
109            let normed = self.norm2.forward(&x);
110            let ff_out = self.linear1.forward(&normed).relu();
111            let ff_out = self.linear2.forward(&ff_out);
112            x.add_var(&ff_out)
113        } else {
114            let attn_out = self.self_attn.attention(src, src, src, src_mask);
115            let x = src.add_var(&attn_out);
116            let x = self.norm1.forward(&x);
117
118            let ff_out = self.linear1.forward(&x).relu();
119            let ff_out = self.linear2.forward(&ff_out);
120            let x = x.add_var(&ff_out);
121            self.norm2.forward(&x)
122        }
123    }
124
125    /// Returns the model dimension.
126    pub fn d_model(&self) -> usize {
127        self.d_model
128    }
129}
130
131impl Module for TransformerEncoderLayer {
132    fn forward(&self, input: &Variable) -> Variable {
133        self.forward_with_mask(input, None)
134    }
135
136    fn parameters(&self) -> Vec<Parameter> {
137        let mut params = Vec::new();
138        params.extend(self.self_attn.parameters());
139        params.extend(self.linear1.parameters());
140        params.extend(self.linear2.parameters());
141        params.extend(self.norm1.parameters());
142        params.extend(self.norm2.parameters());
143        params
144    }
145
146    fn named_parameters(&self) -> HashMap<String, Parameter> {
147        let mut params = HashMap::new();
148        for (name, param) in self.self_attn.named_parameters() {
149            params.insert(format!("self_attn.{name}"), param);
150        }
151        for (name, param) in self.linear1.named_parameters() {
152            params.insert(format!("linear1.{name}"), param);
153        }
154        for (name, param) in self.linear2.named_parameters() {
155            params.insert(format!("linear2.{name}"), param);
156        }
157        for (name, param) in self.norm1.named_parameters() {
158            params.insert(format!("norm1.{name}"), param);
159        }
160        for (name, param) in self.norm2.named_parameters() {
161            params.insert(format!("norm2.{name}"), param);
162        }
163        params
164    }
165
166    fn name(&self) -> &'static str {
167        "TransformerEncoderLayer"
168    }
169}
170
171// =============================================================================
172// TransformerDecoderLayer
173// =============================================================================
174
175/// A single Transformer decoder layer.
176///
177/// Consists of:
178/// 1. Masked multi-head self-attention (causal)
179/// 2. Multi-head cross-attention over encoder output
180/// 3. Position-wise feedforward network
181///
182/// Each sublayer has residual connections and layer normalization.
183///
184/// # Shape
185/// - Target: (N, T, E)
186/// - Memory: (N, S, E)
187/// - Output: (N, T, E)
188pub struct TransformerDecoderLayer {
189    /// Masked self-attention (causal).
190    self_attn: MultiHeadAttention,
191    /// Cross-attention over encoder output.
192    cross_attn: MultiHeadAttention,
193    /// Feedforward network — first linear.
194    linear1: Linear,
195    /// Feedforward network — second linear.
196    linear2: Linear,
197    /// Layer norm after self-attention.
198    norm1: LayerNorm,
199    /// Layer norm after cross-attention.
200    norm2: LayerNorm,
201    /// Layer norm after feedforward.
202    norm3: LayerNorm,
203    /// Model dimension.
204    d_model: usize,
205    /// Whether to use pre-norm (norm before sublayer) instead of post-norm.
206    pre_norm: bool,
207}
208
209impl TransformerDecoderLayer {
210    /// Creates a new TransformerDecoderLayer (post-norm, default).
211    ///
212    /// # Arguments
213    /// * `d_model` - Embedding dimension
214    /// * `nhead` - Number of attention heads
215    /// * `dim_feedforward` - Hidden dimension of feedforward network
216    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
217        Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
218    }
219
220    /// Creates a TransformerDecoderLayer with configurable norm ordering.
221    pub fn new_with_pre_norm(
222        d_model: usize,
223        nhead: usize,
224        dim_feedforward: usize,
225        pre_norm: bool,
226    ) -> Self {
227        Self {
228            self_attn: MultiHeadAttention::new(d_model, nhead),
229            cross_attn: MultiHeadAttention::new(d_model, nhead),
230            linear1: Linear::new(d_model, dim_feedforward),
231            linear2: Linear::new(dim_feedforward, d_model),
232            norm1: LayerNorm::single(d_model),
233            norm2: LayerNorm::single(d_model),
234            norm3: LayerNorm::single(d_model),
235            d_model,
236            pre_norm,
237        }
238    }
239
240    /// Forward pass with encoder memory and optional masks.
241    ///
242    /// # Arguments
243    /// * `tgt` - Target sequence (N, T, E)
244    /// * `memory` - Encoder output (N, S, E)
245    /// * `tgt_mask` - Optional causal mask for self-attention
246    /// * `memory_mask` - Optional mask for cross-attention
247    pub fn forward_with_memory(
248        &self,
249        tgt: &Variable,
250        memory: &Variable,
251        tgt_mask: Option<&Variable>,
252        memory_mask: Option<&Variable>,
253    ) -> Variable {
254        if self.pre_norm {
255            // Pre-norm: norm before each sublayer, inside residual branch
256            let normed = self.norm1.forward(tgt);
257            let self_attn_out = self
258                .self_attn
259                .attention(&normed, &normed, &normed, tgt_mask);
260            let x = tgt.add_var(&self_attn_out);
261
262            let normed = self.norm2.forward(&x);
263            let cross_attn_out = self
264                .cross_attn
265                .attention(&normed, memory, memory, memory_mask);
266            let x = x.add_var(&cross_attn_out);
267
268            let normed = self.norm3.forward(&x);
269            let ff_out = self.linear1.forward(&normed).relu();
270            let ff_out = self.linear2.forward(&ff_out);
271            x.add_var(&ff_out)
272        } else {
273            // Post-norm (original)
274            let self_attn_out = self.self_attn.attention(tgt, tgt, tgt, tgt_mask);
275            let x = tgt.add_var(&self_attn_out);
276            let x = self.norm1.forward(&x);
277
278            let cross_attn_out = self.cross_attn.attention(&x, memory, memory, memory_mask);
279            let x = x.add_var(&cross_attn_out);
280            let x = self.norm2.forward(&x);
281
282            let ff_out = self.linear1.forward(&x).relu();
283            let ff_out = self.linear2.forward(&ff_out);
284            let x = x.add_var(&ff_out);
285            self.norm3.forward(&x)
286        }
287    }
288
289    /// Returns the model dimension.
290    pub fn d_model(&self) -> usize {
291        self.d_model
292    }
293}
294
295impl Module for TransformerDecoderLayer {
296    fn forward(&self, input: &Variable) -> Variable {
297        // Without memory, can only do self-attention pass.
298        // Use forward_with_memory() for full decoder behavior.
299        if self.pre_norm {
300            let normed = self.norm1.forward(input);
301            let self_attn_out = self.self_attn.attention(&normed, &normed, &normed, None);
302            let x = input.add_var(&self_attn_out);
303
304            // Skip cross-attention (no memory), go straight to FFN
305            let normed = self.norm3.forward(&x);
306            let ff_out = self.linear1.forward(&normed).relu();
307            let ff_out = self.linear2.forward(&ff_out);
308            x.add_var(&ff_out)
309        } else {
310            let self_attn_out = self.self_attn.attention(input, input, input, None);
311            let x = input.add_var(&self_attn_out);
312            let x = self.norm1.forward(&x);
313
314            let x_after_norm2 = self.norm2.forward(&x);
315            let ff_out = self.linear1.forward(&x_after_norm2).relu();
316            let ff_out = self.linear2.forward(&ff_out);
317            let x = x_after_norm2.add_var(&ff_out);
318            self.norm3.forward(&x)
319        }
320    }
321
322    fn parameters(&self) -> Vec<Parameter> {
323        let mut params = Vec::new();
324        params.extend(self.self_attn.parameters());
325        params.extend(self.cross_attn.parameters());
326        params.extend(self.linear1.parameters());
327        params.extend(self.linear2.parameters());
328        params.extend(self.norm1.parameters());
329        params.extend(self.norm2.parameters());
330        params.extend(self.norm3.parameters());
331        params
332    }
333
334    fn named_parameters(&self) -> HashMap<String, Parameter> {
335        let mut params = HashMap::new();
336        for (name, param) in self.self_attn.named_parameters() {
337            params.insert(format!("self_attn.{name}"), param);
338        }
339        for (name, param) in self.cross_attn.named_parameters() {
340            params.insert(format!("cross_attn.{name}"), param);
341        }
342        for (name, param) in self.linear1.named_parameters() {
343            params.insert(format!("linear1.{name}"), param);
344        }
345        for (name, param) in self.linear2.named_parameters() {
346            params.insert(format!("linear2.{name}"), param);
347        }
348        for (name, param) in self.norm1.named_parameters() {
349            params.insert(format!("norm1.{name}"), param);
350        }
351        for (name, param) in self.norm2.named_parameters() {
352            params.insert(format!("norm2.{name}"), param);
353        }
354        for (name, param) in self.norm3.named_parameters() {
355            params.insert(format!("norm3.{name}"), param);
356        }
357        params
358    }
359
360    fn name(&self) -> &'static str {
361        "TransformerDecoderLayer"
362    }
363}
364
365// =============================================================================
366// TransformerEncoder
367// =============================================================================
368
369/// Stack of N TransformerEncoderLayers.
370///
371/// # Shape
372/// - Input: (N, S, E)
373/// - Output: (N, S, E)
374pub struct TransformerEncoder {
375    /// Encoder layers.
376    layers: Vec<TransformerEncoderLayer>,
377    /// Optional final layer norm.
378    norm: Option<LayerNorm>,
379}
380
381impl TransformerEncoder {
382    /// Creates a TransformerEncoder with the given number of layers (post-norm).
383    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
384        Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
385    }
386
387    /// Creates a TransformerEncoder with configurable norm ordering.
388    ///
389    /// With pre-norm, a final LayerNorm is always added after the last layer
390    /// (required since pre-norm layers don't normalize their output).
391    pub fn new_with_pre_norm(
392        d_model: usize,
393        nhead: usize,
394        dim_feedforward: usize,
395        num_layers: usize,
396        pre_norm: bool,
397    ) -> Self {
398        let layers = (0..num_layers)
399            .map(|_| {
400                TransformerEncoderLayer::new_with_pre_norm(
401                    d_model,
402                    nhead,
403                    dim_feedforward,
404                    pre_norm,
405                )
406            })
407            .collect();
408
409        Self {
410            layers,
411            norm: Some(LayerNorm::single(d_model)),
412        }
413    }
414
415    /// Creates a TransformerEncoder without final layer norm.
416    pub fn without_norm(
417        d_model: usize,
418        nhead: usize,
419        dim_feedforward: usize,
420        num_layers: usize,
421    ) -> Self {
422        let layers = (0..num_layers)
423            .map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward))
424            .collect();
425
426        Self { layers, norm: None }
427    }
428
429    /// Forward pass with optional source mask.
430    pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
431        let mut x = src.clone();
432        for layer in &self.layers {
433            x = layer.forward_with_mask(&x, src_mask);
434        }
435        if let Some(ref norm) = self.norm {
436            x = norm.forward(&x);
437        }
438        x
439    }
440
441    /// Returns the number of layers.
442    pub fn num_layers(&self) -> usize {
443        self.layers.len()
444    }
445}
446
447impl Module for TransformerEncoder {
448    fn forward(&self, input: &Variable) -> Variable {
449        self.forward_with_mask(input, None)
450    }
451
452    fn parameters(&self) -> Vec<Parameter> {
453        let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
454        if let Some(ref norm) = self.norm {
455            params.extend(norm.parameters());
456        }
457        params
458    }
459
460    fn named_parameters(&self) -> HashMap<String, Parameter> {
461        let mut params = HashMap::new();
462        for (i, layer) in self.layers.iter().enumerate() {
463            for (name, param) in layer.named_parameters() {
464                params.insert(format!("layers.{i}.{name}"), param);
465            }
466        }
467        if let Some(ref norm) = self.norm {
468            for (name, param) in norm.named_parameters() {
469                params.insert(format!("norm.{name}"), param);
470            }
471        }
472        params
473    }
474
475    fn name(&self) -> &'static str {
476        "TransformerEncoder"
477    }
478}
479
480// =============================================================================
481// TransformerDecoder
482// =============================================================================
483
484/// Stack of N TransformerDecoderLayers.
485///
486/// # Shape
487/// - Target: (N, T, E)
488/// - Memory: (N, S, E)
489/// - Output: (N, T, E)
490pub struct TransformerDecoder {
491    /// Decoder layers.
492    layers: Vec<TransformerDecoderLayer>,
493    /// Optional final layer norm.
494    norm: Option<LayerNorm>,
495}
496
497impl TransformerDecoder {
498    /// Creates a TransformerDecoder with the given number of layers (post-norm).
499    pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
500        Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
501    }
502
503    /// Creates a TransformerDecoder with configurable norm ordering.
504    pub fn new_with_pre_norm(
505        d_model: usize,
506        nhead: usize,
507        dim_feedforward: usize,
508        num_layers: usize,
509        pre_norm: bool,
510    ) -> Self {
511        let layers = (0..num_layers)
512            .map(|_| {
513                TransformerDecoderLayer::new_with_pre_norm(
514                    d_model,
515                    nhead,
516                    dim_feedforward,
517                    pre_norm,
518                )
519            })
520            .collect();
521
522        Self {
523            layers,
524            norm: Some(LayerNorm::single(d_model)),
525        }
526    }
527
528    /// Creates a TransformerDecoder without final layer norm.
529    pub fn without_norm(
530        d_model: usize,
531        nhead: usize,
532        dim_feedforward: usize,
533        num_layers: usize,
534    ) -> Self {
535        let layers = (0..num_layers)
536            .map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward))
537            .collect();
538
539        Self { layers, norm: None }
540    }
541
542    /// Forward pass with encoder memory and optional masks.
543    pub fn forward_with_memory(
544        &self,
545        tgt: &Variable,
546        memory: &Variable,
547        tgt_mask: Option<&Variable>,
548        memory_mask: Option<&Variable>,
549    ) -> Variable {
550        let mut x = tgt.clone();
551        for layer in &self.layers {
552            x = layer.forward_with_memory(&x, memory, tgt_mask, memory_mask);
553        }
554        if let Some(ref norm) = self.norm {
555            x = norm.forward(&x);
556        }
557        x
558    }
559
560    /// Returns the number of layers.
561    pub fn num_layers(&self) -> usize {
562        self.layers.len()
563    }
564}
565
566impl Module for TransformerDecoder {
567    fn forward(&self, input: &Variable) -> Variable {
568        // Without memory, runs self-attention only (for pretraining/testing)
569        let mut x = input.clone();
570        for layer in &self.layers {
571            x = layer.forward(&x);
572        }
573        if let Some(ref norm) = self.norm {
574            x = norm.forward(&x);
575        }
576        x
577    }
578
579    fn parameters(&self) -> Vec<Parameter> {
580        let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
581        if let Some(ref norm) = self.norm {
582            params.extend(norm.parameters());
583        }
584        params
585    }
586
587    fn named_parameters(&self) -> HashMap<String, Parameter> {
588        let mut params = HashMap::new();
589        for (i, layer) in self.layers.iter().enumerate() {
590            for (name, param) in layer.named_parameters() {
591                params.insert(format!("layers.{i}.{name}"), param);
592            }
593        }
594        if let Some(ref norm) = self.norm {
595            for (name, param) in norm.named_parameters() {
596                params.insert(format!("norm.{name}"), param);
597            }
598        }
599        params
600    }
601
602    fn name(&self) -> &'static str {
603        "TransformerDecoder"
604    }
605}
606
607// =============================================================================
608// Seq2SeqTransformer
609// =============================================================================
610
611/// Full Encoder-Decoder Transformer for sequence-to-sequence tasks.
612///
613/// Combines a TransformerEncoder and TransformerDecoder into a single module.
614/// Follows PyTorch's `nn.Transformer` API.
615///
616/// # Architecture
617/// ```text
618/// Source → [Encoder] → Memory
619///                         ↓
620/// Target → [Decoder] → Output
621/// ```
622///
623/// # Shape
624/// - Source: (N, S, E)
625/// - Target: (N, T, E)
626/// - Output: (N, T, E)
627pub struct Seq2SeqTransformer {
628    /// Encoder stack.
629    encoder: TransformerEncoder,
630    /// Decoder stack.
631    decoder: TransformerDecoder,
632    /// Model dimension.
633    d_model: usize,
634    /// Number of attention heads.
635    nhead: usize,
636}
637
638impl Seq2SeqTransformer {
639    /// Creates a new Seq2SeqTransformer.
640    ///
641    /// # Arguments
642    /// * `d_model` - Embedding/model dimension
643    /// * `nhead` - Number of attention heads
644    /// * `num_encoder_layers` - Number of encoder layers
645    /// * `num_decoder_layers` - Number of decoder layers
646    /// * `dim_feedforward` - Hidden dimension of feedforward networks
647    pub fn new(
648        d_model: usize,
649        nhead: usize,
650        num_encoder_layers: usize,
651        num_decoder_layers: usize,
652        dim_feedforward: usize,
653    ) -> Self {
654        Self {
655            encoder: TransformerEncoder::new(d_model, nhead, dim_feedforward, num_encoder_layers),
656            decoder: TransformerDecoder::new(d_model, nhead, dim_feedforward, num_decoder_layers),
657            d_model,
658            nhead,
659        }
660    }
661
662    /// Creates a Seq2SeqTransformer with pre-norm ordering.
663    ///
664    /// Pre-norm applies LayerNorm before each sublayer (inside the residual
665    /// branch), which gives better gradient flow for small datasets.
666    pub fn new_pre_norm(
667        d_model: usize,
668        nhead: usize,
669        num_encoder_layers: usize,
670        num_decoder_layers: usize,
671        dim_feedforward: usize,
672    ) -> Self {
673        Self {
674            encoder: TransformerEncoder::new_with_pre_norm(
675                d_model,
676                nhead,
677                dim_feedforward,
678                num_encoder_layers,
679                true,
680            ),
681            decoder: TransformerDecoder::new_with_pre_norm(
682                d_model,
683                nhead,
684                dim_feedforward,
685                num_decoder_layers,
686                true,
687            ),
688            d_model,
689            nhead,
690        }
691    }
692
693    /// Creates a Seq2SeqTransformer with default settings (6 layers, 2048 FFN).
694    pub fn default_config(d_model: usize, nhead: usize) -> Self {
695        Self::new(d_model, nhead, 6, 6, 2048)
696    }
697
698    /// Full forward pass: encode source, then decode target conditioned on encoder output.
699    ///
700    /// # Arguments
701    /// * `src` - Source sequence (N, S, E)
702    /// * `tgt` - Target sequence (N, T, E)
703    /// * `src_mask` - Optional mask for encoder self-attention
704    /// * `tgt_mask` - Optional causal mask for decoder self-attention
705    /// * `memory_mask` - Optional mask for decoder cross-attention
706    pub fn forward_seq2seq(
707        &self,
708        src: &Variable,
709        tgt: &Variable,
710        src_mask: Option<&Variable>,
711        tgt_mask: Option<&Variable>,
712        memory_mask: Option<&Variable>,
713    ) -> Variable {
714        let memory = self.encoder.forward_with_mask(src, src_mask);
715        self.decoder
716            .forward_with_memory(tgt, &memory, tgt_mask, memory_mask)
717    }
718
719    /// Encode source sequence only (useful for inference).
720    pub fn encode(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
721        self.encoder.forward_with_mask(src, src_mask)
722    }
723
724    /// Decode target given pre-computed encoder memory (useful for inference).
725    pub fn decode(
726        &self,
727        tgt: &Variable,
728        memory: &Variable,
729        tgt_mask: Option<&Variable>,
730        memory_mask: Option<&Variable>,
731    ) -> Variable {
732        self.decoder
733            .forward_with_memory(tgt, memory, tgt_mask, memory_mask)
734    }
735
736    /// Generates a causal (upper-triangular) mask for autoregressive decoding.
737    ///
738    /// Returns a mask of shape (seq_len, seq_len) where future positions are 0.0
739    /// and valid positions are 1.0.
740    pub fn generate_square_subsequent_mask(seq_len: usize) -> Variable {
741        let mut mask_data = vec![0.0f32; seq_len * seq_len];
742        for i in 0..seq_len {
743            for j in 0..=i {
744                mask_data[i * seq_len + j] = 1.0;
745            }
746        }
747        Variable::new(
748            Tensor::from_vec(mask_data, &[seq_len, seq_len]).expect("tensor creation failed"),
749            false,
750        )
751    }
752
753    /// Returns the model dimension.
754    pub fn d_model(&self) -> usize {
755        self.d_model
756    }
757
758    /// Returns the number of attention heads.
759    pub fn nhead(&self) -> usize {
760        self.nhead
761    }
762
763    /// Returns a reference to the encoder.
764    pub fn encoder(&self) -> &TransformerEncoder {
765        &self.encoder
766    }
767
768    /// Returns a reference to the decoder.
769    pub fn decoder(&self) -> &TransformerDecoder {
770        &self.decoder
771    }
772}
773
774impl Module for Seq2SeqTransformer {
775    /// Encode-only forward (Module trait requires single input).
776    ///
777    /// **Important:** This only runs the encoder. For full encode+decode,
778    /// use `forward_seq2seq(src, tgt)` which takes both source and target.
779    fn forward(&self, input: &Variable) -> Variable {
780        self.encoder.forward(input)
781    }
782
783    fn parameters(&self) -> Vec<Parameter> {
784        let mut params = self.encoder.parameters();
785        params.extend(self.decoder.parameters());
786        params
787    }
788
789    fn named_parameters(&self) -> HashMap<String, Parameter> {
790        let mut params = HashMap::new();
791        for (name, param) in self.encoder.named_parameters() {
792            params.insert(format!("encoder.{name}"), param);
793        }
794        for (name, param) in self.decoder.named_parameters() {
795            params.insert(format!("decoder.{name}"), param);
796        }
797        params
798    }
799
800    fn name(&self) -> &'static str {
801        "Seq2SeqTransformer"
802    }
803}
804
805// =============================================================================
806// Tests
807// =============================================================================
808
809#[cfg(test)]
810mod tests {
811    use super::*;
812
813    #[test]
814    fn test_encoder_layer_creation() {
815        let layer = TransformerEncoderLayer::new(64, 4, 256);
816        assert_eq!(layer.d_model(), 64);
817    }
818
819    #[test]
820    fn test_encoder_layer_forward() {
821        let layer = TransformerEncoderLayer::new(64, 4, 256);
822        let input = Variable::new(
823            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
824            false,
825        );
826        let output = layer.forward(&input);
827        assert_eq!(output.shape(), vec![2, 10, 64]);
828    }
829
830    #[test]
831    fn test_decoder_layer_with_memory() {
832        let layer = TransformerDecoderLayer::new(64, 4, 256);
833        let tgt = Variable::new(
834            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
835            false,
836        );
837        let memory = Variable::new(
838            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
839            false,
840        );
841        let output = layer.forward_with_memory(&tgt, &memory, None, None);
842        assert_eq!(output.shape(), vec![2, 5, 64]);
843    }
844
845    #[test]
846    fn test_encoder_stack() {
847        let encoder = TransformerEncoder::new(64, 4, 256, 3);
848        assert_eq!(encoder.num_layers(), 3);
849
850        let input = Variable::new(
851            Tensor::from_vec(vec![0.1; 2 * 8 * 64], &[2, 8, 64]).expect("tensor creation failed"),
852            false,
853        );
854        let output = encoder.forward(&input);
855        assert_eq!(output.shape(), vec![2, 8, 64]);
856    }
857
858    #[test]
859    fn test_decoder_stack() {
860        let decoder = TransformerDecoder::new(64, 4, 256, 3);
861        assert_eq!(decoder.num_layers(), 3);
862
863        let tgt = Variable::new(
864            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
865            false,
866        );
867        let memory = Variable::new(
868            Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
869            false,
870        );
871        let output = decoder.forward_with_memory(&tgt, &memory, None, None);
872        assert_eq!(output.shape(), vec![2, 5, 64]);
873    }
874
875    #[test]
876    fn test_seq2seq_transformer() {
877        let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
878        assert_eq!(transformer.d_model(), 64);
879        assert_eq!(transformer.nhead(), 4);
880
881        let src = Variable::new(
882            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
883            false,
884        );
885        let tgt = Variable::new(
886            Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
887            false,
888        );
889        let output = transformer.forward_seq2seq(&src, &tgt, None, None, None);
890        assert_eq!(output.shape(), vec![2, 5, 64]);
891    }
892
893    #[test]
894    fn test_seq2seq_encode_decode_separate() {
895        let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
896
897        let src = Variable::new(
898            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
899            false,
900        );
901        let tgt = Variable::new(
902            Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
903            false,
904        );
905
906        // Encode once, decode multiple times (autoregressive inference)
907        let memory = transformer.encode(&src, None);
908        assert_eq!(memory.shape(), vec![2, 10, 64]);
909
910        let output = transformer.decode(&tgt, &memory, None, None);
911        assert_eq!(output.shape(), vec![2, 5, 64]);
912    }
913
914    #[test]
915    fn test_causal_mask() {
916        let mask = Seq2SeqTransformer::generate_square_subsequent_mask(4);
917        let mask_data = mask.data().to_vec();
918        // Row 0: [1, 0, 0, 0]
919        // Row 1: [1, 1, 0, 0]
920        // Row 2: [1, 1, 1, 0]
921        // Row 3: [1, 1, 1, 1]
922        assert_eq!(mask_data[0], 1.0); // (0,0) = visible
923        assert_eq!(mask_data[1], 0.0); // (0,1) = masked
924        assert_eq!(mask_data[4], 1.0); // (1,0) = visible
925        assert_eq!(mask_data[5], 1.0); // (1,1) = visible
926        assert_eq!(mask_data[6], 0.0); // (1,2) = masked
927        assert_eq!(mask_data[15], 1.0); // (3,3) = visible
928    }
929
930    #[test]
931    fn test_default_config() {
932        let transformer = Seq2SeqTransformer::default_config(512, 8);
933        assert_eq!(transformer.encoder().num_layers(), 6);
934        assert_eq!(transformer.decoder().num_layers(), 6);
935    }
936
937    #[test]
938    fn test_parameter_count() {
939        let layer = TransformerEncoderLayer::new(64, 4, 256);
940        let params = layer.parameters();
941        // self_attn: 4 projections × (weight + bias) = 8
942        // linear1: weight + bias = 2
943        // linear2: weight + bias = 2
944        // norm1: weight + bias = 2
945        // norm2: weight + bias = 2
946        assert_eq!(params.len(), 16);
947    }
948
949    #[test]
950    fn test_decoder_parameter_count() {
951        let layer = TransformerDecoderLayer::new(64, 4, 256);
952        let params = layer.parameters();
953        // self_attn: 8, cross_attn: 8, linear1: 2, linear2: 2, norm1: 2, norm2: 2, norm3: 2
954        assert_eq!(params.len(), 26);
955    }
956
957    #[test]
958    fn test_named_parameters_hierarchy() {
959        let transformer = Seq2SeqTransformer::new(64, 4, 1, 1, 256);
960        let named = transformer.named_parameters();
961        // Verify hierarchical naming
962        assert!(named.contains_key("encoder.layers.0.self_attn.q_proj.weight"));
963        assert!(named.contains_key("decoder.layers.0.cross_attn.q_proj.weight"));
964        assert!(named.contains_key("encoder.norm.weight"));
965        assert!(named.contains_key("decoder.norm.weight"));
966    }
967
968    #[test]
969    fn test_seq2seq_with_causal_mask() {
970        let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
971        let src = Variable::new(
972            Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
973            false,
974        );
975        let tgt = Variable::new(
976            Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
977            false,
978        );
979        let tgt_mask = Seq2SeqTransformer::generate_square_subsequent_mask(5);
980        let output = transformer.forward_seq2seq(&src, &tgt, None, Some(&tgt_mask), None);
981        assert_eq!(output.shape(), vec![2, 5, 64]);
982    }
983}