Skip to main content

jepa_vision/
vit.rs

1//! Vision Transformer (ViT) encoder for JEPA.
2//!
3//! Implements RFC-002 (Encoder Module) — concrete ViT encoder.
4//!
5//! The ViT encoder converts an image into a sequence of patch-level
6//! representations suitable for JEPA training and inference.
7//!
8//! ```text
9//! [B, C, H, W]  ──►  PatchEmbedding  ──►  2D RoPE  ──►  N × TransformerBlock  ──►  LayerNorm
10//!                   [B, S, D]         [B, S, D]        [B, S, D]                   [B, S, D]
11//! ```
12//!
13//! Six preset configurations are provided:
14//!
15//! | Preset | Layers | Dim | Heads | Params (approx.) |
16//! |--------|--------|-----|-------|------------------|
17//! | `tiny_test` | 2 | 32 | 2 | ~12 K |
18//! | `vit_small_patch16` | 12 | 384 | 6 | ~22 M |
19//! | `vit_base_patch16` | 12 | 768 | 12 | ~86 M |
20//! | `vit_large_patch16` | 24 | 1024 | 16 | ~307 M |
21//! | `vit_huge_patch14` | 32 | 1280 | 16 | ~632 M |
22//! | `vit_giant_patch14` | 40 | 1408 | 16 | ~1.0 B |
23//!
24//! Two forward paths exist:
25//! - [`VitEncoder::forward`] — encode all patches (standard inference).
26//! - [`VitEncoder::forward_visible_tokens`] — encode only visible (context)
27//!   patches for efficient masked training.
28
29use std::collections::HashMap;
30
31use burn::module::{Module, Param};
32use burn::nn::{LayerNorm, LayerNormConfig, LayerNormRecord, Linear, LinearConfig, LinearRecord};
33use burn::prelude::*;
34use burn::tensor::backend::Backend;
35use burn::tensor::TensorData;
36
37use jepa_core::ema::Ema;
38use jepa_core::types::Representation;
39use jepa_core::Encoder;
40
41use crate::patch::{PatchEmbedding, PatchEmbeddingConfig};
42use crate::rope::{RotaryPositionEncoding2D, RotaryPositionEncoding2DConfig};
43use crate::token_ops::gather_token_sequence;
44
45/// Errors returned when loading a ViT encoder from named tensors.
46#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)]
47pub enum VitLoadError {
48    #[error("missing checkpoint tensor `{0}`")]
49    MissingKey(String),
50    #[error(
51        "shape mismatch for `{key}`: checkpoint {checkpoint_shape:?} vs model {model_shape:?}"
52    )]
53    ShapeMismatch {
54        key: String,
55        checkpoint_shape: Vec<usize>,
56        model_shape: Vec<usize>,
57    },
58}
59
60/// Configuration for a Vision Transformer encoder.
61///
62/// # Example
63///
64/// ```
65/// use jepa_vision::vit::VitConfig;
66/// use jepa_core::Encoder;
67/// use burn_ndarray::NdArray;
68///
69/// type B = NdArray<f32>;
70/// let device = burn_ndarray::NdArrayDevice::Cpu;
71///
72/// let config = VitConfig::tiny_test();
73/// let encoder = config.init::<B>(&device);
74/// assert_eq!(encoder.embed_dim(), 32);
75/// ```
76#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
77pub struct VitConfig {
78    /// Number of input channels (e.g., 3 for RGB).
79    pub in_channels: usize,
80    /// Input image height in pixels.
81    pub image_height: usize,
82    /// Input image width in pixels.
83    pub image_width: usize,
84    /// Patch size `(height, width)`.
85    pub patch_size: (usize, usize),
86    /// Embedding dimension.
87    pub embed_dim: usize,
88    /// Number of transformer layers.
89    pub num_layers: usize,
90    /// Number of attention heads.
91    pub num_heads: usize,
92    /// MLP hidden dimension (typically 4 * embed_dim).
93    pub mlp_dim: usize,
94    /// Dropout rate (not used during inference).
95    pub dropout: f64,
96}
97
98impl VitConfig {
99    /// Create a ViT-Base/16 config for 224x224 images.
100    pub fn vit_base_patch16() -> Self {
101        Self {
102            in_channels: 3,
103            image_height: 224,
104            image_width: 224,
105            patch_size: (16, 16),
106            embed_dim: 768,
107            num_layers: 12,
108            num_heads: 12,
109            mlp_dim: 3072,
110            dropout: 0.0,
111        }
112    }
113
114    /// Create a ViT-Small/16 config for 224x224 images.
115    pub fn vit_small_patch16() -> Self {
116        Self {
117            in_channels: 3,
118            image_height: 224,
119            image_width: 224,
120            patch_size: (16, 16),
121            embed_dim: 384,
122            num_layers: 12,
123            num_heads: 6,
124            mlp_dim: 1536,
125            dropout: 0.0,
126        }
127    }
128
129    /// Create a ViT-Large/16 config for 224x224 images.
130    ///
131    /// Matches the architecture used in Facebook Research I-JEPA ViT-L/16.
132    pub fn vit_large_patch16() -> Self {
133        Self {
134            in_channels: 3,
135            image_height: 224,
136            image_width: 224,
137            patch_size: (16, 16),
138            embed_dim: 1024,
139            num_layers: 24,
140            num_heads: 16,
141            mlp_dim: 4096,
142            dropout: 0.0,
143        }
144    }
145
146    /// Create a ViT-Huge/14 config for 224x224 images.
147    ///
148    /// Matches the architecture used in Facebook Research I-JEPA ViT-H/14
149    /// (the primary model released with the I-JEPA paper).
150    pub fn vit_huge_patch14() -> Self {
151        Self {
152            in_channels: 3,
153            image_height: 224,
154            image_width: 224,
155            patch_size: (14, 14),
156            embed_dim: 1280,
157            num_layers: 32,
158            num_heads: 16,
159            mlp_dim: 5120,
160            dropout: 0.0,
161        }
162    }
163
164    /// Create a ViT-Huge/16 config for 448x448 images.
165    ///
166    /// Matches the architecture used in Facebook Research I-JEPA ViT-H/16-448.
167    pub fn vit_huge_patch16_448() -> Self {
168        Self {
169            in_channels: 3,
170            image_height: 448,
171            image_width: 448,
172            patch_size: (16, 16),
173            embed_dim: 1280,
174            num_layers: 32,
175            num_heads: 16,
176            mlp_dim: 5120,
177            dropout: 0.0,
178        }
179    }
180
181    /// Create a ViT-Giant/16 config for 224x224 images.
182    ///
183    /// Matches the architecture used in Facebook Research I-JEPA ViT-G/16.
184    pub fn vit_giant_patch16() -> Self {
185        Self {
186            in_channels: 3,
187            image_height: 224,
188            image_width: 224,
189            patch_size: (16, 16),
190            embed_dim: 1408,
191            num_layers: 40,
192            num_heads: 16,
193            mlp_dim: 6144,
194            dropout: 0.0,
195        }
196    }
197
198    /// Create a minimal config for testing.
199    pub fn tiny_test() -> Self {
200        Self {
201            in_channels: 1,
202            image_height: 8,
203            image_width: 8,
204            patch_size: (2, 2),
205            embed_dim: 32,
206            num_layers: 2,
207            num_heads: 4,
208            mlp_dim: 64,
209            dropout: 0.0,
210        }
211    }
212
213    fn grid_height(&self) -> usize {
214        self.image_height / self.patch_size.0
215    }
216
217    fn grid_width(&self) -> usize {
218        self.image_width / self.patch_size.1
219    }
220
221    /// Initialize a [`VitEncoder`] module.
222    pub fn init<B: Backend>(&self, device: &B::Device) -> VitEncoder<B> {
223        let patch_embed_config = PatchEmbeddingConfig::new(
224            self.in_channels,
225            self.patch_size.0,
226            self.patch_size.1,
227            self.embed_dim,
228        );
229        let patch_embed = patch_embed_config.init(device);
230
231        let rope_config = RotaryPositionEncoding2DConfig::new(
232            self.embed_dim,
233            self.grid_height(),
234            self.grid_width(),
235        );
236        let positional_encoding = rope_config.init(device);
237
238        let blocks: Vec<TransformerBlock<B>> = (0..self.num_layers)
239            .map(|_| {
240                TransformerBlockConfig {
241                    embed_dim: self.embed_dim,
242                    num_heads: self.num_heads,
243                    mlp_dim: self.mlp_dim,
244                }
245                .init(device)
246            })
247            .collect();
248
249        let norm = LayerNormConfig::new(self.embed_dim).init(device);
250
251        VitEncoder {
252            patch_embed,
253            positional_encoding,
254            blocks,
255            norm,
256            embed_dim: self.embed_dim,
257        }
258    }
259}
260
261/// Vision Transformer encoder.
262///
263/// Maps images to patch-level representations via:
264/// 1. Patch embedding (linear projection of flattened patches)
265/// 2. 2D Rotary Position Encoding
266/// 3. Stack of transformer blocks (self-attention + MLP)
267/// 4. Final layer normalization
268///
269/// Output shape: `[batch, num_patches, embed_dim]`
270#[derive(Module, Debug)]
271pub struct VitEncoder<B: Backend> {
272    /// Patch embedding: image → patch tokens.
273    patch_embed: PatchEmbedding<B>,
274    /// 2D Rotary Position Encoding.
275    positional_encoding: RotaryPositionEncoding2D<B>,
276    /// Stack of transformer blocks.
277    blocks: Vec<TransformerBlock<B>>,
278    /// Final layer normalization.
279    norm: LayerNorm<B>,
280    /// Output embedding dimension.
281    embed_dim: usize,
282}
283
284impl<B: Backend> VitEncoder<B> {
285    fn positioned_patch_tokens(&self, images: &Tensor<B, 4>) -> Tensor<B, 3> {
286        // 1. Patch embedding
287        let x = self.patch_embed.forward(images.clone());
288
289        // 2. Apply RoPE before any masking so absolute token positions remain correct.
290        self.positional_encoding.forward(x)
291    }
292
293    fn encode_positioned_tokens(&self, mut x: Tensor<B, 3>) -> Representation<B> {
294        // Transformer blocks
295        for block in &self.blocks {
296            x = block.forward(x);
297        }
298
299        // Layer norm
300        x = self.norm.forward(x);
301
302        Representation::new(x)
303    }
304
305    /// Forward pass: image → representation.
306    ///
307    /// # Arguments
308    /// * `images` - Input images. Shape: `[batch, channels, height, width]`
309    ///
310    /// # Returns
311    /// Patch-level representations. Shape: `[batch, num_patches, embed_dim]`
312    pub fn forward(&self, images: &Tensor<B, 4>) -> Representation<B> {
313        let x = self.positioned_patch_tokens(images);
314        self.encode_positioned_tokens(x)
315    }
316
317    /// Encode only the visible patch tokens for strict JEPA context encoding.
318    ///
319    /// The image is patchified and position-encoded using the full grid so the
320    /// surviving tokens retain their real flattened positions, then masked
321    /// tokens are removed before self-attention runs.
322    pub fn forward_visible_tokens(
323        &self,
324        images: &Tensor<B, 4>,
325        visible_indices: &[usize],
326    ) -> Representation<B> {
327        let x = self.positioned_patch_tokens(images);
328        let x = gather_token_sequence(x, visible_indices);
329        self.encode_positioned_tokens(x)
330    }
331
332    /// Load a ViT encoder from a map of burn-style parameter names to tensor data.
333    ///
334    /// Expected parameter names match the burn module record layout, for example
335    /// `patch_embed.projection.weight` and `blocks.0.attn.out_proj.bias`.
336    pub fn load_named_tensors(
337        self,
338        tensors: &HashMap<String, TensorData>,
339    ) -> Result<Self, VitLoadError> {
340        let mut record = self.clone().into_record();
341
342        load_linear_record(
343            &mut record.patch_embed.projection,
344            "patch_embed.projection",
345            tensors,
346        )?;
347
348        for (index, block) in record.blocks.iter_mut().enumerate() {
349            load_layer_norm_record(&mut block.norm1, &format!("blocks.{index}.norm1"), tensors)?;
350            load_linear_record(
351                &mut block.attn.qkv,
352                &format!("blocks.{index}.attn.qkv"),
353                tensors,
354            )?;
355            load_linear_record(
356                &mut block.attn.out_proj,
357                &format!("blocks.{index}.attn.out_proj"),
358                tensors,
359            )?;
360            load_layer_norm_record(&mut block.norm2, &format!("blocks.{index}.norm2"), tensors)?;
361            load_linear_record(
362                &mut block.mlp.fc1,
363                &format!("blocks.{index}.mlp.fc1"),
364                tensors,
365            )?;
366            load_linear_record(
367                &mut block.mlp.fc2,
368                &format!("blocks.{index}.mlp.fc2"),
369                tensors,
370            )?;
371        }
372
373        load_layer_norm_record(&mut record.norm, "norm", tensors)?;
374
375        Ok(self.load_record(record))
376    }
377
378    /// Update this encoder toward an online encoder using EMA.
379    ///
380    /// The returned encoder preserves the gradient setting of the target
381    /// encoder parameters while detaching the blended tensors from any active
382    /// autodiff graph.
383    pub fn ema_update_from(self, online: &Self, ema: &Ema, step: usize) -> Self {
384        let mut target_record = self.clone().into_record();
385        let online_record = online.clone().into_record();
386
387        ema_update_linear_record(
388            &mut target_record.patch_embed.projection,
389            &online_record.patch_embed.projection,
390            ema,
391            step,
392        );
393
394        for (target_block, online_block) in target_record
395            .blocks
396            .iter_mut()
397            .zip(online_record.blocks.iter())
398        {
399            ema_update_layer_norm_record(&mut target_block.norm1, &online_block.norm1, ema, step);
400            ema_update_linear_record(
401                &mut target_block.attn.qkv,
402                &online_block.attn.qkv,
403                ema,
404                step,
405            );
406            ema_update_linear_record(
407                &mut target_block.attn.out_proj,
408                &online_block.attn.out_proj,
409                ema,
410                step,
411            );
412            ema_update_layer_norm_record(&mut target_block.norm2, &online_block.norm2, ema, step);
413            ema_update_linear_record(&mut target_block.mlp.fc1, &online_block.mlp.fc1, ema, step);
414            ema_update_linear_record(&mut target_block.mlp.fc2, &online_block.mlp.fc2, ema, step);
415        }
416
417        ema_update_layer_norm_record(&mut target_record.norm, &online_record.norm, ema, step);
418
419        self.load_record(target_record)
420    }
421}
422
423impl<B: Backend> Encoder<B> for VitEncoder<B> {
424    type Input = Tensor<B, 4>;
425
426    fn encode(&self, input: &Self::Input) -> Representation<B> {
427        self.forward(input)
428    }
429
430    fn embed_dim(&self) -> usize {
431        self.embed_dim
432    }
433}
434
435fn load_linear_record<B: Backend>(
436    record: &mut LinearRecord<B>,
437    prefix: &str,
438    tensors: &HashMap<String, TensorData>,
439) -> Result<(), VitLoadError> {
440    load_param_from_tensors(&mut record.weight, &format!("{prefix}.weight"), tensors)?;
441    load_optional_param_from_tensors(&mut record.bias, &format!("{prefix}.bias"), tensors)?;
442    Ok(())
443}
444
445fn load_layer_norm_record<B: Backend>(
446    record: &mut LayerNormRecord<B>,
447    prefix: &str,
448    tensors: &HashMap<String, TensorData>,
449) -> Result<(), VitLoadError> {
450    load_param_from_tensors(&mut record.gamma, &format!("{prefix}.weight"), tensors)?;
451    load_optional_param_from_tensors(&mut record.beta, &format!("{prefix}.bias"), tensors)?;
452    Ok(())
453}
454
455fn load_param_from_tensors<B: Backend, const D: usize>(
456    param: &mut Param<Tensor<B, D>>,
457    key: &str,
458    tensors: &HashMap<String, TensorData>,
459) -> Result<(), VitLoadError> {
460    let tensor = tensors
461        .get(key)
462        .ok_or_else(|| VitLoadError::MissingKey(key.to_string()))?;
463    let expected_shape = param.lazy_shape().dims;
464    if tensor.shape != expected_shape {
465        return Err(VitLoadError::ShapeMismatch {
466            key: key.to_string(),
467            checkpoint_shape: tensor.shape.clone(),
468            model_shape: expected_shape,
469        });
470    }
471
472    *param = param
473        .clone()
474        .load_record(Param::from_data(tensor.clone(), &param.lazy_device()));
475    Ok(())
476}
477
478fn load_optional_param_from_tensors<B: Backend, const D: usize>(
479    param: &mut Option<Param<Tensor<B, D>>>,
480    key: &str,
481    tensors: &HashMap<String, TensorData>,
482) -> Result<(), VitLoadError> {
483    let Some(inner) = param else {
484        return Ok(());
485    };
486
487    load_param_from_tensors(inner, key, tensors)
488}
489
490fn ema_update_linear_record<B: Backend>(
491    target: &mut LinearRecord<B>,
492    online: &LinearRecord<B>,
493    ema: &Ema,
494    step: usize,
495) {
496    ema_update_param(&mut target.weight, &online.weight, ema, step);
497    ema_update_optional_param(&mut target.bias, &online.bias, ema, step);
498}
499
500fn ema_update_layer_norm_record<B: Backend>(
501    target: &mut LayerNormRecord<B>,
502    online: &LayerNormRecord<B>,
503    ema: &Ema,
504    step: usize,
505) {
506    ema_update_param(&mut target.gamma, &online.gamma, ema, step);
507    ema_update_optional_param(&mut target.beta, &online.beta, ema, step);
508}
509
510fn ema_update_param<B: Backend, const D: usize>(
511    target: &mut Param<Tensor<B, D>>,
512    online: &Param<Tensor<B, D>>,
513    ema: &Ema,
514    step: usize,
515) {
516    let param_id = target.clone().consume().0;
517    let updated = ema.update_tensor(target.val().detach(), &online.val().detach(), step);
518    let record = Param::initialized(param_id, updated.detach());
519    *target = target.clone().load_record(record);
520}
521
522fn ema_update_optional_param<B: Backend, const D: usize>(
523    target: &mut Option<Param<Tensor<B, D>>>,
524    online: &Option<Param<Tensor<B, D>>>,
525    ema: &Ema,
526    step: usize,
527) {
528    let (Some(target), Some(online)) = (target, online) else {
529        return;
530    };
531
532    ema_update_param(target, online, ema, step);
533}
534
535// --- Transformer components ---
536
537/// Configuration for a transformer block.
538#[derive(Debug, Clone)]
539struct TransformerBlockConfig {
540    embed_dim: usize,
541    num_heads: usize,
542    mlp_dim: usize,
543}
544
545impl TransformerBlockConfig {
546    fn init<B: Backend>(&self, device: &B::Device) -> TransformerBlock<B> {
547        TransformerBlock {
548            norm1: LayerNormConfig::new(self.embed_dim).init(device),
549            attn: MultiHeadSelfAttentionConfig {
550                embed_dim: self.embed_dim,
551                num_heads: self.num_heads,
552            }
553            .init(device),
554            norm2: LayerNormConfig::new(self.embed_dim).init(device),
555            mlp: MlpConfig {
556                in_dim: self.embed_dim,
557                hidden_dim: self.mlp_dim,
558            }
559            .init(device),
560        }
561    }
562}
563
564/// Pre-norm transformer block: LN → Attention → residual → LN → MLP → residual.
565#[derive(Module, Debug)]
566struct TransformerBlock<B: Backend> {
567    norm1: LayerNorm<B>,
568    attn: MultiHeadSelfAttention<B>,
569    norm2: LayerNorm<B>,
570    mlp: Mlp<B>,
571}
572
573impl<B: Backend> TransformerBlock<B> {
574    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
575        // Pre-norm attention with residual
576        let residual = x.clone();
577        let x_norm = self.norm1.forward(x);
578        let attn_out = self.attn.forward(x_norm);
579        let x = residual + attn_out;
580
581        // Pre-norm MLP with residual
582        let residual = x.clone();
583        let x_norm = self.norm2.forward(x);
584        let mlp_out = self.mlp.forward(x_norm);
585        residual + mlp_out
586    }
587}
588
589// --- Multi-Head Self-Attention ---
590
591#[derive(Debug, Clone)]
592struct MultiHeadSelfAttentionConfig {
593    embed_dim: usize,
594    num_heads: usize,
595}
596
597impl MultiHeadSelfAttentionConfig {
598    fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
599        let head_dim = self.embed_dim / self.num_heads;
600        MultiHeadSelfAttention {
601            qkv: LinearConfig::new(self.embed_dim, 3 * self.embed_dim).init(device),
602            out_proj: LinearConfig::new(self.embed_dim, self.embed_dim).init(device),
603            num_heads: self.num_heads,
604            head_dim,
605        }
606    }
607}
608
609/// Multi-head self-attention.
610///
611/// Computes scaled dot-product attention across multiple heads.
612#[derive(Module, Debug)]
613struct MultiHeadSelfAttention<B: Backend> {
614    /// Combined QKV projection.
615    qkv: Linear<B>,
616    /// Output projection.
617    out_proj: Linear<B>,
618    /// Number of attention heads.
619    num_heads: usize,
620    /// Dimension per head.
621    head_dim: usize,
622}
623
624impl<B: Backend> MultiHeadSelfAttention<B> {
625    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
626        let [batch, seq_len, _embed_dim] = x.dims();
627        let embed_dim = self.num_heads * self.head_dim;
628
629        // Combined QKV: [batch, seq_len, 3 * embed_dim]
630        let qkv = self.qkv.forward(x);
631
632        // Split into Q, K, V
633        let q = qkv.clone().slice([0..batch, 0..seq_len, 0..embed_dim]);
634        let k = qkv
635            .clone()
636            .slice([0..batch, 0..seq_len, embed_dim..2 * embed_dim]);
637        let v = qkv.slice([0..batch, 0..seq_len, 2 * embed_dim..3 * embed_dim]);
638
639        // Reshape to multi-head: [batch, seq_len, num_heads, head_dim] → [batch, num_heads, seq_len, head_dim]
640        let q = q
641            .reshape([batch, seq_len, self.num_heads, self.head_dim])
642            .swap_dims(1, 2);
643        let k = k
644            .reshape([batch, seq_len, self.num_heads, self.head_dim])
645            .swap_dims(1, 2);
646        let v = v
647            .reshape([batch, seq_len, self.num_heads, self.head_dim])
648            .swap_dims(1, 2);
649
650        // Scaled dot-product attention
651        let scale = (self.head_dim as f64).sqrt();
652        let attn_weights = q.matmul(k.transpose()) / scale; // [batch, heads, seq, seq]
653        let attn_weights = burn::tensor::activation::softmax(attn_weights, 3);
654
655        // Apply attention to values
656        let out = attn_weights.matmul(v); // [batch, heads, seq, head_dim]
657
658        // Reshape back: [batch, seq_len, embed_dim]
659        let out = out.swap_dims(1, 2).reshape([batch, seq_len, embed_dim]);
660
661        self.out_proj.forward(out)
662    }
663}
664
665// --- MLP ---
666
667#[derive(Debug, Clone)]
668struct MlpConfig {
669    in_dim: usize,
670    hidden_dim: usize,
671}
672
673impl MlpConfig {
674    fn init<B: Backend>(&self, device: &B::Device) -> Mlp<B> {
675        Mlp {
676            fc1: LinearConfig::new(self.in_dim, self.hidden_dim).init(device),
677            fc2: LinearConfig::new(self.hidden_dim, self.in_dim).init(device),
678        }
679    }
680}
681
682/// Two-layer MLP with GELU activation.
683#[derive(Module, Debug)]
684struct Mlp<B: Backend> {
685    fc1: Linear<B>,
686    fc2: Linear<B>,
687}
688
689impl<B: Backend> Mlp<B> {
690    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
691        let x = self.fc1.forward(x);
692        let x = burn::tensor::activation::gelu(x);
693        self.fc2.forward(x)
694    }
695}
696
697#[cfg(test)]
698mod tests {
699    use super::*;
700    use burn_ndarray::NdArray;
701    use std::collections::HashMap;
702
703    type TestBackend = NdArray<f32>;
704
705    fn device() -> burn_ndarray::NdArrayDevice {
706        burn_ndarray::NdArrayDevice::Cpu
707    }
708
709    #[test]
710    fn test_vit_encoder_output_shape() {
711        let config = VitConfig::tiny_test();
712        let encoder = config.init::<TestBackend>(&device());
713
714        let images: Tensor<TestBackend, 4> = Tensor::zeros([2, 1, 8, 8], &device());
715        let repr = encoder.forward(&images);
716
717        // 8/2 = 4 patches per side, 4*4 = 16 patches total
718        assert_eq!(repr.batch_size(), 2);
719        assert_eq!(repr.seq_len(), 16);
720        assert_eq!(repr.embed_dim(), 32);
721    }
722
723    #[test]
724    fn test_vit_encoder_trait_impl() {
725        let config = VitConfig::tiny_test();
726        let encoder = config.init::<TestBackend>(&device());
727
728        let images: Tensor<TestBackend, 4> = Tensor::zeros([1, 1, 8, 8], &device());
729        let repr = Encoder::encode(&encoder, &images);
730
731        assert_eq!(repr.batch_size(), 1);
732        assert_eq!(repr.seq_len(), 16);
733        assert_eq!(encoder.embed_dim(), 32);
734    }
735
736    #[test]
737    fn test_vit_encoder_different_inputs_different_outputs() {
738        let config = VitConfig::tiny_test();
739        let encoder = config.init::<TestBackend>(&device());
740
741        let a: Tensor<TestBackend, 4> = Tensor::zeros([1, 1, 8, 8], &device());
742        let b: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
743
744        let repr_a = encoder.forward(&a);
745        let repr_b = encoder.forward(&b);
746
747        let diff: f32 = (repr_a.embeddings - repr_b.embeddings)
748            .abs()
749            .sum()
750            .into_scalar()
751            .elem();
752        assert!(
753            diff > 1e-6,
754            "different inputs should produce different representations"
755        );
756    }
757
758    #[test]
759    fn test_transformer_block_residual() {
760        // Verify the transformer block preserves the residual connection
761        let block = TransformerBlockConfig {
762            embed_dim: 16,
763            num_heads: 2,
764            mlp_dim: 32,
765        }
766        .init::<TestBackend>(&device());
767
768        let x: Tensor<TestBackend, 3> = Tensor::zeros([1, 4, 16], &device());
769        let out = block.forward(x);
770        assert_eq!(out.dims(), [1, 4, 16]);
771    }
772
773    #[test]
774    fn test_mhsa_output_shape() {
775        let attn = MultiHeadSelfAttentionConfig {
776            embed_dim: 16,
777            num_heads: 4,
778        }
779        .init::<TestBackend>(&device());
780
781        let x: Tensor<TestBackend, 3> = Tensor::zeros([2, 8, 16], &device());
782        let out = attn.forward(x);
783        assert_eq!(out.dims(), [2, 8, 16]);
784    }
785
786    #[test]
787    fn test_mlp_output_shape() {
788        let mlp = MlpConfig {
789            in_dim: 16,
790            hidden_dim: 64,
791        }
792        .init::<TestBackend>(&device());
793
794        let x: Tensor<TestBackend, 3> = Tensor::zeros([2, 8, 16], &device());
795        let out = mlp.forward(x);
796        assert_eq!(out.dims(), [2, 8, 16]);
797    }
798
799    fn checkpoint_tensors_from_encoder(
800        encoder: &VitEncoder<TestBackend>,
801    ) -> HashMap<String, TensorData> {
802        let record = encoder.clone().into_record();
803        let mut tensors = HashMap::new();
804
805        insert_linear_tensors(
806            &mut tensors,
807            "patch_embed.projection",
808            &record.patch_embed.projection,
809        );
810
811        for (index, block) in record.blocks.iter().enumerate() {
812            insert_layer_norm_tensors(&mut tensors, &format!("blocks.{index}.norm1"), &block.norm1);
813            insert_linear_tensors(
814                &mut tensors,
815                &format!("blocks.{index}.attn.qkv"),
816                &block.attn.qkv,
817            );
818            insert_linear_tensors(
819                &mut tensors,
820                &format!("blocks.{index}.attn.out_proj"),
821                &block.attn.out_proj,
822            );
823            insert_layer_norm_tensors(&mut tensors, &format!("blocks.{index}.norm2"), &block.norm2);
824            insert_linear_tensors(
825                &mut tensors,
826                &format!("blocks.{index}.mlp.fc1"),
827                &block.mlp.fc1,
828            );
829            insert_linear_tensors(
830                &mut tensors,
831                &format!("blocks.{index}.mlp.fc2"),
832                &block.mlp.fc2,
833            );
834        }
835
836        insert_layer_norm_tensors(&mut tensors, "norm", &record.norm);
837
838        tensors
839    }
840
841    fn insert_linear_tensors(
842        tensors: &mut HashMap<String, TensorData>,
843        prefix: &str,
844        record: &LinearRecord<TestBackend>,
845    ) {
846        tensors.insert(format!("{prefix}.weight"), record.weight.val().to_data());
847        if let Some(bias) = &record.bias {
848            tensors.insert(format!("{prefix}.bias"), bias.val().to_data());
849        }
850    }
851
852    fn insert_layer_norm_tensors(
853        tensors: &mut HashMap<String, TensorData>,
854        prefix: &str,
855        record: &LayerNormRecord<TestBackend>,
856    ) {
857        tensors.insert(format!("{prefix}.weight"), record.gamma.val().to_data());
858        if let Some(beta) = &record.beta {
859            tensors.insert(format!("{prefix}.bias"), beta.val().to_data());
860        }
861    }
862
863    #[test]
864    fn test_vit_encoder_load_named_tensors_restores_encoder_state() {
865        let config = VitConfig::tiny_test();
866        let source = config.init::<TestBackend>(&device());
867        let target = config.init::<TestBackend>(&device());
868        let tensors = checkpoint_tensors_from_encoder(&source);
869
870        let loaded = target
871            .load_named_tensors(&tensors)
872            .expect("loading tensors exported from a matching encoder should succeed");
873
874        let images: Tensor<TestBackend, 4> = Tensor::random(
875            [2, 1, 8, 8],
876            burn::tensor::Distribution::Normal(0.0, 1.0),
877            &device(),
878        );
879
880        let source_repr = source.forward(&images);
881        let loaded_repr = loaded.forward(&images);
882        let diff: f32 = (source_repr.embeddings - loaded_repr.embeddings)
883            .abs()
884            .sum()
885            .into_scalar()
886            .elem();
887        assert!(
888            diff < 1e-6,
889            "loading the exported tensors should restore the encoder exactly, diff={diff}"
890        );
891    }
892
893    #[test]
894    fn test_vit_encoder_load_named_tensors_rejects_shape_mismatch() {
895        let config = VitConfig::tiny_test();
896        let encoder = config.init::<TestBackend>(&device());
897        let mut tensors = checkpoint_tensors_from_encoder(&encoder);
898        tensors.insert(
899            "norm.weight".to_string(),
900            TensorData::new(vec![1.0f32; 31], [31]),
901        );
902
903        let err = config
904            .init::<TestBackend>(&device())
905            .load_named_tensors(&tensors)
906            .expect_err("shape mismatch should be reported");
907
908        assert!(matches!(
909            err,
910            VitLoadError::ShapeMismatch { key, .. } if key == "norm.weight"
911        ));
912    }
913
914    #[test]
915    fn test_vit_encoder_ema_update_moves_target_toward_online() {
916        let config = VitConfig::tiny_test();
917        let target = config.init::<TestBackend>(&device());
918        let online = config.init::<TestBackend>(&device());
919        let ema = Ema::new(0.5);
920        let images: Tensor<TestBackend, 4> = Tensor::random(
921            [1, 1, 8, 8],
922            burn::tensor::Distribution::Normal(0.0, 1.0),
923            &device(),
924        );
925
926        let target_before = target.forward(&images);
927        let online_before = online.forward(&images);
928        let updated = target.clone().ema_update_from(&online, &ema, 0);
929        let updated_repr = updated.forward(&images);
930
931        let before_distance: f32 = (target_before.embeddings.clone()
932            - online_before.embeddings.clone())
933        .abs()
934        .sum()
935        .into_scalar()
936        .elem();
937        let after_distance: f32 = (updated_repr.embeddings - online_before.embeddings)
938            .abs()
939            .sum()
940            .into_scalar()
941            .elem();
942
943        assert!(
944            after_distance < before_distance,
945            "EMA update should move target toward online encoder"
946        );
947    }
948
949    use burn::tensor::ElementConversion;
950    use proptest::prelude::*;
951
952    proptest! {
953        /// Property: ViT encoder output is always finite (no NaN/Inf) for
954        /// small normally-distributed inputs.
955        #[test]
956        fn prop_vit_output_is_finite(batch in 1usize..3) {
957            let config = VitConfig::tiny_test();
958            let encoder = config.init::<TestBackend>(&device());
959
960            let images: Tensor<TestBackend, 4> = Tensor::random(
961                [batch, 1, 8, 8],
962                burn::tensor::Distribution::Normal(0.0, 1.0),
963                &device(),
964            );
965            let repr = encoder.forward(&images);
966
967            // Check shape
968            prop_assert_eq!(repr.batch_size(), batch);
969            prop_assert_eq!(repr.seq_len(), 16);
970            prop_assert_eq!(repr.embed_dim(), 32);
971
972            // Check finiteness: sum of abs should be finite and non-NaN
973            let total: f32 = repr.embeddings.abs().sum().into_scalar().elem();
974            prop_assert!(total.is_finite(), "ViT output should be finite, got {}", total);
975        }
976
977        /// Property: ViT encoder is deterministic — same input always produces
978        /// the same output.
979        #[test]
980        fn prop_vit_is_deterministic(batch in 1usize..3) {
981            let config = VitConfig::tiny_test();
982            let encoder = config.init::<TestBackend>(&device());
983
984            let images: Tensor<TestBackend, 4> = Tensor::ones([batch, 1, 8, 8], &device());
985            let repr1 = encoder.forward(&images);
986            let repr2 = encoder.forward(&images);
987
988            let diff: f32 = (repr1.embeddings - repr2.embeddings)
989                .abs()
990                .sum()
991                .into_scalar()
992                .elem();
993            prop_assert!(diff < 1e-6, "ViT should be deterministic, diff={}", diff);
994        }
995
996        /// Property: transformer block preserves tensor dimensions for any
997        /// valid (embed_dim, num_heads) combination where embed_dim % num_heads == 0.
998        #[test]
999        fn prop_transformer_block_preserves_shape(
1000            seq_len in 2usize..8,
1001            num_heads in proptest::sample::select(vec![2usize, 4]),
1002        ) {
1003            let embed_dim = 16; // divisible by 2 and 4
1004            let block = TransformerBlockConfig {
1005                embed_dim,
1006                num_heads,
1007                mlp_dim: embed_dim * 4,
1008            }
1009            .init::<TestBackend>(&device());
1010
1011            let x: Tensor<TestBackend, 3> = Tensor::random(
1012                [1, seq_len, embed_dim],
1013                burn::tensor::Distribution::Normal(0.0, 1.0),
1014                &device(),
1015            );
1016            let out = block.forward(x);
1017            prop_assert_eq!(out.dims(), [1, seq_len, embed_dim]);
1018
1019            let total: f32 = out.abs().sum().into_scalar().elem();
1020            prop_assert!(total.is_finite(), "block output should be finite");
1021        }
1022    }
1023}