1use std::collections::HashMap;
26
27use axonml_autograd::Variable;
28use axonml_tensor::Tensor;
29
30use crate::layers::attention::MultiHeadAttention;
31use crate::layers::linear::Linear;
32use crate::layers::norm::LayerNorm;
33use crate::module::Module;
34use crate::parameter::Parameter;
35
36pub struct TransformerEncoderLayer {
51 self_attn: MultiHeadAttention,
53 linear1: Linear,
55 linear2: Linear,
57 norm1: LayerNorm,
59 norm2: LayerNorm,
61 d_model: usize,
63 pre_norm: bool,
65}
66
67impl TransformerEncoderLayer {
68 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
75 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
76 }
77
78 pub fn new_with_pre_norm(
80 d_model: usize,
81 nhead: usize,
82 dim_feedforward: usize,
83 pre_norm: bool,
84 ) -> Self {
85 Self {
86 self_attn: MultiHeadAttention::new(d_model, nhead),
87 linear1: Linear::new(d_model, dim_feedforward),
88 linear2: Linear::new(dim_feedforward, d_model),
89 norm1: LayerNorm::single(d_model),
90 norm2: LayerNorm::single(d_model),
91 d_model,
92 pre_norm,
93 }
94 }
95
96 pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
102 if self.pre_norm {
103 let normed = self.norm1.forward(src);
104 let attn_out = self
105 .self_attn
106 .attention(&normed, &normed, &normed, src_mask);
107 let x = src.add_var(&attn_out);
108
109 let normed = self.norm2.forward(&x);
110 let ff_out = self.linear1.forward(&normed).relu();
111 let ff_out = self.linear2.forward(&ff_out);
112 x.add_var(&ff_out)
113 } else {
114 let attn_out = self.self_attn.attention(src, src, src, src_mask);
115 let x = src.add_var(&attn_out);
116 let x = self.norm1.forward(&x);
117
118 let ff_out = self.linear1.forward(&x).relu();
119 let ff_out = self.linear2.forward(&ff_out);
120 let x = x.add_var(&ff_out);
121 self.norm2.forward(&x)
122 }
123 }
124
125 pub fn d_model(&self) -> usize {
127 self.d_model
128 }
129}
130
131impl Module for TransformerEncoderLayer {
132 fn forward(&self, input: &Variable) -> Variable {
133 self.forward_with_mask(input, None)
134 }
135
136 fn parameters(&self) -> Vec<Parameter> {
137 let mut params = Vec::new();
138 params.extend(self.self_attn.parameters());
139 params.extend(self.linear1.parameters());
140 params.extend(self.linear2.parameters());
141 params.extend(self.norm1.parameters());
142 params.extend(self.norm2.parameters());
143 params
144 }
145
146 fn named_parameters(&self) -> HashMap<String, Parameter> {
147 let mut params = HashMap::new();
148 for (name, param) in self.self_attn.named_parameters() {
149 params.insert(format!("self_attn.{name}"), param);
150 }
151 for (name, param) in self.linear1.named_parameters() {
152 params.insert(format!("linear1.{name}"), param);
153 }
154 for (name, param) in self.linear2.named_parameters() {
155 params.insert(format!("linear2.{name}"), param);
156 }
157 for (name, param) in self.norm1.named_parameters() {
158 params.insert(format!("norm1.{name}"), param);
159 }
160 for (name, param) in self.norm2.named_parameters() {
161 params.insert(format!("norm2.{name}"), param);
162 }
163 params
164 }
165
166 fn name(&self) -> &'static str {
167 "TransformerEncoderLayer"
168 }
169}
170
171pub struct TransformerDecoderLayer {
189 self_attn: MultiHeadAttention,
191 cross_attn: MultiHeadAttention,
193 linear1: Linear,
195 linear2: Linear,
197 norm1: LayerNorm,
199 norm2: LayerNorm,
201 norm3: LayerNorm,
203 d_model: usize,
205 pre_norm: bool,
207}
208
209impl TransformerDecoderLayer {
210 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
217 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
218 }
219
220 pub fn new_with_pre_norm(
222 d_model: usize,
223 nhead: usize,
224 dim_feedforward: usize,
225 pre_norm: bool,
226 ) -> Self {
227 Self {
228 self_attn: MultiHeadAttention::new(d_model, nhead),
229 cross_attn: MultiHeadAttention::new(d_model, nhead),
230 linear1: Linear::new(d_model, dim_feedforward),
231 linear2: Linear::new(dim_feedforward, d_model),
232 norm1: LayerNorm::single(d_model),
233 norm2: LayerNorm::single(d_model),
234 norm3: LayerNorm::single(d_model),
235 d_model,
236 pre_norm,
237 }
238 }
239
240 pub fn forward_with_memory(
248 &self,
249 tgt: &Variable,
250 memory: &Variable,
251 tgt_mask: Option<&Variable>,
252 memory_mask: Option<&Variable>,
253 ) -> Variable {
254 if self.pre_norm {
255 let normed = self.norm1.forward(tgt);
257 let self_attn_out = self
258 .self_attn
259 .attention(&normed, &normed, &normed, tgt_mask);
260 let x = tgt.add_var(&self_attn_out);
261
262 let normed = self.norm2.forward(&x);
263 let cross_attn_out = self
264 .cross_attn
265 .attention(&normed, memory, memory, memory_mask);
266 let x = x.add_var(&cross_attn_out);
267
268 let normed = self.norm3.forward(&x);
269 let ff_out = self.linear1.forward(&normed).relu();
270 let ff_out = self.linear2.forward(&ff_out);
271 x.add_var(&ff_out)
272 } else {
273 let self_attn_out = self.self_attn.attention(tgt, tgt, tgt, tgt_mask);
275 let x = tgt.add_var(&self_attn_out);
276 let x = self.norm1.forward(&x);
277
278 let cross_attn_out = self.cross_attn.attention(&x, memory, memory, memory_mask);
279 let x = x.add_var(&cross_attn_out);
280 let x = self.norm2.forward(&x);
281
282 let ff_out = self.linear1.forward(&x).relu();
283 let ff_out = self.linear2.forward(&ff_out);
284 let x = x.add_var(&ff_out);
285 self.norm3.forward(&x)
286 }
287 }
288
289 pub fn d_model(&self) -> usize {
291 self.d_model
292 }
293}
294
295impl Module for TransformerDecoderLayer {
296 fn forward(&self, input: &Variable) -> Variable {
297 if self.pre_norm {
300 let normed = self.norm1.forward(input);
301 let self_attn_out = self.self_attn.attention(&normed, &normed, &normed, None);
302 let x = input.add_var(&self_attn_out);
303
304 let normed = self.norm3.forward(&x);
306 let ff_out = self.linear1.forward(&normed).relu();
307 let ff_out = self.linear2.forward(&ff_out);
308 x.add_var(&ff_out)
309 } else {
310 let self_attn_out = self.self_attn.attention(input, input, input, None);
311 let x = input.add_var(&self_attn_out);
312 let x = self.norm1.forward(&x);
313
314 let x_after_norm2 = self.norm2.forward(&x);
315 let ff_out = self.linear1.forward(&x_after_norm2).relu();
316 let ff_out = self.linear2.forward(&ff_out);
317 let x = x_after_norm2.add_var(&ff_out);
318 self.norm3.forward(&x)
319 }
320 }
321
322 fn parameters(&self) -> Vec<Parameter> {
323 let mut params = Vec::new();
324 params.extend(self.self_attn.parameters());
325 params.extend(self.cross_attn.parameters());
326 params.extend(self.linear1.parameters());
327 params.extend(self.linear2.parameters());
328 params.extend(self.norm1.parameters());
329 params.extend(self.norm2.parameters());
330 params.extend(self.norm3.parameters());
331 params
332 }
333
334 fn named_parameters(&self) -> HashMap<String, Parameter> {
335 let mut params = HashMap::new();
336 for (name, param) in self.self_attn.named_parameters() {
337 params.insert(format!("self_attn.{name}"), param);
338 }
339 for (name, param) in self.cross_attn.named_parameters() {
340 params.insert(format!("cross_attn.{name}"), param);
341 }
342 for (name, param) in self.linear1.named_parameters() {
343 params.insert(format!("linear1.{name}"), param);
344 }
345 for (name, param) in self.linear2.named_parameters() {
346 params.insert(format!("linear2.{name}"), param);
347 }
348 for (name, param) in self.norm1.named_parameters() {
349 params.insert(format!("norm1.{name}"), param);
350 }
351 for (name, param) in self.norm2.named_parameters() {
352 params.insert(format!("norm2.{name}"), param);
353 }
354 for (name, param) in self.norm3.named_parameters() {
355 params.insert(format!("norm3.{name}"), param);
356 }
357 params
358 }
359
360 fn name(&self) -> &'static str {
361 "TransformerDecoderLayer"
362 }
363}
364
365pub struct TransformerEncoder {
375 layers: Vec<TransformerEncoderLayer>,
377 norm: Option<LayerNorm>,
379}
380
381impl TransformerEncoder {
382 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
384 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
385 }
386
387 pub fn new_with_pre_norm(
392 d_model: usize,
393 nhead: usize,
394 dim_feedforward: usize,
395 num_layers: usize,
396 pre_norm: bool,
397 ) -> Self {
398 let layers = (0..num_layers)
399 .map(|_| {
400 TransformerEncoderLayer::new_with_pre_norm(
401 d_model,
402 nhead,
403 dim_feedforward,
404 pre_norm,
405 )
406 })
407 .collect();
408
409 Self {
410 layers,
411 norm: Some(LayerNorm::single(d_model)),
412 }
413 }
414
415 pub fn without_norm(
417 d_model: usize,
418 nhead: usize,
419 dim_feedforward: usize,
420 num_layers: usize,
421 ) -> Self {
422 let layers = (0..num_layers)
423 .map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward))
424 .collect();
425
426 Self { layers, norm: None }
427 }
428
429 pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
431 let mut x = src.clone();
432 for layer in &self.layers {
433 x = layer.forward_with_mask(&x, src_mask);
434 }
435 if let Some(ref norm) = self.norm {
436 x = norm.forward(&x);
437 }
438 x
439 }
440
441 pub fn num_layers(&self) -> usize {
443 self.layers.len()
444 }
445}
446
447impl Module for TransformerEncoder {
448 fn forward(&self, input: &Variable) -> Variable {
449 self.forward_with_mask(input, None)
450 }
451
452 fn parameters(&self) -> Vec<Parameter> {
453 let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
454 if let Some(ref norm) = self.norm {
455 params.extend(norm.parameters());
456 }
457 params
458 }
459
460 fn named_parameters(&self) -> HashMap<String, Parameter> {
461 let mut params = HashMap::new();
462 for (i, layer) in self.layers.iter().enumerate() {
463 for (name, param) in layer.named_parameters() {
464 params.insert(format!("layers.{i}.{name}"), param);
465 }
466 }
467 if let Some(ref norm) = self.norm {
468 for (name, param) in norm.named_parameters() {
469 params.insert(format!("norm.{name}"), param);
470 }
471 }
472 params
473 }
474
475 fn name(&self) -> &'static str {
476 "TransformerEncoder"
477 }
478}
479
480pub struct TransformerDecoder {
491 layers: Vec<TransformerDecoderLayer>,
493 norm: Option<LayerNorm>,
495}
496
497impl TransformerDecoder {
498 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
500 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
501 }
502
503 pub fn new_with_pre_norm(
505 d_model: usize,
506 nhead: usize,
507 dim_feedforward: usize,
508 num_layers: usize,
509 pre_norm: bool,
510 ) -> Self {
511 let layers = (0..num_layers)
512 .map(|_| {
513 TransformerDecoderLayer::new_with_pre_norm(
514 d_model,
515 nhead,
516 dim_feedforward,
517 pre_norm,
518 )
519 })
520 .collect();
521
522 Self {
523 layers,
524 norm: Some(LayerNorm::single(d_model)),
525 }
526 }
527
528 pub fn without_norm(
530 d_model: usize,
531 nhead: usize,
532 dim_feedforward: usize,
533 num_layers: usize,
534 ) -> Self {
535 let layers = (0..num_layers)
536 .map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward))
537 .collect();
538
539 Self { layers, norm: None }
540 }
541
542 pub fn forward_with_memory(
544 &self,
545 tgt: &Variable,
546 memory: &Variable,
547 tgt_mask: Option<&Variable>,
548 memory_mask: Option<&Variable>,
549 ) -> Variable {
550 let mut x = tgt.clone();
551 for layer in &self.layers {
552 x = layer.forward_with_memory(&x, memory, tgt_mask, memory_mask);
553 }
554 if let Some(ref norm) = self.norm {
555 x = norm.forward(&x);
556 }
557 x
558 }
559
560 pub fn num_layers(&self) -> usize {
562 self.layers.len()
563 }
564}
565
566impl Module for TransformerDecoder {
567 fn forward(&self, input: &Variable) -> Variable {
568 let mut x = input.clone();
570 for layer in &self.layers {
571 x = layer.forward(&x);
572 }
573 if let Some(ref norm) = self.norm {
574 x = norm.forward(&x);
575 }
576 x
577 }
578
579 fn parameters(&self) -> Vec<Parameter> {
580 let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
581 if let Some(ref norm) = self.norm {
582 params.extend(norm.parameters());
583 }
584 params
585 }
586
587 fn named_parameters(&self) -> HashMap<String, Parameter> {
588 let mut params = HashMap::new();
589 for (i, layer) in self.layers.iter().enumerate() {
590 for (name, param) in layer.named_parameters() {
591 params.insert(format!("layers.{i}.{name}"), param);
592 }
593 }
594 if let Some(ref norm) = self.norm {
595 for (name, param) in norm.named_parameters() {
596 params.insert(format!("norm.{name}"), param);
597 }
598 }
599 params
600 }
601
602 fn name(&self) -> &'static str {
603 "TransformerDecoder"
604 }
605}
606
607pub struct Seq2SeqTransformer {
628 encoder: TransformerEncoder,
630 decoder: TransformerDecoder,
632 d_model: usize,
634 nhead: usize,
636}
637
638impl Seq2SeqTransformer {
639 pub fn new(
648 d_model: usize,
649 nhead: usize,
650 num_encoder_layers: usize,
651 num_decoder_layers: usize,
652 dim_feedforward: usize,
653 ) -> Self {
654 Self {
655 encoder: TransformerEncoder::new(d_model, nhead, dim_feedforward, num_encoder_layers),
656 decoder: TransformerDecoder::new(d_model, nhead, dim_feedforward, num_decoder_layers),
657 d_model,
658 nhead,
659 }
660 }
661
662 pub fn new_pre_norm(
667 d_model: usize,
668 nhead: usize,
669 num_encoder_layers: usize,
670 num_decoder_layers: usize,
671 dim_feedforward: usize,
672 ) -> Self {
673 Self {
674 encoder: TransformerEncoder::new_with_pre_norm(
675 d_model,
676 nhead,
677 dim_feedforward,
678 num_encoder_layers,
679 true,
680 ),
681 decoder: TransformerDecoder::new_with_pre_norm(
682 d_model,
683 nhead,
684 dim_feedforward,
685 num_decoder_layers,
686 true,
687 ),
688 d_model,
689 nhead,
690 }
691 }
692
693 pub fn default_config(d_model: usize, nhead: usize) -> Self {
695 Self::new(d_model, nhead, 6, 6, 2048)
696 }
697
698 pub fn forward_seq2seq(
707 &self,
708 src: &Variable,
709 tgt: &Variable,
710 src_mask: Option<&Variable>,
711 tgt_mask: Option<&Variable>,
712 memory_mask: Option<&Variable>,
713 ) -> Variable {
714 let memory = self.encoder.forward_with_mask(src, src_mask);
715 self.decoder
716 .forward_with_memory(tgt, &memory, tgt_mask, memory_mask)
717 }
718
719 pub fn encode(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
721 self.encoder.forward_with_mask(src, src_mask)
722 }
723
724 pub fn decode(
726 &self,
727 tgt: &Variable,
728 memory: &Variable,
729 tgt_mask: Option<&Variable>,
730 memory_mask: Option<&Variable>,
731 ) -> Variable {
732 self.decoder
733 .forward_with_memory(tgt, memory, tgt_mask, memory_mask)
734 }
735
736 pub fn generate_square_subsequent_mask(seq_len: usize) -> Variable {
741 let mut mask_data = vec![0.0f32; seq_len * seq_len];
742 for i in 0..seq_len {
743 for j in 0..=i {
744 mask_data[i * seq_len + j] = 1.0;
745 }
746 }
747 Variable::new(
748 Tensor::from_vec(mask_data, &[seq_len, seq_len]).expect("tensor creation failed"),
749 false,
750 )
751 }
752
753 pub fn d_model(&self) -> usize {
755 self.d_model
756 }
757
758 pub fn nhead(&self) -> usize {
760 self.nhead
761 }
762
763 pub fn encoder(&self) -> &TransformerEncoder {
765 &self.encoder
766 }
767
768 pub fn decoder(&self) -> &TransformerDecoder {
770 &self.decoder
771 }
772}
773
774impl Module for Seq2SeqTransformer {
775 fn forward(&self, input: &Variable) -> Variable {
780 self.encoder.forward(input)
781 }
782
783 fn parameters(&self) -> Vec<Parameter> {
784 let mut params = self.encoder.parameters();
785 params.extend(self.decoder.parameters());
786 params
787 }
788
789 fn named_parameters(&self) -> HashMap<String, Parameter> {
790 let mut params = HashMap::new();
791 for (name, param) in self.encoder.named_parameters() {
792 params.insert(format!("encoder.{name}"), param);
793 }
794 for (name, param) in self.decoder.named_parameters() {
795 params.insert(format!("decoder.{name}"), param);
796 }
797 params
798 }
799
800 fn name(&self) -> &'static str {
801 "Seq2SeqTransformer"
802 }
803}
804
805#[cfg(test)]
810mod tests {
811 use super::*;
812
813 #[test]
814 fn test_encoder_layer_creation() {
815 let layer = TransformerEncoderLayer::new(64, 4, 256);
816 assert_eq!(layer.d_model(), 64);
817 }
818
819 #[test]
820 fn test_encoder_layer_forward() {
821 let layer = TransformerEncoderLayer::new(64, 4, 256);
822 let input = Variable::new(
823 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
824 false,
825 );
826 let output = layer.forward(&input);
827 assert_eq!(output.shape(), vec![2, 10, 64]);
828 }
829
830 #[test]
831 fn test_decoder_layer_with_memory() {
832 let layer = TransformerDecoderLayer::new(64, 4, 256);
833 let tgt = Variable::new(
834 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
835 false,
836 );
837 let memory = Variable::new(
838 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
839 false,
840 );
841 let output = layer.forward_with_memory(&tgt, &memory, None, None);
842 assert_eq!(output.shape(), vec![2, 5, 64]);
843 }
844
845 #[test]
846 fn test_encoder_stack() {
847 let encoder = TransformerEncoder::new(64, 4, 256, 3);
848 assert_eq!(encoder.num_layers(), 3);
849
850 let input = Variable::new(
851 Tensor::from_vec(vec![0.1; 2 * 8 * 64], &[2, 8, 64]).expect("tensor creation failed"),
852 false,
853 );
854 let output = encoder.forward(&input);
855 assert_eq!(output.shape(), vec![2, 8, 64]);
856 }
857
858 #[test]
859 fn test_decoder_stack() {
860 let decoder = TransformerDecoder::new(64, 4, 256, 3);
861 assert_eq!(decoder.num_layers(), 3);
862
863 let tgt = Variable::new(
864 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
865 false,
866 );
867 let memory = Variable::new(
868 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
869 false,
870 );
871 let output = decoder.forward_with_memory(&tgt, &memory, None, None);
872 assert_eq!(output.shape(), vec![2, 5, 64]);
873 }
874
875 #[test]
876 fn test_seq2seq_transformer() {
877 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
878 assert_eq!(transformer.d_model(), 64);
879 assert_eq!(transformer.nhead(), 4);
880
881 let src = Variable::new(
882 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
883 false,
884 );
885 let tgt = Variable::new(
886 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
887 false,
888 );
889 let output = transformer.forward_seq2seq(&src, &tgt, None, None, None);
890 assert_eq!(output.shape(), vec![2, 5, 64]);
891 }
892
893 #[test]
894 fn test_seq2seq_encode_decode_separate() {
895 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
896
897 let src = Variable::new(
898 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
899 false,
900 );
901 let tgt = Variable::new(
902 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
903 false,
904 );
905
906 let memory = transformer.encode(&src, None);
908 assert_eq!(memory.shape(), vec![2, 10, 64]);
909
910 let output = transformer.decode(&tgt, &memory, None, None);
911 assert_eq!(output.shape(), vec![2, 5, 64]);
912 }
913
914 #[test]
915 fn test_causal_mask() {
916 let mask = Seq2SeqTransformer::generate_square_subsequent_mask(4);
917 let mask_data = mask.data().to_vec();
918 assert_eq!(mask_data[0], 1.0); assert_eq!(mask_data[1], 0.0); assert_eq!(mask_data[4], 1.0); assert_eq!(mask_data[5], 1.0); assert_eq!(mask_data[6], 0.0); assert_eq!(mask_data[15], 1.0); }
929
930 #[test]
931 fn test_default_config() {
932 let transformer = Seq2SeqTransformer::default_config(512, 8);
933 assert_eq!(transformer.encoder().num_layers(), 6);
934 assert_eq!(transformer.decoder().num_layers(), 6);
935 }
936
937 #[test]
938 fn test_parameter_count() {
939 let layer = TransformerEncoderLayer::new(64, 4, 256);
940 let params = layer.parameters();
941 assert_eq!(params.len(), 16);
947 }
948
949 #[test]
950 fn test_decoder_parameter_count() {
951 let layer = TransformerDecoderLayer::new(64, 4, 256);
952 let params = layer.parameters();
953 assert_eq!(params.len(), 26);
955 }
956
957 #[test]
958 fn test_named_parameters_hierarchy() {
959 let transformer = Seq2SeqTransformer::new(64, 4, 1, 1, 256);
960 let named = transformer.named_parameters();
961 assert!(named.contains_key("encoder.layers.0.self_attn.q_proj.weight"));
963 assert!(named.contains_key("decoder.layers.0.cross_attn.q_proj.weight"));
964 assert!(named.contains_key("encoder.norm.weight"));
965 assert!(named.contains_key("decoder.norm.weight"));
966 }
967
968 #[test]
969 fn test_seq2seq_with_causal_mask() {
970 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
971 let src = Variable::new(
972 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
973 false,
974 );
975 let tgt = Variable::new(
976 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
977 false,
978 );
979 let tgt_mask = Seq2SeqTransformer::generate_square_subsequent_mask(5);
980 let output = transformer.forward_seq2seq(&src, &tgt, None, Some(&tgt_mask), None);
981 assert_eq!(output.shape(), vec![2, 5, 64]);
982 }
983}