Skip to main content

jepa_vision/
video.rs

1//! V-JEPA video encoder with 3D tubelets and 3D RoPE.
2//!
3//! Implements RFC-002 (Encoder Module) for video input.
4//!
5//! V-JEPA extends I-JEPA from images to video by replacing 2-D patches
6//! with 3-D **tubelets** `(temporal × height × width)` and using 3-D
7//! Rotary Position Encoding for spatiotemporal position awareness.
8//!
9//! ```text
10//! [B, C, T, H, W]
11//!       │
12//!       ▼
13//! TubeletEmbedding  ──►  3D RoPE  ──►  N × TransformerBlock  ──►  LayerNorm
14//! [B, S, D]              [B, S, D]     [B, S, D]                   [B, S, D]
15//! ```
16//!
17//! where `S = (T/t) × (H/h) × (W/w)` for tubelet size `(t, h, w)`.
18//!
19//! The module also provides [`VJepa`], a full V-JEPA pipeline struct
20//! with `forward_step_strict` for masked training with pre-encoder
21//! token filtering, mirroring the reference implementation.
22//!
23//! References:
24//! - Bardes, A. et al. (2024). *V-JEPA: Latent Video Prediction for
25//!   Visual Representation Learning*.
26//! - Bardes, A. et al. (2025). *V-JEPA 2: Self-Supervised Video Models
27//!   Enable Understanding, Generation, and Planning*.
28
29use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};
30use burn::prelude::*;
31use burn::tensor::backend::Backend;
32
33use jepa_core::types::{Energy, MaskError, MaskSpec, Representation};
34use jepa_core::{CollapseRegularizer, Encoder, EnergyFn};
35
36use crate::token_ops::gather_token_sequence;
37
38/// Configuration for a V-JEPA video encoder.
39#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
40pub struct VitVideoConfig {
41    /// Number of input channels (e.g., 3 for RGB).
42    pub in_channels: usize,
43    /// Number of input frames.
44    pub num_frames: usize,
45    /// Frame height in pixels.
46    pub frame_height: usize,
47    /// Frame width in pixels.
48    pub frame_width: usize,
49    /// Tubelet size `(temporal, height, width)`.
50    pub tubelet_size: (usize, usize, usize),
51    /// Embedding dimension.
52    pub embed_dim: usize,
53    /// Number of transformer layers.
54    pub num_layers: usize,
55    /// Number of attention heads.
56    pub num_heads: usize,
57    /// MLP hidden dimension (typically 4 * embed_dim).
58    pub mlp_dim: usize,
59}
60
61impl VitVideoConfig {
62    /// Grid dimensions `(temporal, height, width)` after tubelet embedding.
63    pub fn grid_dims(&self) -> (usize, usize, usize) {
64        (
65            self.num_frames / self.tubelet_size.0,
66            self.frame_height / self.tubelet_size.1,
67            self.frame_width / self.tubelet_size.2,
68        )
69    }
70
71    /// Total number of tubelets.
72    pub fn num_tubelets(&self) -> usize {
73        let (gt, gh, gw) = self.grid_dims();
74        gt * gh * gw
75    }
76
77    /// Create a tiny config for testing.
78    pub fn tiny_test() -> Self {
79        Self {
80            in_channels: 1,
81            num_frames: 4,
82            frame_height: 8,
83            frame_width: 8,
84            tubelet_size: (2, 2, 2),
85            embed_dim: 32,
86            num_layers: 2,
87            num_heads: 4,
88            mlp_dim: 64,
89        }
90    }
91
92    /// Initialize a [`VitVideoEncoder`] module.
93    pub fn init<B: Backend>(&self, device: &B::Device) -> VitVideoEncoder<B> {
94        let tubelet_embed_config = TubeletEmbeddingConfig {
95            in_channels: self.in_channels,
96            tubelet_t: self.tubelet_size.0,
97            tubelet_h: self.tubelet_size.1,
98            tubelet_w: self.tubelet_size.2,
99            embed_dim: self.embed_dim,
100        };
101        let tubelet_embed = tubelet_embed_config.init(device);
102
103        let (gt, gh, gw) = self.grid_dims();
104        let rope = RotaryPositionEncoding3DConfig::new(self.embed_dim, gt, gh, gw).init(device);
105
106        let blocks: Vec<VideoTransformerBlock<B>> = (0..self.num_layers)
107            .map(|_| {
108                VideoTransformerBlockConfig {
109                    embed_dim: self.embed_dim,
110                    num_heads: self.num_heads,
111                    mlp_dim: self.mlp_dim,
112                }
113                .init(device)
114            })
115            .collect();
116
117        let norm = LayerNormConfig::new(self.embed_dim).init(device);
118
119        VitVideoEncoder {
120            tubelet_embed,
121            positional_encoding: rope,
122            blocks,
123            norm,
124            embed_dim: self.embed_dim,
125        }
126    }
127}
128
129/// Vision Transformer encoder for video.
130///
131/// Maps video clips to tubelet-level representations via:
132/// 1. Tubelet embedding (linear projection of 3D patches)
133/// 2. 3D Rotary Position Encoding (temporal + spatial)
134/// 3. Stack of transformer blocks
135/// 4. Final layer normalization
136///
137/// Output shape: `[batch, num_tubelets, embed_dim]`
138#[derive(Module, Debug)]
139pub struct VitVideoEncoder<B: Backend> {
140    /// Tubelet embedding: video → tubelet tokens.
141    tubelet_embed: TubeletEmbedding<B>,
142    /// 3D Rotary Position Encoding for spatiotemporal positions.
143    positional_encoding: RotaryPositionEncoding3D<B>,
144    /// Stack of transformer blocks.
145    blocks: Vec<VideoTransformerBlock<B>>,
146    /// Final layer normalization.
147    norm: LayerNorm<B>,
148    /// Output embedding dimension.
149    embed_dim: usize,
150}
151
152impl<B: Backend> VitVideoEncoder<B> {
153    fn positioned_tubelet_tokens(&self, video: &Tensor<B, 5>) -> Tensor<B, 3> {
154        // 1. Tubelet embedding
155        let x = self.tubelet_embed.forward(video.clone());
156
157        // 2. Apply 3D RoPE before masking so surviving tubelets keep their
158        // original spatiotemporal coordinates.
159        self.positional_encoding.forward(x)
160    }
161
162    fn encode_positioned_tokens(&self, mut x: Tensor<B, 3>) -> Representation<B> {
163        for block in &self.blocks {
164            x = block.forward(x);
165        }
166
167        x = self.norm.forward(x);
168
169        Representation::new(x)
170    }
171
172    /// Forward pass: video → representation.
173    ///
174    /// # Arguments
175    /// * `video` - Input video. Shape: `[batch, channels, frames, height, width]`
176    ///
177    /// # Returns
178    /// Tubelet-level representations. Shape: `[batch, num_tubelets, embed_dim]`
179    pub fn forward(&self, video: &Tensor<B, 5>) -> Representation<B> {
180        let x = self.positioned_tubelet_tokens(video);
181        self.encode_positioned_tokens(x)
182    }
183
184    /// Encode only the visible tubelets for strict JEPA context encoding.
185    pub fn forward_visible_tokens(
186        &self,
187        video: &Tensor<B, 5>,
188        visible_indices: &[usize],
189    ) -> Representation<B> {
190        let x = self.positioned_tubelet_tokens(video);
191        let x = gather_token_sequence(x, visible_indices);
192        self.encode_positioned_tokens(x)
193    }
194}
195
196impl<B: Backend> Encoder<B> for VitVideoEncoder<B> {
197    type Input = Tensor<B, 5>;
198
199    fn encode(&self, input: &Self::Input) -> Representation<B> {
200        self.forward(input)
201    }
202
203    fn embed_dim(&self) -> usize {
204        self.embed_dim
205    }
206}
207
208// --- Tubelet Embedding ---
209
210/// Configuration for tubelet embedding.
211#[derive(Debug, Clone)]
212struct TubeletEmbeddingConfig {
213    in_channels: usize,
214    tubelet_t: usize,
215    tubelet_h: usize,
216    tubelet_w: usize,
217    embed_dim: usize,
218}
219
220impl TubeletEmbeddingConfig {
221    fn init<B: Backend>(&self, device: &B::Device) -> TubeletEmbedding<B> {
222        let tubelet_dim = self.in_channels * self.tubelet_t * self.tubelet_h * self.tubelet_w;
223        let projection = LinearConfig::new(tubelet_dim, self.embed_dim).init(device);
224        TubeletEmbedding {
225            projection,
226            tubelet_t: self.tubelet_t,
227            tubelet_h: self.tubelet_h,
228            tubelet_w: self.tubelet_w,
229            in_channels: self.in_channels,
230        }
231    }
232}
233
234/// Tubelet embedding for video.
235///
236/// Splits a video into non-overlapping 3D tubelets (temporal × height × width)
237/// and projects each through a linear layer.
238///
239/// Input shape: `[batch, channels, frames, height, width]`
240/// Output shape: `[batch, num_tubelets, embed_dim]`
241#[derive(Module, Debug)]
242struct TubeletEmbedding<B: Backend> {
243    projection: Linear<B>,
244    tubelet_t: usize,
245    tubelet_h: usize,
246    tubelet_w: usize,
247    in_channels: usize,
248}
249
250impl<B: Backend> TubeletEmbedding<B> {
251    /// Convert a video batch to tubelet embeddings.
252    ///
253    /// # Arguments
254    /// * `video` - Input video. Shape: `[batch, channels, frames, height, width]`
255    fn forward(&self, video: Tensor<B, 5>) -> Tensor<B, 3> {
256        let [batch, _channels, frames, height, width] = video.dims();
257
258        let grid_t = frames / self.tubelet_t;
259        let grid_h = height / self.tubelet_h;
260        let grid_w = width / self.tubelet_w;
261        let num_tubelets = grid_t * grid_h * grid_w;
262        let tubelet_dim = self.in_channels * self.tubelet_t * self.tubelet_h * self.tubelet_w;
263
264        // NdArray supports max 6 dims, so we split into two steps:
265        // Step 1: Split temporal axis. [B, C, F, H, W] → [B, C, grid_t, tub_t, H, W]
266        let x = video.reshape([
267            batch,
268            self.in_channels,
269            grid_t,
270            self.tubelet_t,
271            height,
272            width,
273        ]);
274        // Permute to [B, grid_t, C, tub_t, H, W] then flatten: [B*grid_t, C*tub_t, H, W]
275        let x = x.permute([0, 2, 1, 3, 4, 5]);
276        let c_t = self.in_channels * self.tubelet_t;
277        let x: Tensor<B, 4> = x.reshape([batch * grid_t, c_t, height, width]);
278
279        // Step 2: Split spatial axes. [B*grid_t, C*tub_t, H, W] → [B*gt, C*tt, gh, th, gw, tw]
280        let x = x.reshape([
281            batch * grid_t,
282            c_t,
283            grid_h,
284            self.tubelet_h,
285            grid_w,
286            self.tubelet_w,
287        ]);
288        // Permute to [B*gt, gh, gw, C*tt, th, tw]
289        let x = x.permute([0, 2, 4, 1, 3, 5]);
290        // Flatten: [B*gt, gh*gw, C*tt*th*tw] then reshape to [B, gt*gh*gw, tubelet_dim]
291        let spatial_tubelets = grid_h * grid_w;
292        let x: Tensor<B, 3> = x.reshape([batch * grid_t, spatial_tubelets, tubelet_dim]);
293        let x = x.reshape([batch, num_tubelets, tubelet_dim]);
294
295        // Project: [B, num_tubelets, embed_dim]
296        self.projection.forward(x)
297    }
298}
299
300// --- 3D Rotary Position Encoding ---
301
302/// Configuration for 3D Rotary Position Encoding.
303#[derive(Debug, Clone)]
304pub struct RotaryPositionEncoding3DConfig {
305    /// Embedding dimension (must be divisible by 2).
306    pub embed_dim: usize,
307    /// Maximum temporal grid size.
308    pub max_t: usize,
309    /// Maximum spatial grid height.
310    pub max_h: usize,
311    /// Maximum spatial grid width.
312    pub max_w: usize,
313    /// Base frequency (default: 10000.0).
314    pub base_freq: f64,
315}
316
317impl RotaryPositionEncoding3DConfig {
318    /// Create a new config.
319    pub fn new(embed_dim: usize, max_t: usize, max_h: usize, max_w: usize) -> Self {
320        Self {
321            embed_dim,
322            max_t,
323            max_h,
324            max_w,
325            base_freq: 10000.0,
326        }
327    }
328
329    /// Initialize the 3D position encoding with precomputed sin/cos tables.
330    pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryPositionEncoding3D<B> {
331        let half_dim = self.embed_dim / 2;
332        // Divide half_dim into 3 parts for temporal, height, width
333        // If not perfectly divisible, temporal and height get one extra each
334        let sixth = half_dim / 3;
335        let dim_t = sixth + (half_dim % 3).min(1);
336        let dim_h = sixth + if half_dim % 3 >= 2 { 1 } else { 0 };
337        let dim_w = sixth;
338        debug_assert_eq!(dim_t + dim_h + dim_w, half_dim);
339
340        let max_seq = self.max_t * self.max_h * self.max_w;
341
342        // Compute frequency bands for each axis
343        let freqs_t = compute_freqs(dim_t, self.base_freq, half_dim);
344        let freqs_h = compute_freqs(dim_h, self.base_freq, half_dim);
345        let freqs_w = compute_freqs(dim_w, self.base_freq, half_dim);
346
347        let mut cos_data = vec![0.0f32; max_seq * half_dim];
348        let mut sin_data = vec![0.0f32; max_seq * half_dim];
349
350        for t in 0..self.max_t {
351            for h in 0..self.max_h {
352                for w in 0..self.max_w {
353                    let pos = t * self.max_h * self.max_w + h * self.max_w + w;
354                    let mut offset = 0;
355
356                    // Temporal frequencies
357                    for (i, &freq) in freqs_t.iter().enumerate() {
358                        let angle = t as f64 * freq;
359                        cos_data[pos * half_dim + offset + i] = angle.cos() as f32;
360                        sin_data[pos * half_dim + offset + i] = angle.sin() as f32;
361                    }
362                    offset += dim_t;
363
364                    // Height frequencies
365                    for (i, &freq) in freqs_h.iter().enumerate() {
366                        let angle = h as f64 * freq;
367                        cos_data[pos * half_dim + offset + i] = angle.cos() as f32;
368                        sin_data[pos * half_dim + offset + i] = angle.sin() as f32;
369                    }
370                    offset += dim_h;
371
372                    // Width frequencies
373                    for (i, &freq) in freqs_w.iter().enumerate() {
374                        let angle = w as f64 * freq;
375                        cos_data[pos * half_dim + offset + i] = angle.cos() as f32;
376                        sin_data[pos * half_dim + offset + i] = angle.sin() as f32;
377                    }
378                }
379            }
380        }
381
382        let cos_table = Tensor::from_floats(
383            burn::tensor::TensorData::new(cos_data, [max_seq, half_dim]),
384            device,
385        );
386        let sin_table = Tensor::from_floats(
387            burn::tensor::TensorData::new(sin_data, [max_seq, half_dim]),
388            device,
389        );
390
391        RotaryPositionEncoding3D {
392            cos_table,
393            sin_table,
394            embed_dim: self.embed_dim,
395        }
396    }
397}
398
399/// Compute frequency bands for one axis of the 3D RoPE.
400fn compute_freqs(num_freqs: usize, base_freq: f64, full_half_dim: usize) -> Vec<f64> {
401    (0..num_freqs)
402        .map(|i| 1.0 / base_freq.powf(2.0 * i as f64 / full_half_dim as f64))
403        .collect()
404}
405
406/// 3D Rotary Position Encoding for video.
407///
408/// Extends RoPE to three dimensions (temporal, height, width) by splitting
409/// the embedding dimension into three groups and applying separate rotary
410/// frequencies for each spatial/temporal axis.
411#[derive(Module, Debug)]
412pub struct RotaryPositionEncoding3D<B: Backend> {
413    /// Precomputed cosine table. Shape: `[max_seq, half_dim]`
414    cos_table: Tensor<B, 2>,
415    /// Precomputed sine table. Shape: `[max_seq, half_dim]`
416    sin_table: Tensor<B, 2>,
417    /// Full embedding dimension.
418    embed_dim: usize,
419}
420
421impl<B: Backend> RotaryPositionEncoding3D<B> {
422    /// Apply 3D rotary encoding to a tensor.
423    ///
424    /// # Arguments
425    /// * `x` - Input tensor. Shape: `[batch, seq_len, embed_dim]`
426    ///
427    /// # Returns
428    /// Rotated tensor with 3D position information encoded. Same shape as input.
429    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
430        let [batch, seq_len, _dim] = x.dims();
431        let half_dim = self.embed_dim / 2;
432
433        let cos = self.cos_table.clone().slice([0..seq_len, 0..half_dim]);
434        let sin = self.sin_table.clone().slice([0..seq_len, 0..half_dim]);
435
436        let cos = cos.unsqueeze::<3>().expand([batch, seq_len, half_dim]);
437        let sin = sin.unsqueeze::<3>().expand([batch, seq_len, half_dim]);
438
439        let x1 = x.clone().slice([0..batch, 0..seq_len, 0..half_dim]);
440        let x2 = x
441            .clone()
442            .slice([0..batch, 0..seq_len, half_dim..self.embed_dim]);
443
444        let out1 = x1.clone() * cos.clone() - x2.clone() * sin.clone();
445        let out2 = x1 * sin + x2 * cos;
446
447        Tensor::cat(vec![out1, out2], 2)
448    }
449}
450
451// --- Video Transformer Block ---
452
453#[derive(Debug, Clone)]
454struct VideoTransformerBlockConfig {
455    embed_dim: usize,
456    num_heads: usize,
457    mlp_dim: usize,
458}
459
460impl VideoTransformerBlockConfig {
461    fn init<B: Backend>(&self, device: &B::Device) -> VideoTransformerBlock<B> {
462        let head_dim = self.embed_dim / self.num_heads;
463        VideoTransformerBlock {
464            norm1: LayerNormConfig::new(self.embed_dim).init(device),
465            attn: VideoSelfAttention {
466                qkv: LinearConfig::new(self.embed_dim, 3 * self.embed_dim).init(device),
467                out_proj: LinearConfig::new(self.embed_dim, self.embed_dim).init(device),
468                num_heads: self.num_heads,
469                head_dim,
470            },
471            norm2: LayerNormConfig::new(self.embed_dim).init(device),
472            mlp: VideoMlp {
473                fc1: LinearConfig::new(self.embed_dim, self.mlp_dim).init(device),
474                fc2: LinearConfig::new(self.mlp_dim, self.embed_dim).init(device),
475            },
476        }
477    }
478}
479
480/// Pre-norm transformer block for video encoder.
481#[derive(Module, Debug)]
482struct VideoTransformerBlock<B: Backend> {
483    norm1: LayerNorm<B>,
484    attn: VideoSelfAttention<B>,
485    norm2: LayerNorm<B>,
486    mlp: VideoMlp<B>,
487}
488
489impl<B: Backend> VideoTransformerBlock<B> {
490    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
491        let residual = x.clone();
492        let x_norm = self.norm1.forward(x);
493        let attn_out = self.attn.forward(x_norm);
494        let x = residual + attn_out;
495
496        let residual = x.clone();
497        let x_norm = self.norm2.forward(x);
498        let mlp_out = self.mlp.forward(x_norm);
499        residual + mlp_out
500    }
501}
502
503/// Multi-head self-attention for video transformer.
504#[derive(Module, Debug)]
505struct VideoSelfAttention<B: Backend> {
506    qkv: Linear<B>,
507    out_proj: Linear<B>,
508    num_heads: usize,
509    head_dim: usize,
510}
511
512impl<B: Backend> VideoSelfAttention<B> {
513    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
514        let [batch, seq_len, _] = x.dims();
515        let embed_dim = self.num_heads * self.head_dim;
516
517        let qkv = self.qkv.forward(x);
518        let q = qkv.clone().slice([0..batch, 0..seq_len, 0..embed_dim]);
519        let k = qkv
520            .clone()
521            .slice([0..batch, 0..seq_len, embed_dim..2 * embed_dim]);
522        let v = qkv.slice([0..batch, 0..seq_len, 2 * embed_dim..3 * embed_dim]);
523
524        let q = q
525            .reshape([batch, seq_len, self.num_heads, self.head_dim])
526            .swap_dims(1, 2);
527        let k = k
528            .reshape([batch, seq_len, self.num_heads, self.head_dim])
529            .swap_dims(1, 2);
530        let v = v
531            .reshape([batch, seq_len, self.num_heads, self.head_dim])
532            .swap_dims(1, 2);
533
534        let scale = (self.head_dim as f64).sqrt();
535        let attn = q.matmul(k.transpose()) / scale;
536        let attn = burn::tensor::activation::softmax(attn, 3);
537        let out = attn.matmul(v);
538        let out = out.swap_dims(1, 2).reshape([batch, seq_len, embed_dim]);
539        self.out_proj.forward(out)
540    }
541}
542
543/// Two-layer MLP with GELU activation for video transformer.
544#[derive(Module, Debug)]
545struct VideoMlp<B: Backend> {
546    fc1: Linear<B>,
547    fc2: Linear<B>,
548}
549
550impl<B: Backend> VideoMlp<B> {
551    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
552        let x = self.fc1.forward(x);
553        let x = burn::tensor::activation::gelu(x);
554        self.fc2.forward(x)
555    }
556}
557
558/// V-JEPA model combining video encoder pair and predictor.
559///
560/// Provides a high-level interface for the V-JEPA video pipeline per RFC-002 and RFC-003.
561/// Uses spatiotemporal masking of tubelets for self-supervised learning on video.
562#[derive(Module, Debug)]
563pub struct VJepa<B: Backend> {
564    /// Context encoder — trained via gradient descent.
565    pub context_encoder: VitVideoEncoder<B>,
566    /// Target encoder — updated via EMA (no gradients).
567    pub target_encoder: VitVideoEncoder<B>,
568    /// Predictor — predicts target tubelet representations from context.
569    pub predictor: crate::image::TransformerPredictor<B>,
570}
571
572/// Output of a strict masked V-JEPA forward step.
573#[derive(Debug, Clone)]
574pub struct StrictVJepaForwardOutput<B: Backend> {
575    /// Prediction energy (main loss signal). Shape: `[1]`
576    pub energy: Energy<B>,
577    /// Collapse prevention regularization loss. Shape: `[1]`
578    pub regularization: Tensor<B, 1>,
579    /// Total loss (energy + weighted regularization). Shape: `[1]`
580    pub total_loss: Tensor<B, 1>,
581    /// The mask used for this step.
582    pub mask: MaskSpec,
583    /// Strictly encoded context representation.
584    pub context: Representation<B>,
585    /// Predicted target representations.
586    pub predicted: Representation<B>,
587    /// Actual target representations from the target encoder.
588    pub target: Representation<B>,
589}
590
591/// Errors returned by [`VJepa::try_forward_step_strict`].
592#[derive(Debug, Clone, thiserror::Error)]
593pub enum StrictVJepaError {
594    #[error(transparent)]
595    InvalidMask(#[from] MaskError),
596    #[error(transparent)]
597    Predictor(#[from] crate::image::PredictorError),
598}
599
600impl<B: Backend> VJepa<B> {
601    /// Encode only the visible tubelets before context self-attention runs.
602    ///
603    /// This method assumes `context_indices` are already valid for the current
604    /// tubelet grid. Use [`VJepa::try_forward_step_strict`] when the indices
605    /// come from caller-controlled masking data.
606    pub fn encode_context_strict(
607        &self,
608        video: &Tensor<B, 5>,
609        context_indices: &[usize],
610    ) -> Representation<B> {
611        self.context_encoder
612            .forward_visible_tokens(video, context_indices)
613    }
614
615    /// Execute a strict masked V-JEPA forward step.
616    ///
617    /// # Panics
618    ///
619    /// Panics if `mask` is invalid or if the predictor receives target
620    /// positions outside its configured capacity. Use
621    /// [`VJepa::try_forward_step_strict`] for typed error reporting.
622    pub fn forward_step_strict<EF, CR>(
623        &self,
624        video: &Tensor<B, 5>,
625        mask: MaskSpec,
626        energy_fn: &EF,
627        regularizer: &CR,
628        reg_weight: f64,
629    ) -> StrictVJepaForwardOutput<B>
630    where
631        EF: EnergyFn<B>,
632        CR: CollapseRegularizer<B>,
633    {
634        self.try_forward_step_strict(video, mask, energy_fn, regularizer, reg_weight)
635            .expect(
636                "VJepa::forward_step_strict failed — mask must be valid (disjoint, non-empty) \
637                 and target count must not exceed predictor capacity; \
638                 use try_forward_step_strict for error handling",
639            )
640    }
641
642    /// Execute a strict masked V-JEPA forward step with typed error reporting.
643    pub fn try_forward_step_strict<EF, CR>(
644        &self,
645        video: &Tensor<B, 5>,
646        mask: MaskSpec,
647        energy_fn: &EF,
648        regularizer: &CR,
649        reg_weight: f64,
650    ) -> Result<StrictVJepaForwardOutput<B>, StrictVJepaError>
651    where
652        EF: EnergyFn<B>,
653        CR: CollapseRegularizer<B>,
654    {
655        mask.validate()?;
656
657        let context = self.encode_context_strict(video, &mask.context_indices);
658        let target_full = self.target_encoder.forward(video);
659        let target = target_full.gather(&mask.target_indices);
660
661        let batch = video.dims()[0];
662        let target_positions = crate::image::target_positions_tensor::<B>(
663            &mask.target_indices,
664            batch,
665            &video.device(),
666        );
667        let predicted = self.predictor.try_predict(&context, &target_positions)?;
668
669        let num_targets = target.seq_len();
670        let embed_dim = target.embed_dim();
671        let pred_flat = predicted
672            .embeddings
673            .clone()
674            .reshape([batch * num_targets, embed_dim]);
675        let target_flat = target
676            .embeddings
677            .clone()
678            .reshape([batch * num_targets, embed_dim]);
679
680        let energy = energy_fn.compute(&predicted, &target);
681        let regularization = regularizer.loss(&pred_flat, &target_flat);
682        let total_loss = energy.value.clone() + regularization.clone() * reg_weight;
683
684        Ok(StrictVJepaForwardOutput {
685            energy,
686            regularization,
687            total_loss,
688            mask,
689            context,
690            predicted,
691            target,
692        })
693    }
694}
695
696/// Configuration for the V-JEPA model.
697#[derive(Debug, Clone)]
698pub struct VJepaConfig {
699    /// Video encoder config (shared by context and target encoders).
700    pub encoder: VitVideoConfig,
701    /// Predictor config.
702    pub predictor: crate::image::TransformerPredictorConfig,
703}
704
705impl VJepaConfig {
706    /// Create a tiny config suitable for testing.
707    pub fn tiny_test() -> Self {
708        let encoder = VitVideoConfig::tiny_test();
709        Self {
710            predictor: crate::image::TransformerPredictorConfig {
711                encoder_embed_dim: encoder.embed_dim,
712                predictor_embed_dim: 16,
713                num_layers: 1,
714                num_heads: 2,
715                max_target_len: 64,
716            },
717            encoder,
718        }
719    }
720
721    /// Initialize a [`VJepa`] model.
722    pub fn init<B: Backend>(&self, device: &B::Device) -> VJepa<B> {
723        VJepa {
724            context_encoder: self.encoder.init(device),
725            target_encoder: self.encoder.init(device),
726            predictor: self.predictor.init(device),
727        }
728    }
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734    use burn::tensor::ElementConversion;
735    use burn_ndarray::NdArray;
736    use jepa_core::Predictor;
737
738    type TestBackend = NdArray<f32>;
739
740    fn device() -> burn_ndarray::NdArrayDevice {
741        burn_ndarray::NdArrayDevice::Cpu
742    }
743
744    fn fixed_video_mask() -> MaskSpec {
745        MaskSpec {
746            context_indices: (0..16).collect(),
747            target_indices: (16..32).collect(),
748            total_tokens: 32,
749        }
750    }
751
752    fn video_with_hidden_tubelet_value(
753        mask: &MaskSpec,
754        hidden_value: f32,
755    ) -> Tensor<TestBackend, 5> {
756        let frames = 4usize;
757        let height = 8usize;
758        let width = 8usize;
759        let mut data = vec![1.0f32; frames * height * width];
760
761        for &index in &mask.target_indices {
762            let temporal_block = index / 16;
763            let spatial_index = index % 16;
764            let spatial_row = spatial_index / 4;
765            let spatial_col = spatial_index % 4;
766
767            let frame_start = temporal_block * 2;
768            let row_start = spatial_row * 2;
769            let col_start = spatial_col * 2;
770
771            for frame in frame_start..frame_start + 2 {
772                for row in row_start..row_start + 2 {
773                    for col in col_start..col_start + 2 {
774                        data[(frame * height + row) * width + col] = hidden_value;
775                    }
776                }
777            }
778        }
779
780        Tensor::from_floats(
781            burn::tensor::TensorData::new(data, [1, 1, frames, height, width]),
782            &device(),
783        )
784    }
785
786    #[test]
787    fn test_vit_video_output_shape() {
788        let config = VitVideoConfig::tiny_test();
789        let encoder = config.init::<TestBackend>(&device());
790
791        // [batch=2, channels=1, frames=4, height=8, width=8]
792        let video: Tensor<TestBackend, 5> = Tensor::zeros([2, 1, 4, 8, 8], &device());
793        let repr = encoder.forward(&video);
794
795        // grid: (4/2, 8/2, 8/2) = (2, 4, 4) = 32 tubelets
796        assert_eq!(repr.batch_size(), 2);
797        assert_eq!(repr.seq_len(), 32);
798        assert_eq!(repr.embed_dim(), 32);
799    }
800
801    #[test]
802    fn test_vit_video_encoder_trait() {
803        let config = VitVideoConfig::tiny_test();
804        let encoder = config.init::<TestBackend>(&device());
805
806        let video: Tensor<TestBackend, 5> = Tensor::zeros([1, 1, 4, 8, 8], &device());
807        let repr = Encoder::encode(&encoder, &video);
808
809        assert_eq!(repr.batch_size(), 1);
810        assert_eq!(repr.seq_len(), 32);
811        assert_eq!(encoder.embed_dim(), 32);
812    }
813
814    #[test]
815    fn test_vit_video_different_inputs_different_outputs() {
816        let config = VitVideoConfig::tiny_test();
817        let encoder = config.init::<TestBackend>(&device());
818
819        let a: Tensor<TestBackend, 5> = Tensor::zeros([1, 1, 4, 8, 8], &device());
820        let b: Tensor<TestBackend, 5> = Tensor::ones([1, 1, 4, 8, 8], &device());
821
822        let repr_a = encoder.forward(&a);
823        let repr_b = encoder.forward(&b);
824
825        let diff: f32 = (repr_a.embeddings - repr_b.embeddings)
826            .abs()
827            .sum()
828            .into_scalar()
829            .elem();
830        assert!(
831            diff > 1e-6,
832            "different video inputs should produce different representations"
833        );
834    }
835
836    #[test]
837    fn test_tubelet_embedding_shape() {
838        let config = TubeletEmbeddingConfig {
839            in_channels: 3,
840            tubelet_t: 2,
841            tubelet_h: 16,
842            tubelet_w: 16,
843            embed_dim: 256,
844        };
845        let embed = config.init::<TestBackend>(&device());
846
847        // 16 frames, 224x224
848        let video: Tensor<TestBackend, 5> = Tensor::zeros([1, 3, 16, 224, 224], &device());
849        let out = embed.forward(video);
850
851        // grid: (16/2, 224/16, 224/16) = (8, 14, 14) = 1568 tubelets
852        assert_eq!(out.dims(), [1, 1568, 256]);
853    }
854
855    #[test]
856    fn test_rope3d_output_shape() {
857        let config = RotaryPositionEncoding3DConfig::new(64, 2, 4, 4);
858        let rope = config.init::<TestBackend>(&device());
859
860        let x: Tensor<TestBackend, 3> = Tensor::ones([2, 32, 64], &device());
861        let out = rope.forward(x);
862        assert_eq!(out.dims(), [2, 32, 64]);
863    }
864
865    #[test]
866    fn test_rope3d_preserves_norm() {
867        let config = RotaryPositionEncoding3DConfig::new(32, 2, 4, 4);
868        let rope = config.init::<TestBackend>(&device());
869
870        let x: Tensor<TestBackend, 3> = Tensor::random(
871            [1, 32, 32],
872            burn::tensor::Distribution::Normal(0.0, 1.0),
873            &device(),
874        );
875
876        let x_norm: f32 = (x.clone() * x.clone()).sum().into_scalar().elem();
877        let out = rope.forward(x);
878        let out_norm: f32 = (out.clone() * out.clone()).sum().into_scalar().elem();
879
880        let ratio = out_norm / x_norm;
881        assert!(
882            (ratio - 1.0).abs() < 0.01,
883            "3D RoPE should approximately preserve norm, ratio: {ratio}"
884        );
885    }
886
887    #[test]
888    fn test_rope3d_different_positions_give_different_outputs() {
889        let config = RotaryPositionEncoding3DConfig::new(16, 2, 2, 2);
890        let rope = config.init::<TestBackend>(&device());
891
892        let x: Tensor<TestBackend, 3> = Tensor::ones([1, 8, 16], &device());
893        let out = rope.forward(x);
894
895        // Positions 0 and 1 should differ (different temporal/spatial positions)
896        let pos0 = out.clone().slice([0..1, 0..1, 0..16]);
897        let pos1 = out.clone().slice([0..1, 1..2, 0..16]);
898
899        let diff: f32 = (pos0 - pos1).abs().sum().into_scalar().elem();
900        assert!(
901            diff > 1e-6,
902            "different 3D positions should produce different outputs"
903        );
904    }
905
906    #[test]
907    fn test_video_config_grid_dims() {
908        let config = VitVideoConfig {
909            in_channels: 3,
910            num_frames: 16,
911            frame_height: 224,
912            frame_width: 224,
913            tubelet_size: (2, 16, 16),
914            embed_dim: 768,
915            num_layers: 12,
916            num_heads: 12,
917            mlp_dim: 3072,
918        };
919        assert_eq!(config.grid_dims(), (8, 14, 14));
920        assert_eq!(config.num_tubelets(), 1568);
921    }
922
923    #[test]
924    fn test_video_transformer_block_residual() {
925        let block = VideoTransformerBlockConfig {
926            embed_dim: 16,
927            num_heads: 2,
928            mlp_dim: 32,
929        }
930        .init::<TestBackend>(&device());
931
932        let x: Tensor<TestBackend, 3> = Tensor::zeros([1, 8, 16], &device());
933        let out = block.forward(x);
934        assert_eq!(out.dims(), [1, 8, 16]);
935    }
936
937    #[test]
938    fn test_video_self_attention_shape() {
939        let attn = VideoSelfAttention {
940            qkv: LinearConfig::new(16, 48).init::<TestBackend>(&device()),
941            out_proj: LinearConfig::new(16, 16).init::<TestBackend>(&device()),
942            num_heads: 4,
943            head_dim: 4,
944        };
945
946        let x: Tensor<TestBackend, 3> = Tensor::zeros([2, 8, 16], &device());
947        let out = attn.forward(x);
948        assert_eq!(out.dims(), [2, 8, 16]);
949    }
950
951    use proptest::prelude::*;
952
953    proptest! {
954        #[test]
955        fn prop_video_config_num_tubelets(
956            grid_t in 1usize..4,
957            grid_h in 1usize..4,
958            grid_w in 1usize..4,
959        ) {
960            let tub = 2;
961            let config = VitVideoConfig {
962                in_channels: 1,
963                num_frames: grid_t * tub,
964                frame_height: grid_h * tub,
965                frame_width: grid_w * tub,
966                tubelet_size: (tub, tub, tub),
967                embed_dim: 16,
968                num_layers: 1,
969                num_heads: 2,
970                mlp_dim: 32,
971            };
972            prop_assert_eq!(config.grid_dims(), (grid_t, grid_h, grid_w));
973            prop_assert_eq!(config.num_tubelets(), grid_t * grid_h * grid_w);
974        }
975
976        #[test]
977        fn prop_rope3d_preserves_shape(
978            max_t in 1usize..3,
979            max_h in 1usize..3,
980            max_w in 1usize..3,
981        ) {
982            let embed_dim = 12; // divisible by 2, and 6/3=2 per axis
983            let config = RotaryPositionEncoding3DConfig::new(embed_dim, max_t, max_h, max_w);
984            let rope = config.init::<TestBackend>(&device());
985            let seq_len = max_t * max_h * max_w;
986            let x: Tensor<TestBackend, 3> = Tensor::ones([1, seq_len, embed_dim], &device());
987            let out = rope.forward(x);
988            prop_assert_eq!(out.dims(), [1, seq_len, embed_dim]);
989        }
990
991        #[test]
992        fn prop_rope3d_preserves_norm(
993            max_t in 1usize..3,
994            max_h in 2usize..4,
995            max_w in 2usize..4,
996        ) {
997            let embed_dim = 12;
998            let config = RotaryPositionEncoding3DConfig::new(embed_dim, max_t, max_h, max_w);
999            let rope = config.init::<TestBackend>(&device());
1000            let seq_len = max_t * max_h * max_w;
1001            let x: Tensor<TestBackend, 3> = Tensor::random(
1002                [1, seq_len, embed_dim],
1003                burn::tensor::Distribution::Normal(0.0, 1.0),
1004                &device(),
1005            );
1006            let x_norm: f32 = (x.clone() * x.clone()).sum().into_scalar().elem();
1007            let out = rope.forward(x);
1008            let out_norm: f32 = (out.clone() * out.clone()).sum().into_scalar().elem();
1009            let ratio = out_norm / x_norm;
1010            prop_assert!((ratio - 1.0).abs() < 0.01, "3D RoPE norm ratio: {}", ratio);
1011        }
1012    }
1013
1014    // ======================================================================
1015    // V-JEPA model tests
1016    // ======================================================================
1017
1018    #[test]
1019    fn test_vjepa_config_tiny() {
1020        let config = VJepaConfig::tiny_test();
1021        assert_eq!(config.encoder.embed_dim, 32);
1022        assert_eq!(config.predictor.predictor_embed_dim, 16);
1023        assert_eq!(config.predictor.encoder_embed_dim, 32);
1024    }
1025
1026    #[test]
1027    fn test_vjepa_model_init() {
1028        let config = VJepaConfig::tiny_test();
1029        let model = config.init::<TestBackend>(&device());
1030
1031        assert_eq!(model.context_encoder.embed_dim, 32);
1032        assert_eq!(model.target_encoder.embed_dim, 32);
1033    }
1034
1035    #[test]
1036    fn test_strict_video_context_encoding_ignores_hidden_tubelets() {
1037        let config = VJepaConfig::tiny_test();
1038        let model = config.init::<TestBackend>(&device());
1039        let mask = fixed_video_mask();
1040
1041        let hidden_low = video_with_hidden_tubelet_value(&mask, 0.0);
1042        let hidden_high = video_with_hidden_tubelet_value(&mask, 1_000.0);
1043
1044        let strict_low = model.encode_context_strict(&hidden_low, &mask.context_indices);
1045        let strict_high = model.encode_context_strict(&hidden_high, &mask.context_indices);
1046
1047        let diff: f32 = (strict_low.embeddings - strict_high.embeddings)
1048            .abs()
1049            .sum()
1050            .into_scalar()
1051            .elem();
1052        assert!(
1053            diff < 1e-5,
1054            "strict masked video context should ignore hidden tubelets, diff={diff}"
1055        );
1056    }
1057
1058    #[test]
1059    fn test_full_video_encoder_context_slice_leaks_hidden_tubelets() {
1060        let config = VitVideoConfig::tiny_test();
1061        let encoder = config.init::<TestBackend>(&device());
1062        let mask = fixed_video_mask();
1063
1064        let hidden_low = video_with_hidden_tubelet_value(&mask, 0.0);
1065        let hidden_high = video_with_hidden_tubelet_value(&mask, 1_000.0);
1066
1067        let approx_low = encoder.forward(&hidden_low).gather(&mask.context_indices);
1068        let approx_high = encoder.forward(&hidden_high).gather(&mask.context_indices);
1069
1070        let diff: f32 = (approx_low.embeddings - approx_high.embeddings)
1071            .abs()
1072            .sum()
1073            .into_scalar()
1074            .elem();
1075        assert!(
1076            diff > 1e-3,
1077            "post-encoder gather path should leak hidden tubelets, diff={diff}"
1078        );
1079    }
1080
1081    #[test]
1082    fn test_strict_video_forward_step_runs_end_to_end() {
1083        let config = VJepaConfig::tiny_test();
1084        let model = config.init::<TestBackend>(&device());
1085        let mask = fixed_video_mask();
1086        let video = video_with_hidden_tubelet_value(&mask, 5.0);
1087        let energy_fn = jepa_core::energy::L2Energy;
1088        let regularizer = jepa_core::collapse::VICReg::default();
1089
1090        let output = model.forward_step_strict(&video, mask.clone(), &energy_fn, &regularizer, 1.0);
1091
1092        assert_eq!(output.context.seq_len(), mask.context_indices.len());
1093        assert_eq!(output.predicted.seq_len(), mask.target_indices.len());
1094        assert_eq!(output.target.seq_len(), mask.target_indices.len());
1095
1096        let total_loss: f32 = output.total_loss.into_scalar().elem();
1097        assert!(
1098            total_loss.is_finite(),
1099            "strict video forward loss should be finite"
1100        );
1101    }
1102
1103    #[test]
1104    fn test_try_strict_video_forward_step_rejects_invalid_mask() {
1105        let config = VJepaConfig::tiny_test();
1106        let model = config.init::<TestBackend>(&device());
1107        let video = Tensor::ones([1, 1, 4, 8, 8], &device());
1108        let invalid_mask = MaskSpec {
1109            context_indices: vec![0],
1110            target_indices: vec![],
1111            total_tokens: 32,
1112        };
1113        let energy_fn = jepa_core::energy::L2Energy;
1114        let regularizer = jepa_core::collapse::VICReg::default();
1115
1116        let err = model
1117            .try_forward_step_strict(&video, invalid_mask, &energy_fn, &regularizer, 1.0)
1118            .unwrap_err();
1119        assert!(matches!(
1120            err,
1121            StrictVJepaError::InvalidMask(MaskError::EmptyTarget)
1122        ));
1123    }
1124
1125    /// BDD: "V-JEPA full pipeline with spatiotemporal masking"
1126    /// Given a V-JEPA model with video encoder pair and predictor
1127    /// When I encode a video clip, generate a spatiotemporal mask,
1128    ///   gather target tubelets, and predict from context
1129    /// Then the energy should be finite and non-negative
1130    #[test]
1131    fn bdd_vjepa_full_pipeline_with_spatiotemporal_masking() {
1132        use jepa_core::{CollapseRegularizer, EnergyFn, MaskingStrategy};
1133        use rand::SeedableRng;
1134
1135        let config = VJepaConfig::tiny_test();
1136        let model = config.init::<TestBackend>(&device());
1137
1138        // Video: [batch=1, channels=1, frames=4, height=8, width=8]
1139        let video: Tensor<TestBackend, 5> = Tensor::random(
1140            [1, 1, 4, 8, 8],
1141            burn::tensor::Distribution::Normal(0.0, 1.0),
1142            &device(),
1143        );
1144
1145        // 1. Encode with both encoders
1146        let context_repr = model.context_encoder.forward(&video);
1147        let target_repr = model.target_encoder.forward(&video);
1148
1149        // grid: (4/2, 8/2, 8/2) = (2, 4, 4) = 32 tubelets
1150        assert_eq!(context_repr.seq_len(), 32);
1151        assert_eq!(target_repr.seq_len(), 32);
1152
1153        // 2. Generate spatiotemporal mask
1154        let masking = jepa_core::masking::SpatiotemporalMasking {
1155            num_targets: 2,
1156            temporal_extent: (1, 2),
1157            spatial_scale: (0.1, 0.2),
1158        };
1159        let shape = jepa_core::types::InputShape::Video {
1160            frames: 2,
1161            height: 4,
1162            width: 4,
1163        };
1164        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
1165        let mask = masking.generate_mask(&shape, &mut rng);
1166        assert!(mask.validate().is_ok());
1167        assert_eq!(mask.context_indices.len() + mask.target_indices.len(), 32);
1168
1169        // 3. Gather target tubelets
1170        let target_gathered = target_repr.gather(&mask.target_indices);
1171        assert_eq!(target_gathered.seq_len(), mask.target_indices.len());
1172
1173        // 4. Predict targets from context
1174        let num_targets = mask.target_indices.len();
1175        let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([1, num_targets], &device());
1176        let predicted = model.predictor.predict(&context_repr, &target_pos, None);
1177        assert_eq!(predicted.seq_len(), num_targets);
1178        assert_eq!(predicted.embed_dim(), 32);
1179
1180        // 5. Compute energy
1181        let energy = jepa_core::energy::L2Energy.compute(&predicted, &target_gathered);
1182        let val: f32 = energy.value.into_scalar().elem();
1183        assert!(val.is_finite(), "energy should be finite, got {val}");
1184        assert!(val >= 0.0, "L2 energy should be non-negative, got {val}");
1185
1186        // 6. Collapse regularization
1187        let embed_dim = predicted.embed_dim();
1188        let pred_flat = predicted.embeddings.reshape([num_targets, embed_dim]);
1189        let target_flat = target_gathered.embeddings.reshape([num_targets, embed_dim]);
1190        let reg: f32 = jepa_core::collapse::VICReg::default()
1191            .loss(&pred_flat, &target_flat)
1192            .into_scalar()
1193            .elem();
1194        assert!(
1195            reg.is_finite(),
1196            "regularization should be finite, got {reg}"
1197        );
1198    }
1199}