1use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};
24use burn::prelude::*;
25use burn::tensor::backend::Backend;
26use burn::tensor::module::embedding;
27
28use jepa_core::types::{Energy, MaskError, MaskSpec, Representation};
29use jepa_core::{CollapseRegularizer, EnergyFn, Predictor};
30
31#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
60pub struct TransformerPredictorConfig {
61 pub encoder_embed_dim: usize,
63 pub predictor_embed_dim: usize,
65 pub num_layers: usize,
67 pub num_heads: usize,
69 pub max_target_len: usize,
74}
75
76impl TransformerPredictorConfig {
77 pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerPredictor<B> {
79 let input_proj =
80 LinearConfig::new(self.encoder_embed_dim, self.predictor_embed_dim).init(device);
81 let output_proj =
82 LinearConfig::new(self.predictor_embed_dim, self.encoder_embed_dim).init(device);
83
84 let blocks: Vec<PredictorBlock<B>> = (0..self.num_layers)
85 .map(|_| {
86 PredictorBlockConfig {
87 embed_dim: self.predictor_embed_dim,
88 num_heads: self.num_heads,
89 }
90 .init(device)
91 })
92 .collect();
93
94 let norm = LayerNormConfig::new(self.predictor_embed_dim).init(device);
95
96 let prediction_tokens =
97 sinusoidal_prediction_tokens(self.max_target_len, self.predictor_embed_dim, device);
98
99 TransformerPredictor {
100 input_proj,
101 output_proj,
102 blocks,
103 norm,
104 prediction_tokens,
105 predictor_embed_dim: self.predictor_embed_dim,
106 encoder_embed_dim: self.encoder_embed_dim,
107 }
108 }
109}
110
111#[derive(Module, Debug)]
125pub struct TransformerPredictor<B: Backend> {
126 input_proj: Linear<B>,
128 output_proj: Linear<B>,
130 blocks: Vec<PredictorBlock<B>>,
132 norm: LayerNorm<B>,
134 prediction_tokens: Tensor<B, 2>,
136 predictor_embed_dim: usize,
138 encoder_embed_dim: usize,
140}
141
142#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)]
144pub enum PredictorError {
145 #[error(
146 "target position batch size mismatch: context batch={context_batch}, target_positions batch={positions_batch}"
147 )]
148 BatchSizeMismatch {
149 context_batch: usize,
150 positions_batch: usize,
151 },
152 #[error("target position must be non-negative, got {0}")]
153 NegativeTargetPosition(i64),
154 #[error(
155 "target position {position} exceeds predictor capacity {max_supported}; increase max_target_len"
156 )]
157 TargetPositionOutOfRange {
158 position: usize,
159 max_supported: usize,
160 },
161}
162
163impl<B: Backend> Predictor<B> for TransformerPredictor<B> {
164 fn predict(
165 &self,
166 context: &Representation<B>,
167 target_positions: &Tensor<B, 2>,
168 _latent: Option<&Tensor<B, 2>>,
169 ) -> Representation<B> {
170 self.try_predict(context, target_positions).expect(
171 "TransformerPredictor::predict failed — target positions must match the context \
172 batch size and not exceed max_target_len; use try_predict for error handling",
173 )
174 }
175}
176
177impl<B: Backend> TransformerPredictor<B> {
178 pub fn try_predict(
180 &self,
181 context: &Representation<B>,
182 target_positions: &Tensor<B, 2>,
183 ) -> Result<Representation<B>, PredictorError> {
184 let [batch, _ctx_len, _enc_dim] = context.embeddings.dims();
185 let [positions_batch, num_targets] = target_positions.dims();
186 if positions_batch != batch {
187 return Err(PredictorError::BatchSizeMismatch {
188 context_batch: batch,
189 positions_batch,
190 });
191 }
192
193 if num_targets == 0 {
194 let device = context.embeddings.device();
195 return Ok(Representation::new(Tensor::zeros(
196 [batch, 0, self.encoder_embed_dim],
197 &device,
198 )));
199 }
200
201 let target_positions = target_positions.clone().int();
202 let min_position: i64 = target_positions.clone().min().into_scalar().elem();
203 if min_position < 0 {
204 return Err(PredictorError::NegativeTargetPosition(min_position));
205 }
206
207 let max_position: i64 = target_positions.clone().max().into_scalar().elem();
208 let max_supported_position = self.prediction_tokens.dims()[0];
209 if max_position >= max_supported_position as i64 {
210 return Err(PredictorError::TargetPositionOutOfRange {
211 position: max_position as usize,
212 max_supported: max_supported_position,
213 });
214 }
215
216 let ctx = self.input_proj.forward(context.embeddings.clone());
218
219 let pred_tokens = embedding(self.prediction_tokens.clone(), target_positions);
221
222 let combined = Tensor::cat(vec![ctx, pred_tokens], 1);
224 let ctx_len = context.embeddings.dims()[1];
225 let total_len = ctx_len + num_targets;
226
227 let mut x = combined;
229 for block in &self.blocks {
230 x = block.forward(x);
231 }
232
233 let pred_out = x.slice([0..batch, ctx_len..total_len, 0..self.predictor_embed_dim]);
235
236 let pred_out = self.norm.forward(pred_out);
238 let pred_out = self.output_proj.forward(pred_out);
239
240 Ok(Representation::new(pred_out))
241 }
242}
243
244fn sinusoidal_prediction_tokens<B: Backend>(
245 max_target_len: usize,
246 embed_dim: usize,
247 device: &B::Device,
248) -> Tensor<B, 2> {
249 let mut data = vec![0.0f32; max_target_len * embed_dim];
250
251 for position in 0..max_target_len {
252 for dim in 0..embed_dim {
253 let exponent = (2 * (dim / 2)) as f64 / embed_dim as f64;
254 let angle = position as f64 / 10_000_f64.powf(exponent);
255 data[position * embed_dim + dim] = if dim % 2 == 0 {
256 angle.sin() as f32
257 } else {
258 angle.cos() as f32
259 };
260 }
261 }
262
263 Tensor::from_floats(
264 burn::tensor::TensorData::new(data, [max_target_len, embed_dim]),
265 device,
266 )
267}
268
269#[derive(Debug, Clone)]
272struct PredictorBlockConfig {
273 embed_dim: usize,
274 num_heads: usize,
275}
276
277impl PredictorBlockConfig {
278 fn init<B: Backend>(&self, device: &B::Device) -> PredictorBlock<B> {
279 let head_dim = self.embed_dim / self.num_heads;
280 PredictorBlock {
281 norm1: LayerNormConfig::new(self.embed_dim).init(device),
282 attn: PredictorAttention {
283 qkv: LinearConfig::new(self.embed_dim, 3 * self.embed_dim).init(device),
284 out_proj: LinearConfig::new(self.embed_dim, self.embed_dim).init(device),
285 num_heads: self.num_heads,
286 head_dim,
287 },
288 norm2: LayerNormConfig::new(self.embed_dim).init(device),
289 mlp: PredictorMlp {
290 fc1: LinearConfig::new(self.embed_dim, self.embed_dim * 4).init(device),
291 fc2: LinearConfig::new(self.embed_dim * 4, self.embed_dim).init(device),
292 },
293 }
294 }
295}
296
297#[derive(Module, Debug)]
298struct PredictorBlock<B: Backend> {
299 norm1: LayerNorm<B>,
300 attn: PredictorAttention<B>,
301 norm2: LayerNorm<B>,
302 mlp: PredictorMlp<B>,
303}
304
305impl<B: Backend> PredictorBlock<B> {
306 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
307 let residual = x.clone();
308 let x_norm = self.norm1.forward(x);
309 let attn_out = self.attn.forward(x_norm);
310 let x = residual + attn_out;
311
312 let residual = x.clone();
313 let x_norm = self.norm2.forward(x);
314 let mlp_out = self.mlp.forward(x_norm);
315 residual + mlp_out
316 }
317}
318
319#[derive(Module, Debug)]
320struct PredictorAttention<B: Backend> {
321 qkv: Linear<B>,
322 out_proj: Linear<B>,
323 num_heads: usize,
324 head_dim: usize,
325}
326
327impl<B: Backend> PredictorAttention<B> {
328 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
329 let [batch, seq_len, _] = x.dims();
330 let embed_dim = self.num_heads * self.head_dim;
331
332 let qkv = self.qkv.forward(x);
333 let q = qkv.clone().slice([0..batch, 0..seq_len, 0..embed_dim]);
334 let k = qkv
335 .clone()
336 .slice([0..batch, 0..seq_len, embed_dim..2 * embed_dim]);
337 let v = qkv.slice([0..batch, 0..seq_len, 2 * embed_dim..3 * embed_dim]);
338
339 let q = q
340 .reshape([batch, seq_len, self.num_heads, self.head_dim])
341 .swap_dims(1, 2);
342 let k = k
343 .reshape([batch, seq_len, self.num_heads, self.head_dim])
344 .swap_dims(1, 2);
345 let v = v
346 .reshape([batch, seq_len, self.num_heads, self.head_dim])
347 .swap_dims(1, 2);
348
349 let scale = (self.head_dim as f64).sqrt();
350 let attn = q.matmul(k.transpose()) / scale;
351 let attn = burn::tensor::activation::softmax(attn, 3);
352 let out = attn.matmul(v);
353 let out = out.swap_dims(1, 2).reshape([batch, seq_len, embed_dim]);
354 self.out_proj.forward(out)
355 }
356}
357
358#[derive(Module, Debug)]
359struct PredictorMlp<B: Backend> {
360 fc1: Linear<B>,
361 fc2: Linear<B>,
362}
363
364impl<B: Backend> PredictorMlp<B> {
365 fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
366 let x = self.fc1.forward(x);
367 let x = burn::tensor::activation::gelu(x);
368 self.fc2.forward(x)
369 }
370}
371
372#[derive(Module, Debug)]
376pub struct IJepa<B: Backend> {
377 pub context_encoder: crate::vit::VitEncoder<B>,
379 pub target_encoder: crate::vit::VitEncoder<B>,
381 pub predictor: TransformerPredictor<B>,
383}
384
385#[derive(Debug, Clone)]
391pub struct StrictIJepaForwardOutput<B: Backend> {
392 pub energy: Energy<B>,
394 pub regularization: Tensor<B, 1>,
396 pub total_loss: Tensor<B, 1>,
398 pub mask: MaskSpec,
400 pub context: Representation<B>,
402 pub predicted: Representation<B>,
404 pub target: Representation<B>,
406}
407
408#[derive(Debug, Clone, thiserror::Error)]
410pub enum StrictIJepaError {
411 #[error(transparent)]
412 InvalidMask(#[from] MaskError),
413 #[error(transparent)]
414 Predictor(#[from] PredictorError),
415}
416
417impl<B: Backend> IJepa<B> {
418 pub fn encode_context_strict(
424 &self,
425 images: &Tensor<B, 4>,
426 context_indices: &[usize],
427 ) -> Representation<B> {
428 self.context_encoder
429 .forward_visible_tokens(images, context_indices)
430 }
431
432 pub fn forward_step_strict<EF, CR>(
443 &self,
444 images: &Tensor<B, 4>,
445 mask: MaskSpec,
446 energy_fn: &EF,
447 regularizer: &CR,
448 reg_weight: f64,
449 ) -> StrictIJepaForwardOutput<B>
450 where
451 EF: EnergyFn<B>,
452 CR: CollapseRegularizer<B>,
453 {
454 self.try_forward_step_strict(images, mask, energy_fn, regularizer, reg_weight)
455 .expect(
456 "IJepa::forward_step_strict failed — mask must be valid (disjoint, non-empty) \
457 and target count must not exceed predictor capacity; \
458 use try_forward_step_strict for error handling",
459 )
460 }
461
462 pub fn try_forward_step_strict<EF, CR>(
464 &self,
465 images: &Tensor<B, 4>,
466 mask: MaskSpec,
467 energy_fn: &EF,
468 regularizer: &CR,
469 reg_weight: f64,
470 ) -> Result<StrictIJepaForwardOutput<B>, StrictIJepaError>
471 where
472 EF: EnergyFn<B>,
473 CR: CollapseRegularizer<B>,
474 {
475 mask.validate()?;
476
477 let context = self.encode_context_strict(images, &mask.context_indices);
478 let target_full = self.target_encoder.forward(images);
479 let target =
480 Representation::new(target_full.embeddings.detach()).gather(&mask.target_indices);
481
482 let batch = images.dims()[0];
483 let target_positions =
484 target_positions_tensor::<B>(&mask.target_indices, batch, &images.device());
485 let predicted = self.predictor.try_predict(&context, &target_positions)?;
486
487 let num_targets = target.seq_len();
488 let embed_dim = target.embed_dim();
489 let pred_flat = predicted
490 .embeddings
491 .clone()
492 .reshape([batch * num_targets, embed_dim]);
493 let target_flat = target
494 .embeddings
495 .clone()
496 .reshape([batch * num_targets, embed_dim]);
497
498 let energy = energy_fn.compute(&predicted, &target);
499 let regularization = regularizer.loss(&pred_flat, &target_flat);
500 let total_loss = energy.value.clone() + regularization.clone() * reg_weight;
501
502 Ok(StrictIJepaForwardOutput {
503 energy,
504 regularization,
505 total_loss,
506 mask,
507 context,
508 predicted,
509 target,
510 })
511 }
512}
513
514pub(crate) fn target_positions_tensor<B: Backend>(
515 indices: &[usize],
516 batch: usize,
517 device: &B::Device,
518) -> Tensor<B, 2> {
519 let mut data = Vec::with_capacity(batch * indices.len());
520 for _ in 0..batch {
521 data.extend(indices.iter().map(|&index| index as f32));
522 }
523
524 Tensor::from_floats(
525 burn::tensor::TensorData::new(data, [batch, indices.len()]),
526 device,
527 )
528}
529
530#[derive(Debug, Clone)]
532pub struct IJepaConfig {
533 pub encoder: crate::vit::VitConfig,
535 pub predictor: TransformerPredictorConfig,
537}
538
539impl IJepaConfig {
540 pub fn tiny_test() -> Self {
542 let encoder = crate::vit::VitConfig::tiny_test();
543 Self {
544 predictor: TransformerPredictorConfig {
545 encoder_embed_dim: encoder.embed_dim,
546 predictor_embed_dim: 16,
547 num_layers: 1,
548 num_heads: 2,
549 max_target_len: 64,
550 },
551 encoder,
552 }
553 }
554
555 pub fn init<B: Backend>(&self, device: &B::Device) -> IJepa<B> {
557 IJepa {
558 context_encoder: self.encoder.init(device),
559 target_encoder: self.encoder.init(device),
560 predictor: self.predictor.init(device),
561 }
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use burn::tensor::ElementConversion;
569 use burn_ndarray::NdArray;
570 use jepa_core::{CollapseRegularizer, EnergyFn, MaskingStrategy};
571 use rand::SeedableRng;
572
573 type TestBackend = NdArray<f32>;
574
575 fn device() -> burn_ndarray::NdArrayDevice {
576 burn_ndarray::NdArrayDevice::Cpu
577 }
578
579 fn target_positions(indices: &[usize], batch: usize) -> Tensor<TestBackend, 2> {
580 let mut data = Vec::with_capacity(batch * indices.len());
581 for _ in 0..batch {
582 data.extend(indices.iter().map(|&index| index as f32));
583 }
584
585 Tensor::from_floats(
586 burn::tensor::TensorData::new(data, [batch, indices.len()]),
587 &device(),
588 )
589 }
590
591 fn fixed_image_mask() -> MaskSpec {
592 MaskSpec {
593 context_indices: vec![0, 1, 4, 5, 10, 11, 14, 15],
594 target_indices: vec![2, 3, 6, 7, 8, 9, 12, 13],
595 total_tokens: 16,
596 }
597 }
598
599 fn image_with_hidden_patch_value(mask: &MaskSpec, hidden_value: f32) -> Tensor<TestBackend, 4> {
600 let image_size = 8usize;
601 let patch_size = 2usize;
602 let mut data = vec![1.0f32; image_size * image_size];
603
604 for &index in &mask.target_indices {
605 let patch_row = index / 4;
606 let patch_col = index % 4;
607 let row_start = patch_row * patch_size;
608 let col_start = patch_col * patch_size;
609
610 for row in row_start..row_start + patch_size {
611 for col in col_start..col_start + patch_size {
612 data[row * image_size + col] = hidden_value;
613 }
614 }
615 }
616
617 Tensor::from_floats(
618 burn::tensor::TensorData::new(data, [1, 1, image_size, image_size]),
619 &device(),
620 )
621 }
622
623 #[test]
624 fn test_predictor_output_shape() {
625 let config = TransformerPredictorConfig {
626 encoder_embed_dim: 32,
627 predictor_embed_dim: 16,
628 num_layers: 1,
629 num_heads: 2,
630 max_target_len: 64,
631 };
632 let predictor = config.init::<TestBackend>(&device());
633
634 let context = Representation::new(Tensor::zeros([2, 8, 32], &device()));
635 let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([2, 4], &device());
636 let predicted = predictor.predict(&context, &target_pos, None);
637
638 assert_eq!(predicted.batch_size(), 2);
639 assert_eq!(predicted.seq_len(), 4);
640 assert_eq!(predicted.embed_dim(), 32);
641 }
642
643 #[test]
644 fn test_predictor_implements_trait() {
645 let config = TransformerPredictorConfig {
646 encoder_embed_dim: 16,
647 predictor_embed_dim: 8,
648 num_layers: 1,
649 num_heads: 2,
650 max_target_len: 16,
651 };
652 let predictor = config.init::<TestBackend>(&device());
653
654 let context = Representation::new(Tensor::zeros([1, 4, 16], &device()));
655 let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([1, 2], &device());
656 let pred: Representation<TestBackend> =
657 Predictor::predict(&predictor, &context, &target_pos, None);
658 assert_eq!(pred.seq_len(), 2);
659 }
660
661 #[test]
662 fn test_predictor_output_depends_on_target_positions() {
663 let config = TransformerPredictorConfig {
664 encoder_embed_dim: 16,
665 predictor_embed_dim: 8,
666 num_layers: 1,
667 num_heads: 2,
668 max_target_len: 16,
669 };
670 let predictor = config.init::<TestBackend>(&device());
671
672 let context = Representation::new(Tensor::zeros([1, 4, 16], &device()));
673 let positions_a = target_positions(&[0, 1], 1);
674 let positions_b = target_positions(&[2, 3], 1);
675
676 let pred_a = predictor.predict(&context, &positions_a, None);
677 let pred_b = predictor.predict(&context, &positions_b, None);
678 let diff: f32 = (pred_a.embeddings - pred_b.embeddings)
679 .abs()
680 .sum()
681 .into_scalar()
682 .elem();
683
684 assert!(
685 diff > 1e-6,
686 "target positions should affect predictor output, diff={diff}"
687 );
688 }
689
690 #[test]
691 fn test_predictor_try_predict_rejects_batch_size_mismatch() {
692 let config = TransformerPredictorConfig {
693 encoder_embed_dim: 16,
694 predictor_embed_dim: 8,
695 num_layers: 1,
696 num_heads: 2,
697 max_target_len: 16,
698 };
699 let predictor = config.init::<TestBackend>(&device());
700
701 let context = Representation::new(Tensor::zeros([2, 4, 16], &device()));
702 let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([1, 2], &device());
703
704 let err = predictor.try_predict(&context, &target_pos).unwrap_err();
705 assert_eq!(
706 err,
707 PredictorError::BatchSizeMismatch {
708 context_batch: 2,
709 positions_batch: 1,
710 }
711 );
712 }
713
714 #[test]
715 fn test_predictor_try_predict_rejects_out_of_range_positions() {
716 let config = TransformerPredictorConfig {
717 encoder_embed_dim: 16,
718 predictor_embed_dim: 8,
719 num_layers: 1,
720 num_heads: 2,
721 max_target_len: 4,
722 };
723 let predictor = config.init::<TestBackend>(&device());
724
725 let context = Representation::new(Tensor::zeros([1, 4, 16], &device()));
726 let target_pos = target_positions(&[0, 4], 1);
727
728 let err = predictor.try_predict(&context, &target_pos).unwrap_err();
729 assert_eq!(
730 err,
731 PredictorError::TargetPositionOutOfRange {
732 position: 4,
733 max_supported: 4,
734 }
735 );
736 }
737
738 #[test]
739 fn test_predictor_try_predict_allows_empty_targets() {
740 let config = TransformerPredictorConfig {
741 encoder_embed_dim: 16,
742 predictor_embed_dim: 8,
743 num_layers: 1,
744 num_heads: 2,
745 max_target_len: 4,
746 };
747 let predictor = config.init::<TestBackend>(&device());
748
749 let context = Representation::new(Tensor::zeros([2, 4, 16], &device()));
750 let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([2, 0], &device());
751
752 let predicted = predictor.try_predict(&context, &target_pos).unwrap();
753 assert_eq!(predicted.batch_size(), 2);
754 assert_eq!(predicted.seq_len(), 0);
755 assert_eq!(predicted.embed_dim(), 16);
756 }
757
758 #[test]
759 fn test_ijepa_full_pipeline() {
760 let config = IJepaConfig::tiny_test();
762 let model = config.init::<TestBackend>(&device());
763
764 let images: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
766
767 let context_repr = model.context_encoder.forward(&images);
769 let target_repr = model.target_encoder.forward(&images);
770
771 assert_eq!(context_repr.seq_len(), 16); assert_eq!(target_repr.seq_len(), 16);
773
774 let masking = jepa_core::masking::BlockMasking {
776 num_targets: 2,
777 target_scale: (0.15, 0.3),
778 target_aspect_ratio: (0.75, 1.5),
779 };
780 let shape = jepa_core::types::InputShape::Image {
781 height: 4,
782 width: 4,
783 };
784 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
785 let mask = masking.generate_mask(&shape, &mut rng);
786
787 let num_targets = mask.target_indices.len();
789 let target_pos = target_positions(&mask.target_indices, 1);
790 let predicted = model.predictor.predict(&context_repr, &target_pos, None);
791
792 assert_eq!(predicted.seq_len(), num_targets);
793 assert_eq!(predicted.embed_dim(), 32);
794
795 let energy = jepa_core::energy::L2Energy.compute(&predicted, &predicted);
799 let val: f32 = energy.value.into_scalar().elem();
800 assert!(val.is_finite(), "energy should be finite");
801 }
802
803 #[test]
804 fn test_ijepa_config_tiny() {
805 let config = IJepaConfig::tiny_test();
806 assert_eq!(config.encoder.embed_dim, 32);
807 assert_eq!(config.predictor.predictor_embed_dim, 16);
808 }
809
810 #[test]
811 fn test_strict_context_encoding_ignores_hidden_patches() {
812 let config = IJepaConfig::tiny_test();
813 let model = config.init::<TestBackend>(&device());
814 let mask = fixed_image_mask();
815
816 let hidden_low = image_with_hidden_patch_value(&mask, 0.0);
817 let hidden_high = image_with_hidden_patch_value(&mask, 1_000.0);
818
819 let strict_low = model.encode_context_strict(&hidden_low, &mask.context_indices);
820 let strict_high = model.encode_context_strict(&hidden_high, &mask.context_indices);
821
822 let diff: f32 = (strict_low.embeddings - strict_high.embeddings)
823 .abs()
824 .sum()
825 .into_scalar()
826 .elem();
827 assert!(
828 diff < 1e-5,
829 "strict masked context should ignore hidden patches, diff={diff}"
830 );
831 }
832
833 #[test]
834 fn test_full_encoder_context_slice_leaks_hidden_patches() {
835 let config = crate::vit::VitConfig::tiny_test();
836 let encoder = config.init::<TestBackend>(&device());
837 let mask = fixed_image_mask();
838
839 let hidden_low = image_with_hidden_patch_value(&mask, 0.0);
840 let hidden_high = image_with_hidden_patch_value(&mask, 1_000.0);
841
842 let approx_low = encoder.forward(&hidden_low).gather(&mask.context_indices);
843 let approx_high = encoder.forward(&hidden_high).gather(&mask.context_indices);
844
845 let diff: f32 = (approx_low.embeddings - approx_high.embeddings)
846 .abs()
847 .sum()
848 .into_scalar()
849 .elem();
850 assert!(
851 diff > 1e-3,
852 "post-encoder gather path should leak hidden patches, diff={diff}"
853 );
854 }
855
856 #[test]
857 fn test_strict_forward_step_runs_end_to_end() {
858 let config = IJepaConfig::tiny_test();
859 let model = config.init::<TestBackend>(&device());
860 let mask = fixed_image_mask();
861 let images = image_with_hidden_patch_value(&mask, 3.0);
862 let energy_fn = jepa_core::energy::L2Energy;
863 let regularizer = jepa_core::collapse::VICReg::default();
864
865 let output =
866 model.forward_step_strict(&images, mask.clone(), &energy_fn, ®ularizer, 1.0);
867
868 assert_eq!(output.context.seq_len(), mask.context_indices.len());
869 assert_eq!(output.predicted.seq_len(), mask.target_indices.len());
870 assert_eq!(output.target.seq_len(), mask.target_indices.len());
871
872 let total_loss: f32 = output.total_loss.into_scalar().elem();
873 assert!(
874 total_loss.is_finite(),
875 "strict forward loss should be finite"
876 );
877 }
878
879 #[test]
880 fn test_try_strict_forward_step_rejects_invalid_mask() {
881 let config = IJepaConfig::tiny_test();
882 let model = config.init::<TestBackend>(&device());
883 let images = Tensor::ones([1, 1, 8, 8], &device());
884 let invalid_mask = MaskSpec {
885 context_indices: vec![],
886 target_indices: vec![0],
887 total_tokens: 16,
888 };
889 let energy_fn = jepa_core::energy::L2Energy;
890 let regularizer = jepa_core::collapse::VICReg::default();
891
892 let err = model
893 .try_forward_step_strict(&images, invalid_mask, &energy_fn, ®ularizer, 1.0)
894 .unwrap_err();
895 assert!(matches!(
896 err,
897 StrictIJepaError::InvalidMask(MaskError::EmptyContext)
898 ));
899 }
900
901 #[test]
911 fn bdd_encode_batch_correct_shape_and_nonzero_variance() {
912 let config = crate::vit::VitConfig::tiny_test();
913 let encoder = config.init::<TestBackend>(&device());
914
915 let batch_size = 4;
917 let images: Tensor<TestBackend, 4> = Tensor::random(
918 [batch_size, 1, 8, 8],
919 burn::tensor::Distribution::Normal(0.0, 1.0),
920 &device(),
921 );
922 let repr = encoder.forward(&images);
923
924 assert_eq!(repr.batch_size(), batch_size);
926 assert_eq!(repr.seq_len(), 16);
927 assert_eq!(repr.embed_dim(), 32);
928
929 let mean_repr = repr.embeddings.clone().mean_dim(0); let diff = repr.embeddings.clone() - mean_repr;
933 let variance: f32 = (diff.clone() * diff).mean().into_scalar().elem();
934 assert!(
935 variance > 1e-6,
936 "representations should have non-zero variance across the batch, got {variance}"
937 );
938 }
939
940 #[test]
946 fn bdd_encoder_pair_same_init_same_output() {
947 let config = crate::vit::VitConfig::tiny_test();
950 let encoder = config.init::<TestBackend>(&device());
951
952 let images: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
953
954 let repr1 = encoder.forward(&images);
956 let repr2 = encoder.forward(&images);
957
958 let diff: f32 = (repr1.embeddings - repr2.embeddings)
959 .abs()
960 .sum()
961 .into_scalar()
962 .elem();
963 assert!(
964 diff < 1e-6,
965 "same encoder + same input should produce identical representations, diff={diff}"
966 );
967 }
968
969 #[test]
975 fn bdd_ema_update_target_lags_context() {
976 let config = IJepaConfig::tiny_test();
977 let model = config.init::<TestBackend>(&device());
978
979 let images: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
980
981 let ctx_repr = model.context_encoder.forward(&images);
983 let tgt_repr = model.target_encoder.forward(&images);
984
985 let initial_diff: f32 = (ctx_repr.embeddings.clone() - tgt_repr.embeddings.clone())
988 .abs()
989 .sum()
990 .into_scalar()
991 .elem();
992
993 let ema = jepa_core::ema::Ema::new(0.99);
996 let target_val = 0.0f64;
997 let online_val = 1.0f64;
998 let mut val = target_val;
999 for step in 0..500 {
1000 val = ema.step(val, online_val, step);
1001 }
1002 assert!(val > 0.9, "EMA should converge toward online, got {val}");
1004 assert!(val < 1.0, "EMA should lag behind online, got {val}");
1005
1006 assert!(
1010 initial_diff >= 0.0,
1011 "initial representations computed successfully"
1012 );
1013 }
1014
1015 #[test]
1020 fn bdd_full_ijepa_pipeline_with_gather() {
1021 let config = IJepaConfig::tiny_test();
1022 let model = config.init::<TestBackend>(&device());
1023
1024 let images: Tensor<TestBackend, 4> = Tensor::random(
1025 [2, 1, 8, 8],
1026 burn::tensor::Distribution::Normal(0.0, 1.0),
1027 &device(),
1028 );
1029
1030 let context_repr = model.context_encoder.forward(&images);
1032 let target_repr = model.target_encoder.forward(&images);
1033
1034 let masking = jepa_core::masking::BlockMasking {
1036 num_targets: 2,
1037 target_scale: (0.15, 0.3),
1038 target_aspect_ratio: (0.75, 1.5),
1039 };
1040 let shape = jepa_core::types::InputShape::Image {
1041 height: 4,
1042 width: 4,
1043 };
1044 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
1045 let mask = masking.generate_mask(&shape, &mut rng);
1046 assert!(mask.validate().is_ok());
1047
1048 let target_gathered = target_repr.gather(&mask.target_indices);
1050 assert_eq!(target_gathered.seq_len(), mask.target_indices.len());
1051 assert_eq!(target_gathered.batch_size(), 2);
1052
1053 let num_targets = mask.target_indices.len();
1055 let target_pos = target_positions(&mask.target_indices, 2);
1056 let predicted = model.predictor.predict(&context_repr, &target_pos, None);
1057 assert_eq!(predicted.seq_len(), num_targets);
1058
1059 let energy = jepa_core::energy::L2Energy.compute(&predicted, &target_gathered);
1061 let val: f32 = energy.value.into_scalar().elem();
1062 assert!(val.is_finite(), "energy should be finite, got {val}");
1063 assert!(val >= 0.0, "L2 energy should be non-negative, got {val}");
1064
1065 let batch = 2;
1067 let embed_dim = predicted.embed_dim();
1068 let pred_flat = predicted
1069 .embeddings
1070 .reshape([batch * num_targets, embed_dim]);
1071 let target_flat = target_gathered
1072 .embeddings
1073 .reshape([batch * num_targets, embed_dim]);
1074 let reg_loss: f32 = jepa_core::collapse::VICReg::default()
1075 .loss(&pred_flat, &target_flat)
1076 .into_scalar()
1077 .elem();
1078 assert!(
1079 reg_loss.is_finite(),
1080 "regularization should be finite, got {reg_loss}"
1081 );
1082 }
1083
1084 #[test]
1090 fn bdd_masking_creates_valid_prediction_tasks() {
1091 let masking = jepa_core::masking::BlockMasking {
1092 num_targets: 4,
1093 target_scale: (0.15, 0.2),
1094 target_aspect_ratio: (0.75, 1.5),
1095 };
1096 let shape = jepa_core::types::InputShape::Image {
1097 height: 4,
1098 width: 4,
1099 };
1100
1101 let mut masks = Vec::new();
1102 for seed in 0..20u64 {
1103 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
1104 let mask = masking.generate_mask(&shape, &mut rng);
1105
1106 assert!(mask.validate().is_ok(), "mask with seed {seed} is invalid");
1107 assert_eq!(
1108 mask.context_indices.len() + mask.target_indices.len(),
1109 16,
1110 "mask with seed {seed} doesn't partition all 16 tokens"
1111 );
1112 assert!(
1113 !mask.context_indices.is_empty(),
1114 "mask with seed {seed} has empty context"
1115 );
1116 assert!(
1117 !mask.target_indices.is_empty(),
1118 "mask with seed {seed} has empty target"
1119 );
1120 masks.push(mask);
1121 }
1122
1123 let first_targets = &masks[0].target_indices;
1125 let some_differ = masks[1..]
1126 .iter()
1127 .any(|m| m.target_indices != *first_targets);
1128 assert!(some_differ, "masks should vary across different seeds");
1129 }
1130
1131 use proptest::prelude::*;
1132
1133 proptest! {
1134 #[test]
1137 fn prop_predictor_output_dim_matches_encoder(
1138 num_targets in 1usize..8,
1139 ) {
1140 let encoder_embed_dim = 32;
1141 let config = TransformerPredictorConfig {
1142 encoder_embed_dim,
1143 predictor_embed_dim: 16,
1144 num_layers: 1,
1145 num_heads: 2,
1146 max_target_len: 64,
1147 };
1148 let predictor = config.init::<TestBackend>(&device());
1149
1150 let context = Representation::new(Tensor::zeros([1, 8, encoder_embed_dim], &device()));
1151 let target_pos: Tensor<TestBackend, 2> =
1152 Tensor::zeros([1, num_targets], &device());
1153 let predicted = predictor.predict(&context, &target_pos, None);
1154
1155 prop_assert_eq!(predicted.batch_size(), 1);
1156 prop_assert_eq!(predicted.seq_len(), num_targets);
1157 prop_assert_eq!(predicted.embed_dim(), encoder_embed_dim);
1158 }
1159
1160 #[test]
1162 fn prop_predictor_output_is_finite(
1163 batch in 1usize..3,
1164 num_targets in 1usize..6,
1165 ) {
1166 let config = TransformerPredictorConfig {
1167 encoder_embed_dim: 16,
1168 predictor_embed_dim: 8,
1169 num_layers: 1,
1170 num_heads: 2,
1171 max_target_len: 16,
1172 };
1173 let predictor = config.init::<TestBackend>(&device());
1174
1175 let context = Representation::new(Tensor::random(
1176 [batch, 4, 16],
1177 burn::tensor::Distribution::Normal(0.0, 1.0),
1178 &device(),
1179 ));
1180 let target_pos: Tensor<TestBackend, 2> =
1181 Tensor::zeros([batch, num_targets], &device());
1182 let predicted = predictor.predict(&context, &target_pos, None);
1183
1184 let total: f32 = predicted
1185 .embeddings
1186 .abs()
1187 .sum()
1188 .into_scalar()
1189 .elem();
1190 prop_assert!(
1191 total.is_finite(),
1192 "predictor output should be finite, got {}",
1193 total
1194 );
1195 }
1196 }
1197}