1use std::collections::HashMap;
14
15use axonml_autograd::Variable;
16use axonml_tensor::Tensor;
17
18use crate::layers::attention::MultiHeadAttention;
19use crate::layers::linear::Linear;
20use crate::layers::norm::LayerNorm;
21use crate::module::Module;
22use crate::parameter::Parameter;
23
24pub struct TransformerEncoderLayer {
39 self_attn: MultiHeadAttention,
41 linear1: Linear,
43 linear2: Linear,
45 norm1: LayerNorm,
47 norm2: LayerNorm,
49 d_model: usize,
51}
52
53impl TransformerEncoderLayer {
54 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
61 Self {
62 self_attn: MultiHeadAttention::new(d_model, nhead),
63 linear1: Linear::new(d_model, dim_feedforward),
64 linear2: Linear::new(dim_feedforward, d_model),
65 norm1: LayerNorm::single(d_model),
66 norm2: LayerNorm::single(d_model),
67 d_model,
68 }
69 }
70
71 pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
77 let attn_out = self.self_attn.attention(src, src, src, src_mask);
79 let x = src.add_var(&attn_out);
80 let x = self.norm1.forward(&x);
81
82 let ff_out = self.linear1.forward(&x).relu();
84 let ff_out = self.linear2.forward(&ff_out);
85 let x = x.add_var(&ff_out);
86 self.norm2.forward(&x)
87 }
88
89 pub fn d_model(&self) -> usize {
91 self.d_model
92 }
93}
94
95impl Module for TransformerEncoderLayer {
96 fn forward(&self, input: &Variable) -> Variable {
97 self.forward_with_mask(input, None)
98 }
99
100 fn parameters(&self) -> Vec<Parameter> {
101 let mut params = Vec::new();
102 params.extend(self.self_attn.parameters());
103 params.extend(self.linear1.parameters());
104 params.extend(self.linear2.parameters());
105 params.extend(self.norm1.parameters());
106 params.extend(self.norm2.parameters());
107 params
108 }
109
110 fn named_parameters(&self) -> HashMap<String, Parameter> {
111 let mut params = HashMap::new();
112 for (name, param) in self.self_attn.named_parameters() {
113 params.insert(format!("self_attn.{name}"), param);
114 }
115 for (name, param) in self.linear1.named_parameters() {
116 params.insert(format!("linear1.{name}"), param);
117 }
118 for (name, param) in self.linear2.named_parameters() {
119 params.insert(format!("linear2.{name}"), param);
120 }
121 for (name, param) in self.norm1.named_parameters() {
122 params.insert(format!("norm1.{name}"), param);
123 }
124 for (name, param) in self.norm2.named_parameters() {
125 params.insert(format!("norm2.{name}"), param);
126 }
127 params
128 }
129
130 fn name(&self) -> &'static str {
131 "TransformerEncoderLayer"
132 }
133}
134
135pub struct TransformerDecoderLayer {
153 self_attn: MultiHeadAttention,
155 cross_attn: MultiHeadAttention,
157 linear1: Linear,
159 linear2: Linear,
161 norm1: LayerNorm,
163 norm2: LayerNorm,
165 norm3: LayerNorm,
167 d_model: usize,
169}
170
171impl TransformerDecoderLayer {
172 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize) -> Self {
179 Self {
180 self_attn: MultiHeadAttention::new(d_model, nhead),
181 cross_attn: MultiHeadAttention::new(d_model, nhead),
182 linear1: Linear::new(d_model, dim_feedforward),
183 linear2: Linear::new(dim_feedforward, d_model),
184 norm1: LayerNorm::single(d_model),
185 norm2: LayerNorm::single(d_model),
186 norm3: LayerNorm::single(d_model),
187 d_model,
188 }
189 }
190
191 pub fn forward_with_memory(
199 &self,
200 tgt: &Variable,
201 memory: &Variable,
202 tgt_mask: Option<&Variable>,
203 memory_mask: Option<&Variable>,
204 ) -> Variable {
205 let self_attn_out = self.self_attn.attention(tgt, tgt, tgt, tgt_mask);
207 let x = tgt.add_var(&self_attn_out);
208 let x = self.norm1.forward(&x);
209
210 let cross_attn_out = self.cross_attn.attention(&x, memory, memory, memory_mask);
212 let x = x.add_var(&cross_attn_out);
213 let x = self.norm2.forward(&x);
214
215 let ff_out = self.linear1.forward(&x).relu();
217 let ff_out = self.linear2.forward(&ff_out);
218 let x = x.add_var(&ff_out);
219 self.norm3.forward(&x)
220 }
221
222 pub fn d_model(&self) -> usize {
224 self.d_model
225 }
226}
227
228impl Module for TransformerDecoderLayer {
229 fn forward(&self, input: &Variable) -> Variable {
230 let self_attn_out = self.self_attn.attention(input, input, input, None);
233 let x = input.add_var(&self_attn_out);
234 let x = self.norm1.forward(&x);
235
236 let x_after_norm2 = self.norm2.forward(&x);
238 let ff_out = self.linear1.forward(&x_after_norm2).relu();
239 let ff_out = self.linear2.forward(&ff_out);
240 let x = x_after_norm2.add_var(&ff_out);
241 self.norm3.forward(&x)
242 }
243
244 fn parameters(&self) -> Vec<Parameter> {
245 let mut params = Vec::new();
246 params.extend(self.self_attn.parameters());
247 params.extend(self.cross_attn.parameters());
248 params.extend(self.linear1.parameters());
249 params.extend(self.linear2.parameters());
250 params.extend(self.norm1.parameters());
251 params.extend(self.norm2.parameters());
252 params.extend(self.norm3.parameters());
253 params
254 }
255
256 fn named_parameters(&self) -> HashMap<String, Parameter> {
257 let mut params = HashMap::new();
258 for (name, param) in self.self_attn.named_parameters() {
259 params.insert(format!("self_attn.{name}"), param);
260 }
261 for (name, param) in self.cross_attn.named_parameters() {
262 params.insert(format!("cross_attn.{name}"), param);
263 }
264 for (name, param) in self.linear1.named_parameters() {
265 params.insert(format!("linear1.{name}"), param);
266 }
267 for (name, param) in self.linear2.named_parameters() {
268 params.insert(format!("linear2.{name}"), param);
269 }
270 for (name, param) in self.norm1.named_parameters() {
271 params.insert(format!("norm1.{name}"), param);
272 }
273 for (name, param) in self.norm2.named_parameters() {
274 params.insert(format!("norm2.{name}"), param);
275 }
276 for (name, param) in self.norm3.named_parameters() {
277 params.insert(format!("norm3.{name}"), param);
278 }
279 params
280 }
281
282 fn name(&self) -> &'static str {
283 "TransformerDecoderLayer"
284 }
285}
286
287pub struct TransformerEncoder {
297 layers: Vec<TransformerEncoderLayer>,
299 norm: Option<LayerNorm>,
301}
302
303impl TransformerEncoder {
304 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
306 let layers = (0..num_layers)
307 .map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward))
308 .collect();
309
310 Self {
311 layers,
312 norm: Some(LayerNorm::single(d_model)),
313 }
314 }
315
316 pub fn without_norm(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
318 let layers = (0..num_layers)
319 .map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward))
320 .collect();
321
322 Self { layers, norm: None }
323 }
324
325 pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
327 let mut x = src.clone();
328 for layer in &self.layers {
329 x = layer.forward_with_mask(&x, src_mask);
330 }
331 if let Some(ref norm) = self.norm {
332 x = norm.forward(&x);
333 }
334 x
335 }
336
337 pub fn num_layers(&self) -> usize {
339 self.layers.len()
340 }
341}
342
343impl Module for TransformerEncoder {
344 fn forward(&self, input: &Variable) -> Variable {
345 self.forward_with_mask(input, None)
346 }
347
348 fn parameters(&self) -> Vec<Parameter> {
349 let mut params: Vec<Parameter> = self.layers.iter()
350 .flat_map(|l| l.parameters())
351 .collect();
352 if let Some(ref norm) = self.norm {
353 params.extend(norm.parameters());
354 }
355 params
356 }
357
358 fn named_parameters(&self) -> HashMap<String, Parameter> {
359 let mut params = HashMap::new();
360 for (i, layer) in self.layers.iter().enumerate() {
361 for (name, param) in layer.named_parameters() {
362 params.insert(format!("layers.{i}.{name}"), param);
363 }
364 }
365 if let Some(ref norm) = self.norm {
366 for (name, param) in norm.named_parameters() {
367 params.insert(format!("norm.{name}"), param);
368 }
369 }
370 params
371 }
372
373 fn name(&self) -> &'static str {
374 "TransformerEncoder"
375 }
376}
377
378pub struct TransformerDecoder {
389 layers: Vec<TransformerDecoderLayer>,
391 norm: Option<LayerNorm>,
393}
394
395impl TransformerDecoder {
396 pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
398 let layers = (0..num_layers)
399 .map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward))
400 .collect();
401
402 Self {
403 layers,
404 norm: Some(LayerNorm::single(d_model)),
405 }
406 }
407
408 pub fn without_norm(d_model: usize, nhead: usize, dim_feedforward: usize, num_layers: usize) -> Self {
410 let layers = (0..num_layers)
411 .map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward))
412 .collect();
413
414 Self { layers, norm: None }
415 }
416
417 pub fn forward_with_memory(
419 &self,
420 tgt: &Variable,
421 memory: &Variable,
422 tgt_mask: Option<&Variable>,
423 memory_mask: Option<&Variable>,
424 ) -> Variable {
425 let mut x = tgt.clone();
426 for layer in &self.layers {
427 x = layer.forward_with_memory(&x, memory, tgt_mask, memory_mask);
428 }
429 if let Some(ref norm) = self.norm {
430 x = norm.forward(&x);
431 }
432 x
433 }
434
435 pub fn num_layers(&self) -> usize {
437 self.layers.len()
438 }
439}
440
441impl Module for TransformerDecoder {
442 fn forward(&self, input: &Variable) -> Variable {
443 let mut x = input.clone();
445 for layer in &self.layers {
446 x = layer.forward(&x);
447 }
448 if let Some(ref norm) = self.norm {
449 x = norm.forward(&x);
450 }
451 x
452 }
453
454 fn parameters(&self) -> Vec<Parameter> {
455 let mut params: Vec<Parameter> = self.layers.iter()
456 .flat_map(|l| l.parameters())
457 .collect();
458 if let Some(ref norm) = self.norm {
459 params.extend(norm.parameters());
460 }
461 params
462 }
463
464 fn named_parameters(&self) -> HashMap<String, Parameter> {
465 let mut params = HashMap::new();
466 for (i, layer) in self.layers.iter().enumerate() {
467 for (name, param) in layer.named_parameters() {
468 params.insert(format!("layers.{i}.{name}"), param);
469 }
470 }
471 if let Some(ref norm) = self.norm {
472 for (name, param) in norm.named_parameters() {
473 params.insert(format!("norm.{name}"), param);
474 }
475 }
476 params
477 }
478
479 fn name(&self) -> &'static str {
480 "TransformerDecoder"
481 }
482}
483
484pub struct Seq2SeqTransformer {
505 encoder: TransformerEncoder,
507 decoder: TransformerDecoder,
509 d_model: usize,
511 nhead: usize,
513}
514
515impl Seq2SeqTransformer {
516 pub fn new(
525 d_model: usize,
526 nhead: usize,
527 num_encoder_layers: usize,
528 num_decoder_layers: usize,
529 dim_feedforward: usize,
530 ) -> Self {
531 Self {
532 encoder: TransformerEncoder::new(d_model, nhead, dim_feedforward, num_encoder_layers),
533 decoder: TransformerDecoder::new(d_model, nhead, dim_feedforward, num_decoder_layers),
534 d_model,
535 nhead,
536 }
537 }
538
539 pub fn default_config(d_model: usize, nhead: usize) -> Self {
541 Self::new(d_model, nhead, 6, 6, 2048)
542 }
543
544 pub fn forward_seq2seq(
553 &self,
554 src: &Variable,
555 tgt: &Variable,
556 src_mask: Option<&Variable>,
557 tgt_mask: Option<&Variable>,
558 memory_mask: Option<&Variable>,
559 ) -> Variable {
560 let memory = self.encoder.forward_with_mask(src, src_mask);
561 self.decoder.forward_with_memory(tgt, &memory, tgt_mask, memory_mask)
562 }
563
564 pub fn encode(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
566 self.encoder.forward_with_mask(src, src_mask)
567 }
568
569 pub fn decode(
571 &self,
572 tgt: &Variable,
573 memory: &Variable,
574 tgt_mask: Option<&Variable>,
575 memory_mask: Option<&Variable>,
576 ) -> Variable {
577 self.decoder.forward_with_memory(tgt, memory, tgt_mask, memory_mask)
578 }
579
580 pub fn generate_square_subsequent_mask(seq_len: usize) -> Variable {
585 let mut mask_data = vec![0.0f32; seq_len * seq_len];
586 for i in 0..seq_len {
587 for j in 0..=i {
588 mask_data[i * seq_len + j] = 1.0;
589 }
590 }
591 Variable::new(
592 Tensor::from_vec(mask_data, &[seq_len, seq_len]).unwrap(),
593 false,
594 )
595 }
596
597 pub fn d_model(&self) -> usize {
599 self.d_model
600 }
601
602 pub fn nhead(&self) -> usize {
604 self.nhead
605 }
606
607 pub fn encoder(&self) -> &TransformerEncoder {
609 &self.encoder
610 }
611
612 pub fn decoder(&self) -> &TransformerDecoder {
614 &self.decoder
615 }
616}
617
618impl Module for Seq2SeqTransformer {
619 fn forward(&self, input: &Variable) -> Variable {
620 self.encoder.forward(input)
622 }
623
624 fn parameters(&self) -> Vec<Parameter> {
625 let mut params = self.encoder.parameters();
626 params.extend(self.decoder.parameters());
627 params
628 }
629
630 fn named_parameters(&self) -> HashMap<String, Parameter> {
631 let mut params = HashMap::new();
632 for (name, param) in self.encoder.named_parameters() {
633 params.insert(format!("encoder.{name}"), param);
634 }
635 for (name, param) in self.decoder.named_parameters() {
636 params.insert(format!("decoder.{name}"), param);
637 }
638 params
639 }
640
641 fn name(&self) -> &'static str {
642 "Seq2SeqTransformer"
643 }
644}
645
646#[cfg(test)]
651mod tests {
652 use super::*;
653
654 #[test]
655 fn test_encoder_layer_creation() {
656 let layer = TransformerEncoderLayer::new(64, 4, 256);
657 assert_eq!(layer.d_model(), 64);
658 }
659
660 #[test]
661 fn test_encoder_layer_forward() {
662 let layer = TransformerEncoderLayer::new(64, 4, 256);
663 let input = Variable::new(
664 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
665 false,
666 );
667 let output = layer.forward(&input);
668 assert_eq!(output.shape(), vec![2, 10, 64]);
669 }
670
671 #[test]
672 fn test_decoder_layer_with_memory() {
673 let layer = TransformerDecoderLayer::new(64, 4, 256);
674 let tgt = Variable::new(
675 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
676 false,
677 );
678 let memory = Variable::new(
679 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
680 false,
681 );
682 let output = layer.forward_with_memory(&tgt, &memory, None, None);
683 assert_eq!(output.shape(), vec![2, 5, 64]);
684 }
685
686 #[test]
687 fn test_encoder_stack() {
688 let encoder = TransformerEncoder::new(64, 4, 256, 3);
689 assert_eq!(encoder.num_layers(), 3);
690
691 let input = Variable::new(
692 Tensor::from_vec(vec![0.1; 2 * 8 * 64], &[2, 8, 64]).unwrap(),
693 false,
694 );
695 let output = encoder.forward(&input);
696 assert_eq!(output.shape(), vec![2, 8, 64]);
697 }
698
699 #[test]
700 fn test_decoder_stack() {
701 let decoder = TransformerDecoder::new(64, 4, 256, 3);
702 assert_eq!(decoder.num_layers(), 3);
703
704 let tgt = Variable::new(
705 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
706 false,
707 );
708 let memory = Variable::new(
709 Tensor::from_vec(vec![0.2; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
710 false,
711 );
712 let output = decoder.forward_with_memory(&tgt, &memory, None, None);
713 assert_eq!(output.shape(), vec![2, 5, 64]);
714 }
715
716 #[test]
717 fn test_seq2seq_transformer() {
718 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
719 assert_eq!(transformer.d_model(), 64);
720 assert_eq!(transformer.nhead(), 4);
721
722 let src = Variable::new(
723 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
724 false,
725 );
726 let tgt = Variable::new(
727 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
728 false,
729 );
730 let output = transformer.forward_seq2seq(&src, &tgt, None, None, None);
731 assert_eq!(output.shape(), vec![2, 5, 64]);
732 }
733
734 #[test]
735 fn test_seq2seq_encode_decode_separate() {
736 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
737
738 let src = Variable::new(
739 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
740 false,
741 );
742 let tgt = Variable::new(
743 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
744 false,
745 );
746
747 let memory = transformer.encode(&src, None);
749 assert_eq!(memory.shape(), vec![2, 10, 64]);
750
751 let output = transformer.decode(&tgt, &memory, None, None);
752 assert_eq!(output.shape(), vec![2, 5, 64]);
753 }
754
755 #[test]
756 fn test_causal_mask() {
757 let mask = Seq2SeqTransformer::generate_square_subsequent_mask(4);
758 let mask_data = mask.data().to_vec();
759 assert_eq!(mask_data[0], 1.0); assert_eq!(mask_data[1], 0.0); assert_eq!(mask_data[4], 1.0); assert_eq!(mask_data[5], 1.0); assert_eq!(mask_data[6], 0.0); assert_eq!(mask_data[15], 1.0); }
770
771 #[test]
772 fn test_default_config() {
773 let transformer = Seq2SeqTransformer::default_config(512, 8);
774 assert_eq!(transformer.encoder().num_layers(), 6);
775 assert_eq!(transformer.decoder().num_layers(), 6);
776 }
777
778 #[test]
779 fn test_parameter_count() {
780 let layer = TransformerEncoderLayer::new(64, 4, 256);
781 let params = layer.parameters();
782 assert_eq!(params.len(), 16);
788 }
789
790 #[test]
791 fn test_decoder_parameter_count() {
792 let layer = TransformerDecoderLayer::new(64, 4, 256);
793 let params = layer.parameters();
794 assert_eq!(params.len(), 26);
796 }
797
798 #[test]
799 fn test_named_parameters_hierarchy() {
800 let transformer = Seq2SeqTransformer::new(64, 4, 1, 1, 256);
801 let named = transformer.named_parameters();
802 assert!(named.contains_key("encoder.layers.0.self_attn.q_proj.weight"));
804 assert!(named.contains_key("decoder.layers.0.cross_attn.q_proj.weight"));
805 assert!(named.contains_key("encoder.norm.weight"));
806 assert!(named.contains_key("decoder.norm.weight"));
807 }
808
809 #[test]
810 fn test_seq2seq_with_causal_mask() {
811 let transformer = Seq2SeqTransformer::new(64, 4, 2, 2, 256);
812 let src = Variable::new(
813 Tensor::from_vec(vec![0.1; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
814 false,
815 );
816 let tgt = Variable::new(
817 Tensor::from_vec(vec![0.2; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
818 false,
819 );
820 let tgt_mask = Seq2SeqTransformer::generate_square_subsequent_mask(5);
821 let output = transformer.forward_seq2seq(&src, &tgt, None, Some(&tgt_mask), None);
822 assert_eq!(output.shape(), vec![2, 5, 64]);
823 }
824}