use burn::prelude::*;
use burn::tensor::ElementConversion;
use burn_ndarray::NdArray;
use rand::SeedableRng;
use jepa_core::collapse::VICReg;
use jepa_core::ema::Ema;
use jepa_core::energy::{EnergyFn, L2Energy};
use jepa_core::masking::{BlockMasking, MaskingStrategy};
use jepa_core::types::InputShape;
use jepa_core::Predictor;
use jepa_train::schedule::{LrSchedule, WarmupCosineSchedule};
use jepa_train::{CheckpointMeta, TrainConfig, TrainMetrics};
use jepa_vision::image::{IJepaConfig, TransformerPredictorConfig};
use jepa_vision::vit::VitConfig;
type B = NdArray<f32>;
fn main() {
println!("=== I-JEPA Training Loop Simulation ===\n");
let device = burn_ndarray::NdArrayDevice::Cpu;
let train_config = TrainConfig {
total_steps: 50,
warmup_steps: 10,
peak_lr: 1.5e-4,
regularization_weight: 1.0,
ema_momentum: 0.996,
batch_size: 2,
log_interval: 10,
checkpoint_interval: 25,
};
assert!(train_config.validate().is_ok());
let encoder_config = VitConfig {
in_channels: 3,
image_height: 16,
image_width: 16,
patch_size: (4, 4),
embed_dim: 32,
num_layers: 1,
num_heads: 4,
mlp_dim: 128,
dropout: 0.0,
};
let predictor_config = TransformerPredictorConfig {
encoder_embed_dim: encoder_config.embed_dim,
predictor_embed_dim: 16,
num_layers: 1,
num_heads: 4,
max_target_len: 16,
};
let model_config = IJepaConfig {
encoder: encoder_config.clone(),
predictor: predictor_config,
};
let grid_h = encoder_config.image_height / encoder_config.patch_size.0;
let grid_w = encoder_config.image_width / encoder_config.patch_size.1;
println!("Config:");
println!(
" Steps: {} (warmup: {})",
train_config.total_steps, train_config.warmup_steps
);
println!(" Peak LR: {}", train_config.peak_lr);
println!(" EMA momentum: {}", train_config.ema_momentum);
println!(" Image: {}x{}, patch grid: {}x{}", 16, 16, grid_h, grid_w);
println!();
let model = model_config.init::<B>(&device);
let masking = BlockMasking {
num_targets: 2,
target_scale: (0.15, 0.3),
target_aspect_ratio: (0.75, 1.5),
};
let input_shape = InputShape::Image {
height: grid_h,
width: grid_w,
};
let lr_schedule = WarmupCosineSchedule::new(
train_config.peak_lr,
train_config.warmup_steps,
train_config.total_steps,
);
let ema = Ema::with_cosine_schedule(train_config.ema_momentum, train_config.total_steps);
let vicreg = VICReg::default();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let mut metrics = TrainMetrics::default();
let mut checkpoint = CheckpointMeta::new(train_config.total_steps);
println!("--- Training Loop ---\n");
for step in 0..train_config.total_steps {
let images: Tensor<B, 4> = Tensor::random(
[train_config.batch_size, 3, 16, 16],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let mask = masking.generate_mask(&input_shape, &mut rng);
let context_repr = model.context_encoder.forward(&images);
let target_repr = model.target_encoder.forward(&images);
let num_targets = mask.target_indices.len();
let target_positions: Tensor<B, 2> =
Tensor::zeros([train_config.batch_size, num_targets], &device);
let predicted = model
.predictor
.predict(&context_repr, &target_positions, None);
let target_slice = target_repr.gather(&mask.target_indices);
let energy = L2Energy.compute(&predicted, &target_slice);
let energy_val: f32 = energy.value.into_scalar().elem();
let embed_dim = predicted.embed_dim();
let batch = predicted.batch_size();
let pred_flat = predicted
.embeddings
.clone()
.reshape([batch * num_targets, embed_dim]);
let tgt_flat = target_slice
.embeddings
.clone()
.reshape([batch * num_targets, embed_dim]);
let vicreg_loss = vicreg.compute(&pred_flat, &tgt_flat);
let reg_val: f32 = vicreg_loss.total().into_scalar().elem();
let total_loss = energy_val + train_config.regularization_weight as f32 * reg_val;
metrics.record(energy_val as f64, reg_val as f64, total_loss as f64);
let lr = lr_schedule.get_lr(step);
let momentum = ema.get_momentum(step);
if (step + 1) % train_config.log_interval == 0 || step == 0 {
let (avg_e, avg_r, avg_t) = metrics.take_averages();
println!(
"Step {:>3}/{}: loss={:.4} (energy={:.4} reg={:.4}) lr={:.2e} mom={:.6}",
step + 1,
train_config.total_steps,
avg_t,
avg_e,
avg_r,
lr,
momentum,
);
}
if (step + 1) % train_config.checkpoint_interval == 0 {
checkpoint.step = step + 1;
checkpoint.learning_rate = lr;
checkpoint.ema_momentum = momentum;
checkpoint.last_loss = Some(total_loss as f64);
println!(
" [Checkpoint] step={}, progress={:.0}%, loss={:.4}",
checkpoint.step,
checkpoint.progress() * 100.0,
total_loss,
);
}
}
checkpoint.step = train_config.total_steps;
println!();
println!("--- Training Complete ---");
println!(
"Final checkpoint: step={}, progress={:.0}%, complete={}",
checkpoint.step,
checkpoint.progress() * 100.0,
checkpoint.is_complete(),
);
println!();
println!("=== Simulation Done ===");
}