Skip to main content

axonml_vision/models/
transformer.rs

1//! Transformer - Attention-based Neural Networks
2//!
3//! Implementation of Transformer architectures for various tasks.
4//!
5//! # Models
6//!
7//! - **`TransformerEncoder`**: Stack of encoder layers
8//! - **`TransformerDecoder`**: Stack of decoder layers
9//! - **Transformer**: Full encoder-decoder Transformer
10//! - **`VisionTransformer` (`ViT`)**: Transformer for image classification
11//!
12//! # Reference
13//!
14//! "Attention Is All You Need" (Vaswani et al., 2017)
15//! <https://arxiv.org/abs/1706.03762>
16//!
17//! "An Image is Worth 16x16 Words" (Dosovitskiy et al., 2020)
18//! <https://arxiv.org/abs/2010.11929>
19
20use axonml_autograd::Variable;
21use axonml_nn::{Dropout, LayerNorm, Linear, Module, MultiHeadAttention, Parameter};
22use axonml_tensor::Tensor;
23
24// =============================================================================
25// Positional Encoding
26// =============================================================================
27
28/// Positional encoding using sinusoidal functions.
29pub struct PositionalEncoding {
30    encoding: Tensor<f32>,
31    max_len: usize,
32    d_model: usize,
33}
34
35impl PositionalEncoding {
36    /// Create positional encoding.
37    #[must_use] pub fn new(d_model: usize, max_len: usize) -> Self {
38        let mut pe = vec![0.0f32; max_len * d_model];
39
40        for pos in 0..max_len {
41            for i in 0..d_model {
42                let div_term = (-(i as f32 / d_model as f32) * (10000.0f32).ln()).exp();
43                if i % 2 == 0 {
44                    pe[pos * d_model + i] = (pos as f32 * div_term).sin();
45                } else {
46                    pe[pos * d_model + i] = (pos as f32 * div_term).cos();
47                }
48            }
49        }
50
51        Self {
52            encoding: Tensor::from_vec(pe, &[max_len, d_model]).unwrap(),
53            max_len,
54            d_model,
55        }
56    }
57
58    /// Add positional encoding to input.
59    #[must_use] pub fn forward(&self, x: &Variable) -> Variable {
60        let shape = x.shape();
61        let seq_len = shape[1];
62        let x_data = x.data().to_vec();
63        let pe_data = self.encoding.to_vec();
64
65        // Broadcasting PE across batch
66        let batch_size = shape[0];
67        let mut result = x_data.clone();
68
69        for b in 0..batch_size {
70            for s in 0..seq_len.min(self.max_len) {
71                for d in 0..self.d_model {
72                    let idx = b * seq_len * self.d_model + s * self.d_model + d;
73                    result[idx] += pe_data[s * self.d_model + d];
74                }
75            }
76        }
77
78        Variable::new(Tensor::from_vec(result, &shape).unwrap(), x.requires_grad())
79    }
80}
81
82// =============================================================================
83// Transformer Encoder Layer
84// =============================================================================
85
86/// A single Transformer encoder layer.
87pub struct TransformerEncoderLayer {
88    self_attn: MultiHeadAttention,
89    ff_linear1: Linear,
90    ff_linear2: Linear,
91    norm1: LayerNorm,
92    norm2: LayerNorm,
93    dropout: Dropout,
94    d_model: usize,
95}
96
97impl TransformerEncoderLayer {
98    /// Create encoder layer.
99    #[must_use] pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, dropout: f32) -> Self {
100        Self {
101            self_attn: MultiHeadAttention::with_options(d_model, nhead, dropout, true),
102            ff_linear1: Linear::new(d_model, dim_feedforward),
103            ff_linear2: Linear::new(dim_feedforward, d_model),
104            norm1: LayerNorm::new(vec![d_model]),
105            norm2: LayerNorm::new(vec![d_model]),
106            dropout: Dropout::new(dropout),
107            d_model,
108        }
109    }
110
111    /// Returns the model dimension.
112    pub fn d_model(&self) -> usize {
113        self.d_model
114    }
115
116    /// Forward with optional attention mask.
117    pub fn forward_with_mask(&self, src: &Variable, src_mask: Option<&Variable>) -> Variable {
118        // Self-attention with residual
119        let attn_out = self.self_attn.attention(src, src, src, src_mask);
120        let attn_out = self.dropout.forward(&attn_out);
121        let src = src.add_var(&attn_out);
122        let src = self.norm1.forward(&src);
123
124        // Feed-forward with residual
125        let ff_out = self.ff_linear1.forward(&src);
126        let ff_out = ff_out.relu();
127        let ff_out = self.dropout.forward(&ff_out);
128        let ff_out = self.ff_linear2.forward(&ff_out);
129        let ff_out = self.dropout.forward(&ff_out);
130        let src = src.add_var(&ff_out);
131
132        self.norm2.forward(&src)
133    }
134}
135
136impl Module for TransformerEncoderLayer {
137    fn forward(&self, input: &Variable) -> Variable {
138        self.forward_with_mask(input, None)
139    }
140
141    fn parameters(&self) -> Vec<Parameter> {
142        let mut params = Vec::new();
143        params.extend(self.self_attn.parameters());
144        params.extend(self.ff_linear1.parameters());
145        params.extend(self.ff_linear2.parameters());
146        params.extend(self.norm1.parameters());
147        params.extend(self.norm2.parameters());
148        params
149    }
150
151    fn train(&mut self) {
152        self.dropout.train();
153    }
154
155    fn eval(&mut self) {
156        self.dropout.eval();
157    }
158
159    fn is_training(&self) -> bool {
160        self.dropout.is_training()
161    }
162}
163
164// =============================================================================
165// Transformer Encoder
166// =============================================================================
167
168/// Stack of Transformer encoder layers.
169pub struct TransformerEncoder {
170    layers: Vec<TransformerEncoderLayer>,
171    norm: Option<LayerNorm>,
172}
173
174impl TransformerEncoder {
175    /// Create encoder with specified layers.
176    #[must_use] pub fn new(
177        d_model: usize,
178        nhead: usize,
179        num_layers: usize,
180        dim_feedforward: usize,
181        dropout: f32,
182    ) -> Self {
183        let layers = (0..num_layers)
184            .map(|_| TransformerEncoderLayer::new(d_model, nhead, dim_feedforward, dropout))
185            .collect();
186
187        Self {
188            layers,
189            norm: Some(LayerNorm::new(vec![d_model])),
190        }
191    }
192
193    /// Forward with optional mask.
194    #[must_use] pub fn forward_with_mask(&self, src: &Variable, mask: Option<&Variable>) -> Variable {
195        let mut output = src.clone();
196        for layer in &self.layers {
197            output = layer.forward_with_mask(&output, mask);
198        }
199        if let Some(norm) = &self.norm {
200            output = norm.forward(&output);
201        }
202        output
203    }
204}
205
206impl Module for TransformerEncoder {
207    fn forward(&self, input: &Variable) -> Variable {
208        self.forward_with_mask(input, None)
209    }
210
211    fn parameters(&self) -> Vec<Parameter> {
212        let mut params = Vec::new();
213        for layer in &self.layers {
214            params.extend(layer.parameters());
215        }
216        if let Some(norm) = &self.norm {
217            params.extend(norm.parameters());
218        }
219        params
220    }
221
222    fn train(&mut self) {
223        for layer in &mut self.layers {
224            layer.train();
225        }
226    }
227
228    fn eval(&mut self) {
229        for layer in &mut self.layers {
230            layer.eval();
231        }
232    }
233
234    fn is_training(&self) -> bool {
235        self.layers.first().map_or(true, axonml_nn::Module::is_training)
236    }
237}
238
239// =============================================================================
240// Transformer Decoder Layer
241// =============================================================================
242
243/// A single Transformer decoder layer.
244pub struct TransformerDecoderLayer {
245    self_attn: MultiHeadAttention,
246    cross_attn: MultiHeadAttention,
247    ff_linear1: Linear,
248    ff_linear2: Linear,
249    norm1: LayerNorm,
250    norm2: LayerNorm,
251    norm3: LayerNorm,
252    dropout: Dropout,
253}
254
255impl TransformerDecoderLayer {
256    /// Create decoder layer.
257    #[must_use] pub fn new(d_model: usize, nhead: usize, dim_feedforward: usize, dropout: f32) -> Self {
258        Self {
259            self_attn: MultiHeadAttention::with_options(d_model, nhead, dropout, true),
260            cross_attn: MultiHeadAttention::with_options(d_model, nhead, dropout, true),
261            ff_linear1: Linear::new(d_model, dim_feedforward),
262            ff_linear2: Linear::new(dim_feedforward, d_model),
263            norm1: LayerNorm::new(vec![d_model]),
264            norm2: LayerNorm::new(vec![d_model]),
265            norm3: LayerNorm::new(vec![d_model]),
266            dropout: Dropout::new(dropout),
267        }
268    }
269
270    /// Forward with memory and masks.
271    pub fn forward_with_memory(
272        &self,
273        tgt: &Variable,
274        memory: &Variable,
275        tgt_mask: Option<&Variable>,
276        memory_mask: Option<&Variable>,
277    ) -> Variable {
278        // Self-attention with residual
279        let attn_out = self.self_attn.attention(tgt, tgt, tgt, tgt_mask);
280        let attn_out = self.dropout.forward(&attn_out);
281        let tgt = tgt.add_var(&attn_out);
282        let tgt = self.norm1.forward(&tgt);
283
284        // Cross-attention with residual
285        let cross_out = self.cross_attn.attention(&tgt, memory, memory, memory_mask);
286        let cross_out = self.dropout.forward(&cross_out);
287        let tgt = tgt.add_var(&cross_out);
288        let tgt = self.norm2.forward(&tgt);
289
290        // Feed-forward with residual
291        let ff_out = self.ff_linear1.forward(&tgt);
292        let ff_out = ff_out.relu();
293        let ff_out = self.dropout.forward(&ff_out);
294        let ff_out = self.ff_linear2.forward(&ff_out);
295        let ff_out = self.dropout.forward(&ff_out);
296        let tgt = tgt.add_var(&ff_out);
297
298        self.norm3.forward(&tgt)
299    }
300}
301
302impl Module for TransformerDecoderLayer {
303    fn forward(&self, input: &Variable) -> Variable {
304        // For standard forward, use self-attention only
305        self.self_attn.forward(input)
306    }
307
308    fn parameters(&self) -> Vec<Parameter> {
309        let mut params = Vec::new();
310        params.extend(self.self_attn.parameters());
311        params.extend(self.cross_attn.parameters());
312        params.extend(self.ff_linear1.parameters());
313        params.extend(self.ff_linear2.parameters());
314        params.extend(self.norm1.parameters());
315        params.extend(self.norm2.parameters());
316        params.extend(self.norm3.parameters());
317        params
318    }
319
320    fn train(&mut self) {
321        self.dropout.train();
322    }
323
324    fn eval(&mut self) {
325        self.dropout.eval();
326    }
327
328    fn is_training(&self) -> bool {
329        self.dropout.is_training()
330    }
331}
332
333// =============================================================================
334// Transformer Decoder
335// =============================================================================
336
337/// Stack of Transformer decoder layers.
338pub struct TransformerDecoder {
339    layers: Vec<TransformerDecoderLayer>,
340    norm: Option<LayerNorm>,
341}
342
343impl TransformerDecoder {
344    /// Create decoder with specified layers.
345    #[must_use] pub fn new(
346        d_model: usize,
347        nhead: usize,
348        num_layers: usize,
349        dim_feedforward: usize,
350        dropout: f32,
351    ) -> Self {
352        let layers = (0..num_layers)
353            .map(|_| TransformerDecoderLayer::new(d_model, nhead, dim_feedforward, dropout))
354            .collect();
355
356        Self {
357            layers,
358            norm: Some(LayerNorm::new(vec![d_model])),
359        }
360    }
361
362    /// Forward with memory and masks.
363    #[must_use] pub fn forward_with_memory(
364        &self,
365        tgt: &Variable,
366        memory: &Variable,
367        tgt_mask: Option<&Variable>,
368        memory_mask: Option<&Variable>,
369    ) -> Variable {
370        let mut output = tgt.clone();
371        for layer in &self.layers {
372            output = layer.forward_with_memory(&output, memory, tgt_mask, memory_mask);
373        }
374        if let Some(norm) = &self.norm {
375            output = norm.forward(&output);
376        }
377        output
378    }
379}
380
381impl Module for TransformerDecoder {
382    fn forward(&self, input: &Variable) -> Variable {
383        let mut output = input.clone();
384        for layer in &self.layers {
385            output = layer.forward(&output);
386        }
387        if let Some(norm) = &self.norm {
388            output = norm.forward(&output);
389        }
390        output
391    }
392
393    fn parameters(&self) -> Vec<Parameter> {
394        let mut params = Vec::new();
395        for layer in &self.layers {
396            params.extend(layer.parameters());
397        }
398        if let Some(norm) = &self.norm {
399            params.extend(norm.parameters());
400        }
401        params
402    }
403
404    fn train(&mut self) {
405        for layer in &mut self.layers {
406            layer.train();
407        }
408    }
409
410    fn eval(&mut self) {
411        for layer in &mut self.layers {
412            layer.eval();
413        }
414    }
415
416    fn is_training(&self) -> bool {
417        self.layers.first().map_or(true, axonml_nn::Module::is_training)
418    }
419}
420
421// =============================================================================
422// Full Transformer
423// =============================================================================
424
425/// Full Transformer model (encoder-decoder).
426pub struct Transformer {
427    encoder: TransformerEncoder,
428    decoder: TransformerDecoder,
429    d_model: usize,
430}
431
432impl Transformer {
433    /// Create a Transformer model.
434    #[must_use] pub fn new(
435        d_model: usize,
436        nhead: usize,
437        num_encoder_layers: usize,
438        num_decoder_layers: usize,
439        dim_feedforward: usize,
440        dropout: f32,
441    ) -> Self {
442        Self {
443            encoder: TransformerEncoder::new(
444                d_model,
445                nhead,
446                num_encoder_layers,
447                dim_feedforward,
448                dropout,
449            ),
450            decoder: TransformerDecoder::new(
451                d_model,
452                nhead,
453                num_decoder_layers,
454                dim_feedforward,
455                dropout,
456            ),
457            d_model,
458        }
459    }
460
461    /// Returns the model dimension.
462    #[must_use] pub fn d_model(&self) -> usize {
463        self.d_model
464    }
465
466    /// Forward pass with source and target.
467    #[must_use] pub fn forward_full(
468        &self,
469        src: &Variable,
470        tgt: &Variable,
471        src_mask: Option<&Variable>,
472        tgt_mask: Option<&Variable>,
473        memory_mask: Option<&Variable>,
474    ) -> Variable {
475        let memory = self.encoder.forward_with_mask(src, src_mask);
476        self.decoder
477            .forward_with_memory(tgt, &memory, tgt_mask, memory_mask)
478    }
479}
480
481impl Module for Transformer {
482    fn forward(&self, input: &Variable) -> Variable {
483        // Encoder-only forward for classification tasks
484        self.encoder.forward(input)
485    }
486
487    fn parameters(&self) -> Vec<Parameter> {
488        let mut params = Vec::new();
489        params.extend(self.encoder.parameters());
490        params.extend(self.decoder.parameters());
491        params
492    }
493
494    fn train(&mut self) {
495        self.encoder.train();
496        self.decoder.train();
497    }
498
499    fn eval(&mut self) {
500        self.encoder.eval();
501        self.decoder.eval();
502    }
503
504    fn is_training(&self) -> bool {
505        self.encoder.is_training()
506    }
507}
508
509// =============================================================================
510// Vision Transformer (ViT)
511// =============================================================================
512
513/// Vision Transformer for image classification.
514///
515/// Converts images into patches and processes them with a Transformer encoder.
516pub struct VisionTransformer {
517    patch_embedding: Linear,
518    pos_encoding: PositionalEncoding,
519    encoder: TransformerEncoder,
520    mlp_head: Linear,
521    cls_token: Parameter,
522    patch_size: usize,
523    num_patches: usize,
524    d_model: usize,
525}
526
527impl VisionTransformer {
528    /// Create a Vision Transformer.
529    ///
530    /// # Arguments
531    /// * `image_size` - Input image size (assumes square)
532    /// * `patch_size` - Size of each patch
533    /// * `in_channels` - Number of input channels (3 for RGB)
534    /// * `num_classes` - Number of output classes
535    /// * `d_model` - Model dimension
536    /// * `nhead` - Number of attention heads
537    /// * `num_layers` - Number of encoder layers
538    /// * `dim_feedforward` - Feed-forward dimension
539    /// * `dropout` - Dropout probability
540    #[must_use] pub fn new(
541        image_size: usize,
542        patch_size: usize,
543        in_channels: usize,
544        num_classes: usize,
545        d_model: usize,
546        nhead: usize,
547        num_layers: usize,
548        dim_feedforward: usize,
549        dropout: f32,
550    ) -> Self {
551        assert!(
552            image_size % patch_size == 0,
553            "Image size must be divisible by patch size"
554        );
555
556        let num_patches = (image_size / patch_size) * (image_size / patch_size);
557        let patch_dim = in_channels * patch_size * patch_size;
558
559        // CLS token as learnable parameter
560        let cls_data = Tensor::from_vec(vec![0.0f32; d_model], &[1, 1, d_model]).unwrap();
561        let cls_token = Parameter::named("cls_token", cls_data, true);
562
563        Self {
564            patch_embedding: Linear::new(patch_dim, d_model),
565            pos_encoding: PositionalEncoding::new(d_model, num_patches + 1), // +1 for CLS
566            encoder: TransformerEncoder::new(d_model, nhead, num_layers, dim_feedforward, dropout),
567            mlp_head: Linear::new(d_model, num_classes),
568            cls_token,
569            patch_size,
570            num_patches,
571            d_model,
572        }
573    }
574
575    /// Create ViT-Tiny.
576    #[must_use] pub fn vit_tiny(image_size: usize, num_classes: usize) -> Self {
577        Self::new(image_size, 16, 3, num_classes, 192, 3, 12, 768, 0.0)
578    }
579
580    /// Create ViT-Small.
581    #[must_use] pub fn vit_small(image_size: usize, num_classes: usize) -> Self {
582        Self::new(image_size, 16, 3, num_classes, 384, 6, 12, 1536, 0.0)
583    }
584
585    /// Create ViT-Base.
586    #[must_use] pub fn vit_base(image_size: usize, num_classes: usize) -> Self {
587        Self::new(image_size, 16, 3, num_classes, 768, 12, 12, 3072, 0.0)
588    }
589
590    /// Create ViT-Large.
591    #[must_use] pub fn vit_large(image_size: usize, num_classes: usize) -> Self {
592        Self::new(image_size, 16, 3, num_classes, 1024, 16, 24, 4096, 0.0)
593    }
594
595    /// Extract patches from image.
596    fn extract_patches(&self, x: &Variable) -> Variable {
597        let shape = x.shape();
598        let batch_size = shape[0];
599        let channels = shape[1];
600        let height = shape[2];
601        let width = shape[3];
602
603        let num_patches_h = height / self.patch_size;
604        let num_patches_w = width / self.patch_size;
605        let patch_dim = channels * self.patch_size * self.patch_size;
606
607        let x_data = x.data().to_vec();
608        let mut patches = vec![0.0f32; batch_size * self.num_patches * patch_dim];
609
610        for b in 0..batch_size {
611            for ph in 0..num_patches_h {
612                for pw in 0..num_patches_w {
613                    let patch_idx = ph * num_patches_w + pw;
614                    for c in 0..channels {
615                        for i in 0..self.patch_size {
616                            for j in 0..self.patch_size {
617                                let img_h = ph * self.patch_size + i;
618                                let img_w = pw * self.patch_size + j;
619                                let img_idx = b * channels * height * width
620                                    + c * height * width
621                                    + img_h * width
622                                    + img_w;
623                                let patch_offset =
624                                    c * self.patch_size * self.patch_size + i * self.patch_size + j;
625                                let out_idx = b * self.num_patches * patch_dim
626                                    + patch_idx * patch_dim
627                                    + patch_offset;
628                                patches[out_idx] = x_data[img_idx];
629                            }
630                        }
631                    }
632                }
633            }
634        }
635
636        Variable::new(
637            Tensor::from_vec(patches, &[batch_size, self.num_patches, patch_dim]).unwrap(),
638            x.requires_grad(),
639        )
640    }
641}
642
643impl Module for VisionTransformer {
644    fn forward(&self, x: &Variable) -> Variable {
645        let shape = x.shape();
646        let batch_size = shape[0];
647
648        // Extract patches: [B, C, H, W] -> [B, num_patches, patch_dim]
649        let patches = self.extract_patches(x);
650
651        // Embed patches: [B, num_patches, patch_dim] -> [B, num_patches, d_model]
652        let patch_emb = self.patch_embedding.forward(&patches);
653
654        // Prepend CLS token
655        let cls_data = self.cls_token.data().to_vec();
656        let patch_emb_data = patch_emb.data().to_vec();
657
658        let mut tokens = vec![0.0f32; batch_size * (self.num_patches + 1) * self.d_model];
659
660        for b in 0..batch_size {
661            // CLS token
662            for d in 0..self.d_model {
663                tokens[b * (self.num_patches + 1) * self.d_model + d] = cls_data[d];
664            }
665            // Patch embeddings
666            for p in 0..self.num_patches {
667                for d in 0..self.d_model {
668                    let src_idx = b * self.num_patches * self.d_model + p * self.d_model + d;
669                    let dst_idx =
670                        b * (self.num_patches + 1) * self.d_model + (p + 1) * self.d_model + d;
671                    tokens[dst_idx] = patch_emb_data[src_idx];
672                }
673            }
674        }
675
676        let tokens = Variable::new(
677            Tensor::from_vec(tokens, &[batch_size, self.num_patches + 1, self.d_model]).unwrap(),
678            x.requires_grad(),
679        );
680
681        // Add positional encoding
682        let tokens = self.pos_encoding.forward(&tokens);
683
684        // Pass through encoder
685        let encoded = self.encoder.forward(&tokens);
686
687        // Extract CLS token output: [B, num_patches+1, d_model] -> [B, d_model]
688        let encoded_data = encoded.data().to_vec();
689        let mut cls_output = vec![0.0f32; batch_size * self.d_model];
690        for b in 0..batch_size {
691            for d in 0..self.d_model {
692                cls_output[b * self.d_model + d] =
693                    encoded_data[b * (self.num_patches + 1) * self.d_model + d];
694            }
695        }
696
697        let cls_output = Variable::new(
698            Tensor::from_vec(cls_output, &[batch_size, self.d_model]).unwrap(),
699            x.requires_grad(),
700        );
701
702        // Classification head
703        self.mlp_head.forward(&cls_output)
704    }
705
706    fn parameters(&self) -> Vec<Parameter> {
707        let mut params = Vec::new();
708        params.push(self.cls_token.clone());
709        params.extend(self.patch_embedding.parameters());
710        params.extend(self.encoder.parameters());
711        params.extend(self.mlp_head.parameters());
712        params
713    }
714
715    fn train(&mut self) {
716        self.encoder.train();
717    }
718
719    fn eval(&mut self) {
720        self.encoder.eval();
721    }
722
723    fn is_training(&self) -> bool {
724        self.encoder.is_training()
725    }
726}
727
728// =============================================================================
729// Convenience Functions
730// =============================================================================
731
732/// Create ViT-Base for `ImageNet` (224x224, 1000 classes).
733#[must_use] pub fn vit_base() -> VisionTransformer {
734    VisionTransformer::vit_base(224, 1000)
735}
736
737/// Create ViT-Large for `ImageNet` (224x224, 1000 classes).
738#[must_use] pub fn vit_large() -> VisionTransformer {
739    VisionTransformer::vit_large(224, 1000)
740}
741
742// =============================================================================
743// Tests
744// =============================================================================
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749
750    #[test]
751    fn test_positional_encoding() {
752        let pe = PositionalEncoding::new(64, 100);
753        let input = Variable::new(
754            Tensor::from_vec(vec![0.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
755            false,
756        );
757        let output = pe.forward(&input);
758        assert_eq!(output.shape(), vec![2, 10, 64]);
759    }
760
761    #[test]
762    fn test_encoder_layer() {
763        let layer = TransformerEncoderLayer::new(64, 4, 256, 0.1);
764        let input = Variable::new(
765            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
766            false,
767        );
768        let output = layer.forward(&input);
769        assert_eq!(output.shape(), vec![2, 10, 64]);
770    }
771
772    #[test]
773    fn test_transformer_encoder() {
774        let encoder = TransformerEncoder::new(64, 4, 2, 256, 0.1);
775        let input = Variable::new(
776            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
777            false,
778        );
779        let output = encoder.forward(&input);
780        assert_eq!(output.shape(), vec![2, 10, 64]);
781    }
782
783    #[test]
784    fn test_transformer() {
785        let transformer = Transformer::new(64, 4, 2, 2, 256, 0.1);
786        let src = Variable::new(
787            Tensor::from_vec(vec![1.0; 2 * 10 * 64], &[2, 10, 64]).unwrap(),
788            false,
789        );
790        let tgt = Variable::new(
791            Tensor::from_vec(vec![1.0; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
792            false,
793        );
794        let output = transformer.forward_full(&src, &tgt, None, None, None);
795        assert_eq!(output.shape(), vec![2, 5, 64]);
796    }
797
798    #[test]
799    fn test_vit_creation() {
800        let vit = VisionTransformer::new(
801            32,  // image_size
802            8,   // patch_size
803            3,   // channels
804            10,  // num_classes
805            64,  // d_model
806            4,   // nhead
807            2,   // num_layers
808            256, // dim_ff
809            0.1, // dropout
810        );
811        let params = vit.parameters();
812        assert!(!params.is_empty());
813    }
814
815    #[test]
816    fn test_vit_forward() {
817        let vit = VisionTransformer::new(32, 8, 3, 10, 64, 4, 2, 256, 0.1);
818        let input = Variable::new(
819            Tensor::from_vec(vec![0.5; 2 * 3 * 32 * 32], &[2, 3, 32, 32]).unwrap(),
820            false,
821        );
822        let output = vit.forward(&input);
823        assert_eq!(output.shape(), vec![2, 10]);
824    }
825
826    #[test]
827    fn test_vit_tiny() {
828        let vit = VisionTransformer::vit_tiny(32, 10);
829        let params = vit.parameters();
830        assert!(!params.is_empty());
831    }
832}