Skip to main content

jepa_core/
lib.rs

1//! # jepa-core
2//!
3//! Core traits and tensor abstractions for the
4//! **Joint Embedding Predictive Architecture (JEPA)**.
5//!
6//! JEPA (LeCun, 2022) is a self-supervised learning framework that predicts in
7//! *representation space* rather than pixel space. Instead of reconstructing raw
8//! inputs (as in MAE or BERT), a JEPA model learns to predict the latent
9//! representations of masked target regions from visible context regions. This
10//! avoids wasting model capacity on pixel-level details and encourages the
11//! encoder to capture high-level semantic structure.
12//!
13//! ```text
14//!                   ┌────────────────┐
15//!        x_context ─►  Context       │
16//!                   │  Encoder (θ)   ├─► s_x ──┐
17//!                   └────────────────┘         │
18//!                                              ▼
19//!                                        ┌──────────┐
20//!                              z (opt.) ─►          │
21//!                                        │ Predictor├─► ŝ_y ──┐
22//!                     target_positions ─►│          │         │
23//!                                        └──────────┘         │  ┌──────────┐
24//!                                                             ├──► EnergyFn │─► loss
25//!                   ┌────────────────┐                        │  └──────────┘
26//!        x_target  ─►  Target        │                        │
27//!                   │  Encoder (ξ)   ├─► s_y ─────────────────┘
28//!                   └────────────────┘
29//!                        ↑
30//!                        │ EMA(θ → ξ)
31//! ```
32//!
33//! This crate is **backend-agnostic**: all tensor-bearing APIs are generic over
34//! [`burn::tensor::backend::Backend`], so they work with any burn backend
35//! (NdArray, Wgpu, Tch, etc.).
36//!
37//! ## Crate layout
38//!
39//! | Module | Purpose |
40//! |--------|---------|
41//! | [`encoder`] | [`Encoder`] trait — maps raw inputs to [`Representation`]s |
42//! | [`predictor`] | [`Predictor`] trait — predicts target representations from context |
43//! | [`energy`] | [`EnergyFn`] trait and impls ([`L2Energy`], [`CosineEnergy`], [`SmoothL1Energy`]) |
44//! | [`masking`] | [`MaskingStrategy`] trait and impls ([`BlockMasking`], [`SpatiotemporalMasking`], [`MultiBlockMasking`]) |
45//! | [`collapse`] | [`CollapseRegularizer`] trait and impls ([`VICReg`], [`BarlowTwins`]) |
46//! | [`ema`] | [`Ema`] — exponential moving average updater with optional cosine schedule |
47//! | [`types`] | Semantic tensor wrappers: [`Representation`], [`Energy`], [`MaskSpec`], [`InputShape`] |
48//! | [`config`] | [`JepaConfig`] with ViT presets and a validated [`JepaConfigBuilder`] |
49//!
50//! ## Quick start
51//!
52//! ```rust
53//! use jepa_core::{Encoder, Predictor, EnergyFn, MaskingStrategy};
54//! use jepa_core::types::{Representation, InputShape};
55//! use jepa_core::energy::L2Energy;
56//! use jepa_core::masking::BlockMasking;
57//! use jepa_core::ema::Ema;
58//! use rand::SeedableRng;
59//!
60//! // Configure masking: 4 target blocks covering ~15-20% of patches
61//! let masking = BlockMasking {
62//!     num_targets: 4,
63//!     target_scale: (0.15, 0.2),
64//!     target_aspect_ratio: (0.75, 1.5),
65//! };
66//!
67//! // Generate a mask for a 14×14 patch grid (ViT-H/14 on 224×224)
68//! let shape = InputShape::Image { height: 14, width: 14 };
69//! let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
70//! let mask = masking.generate_mask(&shape, &mut rng);
71//! assert!(mask.validate().is_ok());
72//!
73//! // EMA with cosine momentum schedule
74//! let ema = Ema::with_cosine_schedule(0.996, 100_000);
75//! assert!((ema.get_momentum(0) - 0.996).abs() < 1e-6);
76//! ```
77//!
78//! ## References
79//!
80//! - LeCun, Y. (2022). *A Path Towards Autonomous Machine Intelligence*.
81//! - Assran, M. et al. (2023). *Self-Supervised Learning from Images with a
82//!   Joint-Embedding Predictive Architecture*. CVPR.
83//! - Bardes, A. et al. (2024). *V-JEPA: Latent Video Prediction for Visual
84//!   Representation Learning*.
85//! - Bardes, A. et al. (2025). *V-JEPA 2: Self-Supervised Video Models Enable
86//!   Understanding, Generation, and Planning*.
87
88pub mod collapse;
89pub mod config;
90pub mod ema;
91pub mod encoder;
92pub mod energy;
93pub mod masking;
94pub mod predictor;
95pub mod types;
96
97// Core types
98pub use types::{Energy, InputShape, MaskError, MaskSpec, Representation};
99
100// Traits
101pub use collapse::CollapseRegularizer;
102pub use encoder::Encoder;
103pub use energy::EnergyFn;
104pub use masking::MaskingStrategy;
105pub use predictor::Predictor;
106
107// Config
108pub use config::{ConfigError, JepaConfig, JepaConfigBuilder};
109
110// Concrete implementations
111pub use collapse::{BarlowTwins, VICReg};
112pub use ema::{CosineMomentumSchedule, Ema};
113pub use energy::{CosineEnergy, L2Energy, SmoothL1Energy};
114pub use masking::{BlockMasking, MultiBlockMasking, SpatiotemporalMasking};