1use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_tensor::Tensor;
21
22use crate::layers::attention::MultiHeadAttention;
23use crate::layers::linear::Linear;
24use crate::layers::norm::LayerNorm;
25use crate::module::Module;
26use crate::parameter::Parameter;
27
28pub struct TransformerEncoderLayer {
43 self_attn: MultiHeadAttention,
45 linear1: Linear,
47 linear2: Linear,
49 norm1: LayerNorm,
51 norm2: LayerNorm,
53 d_model: usize,
55 pre_norm: bool,
57}
58
59impl TransformerEncoderLayer {
60 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
67 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
68 }
69
70 pub fn new_with_pre_norm(
72 d_model: usize,
73 nhead: usize,
74 dim_feedforward: usize,
75 pre_norm: bool,
76 ) -> Self {
77 Self {
78 self_attn: MultiHeadAttention::new(d_model, nhead),
79 linear1: Linear::new(d_model, dim_feedforward),
80 linear2: Linear::new(dim_feedforward, d_model),
81 norm1: LayerNorm::single(d_model),
82 norm2: LayerNorm::single(d_model),
83 d_model,
84 pre_norm,
85 }
86 }
87
88 pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
94 if self.pre_norm {
95 let normed = self.norm1.forward(src);
96 let attn_out = self
97 .self_attn
98 .attention(&normed, &normed, &normed, src_mask);
99 let x = src.add_var(&attn_out);
100
101 let normed = self.norm2.forward(&x);
102 let ff_out = self.linear1.forward(&normed).relu();
103 let ff_out = self.linear2.forward(&ff_out);
104 x.add_var(&ff_out)
105 } else {
106 let attn_out = self.self_attn.attention(src, src, src, src_mask);
107 let x = src.add_var(&attn_out);
108 let x = self.norm1.forward(&x);
109
110 let ff_out = self.linear1.forward(&x).relu();
111 let ff_out = self.linear2.forward(&ff_out);
112 let x = x.add_var(&ff_out);
113 self.norm2.forward(&x)
114 }
115 }
116
117 pub fn d_model(&self) -> usize {
119 self.d_model
120 }
121}
122
123impl Module for TransformerEncoderLayer {
124 fn forward(&self, input: &Variable) -> Variable {
125 self.forward_with_mask(input, None)
126 }
127
128 fn parameters(&self) -> Vec<Parameter> {
129 let mut params = Vec::new();
130 params.extend(self.self_attn.parameters());
131 params.extend(self.linear1.parameters());
132 params.extend(self.linear2.parameters());
133 params.extend(self.norm1.parameters());
134 params.extend(self.norm2.parameters());
135 params
136 }
137
138 fn named_parameters(&self) -> HashMap<String, Parameter> {
139 let mut params = HashMap::new();
140 for (name, param) in self.self_attn.named_parameters() {
141 params.insert(format!("self_attn.{name}"), param);
142 }
143 for (name, param) in self.linear1.named_parameters() {
144 params.insert(format!("linear1.{name}"), param);
145 }
146 for (name, param) in self.linear2.named_parameters() {
147 params.insert(format!("linear2.{name}"), param);
148 }
149 for (name, param) in self.norm1.named_parameters() {
150 params.insert(format!("norm1.{name}"), param);
151 }
152 for (name, param) in self.norm2.named_parameters() {
153 params.insert(format!("norm2.{name}"), param);
154 }
155 params
156 }
157
158 fn name(&self) -> &'static str {
159 "TransformerEncoderLayer"
160 }
161}
162
163pub struct TransformerDecoderLayer {
181 self_attn: MultiHeadAttention,
183 cross_attn: MultiHeadAttention,
185 linear1: Linear,
187 linear2: Linear,
189 norm1: LayerNorm,
191 norm2: LayerNorm,
193 norm3: LayerNorm,
195 d_model: usize,
197 pre_norm: bool,
199}
200
201impl TransformerDecoderLayer {
202 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
209 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, false)
210 }
211
212 pub fn new_with_pre_norm(
214 d_model: usize,
215 nhead: usize,
216 dim_feedforward: usize,
217 pre_norm: bool,
218 ) -> Self {
219 Self {
220 self_attn: MultiHeadAttention::new(d_model, nhead),
221 cross_attn: MultiHeadAttention::new(d_model, nhead),
222 linear1: Linear::new(d_model, dim_feedforward),
223 linear2: Linear::new(dim_feedforward, d_model),
224 norm1: LayerNorm::single(d_model),
225 norm2: LayerNorm::single(d_model),
226 norm3: LayerNorm::single(d_model),
227 d_model,
228 pre_norm,
229 }
230 }
231
232 pub fn forward_with_memory(
240 &self,
241 tgt: &Variable,
242 memory: &Variable,
243 tgt_mask: Option<&Variable>,
244 memory_mask: Option<&Variable>,
245 ) -> Variable {
246 if self.pre_norm {
247 let normed = self.norm1.forward(tgt);
249 let self_attn_out = self
250 .self_attn
251 .attention(&normed, &normed, &normed, tgt_mask);
252 let x = tgt.add_var(&self_attn_out);
253
254 let normed = self.norm2.forward(&x);
255 let cross_attn_out = self
256 .cross_attn
257 .attention(&normed, memory, memory, memory_mask);
258 let x = x.add_var(&cross_attn_out);
259
260 let normed = self.norm3.forward(&x);
261 let ff_out = self.linear1.forward(&normed).relu();
262 let ff_out = self.linear2.forward(&ff_out);
263 x.add_var(&ff_out)
264 } else {
265 let self_attn_out = self.self_attn.attention(tgt, tgt, tgt, tgt_mask);
267 let x = tgt.add_var(&self_attn_out);
268 let x = self.norm1.forward(&x);
269
270 let cross_attn_out = self.cross_attn.attention(&x, memory, memory, memory_mask);
271 let x = x.add_var(&cross_attn_out);
272 let x = self.norm2.forward(&x);
273
274 let ff_out = self.linear1.forward(&x).relu();
275 let ff_out = self.linear2.forward(&ff_out);
276 let x = x.add_var(&ff_out);
277 self.norm3.forward(&x)
278 }
279 }
280
281 pub fn d_model(&self) -> usize {
283 self.d_model
284 }
285}
286
287impl Module for TransformerDecoderLayer {
288 fn forward(&self, input: &Variable) -> Variable {
289 if self.pre_norm {
292 let normed = self.norm1.forward(input);
293 let self_attn_out = self.self_attn.attention(&normed, &normed, &normed, None);
294 let x = input.add_var(&self_attn_out);
295
296 let normed = self.norm3.forward(&x);
298 let ff_out = self.linear1.forward(&normed).relu();
299 let ff_out = self.linear2.forward(&ff_out);
300 x.add_var(&ff_out)
301 } else {
302 let self_attn_out = self.self_attn.attention(input, input, input, None);
303 let x = input.add_var(&self_attn_out);
304 let x = self.norm1.forward(&x);
305
306 let x_after_norm2 = self.norm2.forward(&x);
307 let ff_out = self.linear1.forward(&x_after_norm2).relu();
308 let ff_out = self.linear2.forward(&ff_out);
309 let x = x_after_norm2.add_var(&ff_out);
310 self.norm3.forward(&x)
311 }
312 }
313
314 fn parameters(&self) -> Vec<Parameter> {
315 let mut params = Vec::new();
316 params.extend(self.self_attn.parameters());
317 params.extend(self.cross_attn.parameters());
318 params.extend(self.linear1.parameters());
319 params.extend(self.linear2.parameters());
320 params.extend(self.norm1.parameters());
321 params.extend(self.norm2.parameters());
322 params.extend(self.norm3.parameters());
323 params
324 }
325
326 fn named_parameters(&self) -> HashMap<String, Parameter> {
327 let mut params = HashMap::new();
328 for (name, param) in self.self_attn.named_parameters() {
329 params.insert(format!("self_attn.{name}"), param);
330 }
331 for (name, param) in self.cross_attn.named_parameters() {
332 params.insert(format!("cross_attn.{name}"), param);
333 }
334 for (name, param) in self.linear1.named_parameters() {
335 params.insert(format!("linear1.{name}"), param);
336 }
337 for (name, param) in self.linear2.named_parameters() {
338 params.insert(format!("linear2.{name}"), param);
339 }
340 for (name, param) in self.norm1.named_parameters() {
341 params.insert(format!("norm1.{name}"), param);
342 }
343 for (name, param) in self.norm2.named_parameters() {
344 params.insert(format!("norm2.{name}"), param);
345 }
346 for (name, param) in self.norm3.named_parameters() {
347 params.insert(format!("norm3.{name}"), param);
348 }
349 params
350 }
351
352 fn name(&self) -> &'static str {
353 "TransformerDecoderLayer"
354 }
355}
356
357pub struct TransformerEncoder {
367 layers: Vec<TransformerEncoderLayer>,
369 norm: Option<LayerNorm>,
371}
372
373impl TransformerEncoder {
374 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
376 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
377 }
378
379 pub fn new_with_pre_norm(
384 d_model: usize,
385 nhead: usize,
386 dim_feedforward: usize,
387 num_layers: usize,
388 pre_norm: bool,
389 ) -> Self {
390 let layers = (0..num_layers)
391 .map(|_| {
392 TransformerEncoderLayer::new_with_pre_norm(
393 d_model,
394 nhead,
395 dim_feedforward,
396 pre_norm,
397 )
398 })
399 .collect();
400
401 Self {
402 layers,
403 norm: Some(LayerNorm::single(d_model)),
404 }
405 }
406
407 pub fn without_norm(
409 d_model: usize,
410 nhead: usize,
411 dim_feedforward: usize,
412 num_layers: usize,
413 ) -> Self {
414 let layers = (0..num_layers)
415 .map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward))
416 .collect();
417
418 Self { layers, norm: None }
419 }
420
421 pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
423 let mut x = src.clone();
424 for layer in &self.layers {
425 x = layer.forward_with_mask(&x, src_mask);
426 }
427 if let Some(ref norm) = self.norm {
428 x = norm.forward(&x);
429 }
430 x
431 }
432
433 pub fn num_layers(&self) -> usize {
435 self.layers.len()
436 }
437}
438
439impl Module for TransformerEncoder {
440 fn forward(&self, input: &Variable) -> Variable {
441 self.forward_with_mask(input, None)
442 }
443
444 fn parameters(&self) -> Vec<Parameter> {
445 let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
446 if let Some(ref norm) = self.norm {
447 params.extend(norm.parameters());
448 }
449 params
450 }
451
452 fn named_parameters(&self) -> HashMap<String, Parameter> {
453 let mut params = HashMap::new();
454 for (i, layer) in self.layers.iter().enumerate() {
455 for (name, param) in layer.named_parameters() {
456 params.insert(format!("layers.{i}.{name}"), param);
457 }
458 }
459 if let Some(ref norm) = self.norm {
460 for (name, param) in norm.named_parameters() {
461 params.insert(format!("norm.{name}"), param);
462 }
463 }
464 params
465 }
466
467 fn name(&self) -> &'static str {
468 "TransformerEncoder"
469 }
470}
471
472pub struct TransformerDecoder {
483 layers: Vec<TransformerDecoderLayer>,
485 norm: Option<LayerNorm>,
487}
488
489impl TransformerDecoder {
490 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
492 Self::new_with_pre_norm(d_model, nhead, dim_feedforward, num_layers, false)
493 }
494
495 pub fn new_with_pre_norm(
497 d_model: usize,
498 nhead: usize,
499 dim_feedforward: usize,
500 num_layers: usize,
501 pre_norm: bool,
502 ) -> Self {
503 let layers = (0..num_layers)
504 .map(|_| {
505 TransformerDecoderLayer::new_with_pre_norm(
506 d_model,
507 nhead,
508 dim_feedforward,
509 pre_norm,
510 )
511 })
512 .collect();
513
514 Self {
515 layers,
516 norm: Some(LayerNorm::single(d_model)),
517 }
518 }
519
520 pub fn without_norm(
522 d_model: usize,
523 nhead: usize,
524 dim_feedforward: usize,
525 num_layers: usize,
526 ) -> Self {
527 let layers = (0..num_layers)
528 .map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward))
529 .collect();
530
531 Self { layers, norm: None }
532 }
533
534 pub fn forward_with_memory(
536 &self,
537 tgt: &Variable,
538 memory: &Variable,
539 tgt_mask: Option<&Variable>,
540 memory_mask: Option<&Variable>,
541 ) -> Variable {
542 let mut x = tgt.clone();
543 for layer in &self.layers {
544 x = layer.forward_with_memory(&x, memory, tgt_mask, memory_mask);
545 }
546 if let Some(ref norm) = self.norm {
547 x = norm.forward(&x);
548 }
549 x
550 }
551
552 pub fn num_layers(&self) -> usize {
554 self.layers.len()
555 }
556}
557
558impl Module for TransformerDecoder {
559 fn forward(&self, input: &Variable) -> Variable {
560 let mut x = input.clone();
562 for layer in &self.layers {
563 x = layer.forward(&x);
564 }
565 if let Some(ref norm) = self.norm {
566 x = norm.forward(&x);
567 }
568 x
569 }
570
571 fn parameters(&self) -> Vec<Parameter> {
572 let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
573 if let Some(ref norm) = self.norm {
574 params.extend(norm.parameters());
575 }
576 params
577 }
578
579 fn named_parameters(&self) -> HashMap<String, Parameter> {
580 let mut params = HashMap::new();
581 for (i, layer) in self.layers.iter().enumerate() {
582 for (name, param) in layer.named_parameters() {
583 params.insert(format!("layers.{i}.{name}"), param);
584 }
585 }
586 if let Some(ref norm) = self.norm {
587 for (name, param) in norm.named_parameters() {
588 params.insert(format!("norm.{name}"), param);
589 }
590 }
591 params
592 }
593
594 fn name(&self) -> &'static str {
595 "TransformerDecoder"
596 }
597}
598
599pub struct Seq2SeqTransformer {
620 encoder: TransformerEncoder,
622 decoder: TransformerDecoder,
624 d_model: usize,
626 nhead: usize,
628}
629
630impl Seq2SeqTransformer {
631 pub fn new(
640 d_model: usize,
641 nhead: usize,
642 num_encoder_layers: usize,
643 num_decoder_layers: usize,
644 dim_feedforward: usize,
645 ) -> Self {
646 Self {
647 encoder: TransformerEncoder::new(d_model, nhead, dim_feedforward, num_encoder_layers),
648 decoder: TransformerDecoder::new(d_model, nhead, dim_feedforward, num_decoder_layers),
649 d_model,
650 nhead,
651 }
652 }
653
654 pub fn new_pre_norm(
659 d_model: usize,
660 nhead: usize,
661 num_encoder_layers: usize,
662 num_decoder_layers: usize,
663 dim_feedforward: usize,
664 ) -> Self {
665 Self {
666 encoder: TransformerEncoder::new_with_pre_norm(
667 d_model,
668 nhead,
669 dim_feedforward,
670 num_encoder_layers,
671 true,
672 ),
673 decoder: TransformerDecoder::new_with_pre_norm(
674 d_model,
675 nhead,
676 dim_feedforward,
677 num_decoder_layers,
678 true,
679 ),
680 d_model,
681 nhead,
682 }
683 }
684
685 pub fn default_config(d_model: usize, nhead: usize) -> Self {
687 Self::new(d_model, nhead, 6, 6, 2048)
688 }
689
690 pub fn forward_seq2seq(
699 &self,
700 src: &Variable,
701 tgt: &Variable,
702 src_mask: Option<&Variable>,
703 tgt_mask: Option<&Variable>,
704 memory_mask: Option<&Variable>,
705 ) -> Variable {
706 let memory = self.encoder.forward_with_mask(src, src_mask);
707 self.decoder
708 .forward_with_memory(tgt, &memory, tgt_mask, memory_mask)
709 }
710
711 pub fn encode(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
713 self.encoder.forward_with_mask(src, src_mask)
714 }
715
716 pub fn decode(
718 &self,
719 tgt: &Variable,
720 memory: &Variable,
721 tgt_mask: Option<&Variable>,
722 memory_mask: Option<&Variable>,
723 ) -> Variable {
724 self.decoder
725 .forward_with_memory(tgt, memory, tgt_mask, memory_mask)
726 }
727
728 pub fn generate_square_subsequent_mask(seq_len: usize) -> Variable {
733 let mut mask_data = vec![0.0f32; seq_len * seq_len];
734 for i in 0..seq_len {
735 for j in 0..=i {
736 mask_data[i * seq_len + j] = 1.0;
737 }
738 }
739 Variable::new(
740 Tensor::from_vec(mask_data, &[seq_len, seq_len]).expect("tensor creation failed"),
741 false,
742 )
743 }
744
745 pub fn d_model(&self) -> usize {
747 self.d_model
748 }
749
750 pub fn nhead(&self) -> usize {
752 self.nhead
753 }
754
755 pub fn encoder(&self) -> &TransformerEncoder {
757 &self.encoder
758 }
759
760 pub fn decoder(&self) -> &TransformerDecoder {
762 &self.decoder
763 }
764}
765
766impl Module for Seq2SeqTransformer {
767 fn forward(&self, input: &Variable) -> Variable {
772 self.encoder.forward(input)
773 }
774
775 fn parameters(&self) -> Vec<Parameter> {
776 let mut params = self.encoder.parameters();
777 params.extend(self.decoder.parameters());
778 params
779 }
780
781 fn named_parameters(&self) -> HashMap<String, Parameter> {
782 let mut params = HashMap::new();
783 for (name, param) in self.encoder.named_parameters() {
784 params.insert(format!("encoder.{name}"), param);
785 }
786 for (name, param) in self.decoder.named_parameters() {
787 params.insert(format!("decoder.{name}"), param);
788 }
789 params
790 }
791
792 fn name(&self) -> &'static str {
793 "Seq2SeqTransformer"
794 }
795}
796
797#[cfg(test)]
802mod tests {
803 use super::*;
804
805 #[test]
806 fn test_encoder_layer_creation() {
807 let layer = TransformerEncoderLayer::new(64, 4, 256);
808 assert_eq!(layer.d_model(), 64);
809 }
810
811 #[test]
812 fn test_encoder_layer_forward() {
813 let layer = TransformerEncoderLayer::new(64, 4, 256);
814 let input = Variable::new(
815 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
816 false,
817 );
818 let output = layer.forward(&input);
819 assert_eq!(output.shape(), vec![2, 10, 64]);
820 }
821
822 #[test]
823 fn test_decoder_layer_with_memory() {
824 let layer = TransformerDecoderLayer::new(64, 4, 256);
825 let tgt = Variable::new(
826 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
827 false,
828 );
829 let memory = Variable::new(
830 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
831 false,
832 );
833 let output = layer.forward_with_memory(&tgt, &memory, None, None);
834 assert_eq!(output.shape(), vec![2, 5, 64]);
835 }
836
837 #[test]
838 fn test_encoder_stack() {
839 let encoder = TransformerEncoder::new(64, 4, 256, 3);
840 assert_eq!(encoder.num_layers(), 3);
841
842 let input = Variable::new(
843 Tensor::from_vec(vec![0.1; 2 * 8 * 64], &[2, 8, 64]).expect("tensor creation failed"),
844 false,
845 );
846 let output = encoder.forward(&input);
847 assert_eq!(output.shape(), vec![2, 8, 64]);
848 }
849
850 #[test]
851 fn test_decoder_stack() {
852 let decoder = TransformerDecoder::new(64, 4, 256, 3);
853 assert_eq!(decoder.num_layers(), 3);
854
855 let tgt = Variable::new(
856 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
857 false,
858 );
859 let memory = Variable::new(
860 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
861 false,
862 );
863 let output = decoder.forward_with_memory(&tgt, &memory, None, None);
864 assert_eq!(output.shape(), vec![2, 5, 64]);
865 }
866
867 #[test]
868 fn test_seq2seq_transformer() {
869 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
870 assert_eq!(transformer.d_model(), 64);
871 assert_eq!(transformer.nhead(), 4);
872
873 let src = Variable::new(
874 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
875 false,
876 );
877 let tgt = Variable::new(
878 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
879 false,
880 );
881 let output = transformer.forward_seq2seq(&src, &tgt, None, None, None);
882 assert_eq!(output.shape(), vec![2, 5, 64]);
883 }
884
885 #[test]
886 fn test_seq2seq_encode_decode_separate() {
887 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
888
889 let src = Variable::new(
890 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
891 false,
892 );
893 let tgt = Variable::new(
894 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
895 false,
896 );
897
898 let memory = transformer.encode(&src, None);
900 assert_eq!(memory.shape(), vec![2, 10, 64]);
901
902 let output = transformer.decode(&tgt, &memory, None, None);
903 assert_eq!(output.shape(), vec![2, 5, 64]);
904 }
905
906 #[test]
907 fn test_causal_mask() {
908 let mask = Seq2SeqTransformer::generate_square_subsequent_mask(4);
909 let mask_data = mask.data().to_vec();
910 assert_eq!(mask_data[0], 1.0); assert_eq!(mask_data[1], 0.0); assert_eq!(mask_data[4], 1.0); assert_eq!(mask_data[5], 1.0); assert_eq!(mask_data[6], 0.0); assert_eq!(mask_data[15], 1.0); }
921
922 #[test]
923 fn test_default_config() {
924 let transformer = Seq2SeqTransformer::default_config(512, 8);
925 assert_eq!(transformer.encoder().num_layers(), 6);
926 assert_eq!(transformer.decoder().num_layers(), 6);
927 }
928
929 #[test]
930 fn test_parameter_count() {
931 let layer = TransformerEncoderLayer::new(64, 4, 256);
932 let params = layer.parameters();
933 assert_eq!(params.len(), 16);
939 }
940
941 #[test]
942 fn test_decoder_parameter_count() {
943 let layer = TransformerDecoderLayer::new(64, 4, 256);
944 let params = layer.parameters();
945 assert_eq!(params.len(), 26);
947 }
948
949 #[test]
950 fn test_named_parameters_hierarchy() {
951 let transformer = Seq2SeqTransformer::new(64, 4, 1, 1, 256);
952 let named = transformer.named_parameters();
953 assert!(named.contains_key("encoder.layers.0.self_attn.q_proj.weight"));
955 assert!(named.contains_key("decoder.layers.0.cross_attn.q_proj.weight"));
956 assert!(named.contains_key("encoder.norm.weight"));
957 assert!(named.contains_key("decoder.norm.weight"));
958 }
959
960 #[test]
961 fn test_seq2seq_with_causal_mask() {
962 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
963 let src = Variable::new(
964 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).expect("tensor creation failed"),
965 false,
966 );
967 let tgt = Variable::new(
968 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
969 false,
970 );
971 let tgt_mask = Seq2SeqTransformer::generate_square_subsequent_mask(5);
972 let output = transformer.forward_seq2seq(&src, &tgt, None, Some(&tgt_mask), None);
973 assert_eq!(output.shape(), vec![2, 5, 64]);
974 }
975}