use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};
use burn::prelude::*;
use burn::tensor::backend::Backend;
use burn::tensor::module::embedding;
use jepa_core::types::{Energy, MaskError, MaskSpec, Representation};
use jepa_core::{CollapseRegularizer, EnergyFn, Predictor};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TransformerPredictorConfig {
pub encoder_embed_dim: usize,
pub predictor_embed_dim: usize,
pub num_layers: usize,
pub num_heads: usize,
pub max_target_len: usize,
}
impl TransformerPredictorConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerPredictor<B> {
let input_proj =
LinearConfig::new(self.encoder_embed_dim, self.predictor_embed_dim).init(device);
let output_proj =
LinearConfig::new(self.predictor_embed_dim, self.encoder_embed_dim).init(device);
let blocks: Vec<PredictorBlock<B>> = (0..self.num_layers)
.map(|_| {
PredictorBlockConfig {
embed_dim: self.predictor_embed_dim,
num_heads: self.num_heads,
}
.init(device)
})
.collect();
let norm = LayerNormConfig::new(self.predictor_embed_dim).init(device);
let prediction_tokens =
sinusoidal_prediction_tokens(self.max_target_len, self.predictor_embed_dim, device);
TransformerPredictor {
input_proj,
output_proj,
blocks,
norm,
prediction_tokens,
predictor_embed_dim: self.predictor_embed_dim,
encoder_embed_dim: self.encoder_embed_dim,
}
}
}
#[derive(Module, Debug)]
pub struct TransformerPredictor<B: Backend> {
input_proj: Linear<B>,
output_proj: Linear<B>,
blocks: Vec<PredictorBlock<B>>,
norm: LayerNorm<B>,
prediction_tokens: Tensor<B, 2>,
predictor_embed_dim: usize,
encoder_embed_dim: usize,
}
#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)]
pub enum PredictorError {
#[error(
"target position batch size mismatch: context batch={context_batch}, target_positions batch={positions_batch}"
)]
BatchSizeMismatch {
context_batch: usize,
positions_batch: usize,
},
#[error("target position must be non-negative, got {0}")]
NegativeTargetPosition(i64),
#[error(
"target position {position} exceeds predictor capacity {max_supported}; increase max_target_len"
)]
TargetPositionOutOfRange {
position: usize,
max_supported: usize,
},
}
impl<B: Backend> Predictor<B> for TransformerPredictor<B> {
fn predict(
&self,
context: &Representation<B>,
target_positions: &Tensor<B, 2>,
_latent: Option<&Tensor<B, 2>>,
) -> Representation<B> {
self.try_predict(context, target_positions).expect(
"TransformerPredictor::predict failed — target positions must match the context \
batch size and not exceed max_target_len; use try_predict for error handling",
)
}
}
impl<B: Backend> TransformerPredictor<B> {
pub fn try_predict(
&self,
context: &Representation<B>,
target_positions: &Tensor<B, 2>,
) -> Result<Representation<B>, PredictorError> {
let [batch, _ctx_len, _enc_dim] = context.embeddings.dims();
let [positions_batch, num_targets] = target_positions.dims();
if positions_batch != batch {
return Err(PredictorError::BatchSizeMismatch {
context_batch: batch,
positions_batch,
});
}
if num_targets == 0 {
let device = context.embeddings.device();
return Ok(Representation::new(Tensor::zeros(
[batch, 0, self.encoder_embed_dim],
&device,
)));
}
let target_positions = target_positions.clone().int();
let min_position: i64 = target_positions.clone().min().into_scalar().elem();
if min_position < 0 {
return Err(PredictorError::NegativeTargetPosition(min_position));
}
let max_position: i64 = target_positions.clone().max().into_scalar().elem();
let max_supported_position = self.prediction_tokens.dims()[0];
if max_position >= max_supported_position as i64 {
return Err(PredictorError::TargetPositionOutOfRange {
position: max_position as usize,
max_supported: max_supported_position,
});
}
let ctx = self.input_proj.forward(context.embeddings.clone());
let pred_tokens = embedding(self.prediction_tokens.clone(), target_positions);
let combined = Tensor::cat(vec![ctx, pred_tokens], 1);
let ctx_len = context.embeddings.dims()[1];
let total_len = ctx_len + num_targets;
let mut x = combined;
for block in &self.blocks {
x = block.forward(x);
}
let pred_out = x.slice([0..batch, ctx_len..total_len, 0..self.predictor_embed_dim]);
let pred_out = self.norm.forward(pred_out);
let pred_out = self.output_proj.forward(pred_out);
Ok(Representation::new(pred_out))
}
}
fn sinusoidal_prediction_tokens<B: Backend>(
max_target_len: usize,
embed_dim: usize,
device: &B::Device,
) -> Tensor<B, 2> {
let mut data = vec![0.0f32; max_target_len * embed_dim];
for position in 0..max_target_len {
for dim in 0..embed_dim {
let exponent = (2 * (dim / 2)) as f64 / embed_dim as f64;
let angle = position as f64 / 10_000_f64.powf(exponent);
data[position * embed_dim + dim] = if dim % 2 == 0 {
angle.sin() as f32
} else {
angle.cos() as f32
};
}
}
Tensor::from_floats(
burn::tensor::TensorData::new(data, [max_target_len, embed_dim]),
device,
)
}
#[derive(Debug, Clone)]
struct PredictorBlockConfig {
embed_dim: usize,
num_heads: usize,
}
impl PredictorBlockConfig {
fn init<B: Backend>(&self, device: &B::Device) -> PredictorBlock<B> {
let head_dim = self.embed_dim / self.num_heads;
PredictorBlock {
norm1: LayerNormConfig::new(self.embed_dim).init(device),
attn: PredictorAttention {
qkv: LinearConfig::new(self.embed_dim, 3 * self.embed_dim).init(device),
out_proj: LinearConfig::new(self.embed_dim, self.embed_dim).init(device),
num_heads: self.num_heads,
head_dim,
},
norm2: LayerNormConfig::new(self.embed_dim).init(device),
mlp: PredictorMlp {
fc1: LinearConfig::new(self.embed_dim, self.embed_dim * 4).init(device),
fc2: LinearConfig::new(self.embed_dim * 4, self.embed_dim).init(device),
},
}
}
}
#[derive(Module, Debug)]
struct PredictorBlock<B: Backend> {
norm1: LayerNorm<B>,
attn: PredictorAttention<B>,
norm2: LayerNorm<B>,
mlp: PredictorMlp<B>,
}
impl<B: Backend> PredictorBlock<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let residual = x.clone();
let x_norm = self.norm1.forward(x);
let attn_out = self.attn.forward(x_norm);
let x = residual + attn_out;
let residual = x.clone();
let x_norm = self.norm2.forward(x);
let mlp_out = self.mlp.forward(x_norm);
residual + mlp_out
}
}
#[derive(Module, Debug)]
struct PredictorAttention<B: Backend> {
qkv: Linear<B>,
out_proj: Linear<B>,
num_heads: usize,
head_dim: usize,
}
impl<B: Backend> PredictorAttention<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [batch, seq_len, _] = x.dims();
let embed_dim = self.num_heads * self.head_dim;
let qkv = self.qkv.forward(x);
let q = qkv.clone().slice([0..batch, 0..seq_len, 0..embed_dim]);
let k = qkv
.clone()
.slice([0..batch, 0..seq_len, embed_dim..2 * embed_dim]);
let v = qkv.slice([0..batch, 0..seq_len, 2 * embed_dim..3 * embed_dim]);
let q = q
.reshape([batch, seq_len, self.num_heads, self.head_dim])
.swap_dims(1, 2);
let k = k
.reshape([batch, seq_len, self.num_heads, self.head_dim])
.swap_dims(1, 2);
let v = v
.reshape([batch, seq_len, self.num_heads, self.head_dim])
.swap_dims(1, 2);
let scale = (self.head_dim as f64).sqrt();
let attn = q.matmul(k.transpose()) / scale;
let attn = burn::tensor::activation::softmax(attn, 3);
let out = attn.matmul(v);
let out = out.swap_dims(1, 2).reshape([batch, seq_len, embed_dim]);
self.out_proj.forward(out)
}
}
#[derive(Module, Debug)]
struct PredictorMlp<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
}
impl<B: Backend> PredictorMlp<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.fc1.forward(x);
let x = burn::tensor::activation::gelu(x);
self.fc2.forward(x)
}
}
#[derive(Module, Debug)]
pub struct IJepa<B: Backend> {
pub context_encoder: crate::vit::VitEncoder<B>,
pub target_encoder: crate::vit::VitEncoder<B>,
pub predictor: TransformerPredictor<B>,
}
#[derive(Debug, Clone)]
pub struct StrictIJepaForwardOutput<B: Backend> {
pub energy: Energy<B>,
pub regularization: Tensor<B, 1>,
pub total_loss: Tensor<B, 1>,
pub mask: MaskSpec,
pub context: Representation<B>,
pub predicted: Representation<B>,
pub target: Representation<B>,
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum StrictIJepaError {
#[error(transparent)]
InvalidMask(#[from] MaskError),
#[error(transparent)]
Predictor(#[from] PredictorError),
}
impl<B: Backend> IJepa<B> {
pub fn encode_context_strict(
&self,
images: &Tensor<B, 4>,
context_indices: &[usize],
) -> Representation<B> {
self.context_encoder
.forward_visible_tokens(images, context_indices)
}
pub fn forward_step_strict<EF, CR>(
&self,
images: &Tensor<B, 4>,
mask: MaskSpec,
energy_fn: &EF,
regularizer: &CR,
reg_weight: f64,
) -> StrictIJepaForwardOutput<B>
where
EF: EnergyFn<B>,
CR: CollapseRegularizer<B>,
{
self.try_forward_step_strict(images, mask, energy_fn, regularizer, reg_weight)
.expect(
"IJepa::forward_step_strict failed — mask must be valid (disjoint, non-empty) \
and target count must not exceed predictor capacity; \
use try_forward_step_strict for error handling",
)
}
pub fn try_forward_step_strict<EF, CR>(
&self,
images: &Tensor<B, 4>,
mask: MaskSpec,
energy_fn: &EF,
regularizer: &CR,
reg_weight: f64,
) -> Result<StrictIJepaForwardOutput<B>, StrictIJepaError>
where
EF: EnergyFn<B>,
CR: CollapseRegularizer<B>,
{
mask.validate()?;
let context = self.encode_context_strict(images, &mask.context_indices);
let target_full = self.target_encoder.forward(images);
let target =
Representation::new(target_full.embeddings.detach()).gather(&mask.target_indices);
let batch = images.dims()[0];
let target_positions =
target_positions_tensor::<B>(&mask.target_indices, batch, &images.device());
let predicted = self.predictor.try_predict(&context, &target_positions)?;
let num_targets = target.seq_len();
let embed_dim = target.embed_dim();
let pred_flat = predicted
.embeddings
.clone()
.reshape([batch * num_targets, embed_dim]);
let target_flat = target
.embeddings
.clone()
.reshape([batch * num_targets, embed_dim]);
let energy = energy_fn.compute(&predicted, &target);
let regularization = regularizer.loss(&pred_flat, &target_flat);
let total_loss = energy.value.clone() + regularization.clone() * reg_weight;
Ok(StrictIJepaForwardOutput {
energy,
regularization,
total_loss,
mask,
context,
predicted,
target,
})
}
}
pub(crate) fn target_positions_tensor<B: Backend>(
indices: &[usize],
batch: usize,
device: &B::Device,
) -> Tensor<B, 2> {
let mut data = Vec::with_capacity(batch * indices.len());
for _ in 0..batch {
data.extend(indices.iter().map(|&index| index as f32));
}
Tensor::from_floats(
burn::tensor::TensorData::new(data, [batch, indices.len()]),
device,
)
}
#[derive(Debug, Clone)]
pub struct IJepaConfig {
pub encoder: crate::vit::VitConfig,
pub predictor: TransformerPredictorConfig,
}
impl IJepaConfig {
pub fn tiny_test() -> Self {
let encoder = crate::vit::VitConfig::tiny_test();
Self {
predictor: TransformerPredictorConfig {
encoder_embed_dim: encoder.embed_dim,
predictor_embed_dim: 16,
num_layers: 1,
num_heads: 2,
max_target_len: 64,
},
encoder,
}
}
pub fn init<B: Backend>(&self, device: &B::Device) -> IJepa<B> {
IJepa {
context_encoder: self.encoder.init(device),
target_encoder: self.encoder.init(device),
predictor: self.predictor.init(device),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::ElementConversion;
use burn_ndarray::NdArray;
use jepa_core::{CollapseRegularizer, EnergyFn, MaskingStrategy};
use rand::SeedableRng;
type TestBackend = NdArray<f32>;
fn device() -> burn_ndarray::NdArrayDevice {
burn_ndarray::NdArrayDevice::Cpu
}
fn target_positions(indices: &[usize], batch: usize) -> Tensor<TestBackend, 2> {
let mut data = Vec::with_capacity(batch * indices.len());
for _ in 0..batch {
data.extend(indices.iter().map(|&index| index as f32));
}
Tensor::from_floats(
burn::tensor::TensorData::new(data, [batch, indices.len()]),
&device(),
)
}
fn fixed_image_mask() -> MaskSpec {
MaskSpec {
context_indices: vec![0, 1, 4, 5, 10, 11, 14, 15],
target_indices: vec![2, 3, 6, 7, 8, 9, 12, 13],
total_tokens: 16,
}
}
fn image_with_hidden_patch_value(mask: &MaskSpec, hidden_value: f32) -> Tensor<TestBackend, 4> {
let image_size = 8usize;
let patch_size = 2usize;
let mut data = vec![1.0f32; image_size * image_size];
for &index in &mask.target_indices {
let patch_row = index / 4;
let patch_col = index % 4;
let row_start = patch_row * patch_size;
let col_start = patch_col * patch_size;
for row in row_start..row_start + patch_size {
for col in col_start..col_start + patch_size {
data[row * image_size + col] = hidden_value;
}
}
}
Tensor::from_floats(
burn::tensor::TensorData::new(data, [1, 1, image_size, image_size]),
&device(),
)
}
#[test]
fn test_predictor_output_shape() {
let config = TransformerPredictorConfig {
encoder_embed_dim: 32,
predictor_embed_dim: 16,
num_layers: 1,
num_heads: 2,
max_target_len: 64,
};
let predictor = config.init::<TestBackend>(&device());
let context = Representation::new(Tensor::zeros([2, 8, 32], &device()));
let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([2, 4], &device());
let predicted = predictor.predict(&context, &target_pos, None);
assert_eq!(predicted.batch_size(), 2);
assert_eq!(predicted.seq_len(), 4);
assert_eq!(predicted.embed_dim(), 32);
}
#[test]
fn test_predictor_implements_trait() {
let config = TransformerPredictorConfig {
encoder_embed_dim: 16,
predictor_embed_dim: 8,
num_layers: 1,
num_heads: 2,
max_target_len: 16,
};
let predictor = config.init::<TestBackend>(&device());
let context = Representation::new(Tensor::zeros([1, 4, 16], &device()));
let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([1, 2], &device());
let pred: Representation<TestBackend> =
Predictor::predict(&predictor, &context, &target_pos, None);
assert_eq!(pred.seq_len(), 2);
}
#[test]
fn test_predictor_output_depends_on_target_positions() {
let config = TransformerPredictorConfig {
encoder_embed_dim: 16,
predictor_embed_dim: 8,
num_layers: 1,
num_heads: 2,
max_target_len: 16,
};
let predictor = config.init::<TestBackend>(&device());
let context = Representation::new(Tensor::zeros([1, 4, 16], &device()));
let positions_a = target_positions(&[0, 1], 1);
let positions_b = target_positions(&[2, 3], 1);
let pred_a = predictor.predict(&context, &positions_a, None);
let pred_b = predictor.predict(&context, &positions_b, None);
let diff: f32 = (pred_a.embeddings - pred_b.embeddings)
.abs()
.sum()
.into_scalar()
.elem();
assert!(
diff > 1e-6,
"target positions should affect predictor output, diff={diff}"
);
}
#[test]
fn test_predictor_try_predict_rejects_batch_size_mismatch() {
let config = TransformerPredictorConfig {
encoder_embed_dim: 16,
predictor_embed_dim: 8,
num_layers: 1,
num_heads: 2,
max_target_len: 16,
};
let predictor = config.init::<TestBackend>(&device());
let context = Representation::new(Tensor::zeros([2, 4, 16], &device()));
let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([1, 2], &device());
let err = predictor.try_predict(&context, &target_pos).unwrap_err();
assert_eq!(
err,
PredictorError::BatchSizeMismatch {
context_batch: 2,
positions_batch: 1,
}
);
}
#[test]
fn test_predictor_try_predict_rejects_out_of_range_positions() {
let config = TransformerPredictorConfig {
encoder_embed_dim: 16,
predictor_embed_dim: 8,
num_layers: 1,
num_heads: 2,
max_target_len: 4,
};
let predictor = config.init::<TestBackend>(&device());
let context = Representation::new(Tensor::zeros([1, 4, 16], &device()));
let target_pos = target_positions(&[0, 4], 1);
let err = predictor.try_predict(&context, &target_pos).unwrap_err();
assert_eq!(
err,
PredictorError::TargetPositionOutOfRange {
position: 4,
max_supported: 4,
}
);
}
#[test]
fn test_predictor_try_predict_allows_empty_targets() {
let config = TransformerPredictorConfig {
encoder_embed_dim: 16,
predictor_embed_dim: 8,
num_layers: 1,
num_heads: 2,
max_target_len: 4,
};
let predictor = config.init::<TestBackend>(&device());
let context = Representation::new(Tensor::zeros([2, 4, 16], &device()));
let target_pos: Tensor<TestBackend, 2> = Tensor::zeros([2, 0], &device());
let predicted = predictor.try_predict(&context, &target_pos).unwrap();
assert_eq!(predicted.batch_size(), 2);
assert_eq!(predicted.seq_len(), 0);
assert_eq!(predicted.embed_dim(), 16);
}
#[test]
fn test_ijepa_full_pipeline() {
let config = IJepaConfig::tiny_test();
let model = config.init::<TestBackend>(&device());
let images: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
let context_repr = model.context_encoder.forward(&images);
let target_repr = model.target_encoder.forward(&images);
assert_eq!(context_repr.seq_len(), 16); assert_eq!(target_repr.seq_len(), 16);
let masking = jepa_core::masking::BlockMasking {
num_targets: 2,
target_scale: (0.15, 0.3),
target_aspect_ratio: (0.75, 1.5),
};
let shape = jepa_core::types::InputShape::Image {
height: 4,
width: 4,
};
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let mask = masking.generate_mask(&shape, &mut rng);
let num_targets = mask.target_indices.len();
let target_pos = target_positions(&mask.target_indices, 1);
let predicted = model.predictor.predict(&context_repr, &target_pos, None);
assert_eq!(predicted.seq_len(), num_targets);
assert_eq!(predicted.embed_dim(), 32);
let energy = jepa_core::energy::L2Energy.compute(&predicted, &predicted);
let val: f32 = energy.value.into_scalar().elem();
assert!(val.is_finite(), "energy should be finite");
}
#[test]
fn test_ijepa_config_tiny() {
let config = IJepaConfig::tiny_test();
assert_eq!(config.encoder.embed_dim, 32);
assert_eq!(config.predictor.predictor_embed_dim, 16);
}
#[test]
fn test_strict_context_encoding_ignores_hidden_patches() {
let config = IJepaConfig::tiny_test();
let model = config.init::<TestBackend>(&device());
let mask = fixed_image_mask();
let hidden_low = image_with_hidden_patch_value(&mask, 0.0);
let hidden_high = image_with_hidden_patch_value(&mask, 1_000.0);
let strict_low = model.encode_context_strict(&hidden_low, &mask.context_indices);
let strict_high = model.encode_context_strict(&hidden_high, &mask.context_indices);
let diff: f32 = (strict_low.embeddings - strict_high.embeddings)
.abs()
.sum()
.into_scalar()
.elem();
assert!(
diff < 1e-5,
"strict masked context should ignore hidden patches, diff={diff}"
);
}
#[test]
fn test_full_encoder_context_slice_leaks_hidden_patches() {
let config = crate::vit::VitConfig::tiny_test();
let encoder = config.init::<TestBackend>(&device());
let mask = fixed_image_mask();
let hidden_low = image_with_hidden_patch_value(&mask, 0.0);
let hidden_high = image_with_hidden_patch_value(&mask, 1_000.0);
let approx_low = encoder.forward(&hidden_low).gather(&mask.context_indices);
let approx_high = encoder.forward(&hidden_high).gather(&mask.context_indices);
let diff: f32 = (approx_low.embeddings - approx_high.embeddings)
.abs()
.sum()
.into_scalar()
.elem();
assert!(
diff > 1e-3,
"post-encoder gather path should leak hidden patches, diff={diff}"
);
}
#[test]
fn test_strict_forward_step_runs_end_to_end() {
let config = IJepaConfig::tiny_test();
let model = config.init::<TestBackend>(&device());
let mask = fixed_image_mask();
let images = image_with_hidden_patch_value(&mask, 3.0);
let energy_fn = jepa_core::energy::L2Energy;
let regularizer = jepa_core::collapse::VICReg::default();
let output =
model.forward_step_strict(&images, mask.clone(), &energy_fn, ®ularizer, 1.0);
assert_eq!(output.context.seq_len(), mask.context_indices.len());
assert_eq!(output.predicted.seq_len(), mask.target_indices.len());
assert_eq!(output.target.seq_len(), mask.target_indices.len());
let total_loss: f32 = output.total_loss.into_scalar().elem();
assert!(
total_loss.is_finite(),
"strict forward loss should be finite"
);
}
#[test]
fn test_try_strict_forward_step_rejects_invalid_mask() {
let config = IJepaConfig::tiny_test();
let model = config.init::<TestBackend>(&device());
let images = Tensor::ones([1, 1, 8, 8], &device());
let invalid_mask = MaskSpec {
context_indices: vec![],
target_indices: vec![0],
total_tokens: 16,
};
let energy_fn = jepa_core::energy::L2Energy;
let regularizer = jepa_core::collapse::VICReg::default();
let err = model
.try_forward_step_strict(&images, invalid_mask, &energy_fn, ®ularizer, 1.0)
.unwrap_err();
assert!(matches!(
err,
StrictIJepaError::InvalidMask(MaskError::EmptyContext)
));
}
#[test]
fn bdd_encode_batch_correct_shape_and_nonzero_variance() {
let config = crate::vit::VitConfig::tiny_test();
let encoder = config.init::<TestBackend>(&device());
let batch_size = 4;
let images: Tensor<TestBackend, 4> = Tensor::random(
[batch_size, 1, 8, 8],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device(),
);
let repr = encoder.forward(&images);
assert_eq!(repr.batch_size(), batch_size);
assert_eq!(repr.seq_len(), 16);
assert_eq!(repr.embed_dim(), 32);
let mean_repr = repr.embeddings.clone().mean_dim(0); let diff = repr.embeddings.clone() - mean_repr;
let variance: f32 = (diff.clone() * diff).mean().into_scalar().elem();
assert!(
variance > 1e-6,
"representations should have non-zero variance across the batch, got {variance}"
);
}
#[test]
fn bdd_encoder_pair_same_init_same_output() {
let config = crate::vit::VitConfig::tiny_test();
let encoder = config.init::<TestBackend>(&device());
let images: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
let repr1 = encoder.forward(&images);
let repr2 = encoder.forward(&images);
let diff: f32 = (repr1.embeddings - repr2.embeddings)
.abs()
.sum()
.into_scalar()
.elem();
assert!(
diff < 1e-6,
"same encoder + same input should produce identical representations, diff={diff}"
);
}
#[test]
fn bdd_ema_update_target_lags_context() {
let config = IJepaConfig::tiny_test();
let model = config.init::<TestBackend>(&device());
let images: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
let ctx_repr = model.context_encoder.forward(&images);
let tgt_repr = model.target_encoder.forward(&images);
let initial_diff: f32 = (ctx_repr.embeddings.clone() - tgt_repr.embeddings.clone())
.abs()
.sum()
.into_scalar()
.elem();
let ema = jepa_core::ema::Ema::new(0.99);
let target_val = 0.0f64;
let online_val = 1.0f64;
let mut val = target_val;
for step in 0..500 {
val = ema.step(val, online_val, step);
}
assert!(val > 0.9, "EMA should converge toward online, got {val}");
assert!(val < 1.0, "EMA should lag behind online, got {val}");
assert!(
initial_diff >= 0.0,
"initial representations computed successfully"
);
}
#[test]
fn bdd_full_ijepa_pipeline_with_gather() {
let config = IJepaConfig::tiny_test();
let model = config.init::<TestBackend>(&device());
let images: Tensor<TestBackend, 4> = Tensor::random(
[2, 1, 8, 8],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device(),
);
let context_repr = model.context_encoder.forward(&images);
let target_repr = model.target_encoder.forward(&images);
let masking = jepa_core::masking::BlockMasking {
num_targets: 2,
target_scale: (0.15, 0.3),
target_aspect_ratio: (0.75, 1.5),
};
let shape = jepa_core::types::InputShape::Image {
height: 4,
width: 4,
};
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let mask = masking.generate_mask(&shape, &mut rng);
assert!(mask.validate().is_ok());
let target_gathered = target_repr.gather(&mask.target_indices);
assert_eq!(target_gathered.seq_len(), mask.target_indices.len());
assert_eq!(target_gathered.batch_size(), 2);
let num_targets = mask.target_indices.len();
let target_pos = target_positions(&mask.target_indices, 2);
let predicted = model.predictor.predict(&context_repr, &target_pos, None);
assert_eq!(predicted.seq_len(), num_targets);
let energy = jepa_core::energy::L2Energy.compute(&predicted, &target_gathered);
let val: f32 = energy.value.into_scalar().elem();
assert!(val.is_finite(), "energy should be finite, got {val}");
assert!(val >= 0.0, "L2 energy should be non-negative, got {val}");
let batch = 2;
let embed_dim = predicted.embed_dim();
let pred_flat = predicted
.embeddings
.reshape([batch * num_targets, embed_dim]);
let target_flat = target_gathered
.embeddings
.reshape([batch * num_targets, embed_dim]);
let reg_loss: f32 = jepa_core::collapse::VICReg::default()
.loss(&pred_flat, &target_flat)
.into_scalar()
.elem();
assert!(
reg_loss.is_finite(),
"regularization should be finite, got {reg_loss}"
);
}
#[test]
fn bdd_masking_creates_valid_prediction_tasks() {
let masking = jepa_core::masking::BlockMasking {
num_targets: 4,
target_scale: (0.15, 0.2),
target_aspect_ratio: (0.75, 1.5),
};
let shape = jepa_core::types::InputShape::Image {
height: 4,
width: 4,
};
let mut masks = Vec::new();
for seed in 0..20u64 {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(seed);
let mask = masking.generate_mask(&shape, &mut rng);
assert!(mask.validate().is_ok(), "mask with seed {seed} is invalid");
assert_eq!(
mask.context_indices.len() + mask.target_indices.len(),
16,
"mask with seed {seed} doesn't partition all 16 tokens"
);
assert!(
!mask.context_indices.is_empty(),
"mask with seed {seed} has empty context"
);
assert!(
!mask.target_indices.is_empty(),
"mask with seed {seed} has empty target"
);
masks.push(mask);
}
let first_targets = &masks[0].target_indices;
let some_differ = masks[1..]
.iter()
.any(|m| m.target_indices != *first_targets);
assert!(some_differ, "masks should vary across different seeds");
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_predictor_output_dim_matches_encoder(
num_targets in 1usize..8,
) {
let encoder_embed_dim = 32;
let config = TransformerPredictorConfig {
encoder_embed_dim,
predictor_embed_dim: 16,
num_layers: 1,
num_heads: 2,
max_target_len: 64,
};
let predictor = config.init::<TestBackend>(&device());
let context = Representation::new(Tensor::zeros([1, 8, encoder_embed_dim], &device()));
let target_pos: Tensor<TestBackend, 2> =
Tensor::zeros([1, num_targets], &device());
let predicted = predictor.predict(&context, &target_pos, None);
prop_assert_eq!(predicted.batch_size(), 1);
prop_assert_eq!(predicted.seq_len(), num_targets);
prop_assert_eq!(predicted.embed_dim(), encoder_embed_dim);
}
#[test]
fn prop_predictor_output_is_finite(
batch in 1usize..3,
num_targets in 1usize..6,
) {
let config = TransformerPredictorConfig {
encoder_embed_dim: 16,
predictor_embed_dim: 8,
num_layers: 1,
num_heads: 2,
max_target_len: 16,
};
let predictor = config.init::<TestBackend>(&device());
let context = Representation::new(Tensor::random(
[batch, 4, 16],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device(),
));
let target_pos: Tensor<TestBackend, 2> =
Tensor::zeros([batch, num_targets], &device());
let predicted = predictor.predict(&context, &target_pos, None);
let total: f32 = predicted
.embeddings
.abs()
.sum()
.into_scalar()
.elem();
prop_assert!(
total.is_finite(),
"predictor output should be finite, got {}",
total
);
}
}
}