1use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};
30use burn::prelude::*;
31use burn::tensor::backend::Backend;
32
33use jepa_core::types::{Energy, MaskError, MaskSpec, Representation};
34use jepa_core::{CollapseRegularizer, Encoder, EnergyFn};
35
36use crate::token_ops::gather_token_sequence;
37
38#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
40pub struct VitVideoConfig {
41 pub in_channels: usize,
43 pub num_frames: usize,
45 pub frame_height: usize,
47 pub frame_width: usize,
49 pub tubelet_size: (usize, usize, usize),
51 pub embed_dim: usize,
53 pub num_layers: usize,
55 pub num_heads: usize,
57 pub mlp_dim: usize,
59}
60
61impl VitVideoConfig {
62 pub fn grid_dims(&self) -> (usize, usize, usize) {
64 (
65 self.num_frames / self.tubelet_size.0,
66 self.frame_height / self.tubelet_size.1,
67 self.frame_width / self.tubelet_size.2,
68 )
69 }
70
71 pub fn num_tubelets(&self) -> usize {
73 let (gt, gh, gw) = self.grid_dims();
74 gt * gh * gw
75 }
76
77 pub fn tiny_test() -> Self {
79 Self {
80 in_channels: 1,
81 num_frames: 4,
82 frame_height: 8,
83 frame_width: 8,
84 tubelet_size: (2, 2, 2),
85 embed_dim: 32,
86 num_layers: 2,
87 num_heads: 4,
88 mlp_dim: 64,
89 }
90 }
91
92 pub fn init<B: Backend>(&self, device: &B::Device) -> VitVideoEncoder<B> {
94 let tubelet_embed_config = TubeletEmbeddingConfig {
95 in_channels: self.in_channels,
96 tubelet_t: self.tubelet_size.0,
97 tubelet_h: self.tubelet_size.1,
98 tubelet_w: self.tubelet_size.2,
99 embed_dim: self.embed_dim,
100 };
101 let tubelet_embed = tubelet_embed_config.init(device);
102
103 let (gt, gh, gw) = self.grid_dims();
104 let rope = RotaryPositionEncoding3DConfig::new(self.embed_dim, gt, gh, gw).init(device);
105
106 let blocks: Vec<VideoTransformerBlock<B>> = (0..self.num_layers)
107 .map(|_| {
108 VideoTransformerBlockConfig {
109 embed_dim: self.embed_dim,
110 num_heads: self.num_heads,
111 mlp_dim: self.mlp_dim,
112 }
113 .init(device)
114 })
115 .collect();
116
117 let norm = LayerNormConfig::new(self.embed_dim).init(device);
118
119 VitVideoEncoder {
120 tubelet_embed,
121 positional_encoding: rope,
122 blocks,
123 norm,
124 embed_dim: self.embed_dim,
125 }
126 }
127}
128
129#[derive(Module, Debug)]
139pub struct VitVideoEncoder<B: Backend> {
140 tubelet_embed: TubeletEmbedding<B>,
142 positional_encoding: RotaryPositionEncoding3D<B>,
144 blocks: Vec<VideoTransformerBlock<B>>,
146 norm: LayerNorm<B>,
148 embed_dim: usize,
150}
151
152impl<B: Backend> VitVideoEncoder<B> {
153 fn positioned_tubelet_tokens(&self, video: &Tensor<B, 5>) -> Tensor<B, 3> {
154 let x = self.tubelet_embed.forward(video.clone());
156
157 self.positional_encoding.forward(x)
160 }
161
162 fn encode_positioned_tokens(&self, mut x: Tensor<B, 3>) -> Representation<B> {
163 for block in &self.blocks {
164 x = block.forward(x);
165 }
166
167 x = self.norm.forward(x);
168
169 Representation::new(x)
170 }
171
172 pub fn forward(&self, video: &Tensor<B, 5>) -> Representation<B> {
180 let x = self.positioned_tubelet_tokens(video);
181 self.encode_positioned_tokens(x)
182 }
183
184 pub fn forward_visible_tokens(
186 &self,
187 video: &Tensor<B, 5>,
188 visible_indices: &[usize],
189 ) -> Representation<B> {
190 let x = self.positioned_tubelet_tokens(video);
191 let x = gather_token_sequence(x, visible_indices);
192 self.encode_positioned_tokens(x)
193 }
194}
195
196impl<B: Backend> Encoder<B> for VitVideoEncoder<B> {
197 type Input = Tensor<B, 5>;
198
199 fn encode(&self, input: &Self::Input) -> Representation<B> {
200 self.forward(input)
201 }
202
203 fn embed_dim(&self) -> usize {
204 self.embed_dim
205 }
206}
207
208#[derive(Debug, Clone)]
212struct TubeletEmbeddingConfig {
213 in_channels: usize,
214 tubelet_t: usize,
215 tubelet_h: usize,
216 tubelet_w: usize,
217 embed_dim: usize,
218}
219
220impl TubeletEmbeddingConfig {
221 fn init<B: Backend>(&self, device: &B::Device) -> TubeletEmbedding<B> {
222 let tubelet_dim = self.in_channels * self.tubelet_t * self.tubelet_h * self.tubelet_w;
223 let projection = LinearConfig::new(tubelet_dim, self.embed_dim).init(device);
224 TubeletEmbedding {
225 projection,
226 tubelet_t: self.tubelet_t,
227 tubelet_h: self.tubelet_h,
228 tubelet_w: self.tubelet_w,
229 in_channels: self.in_channels,
230 }
231 }
232}
233
234#[derive(Module, Debug)]
242struct TubeletEmbedding<B: Backend> {
243 projection: Linear<B>,
244 tubelet_t: usize,
245 tubelet_h: usize,
246 tubelet_w: usize,
247 in_channels: usize,
248}
249
250impl<B: Backend> TubeletEmbedding<B> {
251 fn forward(&self, video: Tensor<B, 5>) -> Tensor<B, 3> {
256 let [batch, _channels, frames, height, width] = video.dims();
257
258 let grid_t = frames / self.tubelet_t;
259 let grid_h = height / self.tubelet_h;
260 let grid_w = width / self.tubelet_w;
261 let num_tubelets = grid_t * grid_h * grid_w;
262 let tubelet_dim = self.in_channels * self.tubelet_t * self.tubelet_h * self.tubelet_w;
263
264 let x = video.reshape([
267 batch,
268 self.in_channels,
269 grid_t,
270 self.tubelet_t,
271 height,
272 width,
273 ]);
274 let x = x.permute([0, 2, 1, 3, 4, 5]);
276 let c_t = self.in_channels * self.tubelet_t;
277 let x: Tensor<B, 4> = x.reshape([batch * grid_t, c_t, height, width]);
278
279 let x = x.reshape([
281 batch * grid_t,
282 c_t,
283 grid_h,
284 self.tubelet_h,
285 grid_w,
286 self.tubelet_w,
287 ]);
288 let x = x.permute([0, 2, 4, 1, 3, 5]);
290 let spatial_tubelets = grid_h * grid_w;
292 let x: Tensor<B, 3> = x.reshape([batch * grid_t, spatial_tubelets, tubelet_dim]);
293 let x = x.reshape([batch, num_tubelets, tubelet_dim]);
294
295 self.projection.forward(x)
297 }
298}
299
300#[derive(Debug, Clone)]
304pub struct RotaryPositionEncoding3DConfig {
305 pub embed_dim: usize,
307 pub max_t: usize,
309 pub max_h: usize,
311 pub max_w: usize,
313 pub base_freq: f64,
315}
316
317impl RotaryPositionEncoding3DConfig {
318 pub fn new(embed_dim: usize, max_t: usize, max_h: usize, max_w: usize) -> Self {
320 Self {
321 embed_dim,
322 max_t,
323 max_h,
324 max_w,
325 base_freq: 10000.0,
326 }
327 }
328
329 pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryPositionEncoding3D<B> {
331 let half_dim = self.embed_dim / 2;
332 let sixth = half_dim / 3;
335 let dim_t = sixth + (half_dim % 3).min(1);
336 let dim_h = sixth + if half_dim % 3 >= 2 { 1 } else { 0 };
337 let dim_w = sixth;
338 debug_assert_eq!(dim_t + dim_h + dim_w, half_dim);
339
340 let max_seq = self.max_t * self.max_h * self.max_w;
341
342 let freqs_t = compute_freqs(dim_t, self.base_freq, half_dim);
344 let freqs_h = compute_freqs(dim_h, self.base_freq, half_dim);
345 let freqs_w = compute_freqs(dim_w, self.base_freq, half_dim);
346
347 let mut cos_data = vec![0.0f32; max_seq * half_dim];
348 let mut sin_data = vec![0.0f32; max_seq * half_dim];
349
350 for t in 0..self.max_t {
351 for h in 0..self.max_h {
352 for w in 0..self.max_w {
353 let pos = t * self.max_h * self.max_w + h * self.max_w + w;
354 let mut offset = 0;
355
356 for (i, &freq) in freqs_t.iter().enumerate() {
358 let angle = t as f64 * freq;
359 cos_data[pos * half_dim + offset + i] = angle.cos() as f32;
360 sin_data[pos * half_dim + offset + i] = angle.sin() as f32;
361 }
362 offset += dim_t;
363
364 for (i, &freq) in freqs_h.iter().enumerate() {
366 let angle = h as f64 * freq;
367 cos_data[pos * half_dim + offset + i] = angle.cos() as f32;
368 sin_data[pos * half_dim + offset + i] = angle.sin() as f32;
369 }
370 offset += dim_h;
371
372 for (i, &freq) in freqs_w.iter().enumerate() {
374 let angle = w as f64 * freq;
375 cos_data[pos * half_dim + offset + i] = angle.cos() as f32;
376 sin_data[pos * half_dim + offset + i] = angle.sin() as f32;
377 }
378 }
379 }
380 }
381
382 let cos_table = Tensor::from_floats(
383 burn::tensor::TensorData::new(cos_data, [max_seq, half_dim]),
384 device,
385 );
386 let sin_table = Tensor::from_floats(
387 burn::tensor::TensorData::new(sin_data, [max_seq, half_dim]),
388 device,
389 );
390
391 RotaryPositionEncoding3D {
392 cos_table,
393 sin_table,
394 embed_dim: self.embed_dim,
395 }
396 }
397}
398
399fn compute_freqs(num_freqs: usize, base_freq: f64, full_half_dim: usize) -> Vec<f64> {
401 (0..num_freqs)
402 .map(|i| 1.0 / base_freq.powf(2.0 * i as f64 / full_half_dim as f64))
403 .collect()
404}
405
406#[derive(Module, Debug)]
412pub struct RotaryPositionEncoding3D<B: Backend> {
413 cos_table: Tensor<B, 2>,
415 sin_table: Tensor<B, 2>,
417 embed_dim: usize,
419}
420
421impl<B: Backend> RotaryPositionEncoding3D<B> {
422 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
430 let [batch, seq_len, _dim] = x.dims();
431 let half_dim = self.embed_dim / 2;
432
433 let cos = self.cos_table.clone().slice([0..seq_len, 0..half_dim]);
434 let sin = self.sin_table.clone().slice([0..seq_len, 0..half_dim]);
435
436 let cos = cos.unsqueeze::<3>().expand([batch, seq_len, half_dim]);
437 let sin = sin.unsqueeze::<3>().expand([batch, seq_len, half_dim]);
438
439 let x1 = x.clone().slice([0..batch, 0..seq_len, 0..half_dim]);
440 let x2 = x
441 .clone()
442 .slice([0..batch, 0..seq_len, half_dim..self.embed_dim]);
443
444 let out1 = x1.clone() * cos.clone() - x2.clone() * sin.clone();
445 let out2 = x1 * sin + x2 * cos;
446
447 Tensor::cat(vec![out1, out2], 2)
448 }
449}
450
451#[derive(Debug, Clone)]
454struct VideoTransformerBlockConfig {
455 embed_dim: usize,
456 num_heads: usize,
457 mlp_dim: usize,
458}
459
460impl VideoTransformerBlockConfig {
461 fn init<B: Backend>(&self, device: &B::Device) -> VideoTransformerBlock<B> {
462 let head_dim = self.embed_dim / self.num_heads;
463 VideoTransformerBlock {
464 norm1: LayerNormConfig::new(self.embed_dim).init(device),
465 attn: VideoSelfAttention {
466 qkv: LinearConfig::new(self.embed_dim, 3 * self.embed_dim).init(device),
467 out_proj: LinearConfig::new(self.embed_dim, self.embed_dim).init(device),
468 num_heads: self.num_heads,
469 head_dim,
470 },
471 norm2: LayerNormConfig::new(self.embed_dim).init(device),
472 mlp: VideoMlp {
473 fc1: LinearConfig::new(self.embed_dim, self.mlp_dim).init(device),
474 fc2: LinearConfig::new(self.mlp_dim, self.embed_dim).init(device),
475 },
476 }
477 }
478}
479
480#[derive(Module, Debug)]
482struct VideoTransformerBlock<B: Backend> {
483 norm1: LayerNorm<B>,
484 attn: VideoSelfAttention<B>,
485 norm2: LayerNorm<B>,
486 mlp: VideoMlp<B>,
487}
488
489impl<B: Backend> VideoTransformerBlock<B> {
490 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
491 let residual = x.clone();
492 let x_norm = self.norm1.forward(x);
493 let attn_out = self.attn.forward(x_norm);
494 let x = residual + attn_out;
495
496 let residual = x.clone();
497 let x_norm = self.norm2.forward(x);
498 let mlp_out = self.mlp.forward(x_norm);
499 residual + mlp_out
500 }
501}
502
503#[derive(Module, Debug)]
505struct VideoSelfAttention<B: Backend> {
506 qkv: Linear<B>,
507 out_proj: Linear<B>,
508 num_heads: usize,
509 head_dim: usize,
510}
511
512impl<B: Backend> VideoSelfAttention<B> {
513 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
514 let [batch, seq_len, _] = x.dims();
515 let embed_dim = self.num_heads * self.head_dim;
516
517 let qkv = self.qkv.forward(x);
518 let q = qkv.clone().slice([0..batch, 0..seq_len, 0..embed_dim]);
519 let k = qkv
520 .clone()
521 .slice([0..batch, 0..seq_len, embed_dim..2 * embed_dim]);
522 let v = qkv.slice([0..batch, 0..seq_len, 2 * embed_dim..3 * embed_dim]);
523
524 let q = q
525 .reshape([batch, seq_len, self.num_heads, self.head_dim])
526 .swap_dims(1, 2);
527 let k = k
528 .reshape([batch, seq_len, self.num_heads, self.head_dim])
529 .swap_dims(1, 2);
530 let v = v
531 .reshape([batch, seq_len, self.num_heads, self.head_dim])
532 .swap_dims(1, 2);
533
534 let scale = (self.head_dim as f64).sqrt();
535 let attn = q.matmul(k.transpose()) / scale;
536 let attn = burn::tensor::activation::softmax(attn, 3);
537 let out = attn.matmul(v);
538 let out = out.swap_dims(1, 2).reshape([batch, seq_len, embed_dim]);
539 self.out_proj.forward(out)
540 }
541}
542
543#[derive(Module, Debug)]
545struct VideoMlp<B: Backend> {
546 fc1: Linear<B>,
547 fc2: Linear<B>,
548}
549
550impl<B: Backend> VideoMlp<B> {
551 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
552 let x = self.fc1.forward(x);
553 let x = burn::tensor::activation::gelu(x);
554 self.fc2.forward(x)
555 }
556}
557
558#[derive(Module, Debug)]
563pub struct VJepa<B: Backend> {
564 pub context_encoder: VitVideoEncoder<B>,
566 pub target_encoder: VitVideoEncoder<B>,
568 pub predictor: crate::image::TransformerPredictor<B>,
570}
571
572#[derive(Debug, Clone)]
574pub struct StrictVJepaForwardOutput<B: Backend> {
575 pub energy: Energy<B>,
577 pub regularization: Tensor<B, 1>,
579 pub total_loss: Tensor<B, 1>,
581 pub mask: MaskSpec,
583 pub context: Representation<B>,
585 pub predicted: Representation<B>,
587 pub target: Representation<B>,
589}
590
591#[derive(Debug, Clone, thiserror::Error)]
593pub enum StrictVJepaError {
594 #[error(transparent)]
595 InvalidMask(#[from] MaskError),
596 #[error(transparent)]
597 Predictor(#[from] crate::image::PredictorError),
598}
599
600impl<B: Backend> VJepa<B> {
601 pub fn encode_context_strict(
607 &self,
608 video: &Tensor<B, 5>,
609 context_indices: &[usize],
610 ) -> Representation<B> {
611 self.context_encoder
612 .forward_visible_tokens(video, context_indices)
613 }
614
615 pub fn forward_step_strict<EF, CR>(
623 &self,
624 video: &Tensor<B, 5>,
625 mask: MaskSpec,
626 energy_fn: &EF,
627 regularizer: &CR,
628 reg_weight: f64,
629 ) -> StrictVJepaForwardOutput<B>
630 where
631 EF: EnergyFn<B>,
632 CR: CollapseRegularizer<B>,
633 {
634 self.try_forward_step_strict(video, mask, energy_fn, regularizer, reg_weight)
635 .expect(
636 "VJepa::forward_step_strict failed — mask must be valid (disjoint, non-empty) \
637 and target count must not exceed predictor capacity; \
638 use try_forward_step_strict for error handling",
639 )
640 }
641
642 pub fn try_forward_step_strict<EF, CR>(
644 &self,
645 video: &Tensor<B, 5>,
646 mask: MaskSpec,
647 energy_fn: &EF,
648 regularizer: &CR,
649 reg_weight: f64,
650 ) -> Result<StrictVJepaForwardOutput<B>, StrictVJepaError>
651 where
652 EF: EnergyFn<B>,
653 CR: CollapseRegularizer<B>,
654 {
655 mask.validate()?;
656
657 let context = self.encode_context_strict(video, &mask.context_indices);
658 let target_full = self.target_encoder.forward(video);
659 let target = target_full.gather(&mask.target_indices);
660
661 let batch = video.dims()[0];
662 let target_positions = crate::image::target_positions_tensor::<B>(
663 &mask.target_indices,
664 batch,
665 &video.device(),
666 );
667 let predicted = self.predictor.try_predict(&context, &target_positions)?;
668
669 let num_targets = target.seq_len();
670 let embed_dim = target.embed_dim();
671 let pred_flat = predicted
672 .embeddings
673 .clone()
674 .reshape([batch * num_targets, embed_dim]);
675 let target_flat = target
676 .embeddings
677 .clone()
678 .reshape([batch * num_targets, embed_dim]);
679
680 let energy = energy_fn.compute(&predicted, &target);
681 let regularization = regularizer.loss(&pred_flat, &target_flat);
682 let total_loss = energy.value.clone() + regularization.clone() * reg_weight;
683
684 Ok(StrictVJepaForwardOutput {
685 energy,
686 regularization,
687 total_loss,
688 mask,
689 context,
690 predicted,
691 target,
692 })
693 }
694}
695
696#[derive(Debug, Clone)]
698pub struct VJepaConfig {
699 pub encoder: VitVideoConfig,
701 pub predictor: crate::image::TransformerPredictorConfig,
703}
704
705impl VJepaConfig {
706 pub fn tiny_test() -> Self {
708 let encoder = VitVideoConfig::tiny_test();
709 Self {
710 predictor: crate::image::TransformerPredictorConfig {
711 encoder_embed_dim: encoder.embed_dim,
712 predictor_embed_dim: 16,
713 num_layers: 1,
714 num_heads: 2,
715 max_target_len: 64,
716 },
717 encoder,
718 }
719 }
720
721 pub fn init<B: Backend>(&self, device: &B::Device) -> VJepa<B> {
723 VJepa {
724 context_encoder: self.encoder.init(device),
725 target_encoder: self.encoder.init(device),
726 predictor: self.predictor.init(device),
727 }
728 }
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use burn::tensor::ElementConversion;
735 use burn_ndarray::NdArray;
736 use jepa_core::Predictor;
737
738 type TestBackend = NdArray<f32>;
739
740 fn device() -> burn_ndarray::NdArrayDevice {
741 burn_ndarray::NdArrayDevice::Cpu
742 }
743
744 fn fixed_video_mask() -> MaskSpec {
745 MaskSpec {
746 context_indices: (0..16).collect(),
747 target_indices: (16..32).collect(),
748 total_tokens: 32,
749 }
750 }
751
752 fn video_with_hidden_tubelet_value(
753 mask: &MaskSpec,
754 hidden_value: f32,
755 ) -> Tensor<TestBackend, 5> {
756 let frames = 4usize;
757 let height = 8usize;
758 let width = 8usize;
759 let mut data = vec![1.0f32; frames * height * width];
760
761 for &index in &mask.target_indices {
762 let temporal_block = index / 16;
763 let spatial_index = index % 16;
764 let spatial_row = spatial_index / 4;
765 let spatial_col = spatial_index % 4;
766
767 let frame_start = temporal_block * 2;
768 let row_start = spatial_row * 2;
769 let col_start = spatial_col * 2;
770
771 for frame in frame_start..frame_start + 2 {
772 for row in row_start..row_start + 2 {
773 for col in col_start..col_start + 2 {
774 data[(frame * height + row) * width + col] = hidden_value;
775 }
776 }
777 }
778 }
779
780 Tensor::from_floats(
781 burn::tensor::TensorData::new(data, [1, 1, frames, height, width]),
782 &device(),
783 )
784 }
785
786 #[test]
787 fn test_vit_video_output_shape() {
788 let config = VitVideoConfig::tiny_test();
789 let encoder = config.init::<TestBackend>(&device());
790
791 let video: Tensor<TestBackend, 5> = Tensor::zeros([2, 1, 4, 8, 8], &device());
793 let repr = encoder.forward(&video);
794
795 assert_eq!(repr.batch_size(), 2);
797 assert_eq!(repr.seq_len(), 32);
798 assert_eq!(repr.embed_dim(), 32);
799 }
800
801 #[test]
802 fn test_vit_video_encoder_trait() {
803 let config = VitVideoConfig::tiny_test();
804 let encoder = config.init::<TestBackend>(&device());
805
806 let video: Tensor<TestBackend, 5> = Tensor::zeros([1, 1, 4, 8, 8], &device());
807 let repr = Encoder::encode(&encoder, &video);
808
809 assert_eq!(repr.batch_size(), 1);
810 assert_eq!(repr.seq_len(), 32);
811 assert_eq!(encoder.embed_dim(), 32);
812 }
813
814 #[test]
815 fn test_vit_video_different_inputs_different_outputs() {
816 let config = VitVideoConfig::tiny_test();
817 let encoder = config.init::<TestBackend>(&device());
818
819 let a: Tensor<TestBackend, 5> = Tensor::zeros([1, 1, 4, 8, 8], &device());
820 let b: Tensor<TestBackend, 5> = Tensor::ones([1, 1, 4, 8, 8], &device());
821
822 let repr_a = encoder.forward(&a);
823 let repr_b = encoder.forward(&b);
824
825 let diff: f32 = (repr_a.embeddings - repr_b.embeddings)
826 .abs()
827 .sum()
828 .into_scalar()
829 .elem();
830 assert!(
831 diff > 1e-6,
832 "different video inputs should produce different representations"
833 );
834 }
835
836 #[test]
837 fn test_tubelet_embedding_shape() {
838 let config = TubeletEmbeddingConfig {
839 in_channels: 3,
840 tubelet_t: 2,
841 tubelet_h: 16,
842 tubelet_w: 16,
843 embed_dim: 256,
844 };
845 let embed = config.init::<TestBackend>(&device());
846
847 let video: Tensor<TestBackend, 5> = Tensor::zeros([1, 3, 16, 224, 224], &device());
849 let out = embed.forward(video);
850
851 assert_eq!(out.dims(), [1, 1568, 256]);
853 }
854
855 #[test]
856 fn test_rope3d_output_shape() {
857 let config = RotaryPositionEncoding3DConfig::new(64, 2, 4, 4);
858 let rope = config.init::<TestBackend>(&device());
859
860 let x: Tensor<TestBackend, 3> = Tensor::ones([2, 32, 64], &device());
861 let out = rope.forward(x);
862 assert_eq!(out.dims(), [2, 32, 64]);
863 }
864
865 #[test]
866 fn test_rope3d_preserves_norm() {
867 let config = RotaryPositionEncoding3DConfig::new(32, 2, 4, 4);
868 let rope = config.init::<TestBackend>(&device());
869
870 let x: Tensor<TestBackend, 3> = Tensor::random(
871 [1, 32, 32],
872 burn::tensor::Distribution::Normal(0.0, 1.0),
873 &device(),
874 );
875
876 let x_norm: f32 = (x.clone() * x.clone()).sum().into_scalar().elem();
877 let out = rope.forward(x);
878 let out_norm: f32 = (out.clone() * out.clone()).sum().into_scalar().elem();
879
880 let ratio = out_norm / x_norm;
881 assert!(
882 (ratio - 1.0).abs() < 0.01,
883 "3D RoPE should approximately preserve norm, ratio: {ratio}"
884 );
885 }
886
887 #[test]
888 fn test_rope3d_different_positions_give_different_outputs() {
889 let config = RotaryPositionEncoding3DConfig::new(16, 2, 2, 2);
890 let rope = config.init::<TestBackend>(&device());
891
892 let x: Tensor<TestBackend, 3> = Tensor::ones([1, 8, 16], &device());
893 let out = rope.forward(x);
894
895 let pos0 = out.clone().slice([0..1, 0..1, 0..16]);
897 let pos1 = out.clone().slice([0..1, 1..2, 0..16]);
898
899 let diff: f32 = (pos0 - pos1).abs().sum().into_scalar().elem();
900 assert!(
901 diff > 1e-6,
902 "different 3D positions should produce different outputs"
903 );
904 }
905
906 #[test]
907 fn test_video_config_grid_dims() {
908 let config = VitVideoConfig {
909 in_channels: 3,
910 num_frames: 16,
911 frame_height: 224,
912 frame_width: 224,
913 tubelet_size: (2, 16, 16),
914 embed_dim: 768,
915 num_layers: 12,
916 num_heads: 12,
917 mlp_dim: 3072,
918 };
919 assert_eq!(config.grid_dims(), (8, 14, 14));
920 assert_eq!(config.num_tubelets(), 1568);
921 }
922
923 #[test]
924 fn test_video_transformer_block_residual() {
925 let block = VideoTransformerBlockConfig {
926 embed_dim: 16,
927 num_heads: 2,
928 mlp_dim: 32,
929 }
930 .init::<TestBackend>(&device());
931
932 let x: Tensor<TestBackend, 3> = Tensor::zeros([1, 8, 16], &device());
933 let out = block.forward(x);
934 assert_eq!(out.dims(), [1, 8, 16]);
935 }
936
937 #[test]
938 fn test_video_self_attention_shape() {
939 let attn = VideoSelfAttention {
940 qkv: LinearConfig::new(16, 48).init::<TestBackend>(&device()),
941 out_proj: LinearConfig::new(16, 16).init::<TestBackend>(&device()),
942 num_heads: 4,
943 head_dim: 4,
944 };
945
946 let x: Tensor<TestBackend, 3> = Tensor::zeros([2, 8, 16], &device());
947 let out = attn.forward(x);
948 assert_eq!(out.dims(), [2, 8, 16]);
949 }
950
951 use proptest::prelude::*;
952
953 proptest! {
954 #[test]
955 fn prop_video_config_num_tubelets(
956 grid_t in 1usize..4,
957 grid_h in 1usize..4,
958 grid_w in 1usize..4,
959 ) {
960 let tub = 2;
961 let config = VitVideoConfig {
962 in_channels: 1,
963 num_frames: grid_t * tub,
964 frame_height: grid_h * tub,
965 frame_width: grid_w * tub,
966 tubelet_size: (tub, tub, tub),
967 embed_dim: 16,
968 num_layers: 1,
969 num_heads: 2,
970 mlp_dim: 32,
971 };
972 prop_assert_eq!(config.grid_dims(), (grid_t, grid_h, grid_w));
973 prop_assert_eq!(config.num_tubelets(), grid_t * grid_h * grid_w);
974 }
975
976 #[test]
977 fn prop_rope3d_preserves_shape(
978 max_t in 1usize..3,
979 max_h in 1usize..3,
980 max_w in 1usize..3,
981 ) {
982 let embed_dim = 12; let config = RotaryPositionEncoding3DConfig::new(embed_dim, max_t, max_h, max_w);
984 let rope = config.init::<TestBackend>(&device());
985 let seq_len = max_t * max_h * max_w;
986 let x: Tensor<TestBackend, 3> = Tensor::ones([1, seq_len, embed_dim], &device());
987 let out = rope.forward(x);
988 prop_assert_eq!(out.dims(), [1, seq_len, embed_dim]);
989 }
990
991 #[test]
992 fn prop_rope3d_preserves_norm(
993 max_t in 1usize..3,
994 max_h in 2usize..4,
995 max_w in 2usize..4,
996 ) {
997 let embed_dim = 12;
998 let config = RotaryPositionEncoding3DConfig::new(embed_dim, max_t, max_h, max_w);
999 let rope = config.init::<TestBackend>(&device());
1000 let seq_len = max_t * max_h * max_w;
1001 let x: Tensor<TestBackend, 3> = Tensor::random(
1002 [1, seq_len, embed_dim],
1003 burn::tensor::Distribution::Normal(0.0, 1.0),
1004 &device(),
1005 );
1006 let x_norm: f32 = (x.clone() * x.clone()).sum().into_scalar().elem();
1007 let out = rope.forward(x);
1008 let out_norm: f32 = (out.clone() * out.clone()).sum().into_scalar().elem();
1009 let ratio = out_norm / x_norm;
1010 prop_assert!((ratio - 1.0).abs() < 0.01, "3D RoPE norm ratio: {}", ratio);
1011 }
1012 }
1013
1014 #[test]
1019 fn test_vjepa_config_tiny() {
1020 let config = VJepaConfig::tiny_test();
1021 assert_eq!(config.encoder.embed_dim, 32);
1022 assert_eq!(config.predictor.predictor_embed_dim, 16);
1023 assert_eq!(config.predictor.encoder_embed_dim, 32);
1024 }
1025
1026 #[test]
1027 fn test_vjepa_model_init() {
1028 let config = VJepaConfig::tiny_test();
1029 let model = config.init::<TestBackend>(&device());
1030
1031 assert_eq!(model.context_encoder.embed_dim, 32);
1032 assert_eq!(model.target_encoder.embed_dim, 32);
1033 }
1034
1035 #[test]
1036 fn test_strict_video_context_encoding_ignores_hidden_tubelets() {
1037 let config = VJepaConfig::tiny_test();
1038 let model = config.init::<TestBackend>(&device());
1039 let mask = fixed_video_mask();
1040
1041 let hidden_low = video_with_hidden_tubelet_value(&mask, 0.0);
1042 let hidden_high = video_with_hidden_tubelet_value(&mask, 1_000.0);
1043
1044 let strict_low = model.encode_context_strict(&hidden_low, &mask.context_indices);
1045 let strict_high = model.encode_context_strict(&hidden_high, &mask.context_indices);
1046
1047 let diff: f32 = (strict_low.embeddings - strict_high.embeddings)
1048 .abs()
1049 .sum()
1050 .into_scalar()
1051 .elem();
1052 assert!(
1053 diff < 1e-5,
1054 "strict masked video context should ignore hidden tubelets, diff={diff}"
1055 );
1056 }
1057
1058 #[test]
1059 fn test_full_video_encoder_context_slice_leaks_hidden_tubelets() {
1060 let config = VitVideoConfig::tiny_test();
1061 let encoder = config.init::<TestBackend>(&device());
1062 let mask = fixed_video_mask();
1063
1064 let hidden_low = video_with_hidden_tubelet_value(&mask, 0.0);
1065 let hidden_high = video_with_hidden_tubelet_value(&mask, 1_000.0);
1066
1067 let approx_low = encoder.forward(&hidden_low).gather(&mask.context_indices);
1068 let approx_high = encoder.forward(&hidden_high).gather(&mask.context_indices);
1069
1070 let diff: f32 = (approx_low.embeddings - approx_high.embeddings)
1071 .abs()
1072 .sum()
1073 .into_scalar()
1074 .elem();
1075 assert!(
1076 diff > 1e-3,
1077 "post-encoder gather path should leak hidden tubelets, diff={diff}"
1078 );
1079 }
1080
1081 #[test]
1082 fn test_strict_video_forward_step_runs_end_to_end() {
1083 let config = VJepaConfig::tiny_test();
1084 let model = config.init::<TestBackend>(&device());
1085 let mask = fixed_video_mask();
1086 let video = video_with_hidden_tubelet_value(&mask, 5.0);
1087 let energy_fn = jepa_core::energy::L2Energy;
1088 let regularizer = jepa_core::collapse::VICReg::default();
1089
1090 let output = model.forward_step_strict(&video, mask.clone(), &energy_fn, ®ularizer, 1.0);
1091
1092 assert_eq!(output.context.seq_len(), mask.context_indices.len());
1093 assert_eq!(output.predicted.seq_len(), mask.target_indices.len());
1094 assert_eq!(output.target.seq_len(), mask.target_indices.len());
1095
1096 let total_loss: f32 = output.total_loss.into_scalar().elem();
1097 assert!(
1098 total_loss.is_finite(),
1099 "strict video forward loss should be finite"
1100 );
1101 }
1102
1103 #[test]
1104 fn test_try_strict_video_forward_step_rejects_invalid_mask() {
1105 let config = VJepaConfig::tiny_test();
1106 let model = config.init::<TestBackend>(&device());
1107 let video = Tensor::ones([1, 1, 4, 8, 8], &device());
1108 let invalid_mask = MaskSpec {
1109 context_indices: vec![0],
1110 target_indices: vec![],
1111 total_tokens: 32,
1112 };
1113 let energy_fn = jepa_core::energy::L2Energy;
1114 let regularizer = jepa_core::collapse::VICReg::default();
1115
1116 let err = model
1117 .try_forward_step_strict(&video, invalid_mask, &energy_fn, ®ularizer, 1.0)
1118 .unwrap_err();
1119 assert!(matches!(
1120 err,
1121 StrictVJepaError::InvalidMask(MaskError::EmptyTarget)
1122 ));
1123 }
1124
1125 #[test]
1131 fn bdd_vjepa_full_pipeline_with_spatiotemporal_masking() {
1132 use jepa_core::{CollapseRegularizer, EnergyFn, MaskingStrategy};
1133 use rand::SeedableRng;
1134
1135 let config = VJepaConfig::tiny_test();
1136 let model = config.init::<TestBackend>(&device());
1137
1138 let video: Tensor<TestBackend, 5> = Tensor::random(
1140 [1, 1, 4, 8, 8],
1141 burn::tensor::Distribution::Normal(0.0, 1.0),
1142 &device(),
1143 );
1144
1145 let context_repr = model.context_encoder.forward(&video);
1147 let target_repr = model.target_encoder.forward(&video);
1148
1149 assert_eq!(context_repr.seq_len(), 32);
1151 assert_eq!(target_repr.seq_len(), 32);
1152
1153 let masking = jepa_core::masking::SpatiotemporalMasking {
1155 num_targets: 2,
1156 temporal_extent: (1, 2),
1157 spatial_scale: (0.1, 0.2),
1158 };
1159 let shape = jepa_core::types::InputShape::Video {
1160 frames: 2,
1161 height: 4,
1162 width: 4,
1163 };
1164 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
1165 let mask = masking.generate_mask(&shape, &mut rng);
1166 assert!(mask.validate().is_ok());
1167 assert_eq!(mask.context_indices.len() + mask.target_indices.len(), 32);
1168
1169 let target_gathered = target_repr.gather(&mask.target_indices);
1171 assert_eq!(target_gathered.seq_len(), mask.target_indices.len());
1172
1173 let num_targets = mask.target_indices.len();
1175 let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([1, num_targets], &device());
1176 let predicted = model.predictor.predict(&context_repr, &target_pos, None);
1177 assert_eq!(predicted.seq_len(), num_targets);
1178 assert_eq!(predicted.embed_dim(), 32);
1179
1180 let energy = jepa_core::energy::L2Energy.compute(&predicted, &target_gathered);
1182 let val: f32 = energy.value.into_scalar().elem();
1183 assert!(val.is_finite(), "energy should be finite, got {val}");
1184 assert!(val >= 0.0, "L2 energy should be non-negative, got {val}");
1185
1186 let embed_dim = predicted.embed_dim();
1188 let pred_flat = predicted.embeddings.reshape([num_targets, embed_dim]);
1189 let target_flat = target_gathered.embeddings.reshape([num_targets, embed_dim]);
1190 let reg: f32 = jepa_core::collapse::VICReg::default()
1191 .loss(&pred_flat, &target_flat)
1192 .into_scalar()
1193 .elem();
1194 assert!(
1195 reg.is_finite(),
1196 "regularization should be finite, got {reg}"
1197 );
1198 }
1199}