1use std::collections::HashMap;
30
31use burn::module::{Module, Param};
32use burn::nn::{LayerNorm, LayerNormConfig, LayerNormRecord, Linear, LinearConfig, LinearRecord};
33use burn::prelude::*;
34use burn::tensor::backend::Backend;
35use burn::tensor::TensorData;
36
37use jepa_core::ema::Ema;
38use jepa_core::types::Representation;
39use jepa_core::Encoder;
40
41use crate::patch::{PatchEmbedding, PatchEmbeddingConfig};
42use crate::rope::{RotaryPositionEncoding2D, RotaryPositionEncoding2DConfig};
43use crate::token_ops::gather_token_sequence;
44
45#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)]
47pub enum VitLoadError {
48 #[error("missing checkpoint tensor `{0}`")]
49 MissingKey(String),
50 #[error(
51 "shape mismatch for `{key}`: checkpoint {checkpoint_shape:?} vs model {model_shape:?}"
52 )]
53 ShapeMismatch {
54 key: String,
55 checkpoint_shape: Vec<usize>,
56 model_shape: Vec<usize>,
57 },
58}
59
60#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
77pub struct VitConfig {
78 pub in_channels: usize,
80 pub image_height: usize,
82 pub image_width: usize,
84 pub patch_size: (usize, usize),
86 pub embed_dim: usize,
88 pub num_layers: usize,
90 pub num_heads: usize,
92 pub mlp_dim: usize,
94 pub dropout: f64,
96}
97
98impl VitConfig {
99 pub fn vit_base_patch16() -> Self {
101 Self {
102 in_channels: 3,
103 image_height: 224,
104 image_width: 224,
105 patch_size: (16, 16),
106 embed_dim: 768,
107 num_layers: 12,
108 num_heads: 12,
109 mlp_dim: 3072,
110 dropout: 0.0,
111 }
112 }
113
114 pub fn vit_small_patch16() -> Self {
116 Self {
117 in_channels: 3,
118 image_height: 224,
119 image_width: 224,
120 patch_size: (16, 16),
121 embed_dim: 384,
122 num_layers: 12,
123 num_heads: 6,
124 mlp_dim: 1536,
125 dropout: 0.0,
126 }
127 }
128
129 pub fn vit_large_patch16() -> Self {
133 Self {
134 in_channels: 3,
135 image_height: 224,
136 image_width: 224,
137 patch_size: (16, 16),
138 embed_dim: 1024,
139 num_layers: 24,
140 num_heads: 16,
141 mlp_dim: 4096,
142 dropout: 0.0,
143 }
144 }
145
146 pub fn vit_huge_patch14() -> Self {
151 Self {
152 in_channels: 3,
153 image_height: 224,
154 image_width: 224,
155 patch_size: (14, 14),
156 embed_dim: 1280,
157 num_layers: 32,
158 num_heads: 16,
159 mlp_dim: 5120,
160 dropout: 0.0,
161 }
162 }
163
164 pub fn vit_huge_patch16_448() -> Self {
168 Self {
169 in_channels: 3,
170 image_height: 448,
171 image_width: 448,
172 patch_size: (16, 16),
173 embed_dim: 1280,
174 num_layers: 32,
175 num_heads: 16,
176 mlp_dim: 5120,
177 dropout: 0.0,
178 }
179 }
180
181 pub fn vit_giant_patch16() -> Self {
185 Self {
186 in_channels: 3,
187 image_height: 224,
188 image_width: 224,
189 patch_size: (16, 16),
190 embed_dim: 1408,
191 num_layers: 40,
192 num_heads: 16,
193 mlp_dim: 6144,
194 dropout: 0.0,
195 }
196 }
197
198 pub fn tiny_test() -> Self {
200 Self {
201 in_channels: 1,
202 image_height: 8,
203 image_width: 8,
204 patch_size: (2, 2),
205 embed_dim: 32,
206 num_layers: 2,
207 num_heads: 4,
208 mlp_dim: 64,
209 dropout: 0.0,
210 }
211 }
212
213 fn grid_height(&self) -> usize {
214 self.image_height / self.patch_size.0
215 }
216
217 fn grid_width(&self) -> usize {
218 self.image_width / self.patch_size.1
219 }
220
221 pub fn init<B: Backend>(&self, device: &B::Device) -> VitEncoder<B> {
223 let patch_embed_config = PatchEmbeddingConfig::new(
224 self.in_channels,
225 self.patch_size.0,
226 self.patch_size.1,
227 self.embed_dim,
228 );
229 let patch_embed = patch_embed_config.init(device);
230
231 let rope_config = RotaryPositionEncoding2DConfig::new(
232 self.embed_dim,
233 self.grid_height(),
234 self.grid_width(),
235 );
236 let positional_encoding = rope_config.init(device);
237
238 let blocks: Vec<TransformerBlock<B>> = (0..self.num_layers)
239 .map(|_| {
240 TransformerBlockConfig {
241 embed_dim: self.embed_dim,
242 num_heads: self.num_heads,
243 mlp_dim: self.mlp_dim,
244 }
245 .init(device)
246 })
247 .collect();
248
249 let norm = LayerNormConfig::new(self.embed_dim).init(device);
250
251 VitEncoder {
252 patch_embed,
253 positional_encoding,
254 blocks,
255 norm,
256 embed_dim: self.embed_dim,
257 }
258 }
259}
260
261#[derive(Module, Debug)]
271pub struct VitEncoder<B: Backend> {
272 patch_embed: PatchEmbedding<B>,
274 positional_encoding: RotaryPositionEncoding2D<B>,
276 blocks: Vec<TransformerBlock<B>>,
278 norm: LayerNorm<B>,
280 embed_dim: usize,
282}
283
284impl<B: Backend> VitEncoder<B> {
285 fn positioned_patch_tokens(&self, images: &Tensor<B, 4>) -> Tensor<B, 3> {
286 let x = self.patch_embed.forward(images.clone());
288
289 self.positional_encoding.forward(x)
291 }
292
293 fn encode_positioned_tokens(&self, mut x: Tensor<B, 3>) -> Representation<B> {
294 for block in &self.blocks {
296 x = block.forward(x);
297 }
298
299 x = self.norm.forward(x);
301
302 Representation::new(x)
303 }
304
305 pub fn forward(&self, images: &Tensor<B, 4>) -> Representation<B> {
313 let x = self.positioned_patch_tokens(images);
314 self.encode_positioned_tokens(x)
315 }
316
317 pub fn forward_visible_tokens(
323 &self,
324 images: &Tensor<B, 4>,
325 visible_indices: &[usize],
326 ) -> Representation<B> {
327 let x = self.positioned_patch_tokens(images);
328 let x = gather_token_sequence(x, visible_indices);
329 self.encode_positioned_tokens(x)
330 }
331
332 pub fn load_named_tensors(
337 self,
338 tensors: &HashMap<String, TensorData>,
339 ) -> Result<Self, VitLoadError> {
340 let mut record = self.clone().into_record();
341
342 load_linear_record(
343 &mut record.patch_embed.projection,
344 "patch_embed.projection",
345 tensors,
346 )?;
347
348 for (index, block) in record.blocks.iter_mut().enumerate() {
349 load_layer_norm_record(&mut block.norm1, &format!("blocks.{index}.norm1"), tensors)?;
350 load_linear_record(
351 &mut block.attn.qkv,
352 &format!("blocks.{index}.attn.qkv"),
353 tensors,
354 )?;
355 load_linear_record(
356 &mut block.attn.out_proj,
357 &format!("blocks.{index}.attn.out_proj"),
358 tensors,
359 )?;
360 load_layer_norm_record(&mut block.norm2, &format!("blocks.{index}.norm2"), tensors)?;
361 load_linear_record(
362 &mut block.mlp.fc1,
363 &format!("blocks.{index}.mlp.fc1"),
364 tensors,
365 )?;
366 load_linear_record(
367 &mut block.mlp.fc2,
368 &format!("blocks.{index}.mlp.fc2"),
369 tensors,
370 )?;
371 }
372
373 load_layer_norm_record(&mut record.norm, "norm", tensors)?;
374
375 Ok(self.load_record(record))
376 }
377
378 pub fn ema_update_from(self, online: &Self, ema: &Ema, step: usize) -> Self {
384 let mut target_record = self.clone().into_record();
385 let online_record = online.clone().into_record();
386
387 ema_update_linear_record(
388 &mut target_record.patch_embed.projection,
389 &online_record.patch_embed.projection,
390 ema,
391 step,
392 );
393
394 for (target_block, online_block) in target_record
395 .blocks
396 .iter_mut()
397 .zip(online_record.blocks.iter())
398 {
399 ema_update_layer_norm_record(&mut target_block.norm1, &online_block.norm1, ema, step);
400 ema_update_linear_record(
401 &mut target_block.attn.qkv,
402 &online_block.attn.qkv,
403 ema,
404 step,
405 );
406 ema_update_linear_record(
407 &mut target_block.attn.out_proj,
408 &online_block.attn.out_proj,
409 ema,
410 step,
411 );
412 ema_update_layer_norm_record(&mut target_block.norm2, &online_block.norm2, ema, step);
413 ema_update_linear_record(&mut target_block.mlp.fc1, &online_block.mlp.fc1, ema, step);
414 ema_update_linear_record(&mut target_block.mlp.fc2, &online_block.mlp.fc2, ema, step);
415 }
416
417 ema_update_layer_norm_record(&mut target_record.norm, &online_record.norm, ema, step);
418
419 self.load_record(target_record)
420 }
421}
422
423impl<B: Backend> Encoder<B> for VitEncoder<B> {
424 type Input = Tensor<B, 4>;
425
426 fn encode(&self, input: &Self::Input) -> Representation<B> {
427 self.forward(input)
428 }
429
430 fn embed_dim(&self) -> usize {
431 self.embed_dim
432 }
433}
434
435fn load_linear_record<B: Backend>(
436 record: &mut LinearRecord<B>,
437 prefix: &str,
438 tensors: &HashMap<String, TensorData>,
439) -> Result<(), VitLoadError> {
440 load_param_from_tensors(&mut record.weight, &format!("{prefix}.weight"), tensors)?;
441 load_optional_param_from_tensors(&mut record.bias, &format!("{prefix}.bias"), tensors)?;
442 Ok(())
443}
444
445fn load_layer_norm_record<B: Backend>(
446 record: &mut LayerNormRecord<B>,
447 prefix: &str,
448 tensors: &HashMap<String, TensorData>,
449) -> Result<(), VitLoadError> {
450 load_param_from_tensors(&mut record.gamma, &format!("{prefix}.weight"), tensors)?;
451 load_optional_param_from_tensors(&mut record.beta, &format!("{prefix}.bias"), tensors)?;
452 Ok(())
453}
454
455fn load_param_from_tensors<B: Backend, const D: usize>(
456 param: &mut Param<Tensor<B, D>>,
457 key: &str,
458 tensors: &HashMap<String, TensorData>,
459) -> Result<(), VitLoadError> {
460 let tensor = tensors
461 .get(key)
462 .ok_or_else(|| VitLoadError::MissingKey(key.to_string()))?;
463 let expected_shape = param.lazy_shape().dims;
464 if tensor.shape != expected_shape {
465 return Err(VitLoadError::ShapeMismatch {
466 key: key.to_string(),
467 checkpoint_shape: tensor.shape.clone(),
468 model_shape: expected_shape,
469 });
470 }
471
472 *param = param
473 .clone()
474 .load_record(Param::from_data(tensor.clone(), ¶m.lazy_device()));
475 Ok(())
476}
477
478fn load_optional_param_from_tensors<B: Backend, const D: usize>(
479 param: &mut Option<Param<Tensor<B, D>>>,
480 key: &str,
481 tensors: &HashMap<String, TensorData>,
482) -> Result<(), VitLoadError> {
483 let Some(inner) = param else {
484 return Ok(());
485 };
486
487 load_param_from_tensors(inner, key, tensors)
488}
489
490fn ema_update_linear_record<B: Backend>(
491 target: &mut LinearRecord<B>,
492 online: &LinearRecord<B>,
493 ema: &Ema,
494 step: usize,
495) {
496 ema_update_param(&mut target.weight, &online.weight, ema, step);
497 ema_update_optional_param(&mut target.bias, &online.bias, ema, step);
498}
499
500fn ema_update_layer_norm_record<B: Backend>(
501 target: &mut LayerNormRecord<B>,
502 online: &LayerNormRecord<B>,
503 ema: &Ema,
504 step: usize,
505) {
506 ema_update_param(&mut target.gamma, &online.gamma, ema, step);
507 ema_update_optional_param(&mut target.beta, &online.beta, ema, step);
508}
509
510fn ema_update_param<B: Backend, const D: usize>(
511 target: &mut Param<Tensor<B, D>>,
512 online: &Param<Tensor<B, D>>,
513 ema: &Ema,
514 step: usize,
515) {
516 let param_id = target.clone().consume().0;
517 let updated = ema.update_tensor(target.val().detach(), &online.val().detach(), step);
518 let record = Param::initialized(param_id, updated.detach());
519 *target = target.clone().load_record(record);
520}
521
522fn ema_update_optional_param<B: Backend, const D: usize>(
523 target: &mut Option<Param<Tensor<B, D>>>,
524 online: &Option<Param<Tensor<B, D>>>,
525 ema: &Ema,
526 step: usize,
527) {
528 let (Some(target), Some(online)) = (target, online) else {
529 return;
530 };
531
532 ema_update_param(target, online, ema, step);
533}
534
535#[derive(Debug, Clone)]
539struct TransformerBlockConfig {
540 embed_dim: usize,
541 num_heads: usize,
542 mlp_dim: usize,
543}
544
545impl TransformerBlockConfig {
546 fn init<B: Backend>(&self, device: &B::Device) -> TransformerBlock<B> {
547 TransformerBlock {
548 norm1: LayerNormConfig::new(self.embed_dim).init(device),
549 attn: MultiHeadSelfAttentionConfig {
550 embed_dim: self.embed_dim,
551 num_heads: self.num_heads,
552 }
553 .init(device),
554 norm2: LayerNormConfig::new(self.embed_dim).init(device),
555 mlp: MlpConfig {
556 in_dim: self.embed_dim,
557 hidden_dim: self.mlp_dim,
558 }
559 .init(device),
560 }
561 }
562}
563
564#[derive(Module, Debug)]
566struct TransformerBlock<B: Backend> {
567 norm1: LayerNorm<B>,
568 attn: MultiHeadSelfAttention<B>,
569 norm2: LayerNorm<B>,
570 mlp: Mlp<B>,
571}
572
573impl<B: Backend> TransformerBlock<B> {
574 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
575 let residual = x.clone();
577 let x_norm = self.norm1.forward(x);
578 let attn_out = self.attn.forward(x_norm);
579 let x = residual + attn_out;
580
581 let residual = x.clone();
583 let x_norm = self.norm2.forward(x);
584 let mlp_out = self.mlp.forward(x_norm);
585 residual + mlp_out
586 }
587}
588
589#[derive(Debug, Clone)]
592struct MultiHeadSelfAttentionConfig {
593 embed_dim: usize,
594 num_heads: usize,
595}
596
597impl MultiHeadSelfAttentionConfig {
598 fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
599 let head_dim = self.embed_dim / self.num_heads;
600 MultiHeadSelfAttention {
601 qkv: LinearConfig::new(self.embed_dim, 3 * self.embed_dim).init(device),
602 out_proj: LinearConfig::new(self.embed_dim, self.embed_dim).init(device),
603 num_heads: self.num_heads,
604 head_dim,
605 }
606 }
607}
608
609#[derive(Module, Debug)]
613struct MultiHeadSelfAttention<B: Backend> {
614 qkv: Linear<B>,
616 out_proj: Linear<B>,
618 num_heads: usize,
620 head_dim: usize,
622}
623
624impl<B: Backend> MultiHeadSelfAttention<B> {
625 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
626 let [batch, seq_len, _embed_dim] = x.dims();
627 let embed_dim = self.num_heads * self.head_dim;
628
629 let qkv = self.qkv.forward(x);
631
632 let q = qkv.clone().slice([0..batch, 0..seq_len, 0..embed_dim]);
634 let k = qkv
635 .clone()
636 .slice([0..batch, 0..seq_len, embed_dim..2 * embed_dim]);
637 let v = qkv.slice([0..batch, 0..seq_len, 2 * embed_dim..3 * embed_dim]);
638
639 let q = q
641 .reshape([batch, seq_len, self.num_heads, self.head_dim])
642 .swap_dims(1, 2);
643 let k = k
644 .reshape([batch, seq_len, self.num_heads, self.head_dim])
645 .swap_dims(1, 2);
646 let v = v
647 .reshape([batch, seq_len, self.num_heads, self.head_dim])
648 .swap_dims(1, 2);
649
650 let scale = (self.head_dim as f64).sqrt();
652 let attn_weights = q.matmul(k.transpose()) / scale; let attn_weights = burn::tensor::activation::softmax(attn_weights, 3);
654
655 let out = attn_weights.matmul(v); let out = out.swap_dims(1, 2).reshape([batch, seq_len, embed_dim]);
660
661 self.out_proj.forward(out)
662 }
663}
664
665#[derive(Debug, Clone)]
668struct MlpConfig {
669 in_dim: usize,
670 hidden_dim: usize,
671}
672
673impl MlpConfig {
674 fn init<B: Backend>(&self, device: &B::Device) -> Mlp<B> {
675 Mlp {
676 fc1: LinearConfig::new(self.in_dim, self.hidden_dim).init(device),
677 fc2: LinearConfig::new(self.hidden_dim, self.in_dim).init(device),
678 }
679 }
680}
681
682#[derive(Module, Debug)]
684struct Mlp<B: Backend> {
685 fc1: Linear<B>,
686 fc2: Linear<B>,
687}
688
689impl<B: Backend> Mlp<B> {
690 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
691 let x = self.fc1.forward(x);
692 let x = burn::tensor::activation::gelu(x);
693 self.fc2.forward(x)
694 }
695}
696
697#[cfg(test)]
698mod tests {
699 use super::*;
700 use burn_ndarray::NdArray;
701 use std::collections::HashMap;
702
703 type TestBackend = NdArray<f32>;
704
705 fn device() -> burn_ndarray::NdArrayDevice {
706 burn_ndarray::NdArrayDevice::Cpu
707 }
708
709 #[test]
710 fn test_vit_encoder_output_shape() {
711 let config = VitConfig::tiny_test();
712 let encoder = config.init::<TestBackend>(&device());
713
714 let images: Tensor<TestBackend, 4> = Tensor::zeros([2, 1, 8, 8], &device());
715 let repr = encoder.forward(&images);
716
717 assert_eq!(repr.batch_size(), 2);
719 assert_eq!(repr.seq_len(), 16);
720 assert_eq!(repr.embed_dim(), 32);
721 }
722
723 #[test]
724 fn test_vit_encoder_trait_impl() {
725 let config = VitConfig::tiny_test();
726 let encoder = config.init::<TestBackend>(&device());
727
728 let images: Tensor<TestBackend, 4> = Tensor::zeros([1, 1, 8, 8], &device());
729 let repr = Encoder::encode(&encoder, &images);
730
731 assert_eq!(repr.batch_size(), 1);
732 assert_eq!(repr.seq_len(), 16);
733 assert_eq!(encoder.embed_dim(), 32);
734 }
735
736 #[test]
737 fn test_vit_encoder_different_inputs_different_outputs() {
738 let config = VitConfig::tiny_test();
739 let encoder = config.init::<TestBackend>(&device());
740
741 let a: Tensor<TestBackend, 4> = Tensor::zeros([1, 1, 8, 8], &device());
742 let b: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
743
744 let repr_a = encoder.forward(&a);
745 let repr_b = encoder.forward(&b);
746
747 let diff: f32 = (repr_a.embeddings - repr_b.embeddings)
748 .abs()
749 .sum()
750 .into_scalar()
751 .elem();
752 assert!(
753 diff > 1e-6,
754 "different inputs should produce different representations"
755 );
756 }
757
758 #[test]
759 fn test_transformer_block_residual() {
760 let block = TransformerBlockConfig {
762 embed_dim: 16,
763 num_heads: 2,
764 mlp_dim: 32,
765 }
766 .init::<TestBackend>(&device());
767
768 let x: Tensor<TestBackend, 3> = Tensor::zeros([1, 4, 16], &device());
769 let out = block.forward(x);
770 assert_eq!(out.dims(), [1, 4, 16]);
771 }
772
773 #[test]
774 fn test_mhsa_output_shape() {
775 let attn = MultiHeadSelfAttentionConfig {
776 embed_dim: 16,
777 num_heads: 4,
778 }
779 .init::<TestBackend>(&device());
780
781 let x: Tensor<TestBackend, 3> = Tensor::zeros([2, 8, 16], &device());
782 let out = attn.forward(x);
783 assert_eq!(out.dims(), [2, 8, 16]);
784 }
785
786 #[test]
787 fn test_mlp_output_shape() {
788 let mlp = MlpConfig {
789 in_dim: 16,
790 hidden_dim: 64,
791 }
792 .init::<TestBackend>(&device());
793
794 let x: Tensor<TestBackend, 3> = Tensor::zeros([2, 8, 16], &device());
795 let out = mlp.forward(x);
796 assert_eq!(out.dims(), [2, 8, 16]);
797 }
798
799 fn checkpoint_tensors_from_encoder(
800 encoder: &VitEncoder<TestBackend>,
801 ) -> HashMap<String, TensorData> {
802 let record = encoder.clone().into_record();
803 let mut tensors = HashMap::new();
804
805 insert_linear_tensors(
806 &mut tensors,
807 "patch_embed.projection",
808 &record.patch_embed.projection,
809 );
810
811 for (index, block) in record.blocks.iter().enumerate() {
812 insert_layer_norm_tensors(&mut tensors, &format!("blocks.{index}.norm1"), &block.norm1);
813 insert_linear_tensors(
814 &mut tensors,
815 &format!("blocks.{index}.attn.qkv"),
816 &block.attn.qkv,
817 );
818 insert_linear_tensors(
819 &mut tensors,
820 &format!("blocks.{index}.attn.out_proj"),
821 &block.attn.out_proj,
822 );
823 insert_layer_norm_tensors(&mut tensors, &format!("blocks.{index}.norm2"), &block.norm2);
824 insert_linear_tensors(
825 &mut tensors,
826 &format!("blocks.{index}.mlp.fc1"),
827 &block.mlp.fc1,
828 );
829 insert_linear_tensors(
830 &mut tensors,
831 &format!("blocks.{index}.mlp.fc2"),
832 &block.mlp.fc2,
833 );
834 }
835
836 insert_layer_norm_tensors(&mut tensors, "norm", &record.norm);
837
838 tensors
839 }
840
841 fn insert_linear_tensors(
842 tensors: &mut HashMap<String, TensorData>,
843 prefix: &str,
844 record: &LinearRecord<TestBackend>,
845 ) {
846 tensors.insert(format!("{prefix}.weight"), record.weight.val().to_data());
847 if let Some(bias) = &record.bias {
848 tensors.insert(format!("{prefix}.bias"), bias.val().to_data());
849 }
850 }
851
852 fn insert_layer_norm_tensors(
853 tensors: &mut HashMap<String, TensorData>,
854 prefix: &str,
855 record: &LayerNormRecord<TestBackend>,
856 ) {
857 tensors.insert(format!("{prefix}.weight"), record.gamma.val().to_data());
858 if let Some(beta) = &record.beta {
859 tensors.insert(format!("{prefix}.bias"), beta.val().to_data());
860 }
861 }
862
863 #[test]
864 fn test_vit_encoder_load_named_tensors_restores_encoder_state() {
865 let config = VitConfig::tiny_test();
866 let source = config.init::<TestBackend>(&device());
867 let target = config.init::<TestBackend>(&device());
868 let tensors = checkpoint_tensors_from_encoder(&source);
869
870 let loaded = target
871 .load_named_tensors(&tensors)
872 .expect("loading tensors exported from a matching encoder should succeed");
873
874 let images: Tensor<TestBackend, 4> = Tensor::random(
875 [2, 1, 8, 8],
876 burn::tensor::Distribution::Normal(0.0, 1.0),
877 &device(),
878 );
879
880 let source_repr = source.forward(&images);
881 let loaded_repr = loaded.forward(&images);
882 let diff: f32 = (source_repr.embeddings - loaded_repr.embeddings)
883 .abs()
884 .sum()
885 .into_scalar()
886 .elem();
887 assert!(
888 diff < 1e-6,
889 "loading the exported tensors should restore the encoder exactly, diff={diff}"
890 );
891 }
892
893 #[test]
894 fn test_vit_encoder_load_named_tensors_rejects_shape_mismatch() {
895 let config = VitConfig::tiny_test();
896 let encoder = config.init::<TestBackend>(&device());
897 let mut tensors = checkpoint_tensors_from_encoder(&encoder);
898 tensors.insert(
899 "norm.weight".to_string(),
900 TensorData::new(vec![1.0f32; 31], [31]),
901 );
902
903 let err = config
904 .init::<TestBackend>(&device())
905 .load_named_tensors(&tensors)
906 .expect_err("shape mismatch should be reported");
907
908 assert!(matches!(
909 err,
910 VitLoadError::ShapeMismatch { key, .. } if key == "norm.weight"
911 ));
912 }
913
914 #[test]
915 fn test_vit_encoder_ema_update_moves_target_toward_online() {
916 let config = VitConfig::tiny_test();
917 let target = config.init::<TestBackend>(&device());
918 let online = config.init::<TestBackend>(&device());
919 let ema = Ema::new(0.5);
920 let images: Tensor<TestBackend, 4> = Tensor::random(
921 [1, 1, 8, 8],
922 burn::tensor::Distribution::Normal(0.0, 1.0),
923 &device(),
924 );
925
926 let target_before = target.forward(&images);
927 let online_before = online.forward(&images);
928 let updated = target.clone().ema_update_from(&online, &ema, 0);
929 let updated_repr = updated.forward(&images);
930
931 let before_distance: f32 = (target_before.embeddings.clone()
932 - online_before.embeddings.clone())
933 .abs()
934 .sum()
935 .into_scalar()
936 .elem();
937 let after_distance: f32 = (updated_repr.embeddings - online_before.embeddings)
938 .abs()
939 .sum()
940 .into_scalar()
941 .elem();
942
943 assert!(
944 after_distance < before_distance,
945 "EMA update should move target toward online encoder"
946 );
947 }
948
949 use burn::tensor::ElementConversion;
950 use proptest::prelude::*;
951
952 proptest! {
953 #[test]
956 fn prop_vit_output_is_finite(batch in 1usize..3) {
957 let config = VitConfig::tiny_test();
958 let encoder = config.init::<TestBackend>(&device());
959
960 let images: Tensor<TestBackend, 4> = Tensor::random(
961 [batch, 1, 8, 8],
962 burn::tensor::Distribution::Normal(0.0, 1.0),
963 &device(),
964 );
965 let repr = encoder.forward(&images);
966
967 prop_assert_eq!(repr.batch_size(), batch);
969 prop_assert_eq!(repr.seq_len(), 16);
970 prop_assert_eq!(repr.embed_dim(), 32);
971
972 let total: f32 = repr.embeddings.abs().sum().into_scalar().elem();
974 prop_assert!(total.is_finite(), "ViT output should be finite, got {}", total);
975 }
976
977 #[test]
980 fn prop_vit_is_deterministic(batch in 1usize..3) {
981 let config = VitConfig::tiny_test();
982 let encoder = config.init::<TestBackend>(&device());
983
984 let images: Tensor<TestBackend, 4> = Tensor::ones([batch, 1, 8, 8], &device());
985 let repr1 = encoder.forward(&images);
986 let repr2 = encoder.forward(&images);
987
988 let diff: f32 = (repr1.embeddings - repr2.embeddings)
989 .abs()
990 .sum()
991 .into_scalar()
992 .elem();
993 prop_assert!(diff < 1e-6, "ViT should be deterministic, diff={}", diff);
994 }
995
996 #[test]
999 fn prop_transformer_block_preserves_shape(
1000 seq_len in 2usize..8,
1001 num_heads in proptest::sample::select(vec![2usize, 4]),
1002 ) {
1003 let embed_dim = 16; let block = TransformerBlockConfig {
1005 embed_dim,
1006 num_heads,
1007 mlp_dim: embed_dim * 4,
1008 }
1009 .init::<TestBackend>(&device());
1010
1011 let x: Tensor<TestBackend, 3> = Tensor::random(
1012 [1, seq_len, embed_dim],
1013 burn::tensor::Distribution::Normal(0.0, 1.0),
1014 &device(),
1015 );
1016 let out = block.forward(x);
1017 prop_assert_eq!(out.dims(), [1, seq_len, embed_dim]);
1018
1019 let total: f32 = out.abs().sum().into_scalar().elem();
1020 prop_assert!(total.is_finite(), "block output should be finite");
1021 }
1022 }
1023}