1use axonml_autograd::Variable;
21use axonml_nn::{Dropout, LayerNorm, Linear, Module, MultiHeadAttention, Parameter};
22use axonml_tensor::Tensor;
23
24pub struct PositionalEncoding {
30 encoding: Tensor<f32>,
31 max_len: usize,
32 d_model: usize,
33}
34
35impl PositionalEncoding {
36 #[must_use] pub fn new(d_model: usize, max_len: usize) -> Self {
38 let mut pe = vec![0.0f32; max_len * d_model];
39
40 for pos in 0..max_len {
41 for i in 0..d_model {
42 let div_term = (-(i as f32 / d_model as f32) * (10000.0f32).ln()).exp();
43 if i % 2 == 0 {
44 pe[pos * d_model + i] = (pos as f32 * div_term).sin();
45 } else {
46 pe[pos * d_model + i] = (pos as f32 * div_term).cos();
47 }
48 }
49 }
50
51 Self {
52 encoding: Tensor::from_vec(pe, &[max_len, d_model]).unwrap(),
53 max_len,
54 d_model,
55 }
56 }
57
58 #[must_use] pub fn forward(&self, x: &Variable) -> Variable {
60 let shape = x.shape();
61 let seq_len = shape[1];
62 let x_data = x.data().to_vec();
63 let pe_data = self.encoding.to_vec();
64
65 let batch_size = shape[0];
67 let mut result = x_data.clone();
68
69 for b in 0..batch_size {
70 for s in 0..seq_len.min(self.max_len) {
71 for d in 0..self.d_model {
72 let idx = b * seq_len * self.d_model + s * self.d_model + d;
73 result[idx] += pe_data[s * self.d_model + d];
74 }
75 }
76 }
77
78 Variable::new(Tensor::from_vec(result, &shape).unwrap(), x.requires_grad())
79 }
80}
81
82pub struct TransformerEncoderLayer {
88 self_attn: MultiHeadAttention,
89 ff_linear1: Linear,
90 ff_linear2: Linear,
91 norm1: LayerNorm,
92 norm2: LayerNorm,
93 dropout: Dropout,
94 d_model: usize,
95}
96
97impl TransformerEncoderLayer {
98 #[must_use] pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, dropout: f32) -> Self {
100 Self {
101 self_attn: MultiHeadAttention::with_options(d_model, nhead, dropout, true),
102 ff_linear1: Linear::new(d_model, dim_feedforward),
103 ff_linear2: Linear::new(dim_feedforward, d_model),
104 norm1: LayerNorm::new(vec![d_model]),
105 norm2: LayerNorm::new(vec![d_model]),
106 dropout: Dropout::new(dropout),
107 d_model,
108 }
109 }
110
111 pub fn d_model(&self) -> usize {
113 self.d_model
114 }
115
116 pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
118 let attn_out = self.self_attn.attention(src, src, src, src_mask);
120 let attn_out = self.dropout.forward(&attn_out);
121 let src = src.add_var(&attn_out);
122 let src = self.norm1.forward(&src);
123
124 let ff_out = self.ff_linear1.forward(&src);
126 let ff_out = ff_out.relu();
127 let ff_out = self.dropout.forward(&ff_out);
128 let ff_out = self.ff_linear2.forward(&ff_out);
129 let ff_out = self.dropout.forward(&ff_out);
130 let src = src.add_var(&ff_out);
131
132 self.norm2.forward(&src)
133 }
134}
135
136impl Module for TransformerEncoderLayer {
137 fn forward(&self, input: &Variable) -> Variable {
138 self.forward_with_mask(input, None)
139 }
140
141 fn parameters(&self) -> Vec<Parameter> {
142 let mut params = Vec::new();
143 params.extend(self.self_attn.parameters());
144 params.extend(self.ff_linear1.parameters());
145 params.extend(self.ff_linear2.parameters());
146 params.extend(self.norm1.parameters());
147 params.extend(self.norm2.parameters());
148 params
149 }
150
151 fn train(&mut self) {
152 self.dropout.train();
153 }
154
155 fn eval(&mut self) {
156 self.dropout.eval();
157 }
158
159 fn is_training(&self) -> bool {
160 self.dropout.is_training()
161 }
162}
163
164pub struct TransformerEncoder {
170 layers: Vec<TransformerEncoderLayer>,
171 norm: Option<LayerNorm>,
172}
173
174impl TransformerEncoder {
175 #[must_use] pub fn new(
177 d_model: usize,
178 nhead: usize,
179 num_layers: usize,
180 dim_feedforward: usize,
181 dropout: f32,
182 ) -> Self {
183 let layers = (0..num_layers)
184 .map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward, dropout))
185 .collect();
186
187 Self {
188 layers,
189 norm: Some(LayerNorm::new(vec![d_model])),
190 }
191 }
192
193 #[must_use] pub fn forward_with_mask(&self, src: &Variable, mask: Option<&Variable>) -> Variable {
195 let mut output = src.clone();
196 for layer in &self.layers {
197 output = layer.forward_with_mask(&output, mask);
198 }
199 if let Some(norm) = &self.norm {
200 output = norm.forward(&output);
201 }
202 output
203 }
204}
205
206impl Module for TransformerEncoder {
207 fn forward(&self, input: &Variable) -> Variable {
208 self.forward_with_mask(input, None)
209 }
210
211 fn parameters(&self) -> Vec<Parameter> {
212 let mut params = Vec::new();
213 for layer in &self.layers {
214 params.extend(layer.parameters());
215 }
216 if let Some(norm) = &self.norm {
217 params.extend(norm.parameters());
218 }
219 params
220 }
221
222 fn train(&mut self) {
223 for layer in &mut self.layers {
224 layer.train();
225 }
226 }
227
228 fn eval(&mut self) {
229 for layer in &mut self.layers {
230 layer.eval();
231 }
232 }
233
234 fn is_training(&self) -> bool {
235 self.layers.first().map_or(true, axonml_nn::Module::is_training)
236 }
237}
238
239pub struct TransformerDecoderLayer {
245 self_attn: MultiHeadAttention,
246 cross_attn: MultiHeadAttention,
247 ff_linear1: Linear,
248 ff_linear2: Linear,
249 norm1: LayerNorm,
250 norm2: LayerNorm,
251 norm3: LayerNorm,
252 dropout: Dropout,
253}
254
255impl TransformerDecoderLayer {
256 #[must_use] pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, dropout: f32) -> Self {
258 Self {
259 self_attn: MultiHeadAttention::with_options(d_model, nhead, dropout, true),
260 cross_attn: MultiHeadAttention::with_options(d_model, nhead, dropout, true),
261 ff_linear1: Linear::new(d_model, dim_feedforward),
262 ff_linear2: Linear::new(dim_feedforward, d_model),
263 norm1: LayerNorm::new(vec![d_model]),
264 norm2: LayerNorm::new(vec![d_model]),
265 norm3: LayerNorm::new(vec![d_model]),
266 dropout: Dropout::new(dropout),
267 }
268 }
269
270 pub fn forward_with_memory(
272 &self,
273 tgt: &Variable,
274 memory: &Variable,
275 tgt_mask: Option<&Variable>,
276 memory_mask: Option<&Variable>,
277 ) -> Variable {
278 let attn_out = self.self_attn.attention(tgt, tgt, tgt, tgt_mask);
280 let attn_out = self.dropout.forward(&attn_out);
281 let tgt = tgt.add_var(&attn_out);
282 let tgt = self.norm1.forward(&tgt);
283
284 let cross_out = self.cross_attn.attention(&tgt, memory, memory, memory_mask);
286 let cross_out = self.dropout.forward(&cross_out);
287 let tgt = tgt.add_var(&cross_out);
288 let tgt = self.norm2.forward(&tgt);
289
290 let ff_out = self.ff_linear1.forward(&tgt);
292 let ff_out = ff_out.relu();
293 let ff_out = self.dropout.forward(&ff_out);
294 let ff_out = self.ff_linear2.forward(&ff_out);
295 let ff_out = self.dropout.forward(&ff_out);
296 let tgt = tgt.add_var(&ff_out);
297
298 self.norm3.forward(&tgt)
299 }
300}
301
302impl Module for TransformerDecoderLayer {
303 fn forward(&self, input: &Variable) -> Variable {
304 self.self_attn.forward(input)
306 }
307
308 fn parameters(&self) -> Vec<Parameter> {
309 let mut params = Vec::new();
310 params.extend(self.self_attn.parameters());
311 params.extend(self.cross_attn.parameters());
312 params.extend(self.ff_linear1.parameters());
313 params.extend(self.ff_linear2.parameters());
314 params.extend(self.norm1.parameters());
315 params.extend(self.norm2.parameters());
316 params.extend(self.norm3.parameters());
317 params
318 }
319
320 fn train(&mut self) {
321 self.dropout.train();
322 }
323
324 fn eval(&mut self) {
325 self.dropout.eval();
326 }
327
328 fn is_training(&self) -> bool {
329 self.dropout.is_training()
330 }
331}
332
333pub struct TransformerDecoder {
339 layers: Vec<TransformerDecoderLayer>,
340 norm: Option<LayerNorm>,
341}
342
343impl TransformerDecoder {
344 #[must_use] pub fn new(
346 d_model: usize,
347 nhead: usize,
348 num_layers: usize,
349 dim_feedforward: usize,
350 dropout: f32,
351 ) -> Self {
352 let layers = (0..num_layers)
353 .map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward, dropout))
354 .collect();
355
356 Self {
357 layers,
358 norm: Some(LayerNorm::new(vec![d_model])),
359 }
360 }
361
362 #[must_use] pub fn forward_with_memory(
364 &self,
365 tgt: &Variable,
366 memory: &Variable,
367 tgt_mask: Option<&Variable>,
368 memory_mask: Option<&Variable>,
369 ) -> Variable {
370 let mut output = tgt.clone();
371 for layer in &self.layers {
372 output = layer.forward_with_memory(&output, memory, tgt_mask, memory_mask);
373 }
374 if let Some(norm) = &self.norm {
375 output = norm.forward(&output);
376 }
377 output
378 }
379}
380
381impl Module for TransformerDecoder {
382 fn forward(&self, input: &Variable) -> Variable {
383 let mut output = input.clone();
384 for layer in &self.layers {
385 output = layer.forward(&output);
386 }
387 if let Some(norm) = &self.norm {
388 output = norm.forward(&output);
389 }
390 output
391 }
392
393 fn parameters(&self) -> Vec<Parameter> {
394 let mut params = Vec::new();
395 for layer in &self.layers {
396 params.extend(layer.parameters());
397 }
398 if let Some(norm) = &self.norm {
399 params.extend(norm.parameters());
400 }
401 params
402 }
403
404 fn train(&mut self) {
405 for layer in &mut self.layers {
406 layer.train();
407 }
408 }
409
410 fn eval(&mut self) {
411 for layer in &mut self.layers {
412 layer.eval();
413 }
414 }
415
416 fn is_training(&self) -> bool {
417 self.layers.first().map_or(true, axonml_nn::Module::is_training)
418 }
419}
420
421pub struct Transformer {
427 encoder: TransformerEncoder,
428 decoder: TransformerDecoder,
429 d_model: usize,
430}
431
432impl Transformer {
433 #[must_use] pub fn new(
435 d_model: usize,
436 nhead: usize,
437 num_encoder_layers: usize,
438 num_decoder_layers: usize,
439 dim_feedforward: usize,
440 dropout: f32,
441 ) -> Self {
442 Self {
443 encoder: TransformerEncoder::new(
444 d_model,
445 nhead,
446 num_encoder_layers,
447 dim_feedforward,
448 dropout,
449 ),
450 decoder: TransformerDecoder::new(
451 d_model,
452 nhead,
453 num_decoder_layers,
454 dim_feedforward,
455 dropout,
456 ),
457 d_model,
458 }
459 }
460
461 #[must_use] pub fn d_model(&self) -> usize {
463 self.d_model
464 }
465
466 #[must_use] pub fn forward_full(
468 &self,
469 src: &Variable,
470 tgt: &Variable,
471 src_mask: Option<&Variable>,
472 tgt_mask: Option<&Variable>,
473 memory_mask: Option<&Variable>,
474 ) -> Variable {
475 let memory = self.encoder.forward_with_mask(src, src_mask);
476 self.decoder
477 .forward_with_memory(tgt, &memory, tgt_mask, memory_mask)
478 }
479}
480
481impl Module for Transformer {
482 fn forward(&self, input: &Variable) -> Variable {
483 self.encoder.forward(input)
485 }
486
487 fn parameters(&self) -> Vec<Parameter> {
488 let mut params = Vec::new();
489 params.extend(self.encoder.parameters());
490 params.extend(self.decoder.parameters());
491 params
492 }
493
494 fn train(&mut self) {
495 self.encoder.train();
496 self.decoder.train();
497 }
498
499 fn eval(&mut self) {
500 self.encoder.eval();
501 self.decoder.eval();
502 }
503
504 fn is_training(&self) -> bool {
505 self.encoder.is_training()
506 }
507}
508
509pub struct VisionTransformer {
517 patch_embedding: Linear,
518 pos_encoding: PositionalEncoding,
519 encoder: TransformerEncoder,
520 mlp_head: Linear,
521 cls_token: Parameter,
522 patch_size: usize,
523 num_patches: usize,
524 d_model: usize,
525}
526
527impl VisionTransformer {
528 #[must_use] pub fn new(
541 image_size: usize,
542 patch_size: usize,
543 in_channels: usize,
544 num_classes: usize,
545 d_model: usize,
546 nhead: usize,
547 num_layers: usize,
548 dim_feedforward: usize,
549 dropout: f32,
550 ) -> Self {
551 assert!(
552 image_size % patch_size == 0,
553 "Image size must be divisible by patch size"
554 );
555
556 let num_patches = (image_size / patch_size) * (image_size / patch_size);
557 let patch_dim = in_channels * patch_size * patch_size;
558
559 let cls_data = Tensor::from_vec(vec![0.0f32; d_model], &[1, 1, d_model]).unwrap();
561 let cls_token = Parameter::named("cls_token", cls_data, true);
562
563 Self {
564 patch_embedding: Linear::new(patch_dim, d_model),
565 pos_encoding: PositionalEncoding::new(d_model, num_patches + 1), encoder: TransformerEncoder::new(d_model, nhead, num_layers, dim_feedforward, dropout),
567 mlp_head: Linear::new(d_model, num_classes),
568 cls_token,
569 patch_size,
570 num_patches,
571 d_model,
572 }
573 }
574
575 #[must_use] pub fn vit_tiny(image_size: usize, num_classes: usize) -> Self {
577 Self::new(image_size, 16, 3, num_classes, 192, 3, 12, 768, 0.0)
578 }
579
580 #[must_use] pub fn vit_small(image_size: usize, num_classes: usize) -> Self {
582 Self::new(image_size, 16, 3, num_classes, 384, 6, 12, 1536, 0.0)
583 }
584
585 #[must_use] pub fn vit_base(image_size: usize, num_classes: usize) -> Self {
587 Self::new(image_size, 16, 3, num_classes, 768, 12, 12, 3072, 0.0)
588 }
589
590 #[must_use] pub fn vit_large(image_size: usize, num_classes: usize) -> Self {
592 Self::new(image_size, 16, 3, num_classes, 1024, 16, 24, 4096, 0.0)
593 }
594
595 fn extract_patches(&self, x: &Variable) -> Variable {
597 let shape = x.shape();
598 let batch_size = shape[0];
599 let channels = shape[1];
600 let height = shape[2];
601 let width = shape[3];
602
603 let num_patches_h = height / self.patch_size;
604 let num_patches_w = width / self.patch_size;
605 let patch_dim = channels * self.patch_size * self.patch_size;
606
607 let x_data = x.data().to_vec();
608 let mut patches = vec![0.0f32; batch_size * self.num_patches * patch_dim];
609
610 for b in 0..batch_size {
611 for ph in 0..num_patches_h {
612 for pw in 0..num_patches_w {
613 let patch_idx = ph * num_patches_w + pw;
614 for c in 0..channels {
615 for i in 0..self.patch_size {
616 for j in 0..self.patch_size {
617 let img_h = ph * self.patch_size + i;
618 let img_w = pw * self.patch_size + j;
619 let img_idx = b * channels * height * width
620 + c * height * width
621 + img_h * width
622 + img_w;
623 let patch_offset =
624 c * self.patch_size * self.patch_size + i * self.patch_size + j;
625 let out_idx = b * self.num_patches * patch_dim
626 + patch_idx * patch_dim
627 + patch_offset;
628 patches[out_idx] = x_data[img_idx];
629 }
630 }
631 }
632 }
633 }
634 }
635
636 Variable::new(
637 Tensor::from_vec(patches, &[batch_size, self.num_patches, patch_dim]).unwrap(),
638 x.requires_grad(),
639 )
640 }
641}
642
643impl Module for VisionTransformer {
644 fn forward(&self, x: &Variable) -> Variable {
645 let shape = x.shape();
646 let batch_size = shape[0];
647
648 let patches = self.extract_patches(x);
650
651 let patch_emb = self.patch_embedding.forward(&patches);
653
654 let cls_data = self.cls_token.data().to_vec();
656 let patch_emb_data = patch_emb.data().to_vec();
657
658 let mut tokens = vec![0.0f32; batch_size * (self.num_patches + 1) * self.d_model];
659
660 for b in 0..batch_size {
661 for d in 0..self.d_model {
663 tokens[b * (self.num_patches + 1) * self.d_model + d] = cls_data[d];
664 }
665 for p in 0..self.num_patches {
667 for d in 0..self.d_model {
668 let src_idx = b * self.num_patches * self.d_model + p * self.d_model + d;
669 let dst_idx =
670 b * (self.num_patches + 1) * self.d_model + (p + 1) * self.d_model + d;
671 tokens[dst_idx] = patch_emb_data[src_idx];
672 }
673 }
674 }
675
676 let tokens = Variable::new(
677 Tensor::from_vec(tokens, &[batch_size, self.num_patches + 1, self.d_model]).unwrap(),
678 x.requires_grad(),
679 );
680
681 let tokens = self.pos_encoding.forward(&tokens);
683
684 let encoded = self.encoder.forward(&tokens);
686
687 let encoded_data = encoded.data().to_vec();
689 let mut cls_output = vec![0.0f32; batch_size * self.d_model];
690 for b in 0..batch_size {
691 for d in 0..self.d_model {
692 cls_output[b * self.d_model + d] =
693 encoded_data[b * (self.num_patches + 1) * self.d_model + d];
694 }
695 }
696
697 let cls_output = Variable::new(
698 Tensor::from_vec(cls_output, &[batch_size, self.d_model]).unwrap(),
699 x.requires_grad(),
700 );
701
702 self.mlp_head.forward(&cls_output)
704 }
705
706 fn parameters(&self) -> Vec<Parameter> {
707 let mut params = Vec::new();
708 params.push(self.cls_token.clone());
709 params.extend(self.patch_embedding.parameters());
710 params.extend(self.encoder.parameters());
711 params.extend(self.mlp_head.parameters());
712 params
713 }
714
715 fn train(&mut self) {
716 self.encoder.train();
717 }
718
719 fn eval(&mut self) {
720 self.encoder.eval();
721 }
722
723 fn is_training(&self) -> bool {
724 self.encoder.is_training()
725 }
726}
727
728#[must_use] pub fn vit_base() -> VisionTransformer {
734 VisionTransformer::vit_base(224, 1000)
735}
736
737#[must_use] pub fn vit_large() -> VisionTransformer {
739 VisionTransformer::vit_large(224, 1000)
740}
741
742#[cfg(test)]
747mod tests {
748 use super::*;
749
750 #[test]
751 fn test_positional_encoding() {
752 let pe = PositionalEncoding::new(64, 100);
753 let input = Variable::new(
754 Tensor::from_vec(vec![0.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
755 false,
756 );
757 let output = pe.forward(&input);
758 assert_eq!(output.shape(), vec![2, 10, 64]);
759 }
760
761 #[test]
762 fn test_encoder_layer() {
763 let layer = TransformerEncoderLayer::new(64, 4, 256, 0.1);
764 let input = Variable::new(
765 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
766 false,
767 );
768 let output = layer.forward(&input);
769 assert_eq!(output.shape(), vec![2, 10, 64]);
770 }
771
772 #[test]
773 fn test_transformer_encoder() {
774 let encoder = TransformerEncoder::new(64, 4, 2, 256, 0.1);
775 let input = Variable::new(
776 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
777 false,
778 );
779 let output = encoder.forward(&input);
780 assert_eq!(output.shape(), vec![2, 10, 64]);
781 }
782
783 #[test]
784 fn test_transformer() {
785 let transformer = Transformer::new(64, 4, 2, 2, 256, 0.1);
786 let src = Variable::new(
787 Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
788 false,
789 );
790 let tgt = Variable::new(
791 Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
792 false,
793 );
794 let output = transformer.forward_full(&src, &tgt, None, None, None);
795 assert_eq!(output.shape(), vec![2, 5, 64]);
796 }
797
798 #[test]
799 fn test_vit_creation() {
800 let vit = VisionTransformer::new(
801 32, 8, 3, 10, 64, 4, 2, 256, 0.1, );
811 let params = vit.parameters();
812 assert!(!params.is_empty());
813 }
814
815 #[test]
816 fn test_vit_forward() {
817 let vit = VisionTransformer::new(32, 8, 3, 10, 64, 4, 2, 256, 0.1);
818 let input = Variable::new(
819 Tensor::from_vec(vec![0.5; 2 * 3 * 32 * 32], &[2, 3, 32, 32]).unwrap(),
820 false,
821 );
822 let output = vit.forward(&input);
823 assert_eq!(output.shape(), vec![2, 10]);
824 }
825
826 #[test]
827 fn test_vit_tiny() {
828 let vit = VisionTransformer::vit_tiny(32, 10);
829 let params = vit.parameters();
830 assert!(!params.is_empty());
831 }
832}