1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::layers::attention::MultiHeadAttention;
23use crate::layers::linear::Linear;
24use crate::layers::norm::LayerNorm;
25use crate::module::Module;
26use crate::parameter::Parameter;
27
28pub struct TransformerEncoderLayer {
43 self_attn: MultiHeadAttention,
45 linear1: Linear,
47 linear2: Linear,
49 norm1: LayerNorm,
51 norm2: LayerNorm,
53 d_model: usize,
55 pre_norm: bool,
57}
58
59impl TransformerEncoderLayer {
60 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
67 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
68 }
69
70 pub fn new_with_pre_norm(
72 d_model: usize,
73 nhead: usize,
74 dim_feedforward: usize,
75 pre_norm: bool,
76 ) -> Self {
77 Self {
78 self_attn: MultiHeadAttention::new(d_model, nhead),
79 linear1: Linear::new(d_model, dim_feedforward),
80 linear2: Linear::new(dim_feedforward, d_model),
81 norm1: LayerNorm::single(d_model),
82 norm2: LayerNorm::single(d_model),
83 d_model,
84 pre_norm,
85 }
86 }
87
88 pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
94 if self.pre_norm {
95 let normed = self.norm1.forward(src);
96 let attn_out = self
97 .self_attn
98 .attention(&normed, &normed, &normed, src_mask);
99 let x = src.add_var(&attn_out);
100
101 let normed = self.norm2.forward(&x);
102 let ff_out = self.linear1.forward(&normed).relu();
103 let ff_out = self.linear2.forward(&ff_out);
104 x.add_var(&ff_out)
105 } else {
106 let attn_out = self.self_attn.attention(src, src, src, src_mask);
107 let x = src.add_var(&attn_out);
108 let x = self.norm1.forward(&x);
109
110 let ff_out = self.linear1.forward(&x).relu();
111 let ff_out = self.linear2.forward(&ff_out);
112 let x = x.add_var(&ff_out);
113 self.norm2.forward(&x)
114 }
115 }
116
117 pub fn d_model(&self) -> usize {
119 self.d_model
120 }
121}
122
123impl Module for TransformerEncoderLayer {
124 fn forward(&self, input: &Variable) -> Variable {
125 self.forward_with_mask(input, None)
126 }
127
128 fn parameters(&self) -> Vec<Parameter> {
129 let mut params = Vec::new();
130 params.extend(self.self_attn.parameters());
131 params.extend(self.linear1.parameters());
132 params.extend(self.linear2.parameters());
133 params.extend(self.norm1.parameters());
134 params.extend(self.norm2.parameters());
135 params
136 }
137
138 fn named_parameters(&self) -> HashMap<String, Parameter> {
139 let mut params = HashMap::new();
140 for (name, param) in self.self_attn.named_parameters() {
141 params.insert(format!("self_attn.{name}"), param);
142 }
143 for (name, param) in self.linear1.named_parameters() {
144 params.insert(format!("linear1.{name}"), param);
145 }
146 for (name, param) in self.linear2.named_parameters() {
147 params.insert(format!("linear2.{name}"), param);
148 }
149 for (name, param) in self.norm1.named_parameters() {
150 params.insert(format!("norm1.{name}"), param);
151 }
152 for (name, param) in self.norm2.named_parameters() {
153 params.insert(format!("norm2.{name}"), param);
154 }
155 params
156 }
157
158 fn name(&self) -> &'static str {
159 "TransformerEncoderLayer"
160 }
161}
162
163pub struct TransformerDecoderLayer {
181 self_attn: MultiHeadAttention,
183 cross_attn: MultiHeadAttention,
185 linear1: Linear,
187 linear2: Linear,
189 norm1: LayerNorm,
191 norm2: LayerNorm,
193 norm3: LayerNorm,
195 d_model: usize,
197 pre_norm: bool,
199}
200
201impl TransformerDecoderLayer {
202 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
209 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
210 }
211
212 pub fn new_with_pre_norm(
214 d_model: usize,
215 nhead: usize,
216 dim_feedforward: usize,
217 pre_norm: bool,
218 ) -> Self {
219 Self {
220 self_attn: MultiHeadAttention::new(d_model, nhead),
221 cross_attn: MultiHeadAttention::new(d_model, nhead),
222 linear1: Linear::new(d_model, dim_feedforward),
223 linear2: Linear::new(dim_feedforward, d_model),
224 norm1: LayerNorm::single(d_model),
225 norm2: LayerNorm::single(d_model),
226 norm3: LayerNorm::single(d_model),
227 d_model,
228 pre_norm,
229 }
230 }
231
232 pub fn forward_with_memory(
240 &self,
241 tgt: &Variable,
242 memory: &Variable,
243 tgt_mask: Option<&Variable>,
244 memory_mask: Option<&Variable>,
245 ) -> Variable {
246 if self.pre_norm {
247 let normed = self.norm1.forward(tgt);
249 let self_attn_out = self
250 .self_attn
251 .attention(&normed, &normed, &normed, tgt_mask);
252 let x = tgt.add_var(&self_attn_out);
253
254 let normed = self.norm2.forward(&x);
255 let cross_attn_out = self
256 .cross_attn
257 .attention(&normed, memory, memory, memory_mask);
258 let x = x.add_var(&cross_attn_out);
259
260 let normed = self.norm3.forward(&x);
261 let ff_out = self.linear1.forward(&normed).relu();
262 let ff_out = self.linear2.forward(&ff_out);
263 x.add_var(&ff_out)
264 } else {
265 let self_attn_out = self.self_attn.attention(tgt, tgt, tgt, tgt_mask);
267 let x = tgt.add_var(&self_attn_out);
268 let x = self.norm1.forward(&x);
269
270 let cross_attn_out = self.cross_attn.attention(&x, memory, memory, memory_mask);
271 let x = x.add_var(&cross_attn_out);
272 let x = self.norm2.forward(&x);
273
274 let ff_out = self.linear1.forward(&x).relu();
275 let ff_out = self.linear2.forward(&ff_out);
276 let x = x.add_var(&ff_out);
277 self.norm3.forward(&x)
278 }
279 }
280
281 pub fn d_model(&self) -> usize {
283 self.d_model
284 }
285}
286
287impl Module for TransformerDecoderLayer {
288 fn forward(&self, input: &Variable) -> Variable {
289 if self.pre_norm {
292 let normed = self.norm1.forward(input);
293 let self_attn_out = self.self_attn.attention(&normed, &normed, &normed, None);
294 let x = input.add_var(&self_attn_out);
295
296 let normed = self.norm3.forward(&x);
298 let ff_out = self.linear1.forward(&normed).relu();
299 let ff_out = self.linear2.forward(&ff_out);
300 x.add_var(&ff_out)
301 } else {
302 let self_attn_out = self.self_attn.attention(input, input, input, None);
303 let x = input.add_var(&self_attn_out);
304 let x = self.norm1.forward(&x);
305
306 let x_after_norm2 = self.norm2.forward(&x);
307 let ff_out = self.linear1.forward(&x_after_norm2).relu();
308 let ff_out = self.linear2.forward(&ff_out);
309 let x = x_after_norm2.add_var(&ff_out);
310 self.norm3.forward(&x)
311 }
312 }
313
314 fn parameters(&self) -> Vec<Parameter> {
315 let mut params = Vec::new();
316 params.extend(self.self_attn.parameters());
317 params.extend(self.cross_attn.parameters());
318 params.extend(self.linear1.parameters());
319 params.extend(self.linear2.parameters());
320 params.extend(self.norm1.parameters());
321 params.extend(self.norm2.parameters());
322 params.extend(self.norm3.parameters());
323 params
324 }
325
326 fn named_parameters(&self) -> HashMap<String, Parameter> {
327 let mut params = HashMap::new();
328 for (name, param) in self.self_attn.named_parameters() {
329 params.insert(format!("self_attn.{name}"), param);
330 }
331 for (name, param) in self.cross_attn.named_parameters() {
332 params.insert(format!("cross_attn.{name}"), param);
333 }
334 for (name, param) in self.linear1.named_parameters() {
335 params.insert(format!("linear1.{name}"), param);
336 }
337 for (name, param) in self.linear2.named_parameters() {
338 params.insert(format!("linear2.{name}"), param);
339 }
340 for (name, param) in self.norm1.named_parameters() {
341 params.insert(format!("norm1.{name}"), param);
342 }
343 for (name, param) in self.norm2.named_parameters() {
344 params.insert(format!("norm2.{name}"), param);
345 }
346 for (name, param) in self.norm3.named_parameters() {
347 params.insert(format!("norm3.{name}"), param);
348 }
349 params
350 }
351
352 fn name(&self) -> &'static str {
353 "TransformerDecoderLayer"
354 }
355}
356
357pub struct TransformerEncoder {
367 layers: Vec<TransformerEncoderLayer>,
369 norm: Option<LayerNorm>,
371}
372
373impl TransformerEncoder {
374 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
376 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
377 }
378
379 pub fn new_with_pre_norm(
384 d_model: usize,
385 nhead: usize,
386 dim_feedforward: usize,
387 num_layers: usize,
388 pre_norm: bool,
389 ) -> Self {
390 let layers = (0..num_layers)
391 .map(|_| {
392 TransformerEncoderLayer::new_with_pre_norm(
393 d_model,
394 nhead,
395 dim_feedforward,
396 pre_norm,
397 )
398 })
399 .collect();
400
401 Self {
402 layers,
403 norm: Some(LayerNorm::single(d_model)),
404 }
405 }
406
407 pub fn without_norm(
409 d_model: usize,
410 nhead: usize,
411 dim_feedforward: usize,
412 num_layers: usize,
413 ) -> Self {
414 let layers = (0..num_layers)
415 .map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward))
416 .collect();
417
418 Self { layers, norm: None }
419 }
420
421 pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
423 let mut x = src.clone();
424 for layer in &self.layers {
425 x = layer.forward_with_mask(&x, src_mask);
426 }
427 if let Some(ref norm) = self.norm {
428 x = norm.forward(&x);
429 }
430 x
431 }
432
433 pub fn num_layers(&self) -> usize {
435 self.layers.len()
436 }
437}
438
439impl Module for TransformerEncoder {
440 fn forward(&self, input: &Variable) -> Variable {
441 self.forward_with_mask(input, None)
442 }
443
444 fn parameters(&self) -> Vec<Parameter> {
445 let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
446 if let Some(ref norm) = self.norm {
447 params.extend(norm.parameters());
448 }
449 params
450 }
451
452 fn named_parameters(&self) -> HashMap<String, Parameter> {
453 let mut params = HashMap::new();
454 for (i, layer) in self.layers.iter().enumerate() {
455 for (name, param) in layer.named_parameters() {
456 params.insert(format!("layers.{i}.{name}"), param);
457 }
458 }
459 if let Some(ref norm) = self.norm {
460 for (name, param) in norm.named_parameters() {
461 params.insert(format!("norm.{name}"), param);
462 }
463 }
464 params
465 }
466
467 fn name(&self) -> &'static str {
468 "TransformerEncoder"
469 }
470}
471
472pub struct TransformerDecoder {
483 layers: Vec<TransformerDecoderLayer>,
485 norm: Option<LayerNorm>,
487}
488
489impl TransformerDecoder {
490 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
492 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
493 }
494
495 pub fn new_with_pre_norm(
497 d_model: usize,
498 nhead: usize,
499 dim_feedforward: usize,
500 num_layers: usize,
501 pre_norm: bool,
502 ) -> Self {
503 let layers = (0..num_layers)
504 .map(|_| {
505 TransformerDecoderLayer::new_with_pre_norm(
506 d_model,
507 nhead,
508 dim_feedforward,
509 pre_norm,
510 )
511 })
512 .collect();
513
514 Self {
515 layers,
516 norm: Some(LayerNorm::single(d_model)),
517 }
518 }
519
520 pub fn without_norm(
522 d_model: usize,
523 nhead: usize,
524 dim_feedforward: usize,
525 num_layers: usize,
526 ) -> Self {
527 let layers = (0..num_layers)
528 .map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward))
529 .collect();
530
531 Self { layers, norm: None }
532 }
533
534 pub fn forward_with_memory(
536 &self,
537 tgt: &Variable,
538 memory: &Variable,
539 tgt_mask: Option<&Variable>,
540 memory_mask: Option<&Variable>,
541 ) -> Variable {
542 let mut x = tgt.clone();
543 for layer in &self.layers {
544 x = layer.forward_with_memory(&x, memory, tgt_mask, memory_mask);
545 }
546 if let Some(ref norm) = self.norm {
547 x = norm.forward(&x);
548 }
549 x
550 }
551
552 pub fn num_layers(&self) -> usize {
554 self.layers.len()
555 }
556}
557
558impl Module for TransformerDecoder {
559 fn forward(&self, input: &Variable) -> Variable {
560 let mut x = input.clone();
562 for layer in &self.layers {
563 x = layer.forward(&x);
564 }
565 if let Some(ref norm) = self.norm {
566 x = norm.forward(&x);
567 }
568 x
569 }
570
571 fn parameters(&self) -> Vec<Parameter> {
572 let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
573 if let Some(ref norm) = self.norm {
574 params.extend(norm.parameters());
575 }
576 params
577 }
578
579 fn named_parameters(&self) -> HashMap<String, Parameter> {
580 let mut params = HashMap::new();
581 for (i, layer) in self.layers.iter().enumerate() {
582 for (name, param) in layer.named_parameters() {
583 params.insert(format!("layers.{i}.{name}"), param);
584 }
585 }
586 if let Some(ref norm) = self.norm {
587 for (name, param) in norm.named_parameters() {
588 params.insert(format!("norm.{name}"), param);
589 }
590 }
591 params
592 }
593
594 fn name(&self) -> &'static str {
595 "TransformerDecoder"
596 }
597}
598
599pub struct Seq2SeqTransformer {
620 encoder: TransformerEncoder,
622 decoder: TransformerDecoder,
624 d_model: usize,
626 nhead: usize,
628}
629
630impl Seq2SeqTransformer {
631 pub fn new(
640 d_model: usize,
641 nhead: usize,
642 num_encoder_layers: usize,
643 num_decoder_layers: usize,
644 dim_feedforward: usize,
645 ) -> Self {
646 Self {
647 encoder: TransformerEncoder::new(d_model, nhead, dim_feedforward, num_encoder_layers),
648 decoder: TransformerDecoder::new(d_model, nhead, dim_feedforward, num_decoder_layers),
649 d_model,
650 nhead,
651 }
652 }
653
654 pub fn new_pre_norm(
659 d_model: usize,
660 nhead: usize,
661 num_encoder_layers: usize,
662 num_decoder_layers: usize,
663 dim_feedforward: usize,
664 ) -> Self {
665 Self {
666 encoder: TransformerEncoder::new_with_pre_norm(
667 d_model,
668 nhead,
669 dim_feedforward,
670 num_encoder_layers,
671 true,
672 ),
673 decoder: TransformerDecoder::new_with_pre_norm(
674 d_model,
675 nhead,
676 dim_feedforward,
677 num_decoder_layers,
678 true,
679 ),
680 d_model,
681 nhead,
682 }
683 }
684
685 pub fn default_config(d_model: usize, nhead: usize) -> Self {
687 Self::new(d_model, nhead, 6, 6, 2048)
688 }
689
690 pub fn forward_seq2seq(
699 &self,
700 src: &Variable,
701 tgt: &Variable,
702 src_mask: Option<&Variable>,
703 tgt_mask: Option<&Variable>,
704 memory_mask: Option<&Variable>,
705 ) -> Variable {
706 let memory = self.encoder.forward_with_mask(src, src_mask);
707 self.decoder
708 .forward_with_memory(tgt, &memory, tgt_mask, memory_mask)
709 }
710
711 pub fn encode(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
713 self.encoder.forward_with_mask(src, src_mask)
714 }
715
716 pub fn decode(
718 &self,
719 tgt: &Variable,
720 memory: &Variable,
721 tgt_mask: Option<&Variable>,
722 memory_mask: Option<&Variable>,
723 ) -> Variable {
724 self.decoder
725 .forward_with_memory(tgt, memory, tgt_mask, memory_mask)
726 }
727
728 pub fn generate_square_subsequent_mask(seq_len: usize) -> Variable {
733 let mut mask_data = vec![0.0f32; seq_len * seq_len];
734 for i in 0..seq_len {
735 for j in 0..=i {
736 mask_data[i * seq_len + j] = 1.0;
737 }
738 }
739 Variable::new(
740 Tensor::from_vec(mask_data, &[seq_len, seq_len]).unwrap(),
741 false,
742 )
743 }
744
745 pub fn d_model(&self) -> usize {
747 self.d_model
748 }
749
750 pub fn nhead(&self) -> usize {
752 self.nhead
753 }
754
755 pub fn encoder(&self) -> &TransformerEncoder {
757 &self.encoder
758 }
759
760 pub fn decoder(&self) -> &TransformerDecoder {
762 &self.decoder
763 }
764}
765
766impl Module for Seq2SeqTransformer {
767 fn forward(&self, input: &Variable) -> Variable {
768 self.encoder.forward(input)
770 }
771
772 fn parameters(&self) -> Vec<Parameter> {
773 let mut params = self.encoder.parameters();
774 params.extend(self.decoder.parameters());
775 params
776 }
777
778 fn named_parameters(&self) -> HashMap<String, Parameter> {
779 let mut params = HashMap::new();
780 for (name, param) in self.encoder.named_parameters() {
781 params.insert(format!("encoder.{name}"), param);
782 }
783 for (name, param) in self.decoder.named_parameters() {
784 params.insert(format!("decoder.{name}"), param);
785 }
786 params
787 }
788
789 fn name(&self) -> &'static str {
790 "Seq2SeqTransformer"
791 }
792}
793
794#[cfg(test)]
799mod tests {
800 use super::*;
801
802 #[test]
803 fn test_encoder_layer_creation() {
804 let layer = TransformerEncoderLayer::new(64, 4, 256);
805 assert_eq!(layer.d_model(), 64);
806 }
807
808 #[test]
809 fn test_encoder_layer_forward() {
810 let layer = TransformerEncoderLayer::new(64, 4, 256);
811 let input = Variable::new(
812 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
813 false,
814 );
815 let output = layer.forward(&input);
816 assert_eq!(output.shape(), vec![2, 10, 64]);
817 }
818
819 #[test]
820 fn test_decoder_layer_with_memory() {
821 let layer = TransformerDecoderLayer::new(64, 4, 256);
822 let tgt = Variable::new(
823 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
824 false,
825 );
826 let memory = Variable::new(
827 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
828 false,
829 );
830 let output = layer.forward_with_memory(&tgt, &memory, None, None);
831 assert_eq!(output.shape(), vec![2, 5, 64]);
832 }
833
834 #[test]
835 fn test_encoder_stack() {
836 let encoder = TransformerEncoder::new(64, 4, 256, 3);
837 assert_eq!(encoder.num_layers(), 3);
838
839 let input = Variable::new(
840 Tensor::from_vec(vec![0.1; 2 * 8 * 64], &[2, 8, 64]).unwrap(),
841 false,
842 );
843 let output = encoder.forward(&input);
844 assert_eq!(output.shape(), vec![2, 8, 64]);
845 }
846
847 #[test]
848 fn test_decoder_stack() {
849 let decoder = TransformerDecoder::new(64, 4, 256, 3);
850 assert_eq!(decoder.num_layers(), 3);
851
852 let tgt = Variable::new(
853 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
854 false,
855 );
856 let memory = Variable::new(
857 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
858 false,
859 );
860 let output = decoder.forward_with_memory(&tgt, &memory, None, None);
861 assert_eq!(output.shape(), vec![2, 5, 64]);
862 }
863
864 #[test]
865 fn test_seq2seq_transformer() {
866 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
867 assert_eq!(transformer.d_model(), 64);
868 assert_eq!(transformer.nhead(), 4);
869
870 let src = Variable::new(
871 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
872 false,
873 );
874 let tgt = Variable::new(
875 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
876 false,
877 );
878 let output = transformer.forward_seq2seq(&src, &tgt, None, None, None);
879 assert_eq!(output.shape(), vec![2, 5, 64]);
880 }
881
882 #[test]
883 fn test_seq2seq_encode_decode_separate() {
884 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
885
886 let src = Variable::new(
887 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
888 false,
889 );
890 let tgt = Variable::new(
891 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
892 false,
893 );
894
895 let memory = transformer.encode(&src, None);
897 assert_eq!(memory.shape(), vec![2, 10, 64]);
898
899 let output = transformer.decode(&tgt, &memory, None, None);
900 assert_eq!(output.shape(), vec![2, 5, 64]);
901 }
902
903 #[test]
904 fn test_causal_mask() {
905 let mask = Seq2SeqTransformer::generate_square_subsequent_mask(4);
906 let mask_data = mask.data().to_vec();
907 assert_eq!(mask_data[0], 1.0); assert_eq!(mask_data[1], 0.0); assert_eq!(mask_data[4], 1.0); assert_eq!(mask_data[5], 1.0); assert_eq!(mask_data[6], 0.0); assert_eq!(mask_data[15], 1.0); }
918
919 #[test]
920 fn test_default_config() {
921 let transformer = Seq2SeqTransformer::default_config(512, 8);
922 assert_eq!(transformer.encoder().num_layers(), 6);
923 assert_eq!(transformer.decoder().num_layers(), 6);
924 }
925
926 #[test]
927 fn test_parameter_count() {
928 let layer = TransformerEncoderLayer::new(64, 4, 256);
929 let params = layer.parameters();
930 assert_eq!(params.len(), 16);
936 }
937
938 #[test]
939 fn test_decoder_parameter_count() {
940 let layer = TransformerDecoderLayer::new(64, 4, 256);
941 let params = layer.parameters();
942 assert_eq!(params.len(), 26);
944 }
945
946 #[test]
947 fn test_named_parameters_hierarchy() {
948 let transformer = Seq2SeqTransformer::new(64, 4, 1, 1, 256);
949 let named = transformer.named_parameters();
950 assert!(named.contains_key("encoder.layers.0.self_attn.q_proj.weight"));
952 assert!(named.contains_key("decoder.layers.0.cross_attn.q_proj.weight"));
953 assert!(named.contains_key("encoder.norm.weight"));
954 assert!(named.contains_key("decoder.norm.weight"));
955 }
956
957 #[test]
958 fn test_seq2seq_with_causal_mask() {
959 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
960 let src = Variable::new(
961 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
962 false,
963 );
964 let tgt = Variable::new(
965 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
966 false,
967 );
968 let tgt_mask = Seq2SeqTransformer::generate_square_subsequent_mask(5);
969 let output = transformer.forward_seq2seq(&src, &tgt, None, Some(&tgt_mask), None);
970 assert_eq!(output.shape(), vec![2, 5, 64]);
971 }
972}