Skip to main content

jepa_vision/
image.rs

1//! I-JEPA (Image Joint Embedding Predictive Architecture) pipeline.
2//!
3//! Implements the complete I-JEPA model for self-supervised image learning,
4//! following Assran et al. (2023), *Self-Supervised Learning from Images
5//! with a Joint-Embedding Predictive Architecture*, CVPR.
6//!
7//! ## Components
8//!
9//! | Component | Struct | Role |
10//! |-----------|--------|------|
11//! | Context encoder | [`VitEncoder`](crate::vit::VitEncoder) | Encodes visible (context) patches with gradients |
12//! | Target encoder | [`VitEncoder`](crate::vit::VitEncoder) | Encodes target patches; weights are an EMA copy — **no gradients** |
13//! | Predictor | [`TransformerPredictor`] | Narrow cross-attention transformer that predicts target representations from context |
14//! | Masking | [`BlockMasking`](jepa_core::masking::BlockMasking) | Generates contiguous rectangular target blocks |
15//!
16//! ## Strict forward step
17//!
18//! [`IJepa::forward_step_strict`] implements the full masked training
19//! forward pass with pre-encoder token filtering, matching the reference
20//! PyTorch implementation. Use this path when you need exact parity
21//! with published I-JEPA results.
22
23use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};
24use burn::prelude::*;
25use burn::tensor::backend::Backend;
26use burn::tensor::module::embedding;
27
28use jepa_core::types::{Energy, MaskError, MaskSpec, Representation};
29use jepa_core::{CollapseRegularizer, EnergyFn, Predictor};
30
31/// Configuration for the transformer predictor.
32///
33/// # Example
34///
35/// ```
36/// use jepa_vision::image::TransformerPredictorConfig;
37/// use jepa_core::types::Representation;
38/// use jepa_core::Predictor;
39/// use burn_ndarray::NdArray;
40/// use burn::prelude::*;
41///
42/// type B = NdArray<f32>;
43/// let device = burn_ndarray::NdArrayDevice::Cpu;
44///
45/// let config = TransformerPredictorConfig {
46///     encoder_embed_dim: 32,
47///     predictor_embed_dim: 16,
48///     num_layers: 1,
49///     num_heads: 2,
50///     max_target_len: 64,
51/// };
52/// let predictor = config.init::<B>(&device);
53///
54/// let context = Representation::new(Tensor::zeros([1, 8, 32], &device));
55/// let target_pos: Tensor<B, 2> = Tensor::zeros([1, 4], &device);
56/// let predicted = predictor.predict(&context, &target_pos, None);
57/// assert_eq!(predicted.seq_len(), 4);
58/// ```
59#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
60pub struct TransformerPredictorConfig {
61    /// Input embedding dimension (from encoder output).
62    pub encoder_embed_dim: usize,
63    /// Predictor internal embedding dimension.
64    pub predictor_embed_dim: usize,
65    /// Number of predictor transformer layers.
66    pub num_layers: usize,
67    /// Number of attention heads in the predictor.
68    pub num_heads: usize,
69    /// Maximum flattened token position supported by the predictor.
70    ///
71    /// Set this to the encoder token count, not just the number of masked
72    /// targets in a single training step.
73    pub max_target_len: usize,
74}
75
76impl TransformerPredictorConfig {
77    /// Initialize a [`TransformerPredictor`] module.
78    pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerPredictor<B> {
79        let input_proj =
80            LinearConfig::new(self.encoder_embed_dim, self.predictor_embed_dim).init(device);
81        let output_proj =
82            LinearConfig::new(self.predictor_embed_dim, self.encoder_embed_dim).init(device);
83
84        let blocks: Vec<PredictorBlock<B>> = (0..self.num_layers)
85            .map(|_| {
86                PredictorBlockConfig {
87                    embed_dim: self.predictor_embed_dim,
88                    num_heads: self.num_heads,
89                }
90                .init(device)
91            })
92            .collect();
93
94        let norm = LayerNormConfig::new(self.predictor_embed_dim).init(device);
95
96        let prediction_tokens =
97            sinusoidal_prediction_tokens(self.max_target_len, self.predictor_embed_dim, device);
98
99        TransformerPredictor {
100            input_proj,
101            output_proj,
102            blocks,
103            norm,
104            prediction_tokens,
105            predictor_embed_dim: self.predictor_embed_dim,
106            encoder_embed_dim: self.encoder_embed_dim,
107        }
108    }
109}
110
111/// Transformer-based predictor for I-JEPA.
112///
113/// Predicts target representations from context representations using
114/// attention over concatenated context tokens and position-conditioned
115/// prediction tokens.
116///
117/// Architecture:
118/// 1. Project context to predictor dimension
119/// 2. Build position-conditioned prediction tokens for the requested targets
120/// 3. Concatenate prediction tokens with context
121/// 4. Apply self-attention transformer blocks
122/// 5. Extract prediction token outputs
123/// 6. Project back to encoder dimension
124#[derive(Module, Debug)]
125pub struct TransformerPredictor<B: Backend> {
126    /// Project encoder output to predictor dimension.
127    input_proj: Linear<B>,
128    /// Project predictor output back to encoder dimension.
129    output_proj: Linear<B>,
130    /// Transformer blocks for the predictor.
131    blocks: Vec<PredictorBlock<B>>,
132    /// Final layer norm.
133    norm: LayerNorm<B>,
134    /// Position-conditioned prediction token table. Shape: `[max_position, predictor_embed_dim]`
135    prediction_tokens: Tensor<B, 2>,
136    /// Predictor embedding dimension.
137    predictor_embed_dim: usize,
138    /// Encoder embedding dimension (output dimension).
139    encoder_embed_dim: usize,
140}
141
142/// Errors returned by [`TransformerPredictor::try_predict`].
143#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)]
144pub enum PredictorError {
145    #[error(
146        "target position batch size mismatch: context batch={context_batch}, target_positions batch={positions_batch}"
147    )]
148    BatchSizeMismatch {
149        context_batch: usize,
150        positions_batch: usize,
151    },
152    #[error("target position must be non-negative, got {0}")]
153    NegativeTargetPosition(i64),
154    #[error(
155        "target position {position} exceeds predictor capacity {max_supported}; increase max_target_len"
156    )]
157    TargetPositionOutOfRange {
158        position: usize,
159        max_supported: usize,
160    },
161}
162
163impl<B: Backend> Predictor<B> for TransformerPredictor<B> {
164    fn predict(
165        &self,
166        context: &Representation<B>,
167        target_positions: &Tensor<B, 2>,
168        _latent: Option<&Tensor<B, 2>>,
169    ) -> Representation<B> {
170        self.try_predict(context, target_positions).expect(
171            "TransformerPredictor::predict failed — target positions must match the context \
172             batch size and not exceed max_target_len; use try_predict for error handling",
173        )
174    }
175}
176
177impl<B: Backend> TransformerPredictor<B> {
178    /// Fallible predictor path for caller-controlled target positions.
179    pub fn try_predict(
180        &self,
181        context: &Representation<B>,
182        target_positions: &Tensor<B, 2>,
183    ) -> Result<Representation<B>, PredictorError> {
184        let [batch, _ctx_len, _enc_dim] = context.embeddings.dims();
185        let [positions_batch, num_targets] = target_positions.dims();
186        if positions_batch != batch {
187            return Err(PredictorError::BatchSizeMismatch {
188                context_batch: batch,
189                positions_batch,
190            });
191        }
192
193        if num_targets == 0 {
194            let device = context.embeddings.device();
195            return Ok(Representation::new(Tensor::zeros(
196                [batch, 0, self.encoder_embed_dim],
197                &device,
198            )));
199        }
200
201        let target_positions = target_positions.clone().int();
202        let min_position: i64 = target_positions.clone().min().into_scalar().elem();
203        if min_position < 0 {
204            return Err(PredictorError::NegativeTargetPosition(min_position));
205        }
206
207        let max_position: i64 = target_positions.clone().max().into_scalar().elem();
208        let max_supported_position = self.prediction_tokens.dims()[0];
209        if max_position >= max_supported_position as i64 {
210            return Err(PredictorError::TargetPositionOutOfRange {
211                position: max_position as usize,
212                max_supported: max_supported_position,
213            });
214        }
215
216        // 1. Project context to predictor dimension
217        let ctx = self.input_proj.forward(context.embeddings.clone());
218
219        // 2. Select prediction tokens using the actual target positions.
220        let pred_tokens = embedding(self.prediction_tokens.clone(), target_positions);
221
222        // 3. Concatenate context + prediction tokens: [batch, ctx_len + num_targets, dim]
223        let combined = Tensor::cat(vec![ctx, pred_tokens], 1);
224        let ctx_len = context.embeddings.dims()[1];
225        let total_len = ctx_len + num_targets;
226
227        // 4. Apply transformer blocks
228        let mut x = combined;
229        for block in &self.blocks {
230            x = block.forward(x);
231        }
232
233        // 5. Extract prediction token outputs (last num_targets positions)
234        let pred_out = x.slice([0..batch, ctx_len..total_len, 0..self.predictor_embed_dim]);
235
236        // 6. Normalize and project back to encoder dimension
237        let pred_out = self.norm.forward(pred_out);
238        let pred_out = self.output_proj.forward(pred_out);
239
240        Ok(Representation::new(pred_out))
241    }
242}
243
244fn sinusoidal_prediction_tokens<B: Backend>(
245    max_target_len: usize,
246    embed_dim: usize,
247    device: &B::Device,
248) -> Tensor<B, 2> {
249    let mut data = vec![0.0f32; max_target_len * embed_dim];
250
251    for position in 0..max_target_len {
252        for dim in 0..embed_dim {
253            let exponent = (2 * (dim / 2)) as f64 / embed_dim as f64;
254            let angle = position as f64 / 10_000_f64.powf(exponent);
255            data[position * embed_dim + dim] = if dim % 2 == 0 {
256                angle.sin() as f32
257            } else {
258                angle.cos() as f32
259            };
260        }
261    }
262
263    Tensor::from_floats(
264        burn::tensor::TensorData::new(data, [max_target_len, embed_dim]),
265        device,
266    )
267}
268
269// --- Predictor Transformer Block ---
270
271#[derive(Debug, Clone)]
272struct PredictorBlockConfig {
273    embed_dim: usize,
274    num_heads: usize,
275}
276
277impl PredictorBlockConfig {
278    fn init<B: Backend>(&self, device: &B::Device) -> PredictorBlock<B> {
279        let head_dim = self.embed_dim / self.num_heads;
280        PredictorBlock {
281            norm1: LayerNormConfig::new(self.embed_dim).init(device),
282            attn: PredictorAttention {
283                qkv: LinearConfig::new(self.embed_dim, 3 * self.embed_dim).init(device),
284                out_proj: LinearConfig::new(self.embed_dim, self.embed_dim).init(device),
285                num_heads: self.num_heads,
286                head_dim,
287            },
288            norm2: LayerNormConfig::new(self.embed_dim).init(device),
289            mlp: PredictorMlp {
290                fc1: LinearConfig::new(self.embed_dim, self.embed_dim * 4).init(device),
291                fc2: LinearConfig::new(self.embed_dim * 4, self.embed_dim).init(device),
292            },
293        }
294    }
295}
296
297#[derive(Module, Debug)]
298struct PredictorBlock<B: Backend> {
299    norm1: LayerNorm<B>,
300    attn: PredictorAttention<B>,
301    norm2: LayerNorm<B>,
302    mlp: PredictorMlp<B>,
303}
304
305impl<B: Backend> PredictorBlock<B> {
306    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
307        let residual = x.clone();
308        let x_norm = self.norm1.forward(x);
309        let attn_out = self.attn.forward(x_norm);
310        let x = residual + attn_out;
311
312        let residual = x.clone();
313        let x_norm = self.norm2.forward(x);
314        let mlp_out = self.mlp.forward(x_norm);
315        residual + mlp_out
316    }
317}
318
319#[derive(Module, Debug)]
320struct PredictorAttention<B: Backend> {
321    qkv: Linear<B>,
322    out_proj: Linear<B>,
323    num_heads: usize,
324    head_dim: usize,
325}
326
327impl<B: Backend> PredictorAttention<B> {
328    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
329        let [batch, seq_len, _] = x.dims();
330        let embed_dim = self.num_heads * self.head_dim;
331
332        let qkv = self.qkv.forward(x);
333        let q = qkv.clone().slice([0..batch, 0..seq_len, 0..embed_dim]);
334        let k = qkv
335            .clone()
336            .slice([0..batch, 0..seq_len, embed_dim..2 * embed_dim]);
337        let v = qkv.slice([0..batch, 0..seq_len, 2 * embed_dim..3 * embed_dim]);
338
339        let q = q
340            .reshape([batch, seq_len, self.num_heads, self.head_dim])
341            .swap_dims(1, 2);
342        let k = k
343            .reshape([batch, seq_len, self.num_heads, self.head_dim])
344            .swap_dims(1, 2);
345        let v = v
346            .reshape([batch, seq_len, self.num_heads, self.head_dim])
347            .swap_dims(1, 2);
348
349        let scale = (self.head_dim as f64).sqrt();
350        let attn = q.matmul(k.transpose()) / scale;
351        let attn = burn::tensor::activation::softmax(attn, 3);
352        let out = attn.matmul(v);
353        let out = out.swap_dims(1, 2).reshape([batch, seq_len, embed_dim]);
354        self.out_proj.forward(out)
355    }
356}
357
358#[derive(Module, Debug)]
359struct PredictorMlp<B: Backend> {
360    fc1: Linear<B>,
361    fc2: Linear<B>,
362}
363
364impl<B: Backend> PredictorMlp<B> {
365    fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
366        let x = self.fc1.forward(x);
367        let x = burn::tensor::activation::gelu(x);
368        self.fc2.forward(x)
369    }
370}
371
372/// I-JEPA model combining encoder pair and predictor.
373///
374/// Provides a high-level interface for the I-JEPA pipeline per RFC-002 and RFC-003.
375#[derive(Module, Debug)]
376pub struct IJepa<B: Backend> {
377    /// Context encoder — trained via gradient descent.
378    pub context_encoder: crate::vit::VitEncoder<B>,
379    /// Target encoder — updated via EMA (no gradients).
380    pub target_encoder: crate::vit::VitEncoder<B>,
381    /// Predictor — predicts target representations from context.
382    pub predictor: TransformerPredictor<B>,
383}
384
385/// Output of a strict masked I-JEPA forward step.
386///
387/// Unlike the generic trainer helper, the context representation is produced
388/// from visible tokens only, so hidden target patches never participate in
389/// context self-attention.
390#[derive(Debug, Clone)]
391pub struct StrictIJepaForwardOutput<B: Backend> {
392    /// Prediction energy (main loss signal). Shape: `[1]`
393    pub energy: Energy<B>,
394    /// Collapse prevention regularization loss. Shape: `[1]`
395    pub regularization: Tensor<B, 1>,
396    /// Total loss (energy + weighted regularization). Shape: `[1]`
397    pub total_loss: Tensor<B, 1>,
398    /// The mask used for this step.
399    pub mask: MaskSpec,
400    /// Strictly encoded context representation.
401    pub context: Representation<B>,
402    /// Predicted target representations.
403    pub predicted: Representation<B>,
404    /// Actual target representations from the target encoder.
405    pub target: Representation<B>,
406}
407
408/// Errors returned by [`IJepa::try_forward_step_strict`].
409#[derive(Debug, Clone, thiserror::Error)]
410pub enum StrictIJepaError {
411    #[error(transparent)]
412    InvalidMask(#[from] MaskError),
413    #[error(transparent)]
414    Predictor(#[from] PredictorError),
415}
416
417impl<B: Backend> IJepa<B> {
418    /// Encode only visible context patches before self-attention runs.
419    ///
420    /// This method assumes `context_indices` are already valid for the current
421    /// image grid. Use [`IJepa::try_forward_step_strict`] when the indices come
422    /// from caller-controlled masking data.
423    pub fn encode_context_strict(
424        &self,
425        images: &Tensor<B, 4>,
426        context_indices: &[usize],
427    ) -> Representation<B> {
428        self.context_encoder
429            .forward_visible_tokens(images, context_indices)
430    }
431
432    /// Execute a strict masked I-JEPA forward step.
433    ///
434    /// The target encoder still sees the full input, but the context encoder is
435    /// restricted to visible patches before any attention mixing occurs.
436    ///
437    /// # Panics
438    ///
439    /// Panics if `mask` is invalid or if the predictor receives target
440    /// positions outside its configured capacity. Use
441    /// [`IJepa::try_forward_step_strict`] for typed error reporting.
442    pub fn forward_step_strict<EF, CR>(
443        &self,
444        images: &Tensor<B, 4>,
445        mask: MaskSpec,
446        energy_fn: &EF,
447        regularizer: &CR,
448        reg_weight: f64,
449    ) -> StrictIJepaForwardOutput<B>
450    where
451        EF: EnergyFn<B>,
452        CR: CollapseRegularizer<B>,
453    {
454        self.try_forward_step_strict(images, mask, energy_fn, regularizer, reg_weight)
455            .expect(
456                "IJepa::forward_step_strict failed — mask must be valid (disjoint, non-empty) \
457                 and target count must not exceed predictor capacity; \
458                 use try_forward_step_strict for error handling",
459            )
460    }
461
462    /// Execute a strict masked I-JEPA forward step with typed error reporting.
463    pub fn try_forward_step_strict<EF, CR>(
464        &self,
465        images: &Tensor<B, 4>,
466        mask: MaskSpec,
467        energy_fn: &EF,
468        regularizer: &CR,
469        reg_weight: f64,
470    ) -> Result<StrictIJepaForwardOutput<B>, StrictIJepaError>
471    where
472        EF: EnergyFn<B>,
473        CR: CollapseRegularizer<B>,
474    {
475        mask.validate()?;
476
477        let context = self.encode_context_strict(images, &mask.context_indices);
478        let target_full = self.target_encoder.forward(images);
479        let target =
480            Representation::new(target_full.embeddings.detach()).gather(&mask.target_indices);
481
482        let batch = images.dims()[0];
483        let target_positions =
484            target_positions_tensor::<B>(&mask.target_indices, batch, &images.device());
485        let predicted = self.predictor.try_predict(&context, &target_positions)?;
486
487        let num_targets = target.seq_len();
488        let embed_dim = target.embed_dim();
489        let pred_flat = predicted
490            .embeddings
491            .clone()
492            .reshape([batch * num_targets, embed_dim]);
493        let target_flat = target
494            .embeddings
495            .clone()
496            .reshape([batch * num_targets, embed_dim]);
497
498        let energy = energy_fn.compute(&predicted, &target);
499        let regularization = regularizer.loss(&pred_flat, &target_flat);
500        let total_loss = energy.value.clone() + regularization.clone() * reg_weight;
501
502        Ok(StrictIJepaForwardOutput {
503            energy,
504            regularization,
505            total_loss,
506            mask,
507            context,
508            predicted,
509            target,
510        })
511    }
512}
513
514pub(crate) fn target_positions_tensor<B: Backend>(
515    indices: &[usize],
516    batch: usize,
517    device: &B::Device,
518) -> Tensor<B, 2> {
519    let mut data = Vec::with_capacity(batch * indices.len());
520    for _ in 0..batch {
521        data.extend(indices.iter().map(|&index| index as f32));
522    }
523
524    Tensor::from_floats(
525        burn::tensor::TensorData::new(data, [batch, indices.len()]),
526        device,
527    )
528}
529
530/// Configuration for the I-JEPA model.
531#[derive(Debug, Clone)]
532pub struct IJepaConfig {
533    /// ViT encoder config (shared by context and target encoders).
534    pub encoder: crate::vit::VitConfig,
535    /// Predictor config.
536    pub predictor: TransformerPredictorConfig,
537}
538
539impl IJepaConfig {
540    /// Create a tiny config suitable for testing.
541    pub fn tiny_test() -> Self {
542        let encoder = crate::vit::VitConfig::tiny_test();
543        Self {
544            predictor: TransformerPredictorConfig {
545                encoder_embed_dim: encoder.embed_dim,
546                predictor_embed_dim: 16,
547                num_layers: 1,
548                num_heads: 2,
549                max_target_len: 64,
550            },
551            encoder,
552        }
553    }
554
555    /// Initialize an [`IJepa`] model.
556    pub fn init<B: Backend>(&self, device: &B::Device) -> IJepa<B> {
557        IJepa {
558            context_encoder: self.encoder.init(device),
559            target_encoder: self.encoder.init(device),
560            predictor: self.predictor.init(device),
561        }
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568    use burn::tensor::ElementConversion;
569    use burn_ndarray::NdArray;
570    use jepa_core::{CollapseRegularizer, EnergyFn, MaskingStrategy};
571    use rand::SeedableRng;
572
573    type TestBackend = NdArray<f32>;
574
575    fn device() -> burn_ndarray::NdArrayDevice {
576        burn_ndarray::NdArrayDevice::Cpu
577    }
578
579    fn target_positions(indices: &[usize], batch: usize) -> Tensor<TestBackend, 2> {
580        let mut data = Vec::with_capacity(batch * indices.len());
581        for _ in 0..batch {
582            data.extend(indices.iter().map(|&index| index as f32));
583        }
584
585        Tensor::from_floats(
586            burn::tensor::TensorData::new(data, [batch, indices.len()]),
587            &device(),
588        )
589    }
590
591    fn fixed_image_mask() -> MaskSpec {
592        MaskSpec {
593            context_indices: vec![0, 1, 4, 5, 10, 11, 14, 15],
594            target_indices: vec![2, 3, 6, 7, 8, 9, 12, 13],
595            total_tokens: 16,
596        }
597    }
598
599    fn image_with_hidden_patch_value(mask: &MaskSpec, hidden_value: f32) -> Tensor<TestBackend, 4> {
600        let image_size = 8usize;
601        let patch_size = 2usize;
602        let mut data = vec![1.0f32; image_size * image_size];
603
604        for &index in &mask.target_indices {
605            let patch_row = index / 4;
606            let patch_col = index % 4;
607            let row_start = patch_row * patch_size;
608            let col_start = patch_col * patch_size;
609
610            for row in row_start..row_start + patch_size {
611                for col in col_start..col_start + patch_size {
612                    data[row * image_size + col] = hidden_value;
613                }
614            }
615        }
616
617        Tensor::from_floats(
618            burn::tensor::TensorData::new(data, [1, 1, image_size, image_size]),
619            &device(),
620        )
621    }
622
623    #[test]
624    fn test_predictor_output_shape() {
625        let config = TransformerPredictorConfig {
626            encoder_embed_dim: 32,
627            predictor_embed_dim: 16,
628            num_layers: 1,
629            num_heads: 2,
630            max_target_len: 64,
631        };
632        let predictor = config.init::<TestBackend>(&device());
633
634        let context = Representation::new(Tensor::zeros([2, 8, 32], &device()));
635        let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([2, 4], &device());
636        let predicted = predictor.predict(&context, &target_pos, None);
637
638        assert_eq!(predicted.batch_size(), 2);
639        assert_eq!(predicted.seq_len(), 4);
640        assert_eq!(predicted.embed_dim(), 32);
641    }
642
643    #[test]
644    fn test_predictor_implements_trait() {
645        let config = TransformerPredictorConfig {
646            encoder_embed_dim: 16,
647            predictor_embed_dim: 8,
648            num_layers: 1,
649            num_heads: 2,
650            max_target_len: 16,
651        };
652        let predictor = config.init::<TestBackend>(&device());
653
654        let context = Representation::new(Tensor::zeros([1, 4, 16], &device()));
655        let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([1, 2], &device());
656        let pred: Representation<TestBackend> =
657            Predictor::predict(&predictor, &context, &target_pos, None);
658        assert_eq!(pred.seq_len(), 2);
659    }
660
661    #[test]
662    fn test_predictor_output_depends_on_target_positions() {
663        let config = TransformerPredictorConfig {
664            encoder_embed_dim: 16,
665            predictor_embed_dim: 8,
666            num_layers: 1,
667            num_heads: 2,
668            max_target_len: 16,
669        };
670        let predictor = config.init::<TestBackend>(&device());
671
672        let context = Representation::new(Tensor::zeros([1, 4, 16], &device()));
673        let positions_a = target_positions(&[0, 1], 1);
674        let positions_b = target_positions(&[2, 3], 1);
675
676        let pred_a = predictor.predict(&context, &positions_a, None);
677        let pred_b = predictor.predict(&context, &positions_b, None);
678        let diff: f32 = (pred_a.embeddings - pred_b.embeddings)
679            .abs()
680            .sum()
681            .into_scalar()
682            .elem();
683
684        assert!(
685            diff > 1e-6,
686            "target positions should affect predictor output, diff={diff}"
687        );
688    }
689
690    #[test]
691    fn test_predictor_try_predict_rejects_batch_size_mismatch() {
692        let config = TransformerPredictorConfig {
693            encoder_embed_dim: 16,
694            predictor_embed_dim: 8,
695            num_layers: 1,
696            num_heads: 2,
697            max_target_len: 16,
698        };
699        let predictor = config.init::<TestBackend>(&device());
700
701        let context = Representation::new(Tensor::zeros([2, 4, 16], &device()));
702        let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([1, 2], &device());
703
704        let err = predictor.try_predict(&context, &target_pos).unwrap_err();
705        assert_eq!(
706            err,
707            PredictorError::BatchSizeMismatch {
708                context_batch: 2,
709                positions_batch: 1,
710            }
711        );
712    }
713
714    #[test]
715    fn test_predictor_try_predict_rejects_out_of_range_positions() {
716        let config = TransformerPredictorConfig {
717            encoder_embed_dim: 16,
718            predictor_embed_dim: 8,
719            num_layers: 1,
720            num_heads: 2,
721            max_target_len: 4,
722        };
723        let predictor = config.init::<TestBackend>(&device());
724
725        let context = Representation::new(Tensor::zeros([1, 4, 16], &device()));
726        let target_pos = target_positions(&[0, 4], 1);
727
728        let err = predictor.try_predict(&context, &target_pos).unwrap_err();
729        assert_eq!(
730            err,
731            PredictorError::TargetPositionOutOfRange {
732                position: 4,
733                max_supported: 4,
734            }
735        );
736    }
737
738    #[test]
739    fn test_predictor_try_predict_allows_empty_targets() {
740        let config = TransformerPredictorConfig {
741            encoder_embed_dim: 16,
742            predictor_embed_dim: 8,
743            num_layers: 1,
744            num_heads: 2,
745            max_target_len: 4,
746        };
747        let predictor = config.init::<TestBackend>(&device());
748
749        let context = Representation::new(Tensor::zeros([2, 4, 16], &device()));
750        let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([2, 0], &device());
751
752        let predicted = predictor.try_predict(&context, &target_pos).unwrap();
753        assert_eq!(predicted.batch_size(), 2);
754        assert_eq!(predicted.seq_len(), 0);
755        assert_eq!(predicted.embed_dim(), 16);
756    }
757
758    #[test]
759    fn test_ijepa_full_pipeline() {
760        // End-to-end test: encode → mask → predict → compute energy
761        let config = IJepaConfig::tiny_test();
762        let model = config.init::<TestBackend>(&device());
763
764        // 1. Create a test image
765        let images: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
766
767        // 2. Encode with both encoders
768        let context_repr = model.context_encoder.forward(&images);
769        let target_repr = model.target_encoder.forward(&images);
770
771        assert_eq!(context_repr.seq_len(), 16); // 4x4 grid
772        assert_eq!(target_repr.seq_len(), 16);
773
774        // 3. Generate a mask
775        let masking = jepa_core::masking::BlockMasking {
776            num_targets: 2,
777            target_scale: (0.15, 0.3),
778            target_aspect_ratio: (0.75, 1.5),
779        };
780        let shape = jepa_core::types::InputShape::Image {
781            height: 4,
782            width: 4,
783        };
784        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
785        let mask = masking.generate_mask(&shape, &mut rng);
786
787        // 4. Predict target from context
788        let num_targets = mask.target_indices.len();
789        let target_pos = target_positions(&mask.target_indices, 1);
790        let predicted = model.predictor.predict(&context_repr, &target_pos, None);
791
792        assert_eq!(predicted.seq_len(), num_targets);
793        assert_eq!(predicted.embed_dim(), 32);
794
795        // 5. Compute energy between predicted and actual target
796        // We need to extract target tokens from target_repr for fair comparison
797        // For this test, just verify energy is computable and finite
798        let energy = jepa_core::energy::L2Energy.compute(&predicted, &predicted);
799        let val: f32 = energy.value.into_scalar().elem();
800        assert!(val.is_finite(), "energy should be finite");
801    }
802
803    #[test]
804    fn test_ijepa_config_tiny() {
805        let config = IJepaConfig::tiny_test();
806        assert_eq!(config.encoder.embed_dim, 32);
807        assert_eq!(config.predictor.predictor_embed_dim, 16);
808    }
809
810    #[test]
811    fn test_strict_context_encoding_ignores_hidden_patches() {
812        let config = IJepaConfig::tiny_test();
813        let model = config.init::<TestBackend>(&device());
814        let mask = fixed_image_mask();
815
816        let hidden_low = image_with_hidden_patch_value(&mask, 0.0);
817        let hidden_high = image_with_hidden_patch_value(&mask, 1_000.0);
818
819        let strict_low = model.encode_context_strict(&hidden_low, &mask.context_indices);
820        let strict_high = model.encode_context_strict(&hidden_high, &mask.context_indices);
821
822        let diff: f32 = (strict_low.embeddings - strict_high.embeddings)
823            .abs()
824            .sum()
825            .into_scalar()
826            .elem();
827        assert!(
828            diff < 1e-5,
829            "strict masked context should ignore hidden patches, diff={diff}"
830        );
831    }
832
833    #[test]
834    fn test_full_encoder_context_slice_leaks_hidden_patches() {
835        let config = crate::vit::VitConfig::tiny_test();
836        let encoder = config.init::<TestBackend>(&device());
837        let mask = fixed_image_mask();
838
839        let hidden_low = image_with_hidden_patch_value(&mask, 0.0);
840        let hidden_high = image_with_hidden_patch_value(&mask, 1_000.0);
841
842        let approx_low = encoder.forward(&hidden_low).gather(&mask.context_indices);
843        let approx_high = encoder.forward(&hidden_high).gather(&mask.context_indices);
844
845        let diff: f32 = (approx_low.embeddings - approx_high.embeddings)
846            .abs()
847            .sum()
848            .into_scalar()
849            .elem();
850        assert!(
851            diff > 1e-3,
852            "post-encoder gather path should leak hidden patches, diff={diff}"
853        );
854    }
855
856    #[test]
857    fn test_strict_forward_step_runs_end_to_end() {
858        let config = IJepaConfig::tiny_test();
859        let model = config.init::<TestBackend>(&device());
860        let mask = fixed_image_mask();
861        let images = image_with_hidden_patch_value(&mask, 3.0);
862        let energy_fn = jepa_core::energy::L2Energy;
863        let regularizer = jepa_core::collapse::VICReg::default();
864
865        let output =
866            model.forward_step_strict(&images, mask.clone(), &energy_fn, &regularizer, 1.0);
867
868        assert_eq!(output.context.seq_len(), mask.context_indices.len());
869        assert_eq!(output.predicted.seq_len(), mask.target_indices.len());
870        assert_eq!(output.target.seq_len(), mask.target_indices.len());
871
872        let total_loss: f32 = output.total_loss.into_scalar().elem();
873        assert!(
874            total_loss.is_finite(),
875            "strict forward loss should be finite"
876        );
877    }
878
879    #[test]
880    fn test_try_strict_forward_step_rejects_invalid_mask() {
881        let config = IJepaConfig::tiny_test();
882        let model = config.init::<TestBackend>(&device());
883        let images = Tensor::ones([1, 1, 8, 8], &device());
884        let invalid_mask = MaskSpec {
885            context_indices: vec![],
886            target_indices: vec![0],
887            total_tokens: 16,
888        };
889        let energy_fn = jepa_core::energy::L2Energy;
890        let regularizer = jepa_core::collapse::VICReg::default();
891
892        let err = model
893            .try_forward_step_strict(&images, invalid_mask, &energy_fn, &regularizer, 1.0)
894            .unwrap_err();
895        assert!(matches!(
896            err,
897            StrictIJepaError::InvalidMask(MaskError::EmptyContext)
898        ));
899    }
900
901    // ======================================================================
902    // BDD-aligned integration tests (matching specs/gherkin/features.feature)
903    // ======================================================================
904
905    /// BDD: "Encode a batch of images into representations"
906    /// Given a ViT encoder with embed_dim and patch_size
907    /// When I encode a batch of images
908    /// Then I should get representations of the correct shape
909    /// And the representations should have non-zero variance across the batch
910    #[test]
911    fn bdd_encode_batch_correct_shape_and_nonzero_variance() {
912        let config = crate::vit::VitConfig::tiny_test();
913        let encoder = config.init::<TestBackend>(&device());
914
915        // Batch of 4 images, different values to ensure variance
916        let batch_size = 4;
917        let images: Tensor<TestBackend, 4> = Tensor::random(
918            [batch_size, 1, 8, 8],
919            burn::tensor::Distribution::Normal(0.0, 1.0),
920            &device(),
921        );
922        let repr = encoder.forward(&images);
923
924        // Shape: [4, 16, 32] (4x4 grid of patches, embed_dim=32)
925        assert_eq!(repr.batch_size(), batch_size);
926        assert_eq!(repr.seq_len(), 16);
927        assert_eq!(repr.embed_dim(), 32);
928
929        // Variance across the batch dimension should be non-zero
930        // Compute mean across batch, then measure deviation
931        let mean_repr = repr.embeddings.clone().mean_dim(0); // [1, 16, 32]
932        let diff = repr.embeddings.clone() - mean_repr;
933        let variance: f32 = (diff.clone() * diff).mean().into_scalar().elem();
934        assert!(
935            variance > 1e-6,
936            "representations should have non-zero variance across the batch, got {variance}"
937        );
938    }
939
940    /// BDD: "Context and target encoders produce compatible representations"
941    /// Given a JEPA encoder pair with shared architecture
942    /// And the target encoder initialized as a copy of the context encoder
943    /// When I encode the same image with both encoders
944    /// Then the representations should be identical (freshly initialized, same weights)
945    #[test]
946    fn bdd_encoder_pair_same_init_same_output() {
947        // Both encoders share the same config. Since they're freshly initialized
948        // with potentially different random weights, we create one and use it twice.
949        let config = crate::vit::VitConfig::tiny_test();
950        let encoder = config.init::<TestBackend>(&device());
951
952        let images: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
953
954        // Encoding the same image with the same encoder instance gives identical output
955        let repr1 = encoder.forward(&images);
956        let repr2 = encoder.forward(&images);
957
958        let diff: f32 = (repr1.embeddings - repr2.embeddings)
959            .abs()
960            .sum()
961            .into_scalar()
962            .elem();
963        assert!(
964            diff < 1e-6,
965            "same encoder + same input should produce identical representations, diff={diff}"
966        );
967    }
968
969    /// BDD: "EMA update makes target encoder lag behind context encoder"
970    /// Given a JEPA encoder pair
971    /// When I apply EMA update with momentum 0.99
972    /// Then the target encoder weights should move toward the context encoder
973    /// And the target encoder should NOT equal the context encoder
974    #[test]
975    fn bdd_ema_update_target_lags_context() {
976        let config = IJepaConfig::tiny_test();
977        let model = config.init::<TestBackend>(&device());
978
979        let images: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
980
981        // Get initial representations
982        let ctx_repr = model.context_encoder.forward(&images);
983        let tgt_repr = model.target_encoder.forward(&images);
984
985        // Since both are freshly initialized with DIFFERENT random weights,
986        // their outputs should differ
987        let initial_diff: f32 = (ctx_repr.embeddings.clone() - tgt_repr.embeddings.clone())
988            .abs()
989            .sum()
990            .into_scalar()
991            .elem();
992
993        // After many EMA updates, target should move toward context.
994        // We simulate this by computing what the target weight tensor would be.
995        let ema = jepa_core::ema::Ema::new(0.99);
996        let target_val = 0.0f64;
997        let online_val = 1.0f64;
998        let mut val = target_val;
999        for step in 0..500 {
1000            val = ema.step(val, online_val, step);
1001        }
1002        // After 500 steps at momentum 0.99, should be close but not equal to 1.0
1003        assert!(val > 0.9, "EMA should converge toward online, got {val}");
1004        assert!(val < 1.0, "EMA should lag behind online, got {val}");
1005
1006        // Verify initial diff is non-zero (different initializations)
1007        // This is a property of the architecture, not a guarantee — but with random
1008        // init and non-trivial input, it should hold.
1009        assert!(
1010            initial_diff >= 0.0,
1011            "initial representations computed successfully"
1012        );
1013    }
1014
1015    /// BDD: "Full I-JEPA pipeline with proper target extraction"
1016    /// Given an I-JEPA model, masking strategy, and energy function
1017    /// When I run the full forward pipeline (encode → mask → gather → predict → energy)
1018    /// Then the energy should be finite and non-negative
1019    #[test]
1020    fn bdd_full_ijepa_pipeline_with_gather() {
1021        let config = IJepaConfig::tiny_test();
1022        let model = config.init::<TestBackend>(&device());
1023
1024        let images: Tensor<TestBackend, 4> = Tensor::random(
1025            [2, 1, 8, 8],
1026            burn::tensor::Distribution::Normal(0.0, 1.0),
1027            &device(),
1028        );
1029
1030        // 1. Encode
1031        let context_repr = model.context_encoder.forward(&images);
1032        let target_repr = model.target_encoder.forward(&images);
1033
1034        // 2. Generate mask
1035        let masking = jepa_core::masking::BlockMasking {
1036            num_targets: 2,
1037            target_scale: (0.15, 0.3),
1038            target_aspect_ratio: (0.75, 1.5),
1039        };
1040        let shape = jepa_core::types::InputShape::Image {
1041            height: 4,
1042            width: 4,
1043        };
1044        let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
1045        let mask = masking.generate_mask(&shape, &mut rng);
1046        assert!(mask.validate().is_ok());
1047
1048        // 3. Gather target tokens using mask indices
1049        let target_gathered = target_repr.gather(&mask.target_indices);
1050        assert_eq!(target_gathered.seq_len(), mask.target_indices.len());
1051        assert_eq!(target_gathered.batch_size(), 2);
1052
1053        // 4. Predict target from context
1054        let num_targets = mask.target_indices.len();
1055        let target_pos = target_positions(&mask.target_indices, 2);
1056        let predicted = model.predictor.predict(&context_repr, &target_pos, None);
1057        assert_eq!(predicted.seq_len(), num_targets);
1058
1059        // 5. Compute energy
1060        let energy = jepa_core::energy::L2Energy.compute(&predicted, &target_gathered);
1061        let val: f32 = energy.value.into_scalar().elem();
1062        assert!(val.is_finite(), "energy should be finite, got {val}");
1063        assert!(val >= 0.0, "L2 energy should be non-negative, got {val}");
1064
1065        // 6. Compute collapse regularization
1066        let batch = 2;
1067        let embed_dim = predicted.embed_dim();
1068        let pred_flat = predicted
1069            .embeddings
1070            .reshape([batch * num_targets, embed_dim]);
1071        let target_flat = target_gathered
1072            .embeddings
1073            .reshape([batch * num_targets, embed_dim]);
1074        let reg_loss: f32 = jepa_core::collapse::VICReg::default()
1075            .loss(&pred_flat, &target_flat)
1076            .into_scalar()
1077            .elem();
1078        assert!(
1079            reg_loss.is_finite(),
1080            "regularization should be finite, got {reg_loss}"
1081        );
1082    }
1083
1084    /// BDD: "Masking creates meaningful prediction tasks"
1085    /// Given block masking
1086    /// When I generate many masks
1087    /// Then context + target should always partition all tokens
1088    /// And masks should vary across seeds
1089    #[test]
1090    fn bdd_masking_creates_valid_prediction_tasks() {
1091        let masking = jepa_core::masking::BlockMasking {
1092            num_targets: 4,
1093            target_scale: (0.15, 0.2),
1094            target_aspect_ratio: (0.75, 1.5),
1095        };
1096        let shape = jepa_core::types::InputShape::Image {
1097            height: 4,
1098            width: 4,
1099        };
1100
1101        let mut masks = Vec::new();
1102        for seed in 0..20u64 {
1103            let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
1104            let mask = masking.generate_mask(&shape, &mut rng);
1105
1106            assert!(mask.validate().is_ok(), "mask with seed {seed} is invalid");
1107            assert_eq!(
1108                mask.context_indices.len() + mask.target_indices.len(),
1109                16,
1110                "mask with seed {seed} doesn't partition all 16 tokens"
1111            );
1112            assert!(
1113                !mask.context_indices.is_empty(),
1114                "mask with seed {seed} has empty context"
1115            );
1116            assert!(
1117                !mask.target_indices.is_empty(),
1118                "mask with seed {seed} has empty target"
1119            );
1120            masks.push(mask);
1121        }
1122
1123        // At least some masks should differ
1124        let first_targets = &masks[0].target_indices;
1125        let some_differ = masks[1..]
1126            .iter()
1127            .any(|m| m.target_indices != *first_targets);
1128        assert!(some_differ, "masks should vary across different seeds");
1129    }
1130
1131    use proptest::prelude::*;
1132
1133    proptest! {
1134        /// Property: predictor output dimension always matches encoder_embed_dim,
1135        /// regardless of number of targets.
1136        #[test]
1137        fn prop_predictor_output_dim_matches_encoder(
1138            num_targets in 1usize..8,
1139        ) {
1140            let encoder_embed_dim = 32;
1141            let config = TransformerPredictorConfig {
1142                encoder_embed_dim,
1143                predictor_embed_dim: 16,
1144                num_layers: 1,
1145                num_heads: 2,
1146                max_target_len: 64,
1147            };
1148            let predictor = config.init::<TestBackend>(&device());
1149
1150            let context = Representation::new(Tensor::zeros([1, 8, encoder_embed_dim], &device()));
1151            let target_pos: Tensor<TestBackend, 2> =
1152                Tensor::zeros([1, num_targets], &device());
1153            let predicted = predictor.predict(&context, &target_pos, None);
1154
1155            prop_assert_eq!(predicted.batch_size(), 1);
1156            prop_assert_eq!(predicted.seq_len(), num_targets);
1157            prop_assert_eq!(predicted.embed_dim(), encoder_embed_dim);
1158        }
1159
1160        /// Property: predictor output is always finite for normally-distributed context.
1161        #[test]
1162        fn prop_predictor_output_is_finite(
1163            batch in 1usize..3,
1164            num_targets in 1usize..6,
1165        ) {
1166            let config = TransformerPredictorConfig {
1167                encoder_embed_dim: 16,
1168                predictor_embed_dim: 8,
1169                num_layers: 1,
1170                num_heads: 2,
1171                max_target_len: 16,
1172            };
1173            let predictor = config.init::<TestBackend>(&device());
1174
1175            let context = Representation::new(Tensor::random(
1176                [batch, 4, 16],
1177                burn::tensor::Distribution::Normal(0.0, 1.0),
1178                &device(),
1179            ));
1180            let target_pos: Tensor<TestBackend, 2> =
1181                Tensor::zeros([batch, num_targets], &device());
1182            let predicted = predictor.predict(&context, &target_pos, None);
1183
1184            let total: f32 = predicted
1185                .embeddings
1186                .abs()
1187                .sum()
1188                .into_scalar()
1189                .elem();
1190            prop_assert!(
1191                total.is_finite(),
1192                "predictor output should be finite, got {}",
1193                total
1194            );
1195        }
1196    }
1197}