Skip to main content

jepa_core/
config.rs

1//! Configuration types for JEPA architecture.
2//!
3//! Implements RFC-001 (Core Tensor Abstractions) — configuration component.
4//!
5//! [`JepaConfig`] captures the hyperparameters that define the shape and
6//! size of the encoder, predictor, and training components. Four standard
7//! ViT presets are included (Base/16, Large/16, Huge/14, giant/14), and
8//! a validated [`JepaConfigBuilder`] supports ergonomic customization.
9//!
10//! All configs are serializable via `serde` for checkpoint reproducibility.
11
12use serde::{Deserialize, Serialize};
13
14/// Configuration for JEPA architecture dimensions.
15///
16/// Specifies the hyperparameters that define the shape and size of
17/// encoder, predictor, and training components.
18///
19/// # Example
20/// ```
21/// use jepa_core::config::JepaConfig;
22///
23/// let config = JepaConfig {
24///     embed_dim: 256,
25///     predictor_embed_dim: 128,
26///     num_encoder_layers: 12,
27///     num_predictor_layers: 6,
28///     num_heads: 8,
29///     patch_size: (16, 16),
30///     tubelet_size: (2, 16, 16),
31///     ema_momentum: 0.996,
32/// };
33/// assert!(config.validate().is_ok());
34/// ```
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct JepaConfig {
37    /// Embedding dimension of the encoder output.
38    pub embed_dim: usize,
39    /// Embedding dimension of the predictor (can be smaller than encoder).
40    pub predictor_embed_dim: usize,
41    /// Number of transformer layers in the encoder.
42    pub num_encoder_layers: usize,
43    /// Number of transformer layers in the predictor.
44    pub num_predictor_layers: usize,
45    /// Number of attention heads (must divide embed_dim evenly).
46    pub num_heads: usize,
47    /// Patch size for images `(height, width)`.
48    pub patch_size: (usize, usize),
49    /// Tubelet size for video `(temporal, height, width)`.
50    pub tubelet_size: (usize, usize, usize),
51    /// EMA momentum for target encoder updates. Range: `[0.0, 1.0]`.
52    pub ema_momentum: f64,
53}
54
55/// Errors from config validation.
56#[derive(Debug, thiserror::Error)]
57pub enum ConfigError {
58    #[error("embed_dim must be positive, got {0}")]
59    ZeroEmbedDim(usize),
60    #[error("predictor_embed_dim must be positive, got {0}")]
61    ZeroPredictorEmbedDim(usize),
62    #[error("num_encoder_layers must be positive, got {0}")]
63    ZeroEncoderLayers(usize),
64    #[error("num_predictor_layers must be positive, got {0}")]
65    ZeroPredictorLayers(usize),
66    #[error("num_heads must be positive, got {0}")]
67    ZeroHeads(usize),
68    #[error("embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")]
69    HeadDimMismatch { embed_dim: usize, num_heads: usize },
70    #[error("patch_size dimensions must be positive, got ({0}, {1})")]
71    ZeroPatchSize(usize, usize),
72    #[error("tubelet_size dimensions must be positive, got ({0}, {1}, {2})")]
73    ZeroTubeletSize(usize, usize, usize),
74    #[error("ema_momentum must be in [0.0, 1.0], got {0}")]
75    InvalidMomentum(f64),
76}
77
78impl JepaConfig {
79    /// Validate all configuration parameters.
80    pub fn validate(&self) -> Result<(), ConfigError> {
81        if self.embed_dim == 0 {
82            return Err(ConfigError::ZeroEmbedDim(self.embed_dim));
83        }
84        if self.predictor_embed_dim == 0 {
85            return Err(ConfigError::ZeroPredictorEmbedDim(self.predictor_embed_dim));
86        }
87        if self.num_encoder_layers == 0 {
88            return Err(ConfigError::ZeroEncoderLayers(self.num_encoder_layers));
89        }
90        if self.num_predictor_layers == 0 {
91            return Err(ConfigError::ZeroPredictorLayers(self.num_predictor_layers));
92        }
93        if self.num_heads == 0 {
94            return Err(ConfigError::ZeroHeads(self.num_heads));
95        }
96        if self.embed_dim % self.num_heads != 0 {
97            return Err(ConfigError::HeadDimMismatch {
98                embed_dim: self.embed_dim,
99                num_heads: self.num_heads,
100            });
101        }
102        if self.patch_size.0 == 0 || self.patch_size.1 == 0 {
103            return Err(ConfigError::ZeroPatchSize(
104                self.patch_size.0,
105                self.patch_size.1,
106            ));
107        }
108        if self.tubelet_size.0 == 0 || self.tubelet_size.1 == 0 || self.tubelet_size.2 == 0 {
109            return Err(ConfigError::ZeroTubeletSize(
110                self.tubelet_size.0,
111                self.tubelet_size.1,
112                self.tubelet_size.2,
113            ));
114        }
115        if !(0.0..=1.0).contains(&self.ema_momentum) {
116            return Err(ConfigError::InvalidMomentum(self.ema_momentum));
117        }
118        Ok(())
119    }
120
121    /// Head dimension: `embed_dim / num_heads`.
122    pub fn head_dim(&self) -> usize {
123        self.embed_dim / self.num_heads
124    }
125}
126
127impl JepaConfig {
128    /// ViT-Base/16 preset: 12 layers, 768-d, 12 heads, patch 16x16.
129    ///
130    /// Standard ViT-B configuration used in many JEPA experiments.
131    pub fn vit_base_16() -> Self {
132        Self {
133            embed_dim: 768,
134            predictor_embed_dim: 384,
135            num_encoder_layers: 12,
136            num_predictor_layers: 6,
137            num_heads: 12,
138            patch_size: (16, 16),
139            tubelet_size: (2, 16, 16),
140            ema_momentum: 0.996,
141        }
142    }
143
144    /// ViT-Large/16 preset: 24 layers, 1024-d, 16 heads, patch 16x16.
145    ///
146    /// Used by V-JEPA ViT-L/16 checkpoints.
147    pub fn vit_large_16() -> Self {
148        Self {
149            embed_dim: 1024,
150            predictor_embed_dim: 512,
151            num_encoder_layers: 24,
152            num_predictor_layers: 12,
153            num_heads: 16,
154            patch_size: (16, 16),
155            tubelet_size: (2, 16, 16),
156            ema_momentum: 0.996,
157        }
158    }
159
160    /// ViT-Huge/14 preset: 32 layers, 1280-d, 16 heads, patch 14x14.
161    ///
162    /// Used by I-JEPA ViT-H/14 checkpoints.
163    pub fn vit_huge_14() -> Self {
164        Self {
165            embed_dim: 1280,
166            predictor_embed_dim: 640,
167            num_encoder_layers: 32,
168            num_predictor_layers: 12,
169            num_heads: 16,
170            patch_size: (14, 14),
171            tubelet_size: (2, 14, 14),
172            ema_momentum: 0.996,
173        }
174    }
175
176    /// ViT-giant/14 preset: 40 layers, 1408-d, 16 heads, patch 14x14.
177    ///
178    /// Used by V-JEPA 2 ViT-g checkpoints.
179    pub fn vit_giant_14() -> Self {
180        Self {
181            embed_dim: 1408,
182            predictor_embed_dim: 704,
183            num_encoder_layers: 40,
184            num_predictor_layers: 12,
185            num_heads: 16,
186            patch_size: (14, 14),
187            tubelet_size: (2, 14, 14),
188            ema_momentum: 0.996,
189        }
190    }
191}
192
193/// Builder for [`JepaConfig`] with chainable setters.
194///
195/// # Example
196///
197/// ```
198/// use jepa_core::config::JepaConfigBuilder;
199///
200/// let config = JepaConfigBuilder::new()
201///     .embed_dim(512)
202///     .num_heads(8)
203///     .num_encoder_layers(12)
204///     .build()
205///     .expect("config should be valid");
206/// assert_eq!(config.embed_dim, 512);
207/// assert_eq!(config.head_dim(), 64);
208/// ```
209#[derive(Debug, Clone)]
210pub struct JepaConfigBuilder {
211    config: JepaConfig,
212}
213
214impl JepaConfigBuilder {
215    /// Create a new builder starting from the default config.
216    pub fn new() -> Self {
217        Self {
218            config: JepaConfig::default(),
219        }
220    }
221
222    /// Create a builder starting from a named preset.
223    pub fn from_preset(config: JepaConfig) -> Self {
224        Self { config }
225    }
226
227    /// Set the encoder embedding dimension.
228    pub fn embed_dim(mut self, dim: usize) -> Self {
229        self.config.embed_dim = dim;
230        self
231    }
232
233    /// Set the predictor embedding dimension.
234    pub fn predictor_embed_dim(mut self, dim: usize) -> Self {
235        self.config.predictor_embed_dim = dim;
236        self
237    }
238
239    /// Set the number of encoder transformer layers.
240    pub fn num_encoder_layers(mut self, n: usize) -> Self {
241        self.config.num_encoder_layers = n;
242        self
243    }
244
245    /// Set the number of predictor transformer layers.
246    pub fn num_predictor_layers(mut self, n: usize) -> Self {
247        self.config.num_predictor_layers = n;
248        self
249    }
250
251    /// Set the number of attention heads.
252    pub fn num_heads(mut self, n: usize) -> Self {
253        self.config.num_heads = n;
254        self
255    }
256
257    /// Set the image patch size `(height, width)`.
258    pub fn patch_size(mut self, h: usize, w: usize) -> Self {
259        self.config.patch_size = (h, w);
260        self
261    }
262
263    /// Set the video tubelet size `(temporal, height, width)`.
264    pub fn tubelet_size(mut self, t: usize, h: usize, w: usize) -> Self {
265        self.config.tubelet_size = (t, h, w);
266        self
267    }
268
269    /// Set the EMA momentum.
270    pub fn ema_momentum(mut self, m: f64) -> Self {
271        self.config.ema_momentum = m;
272        self
273    }
274
275    /// Build and validate the config.
276    ///
277    /// Returns `Err(ConfigError)` if validation fails.
278    pub fn build(self) -> Result<JepaConfig, ConfigError> {
279        self.config.validate()?;
280        Ok(self.config)
281    }
282}
283
284impl Default for JepaConfigBuilder {
285    fn default() -> Self {
286        Self::new()
287    }
288}
289
290impl Default for JepaConfig {
291    fn default() -> Self {
292        Self {
293            embed_dim: 256,
294            predictor_embed_dim: 128,
295            num_encoder_layers: 12,
296            num_predictor_layers: 6,
297            num_heads: 8,
298            patch_size: (16, 16),
299            tubelet_size: (2, 16, 16),
300            ema_momentum: 0.996,
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_default_config_is_valid() {
311        let config = JepaConfig::default();
312        assert!(config.validate().is_ok());
313    }
314
315    #[test]
316    fn test_head_dim() {
317        let config = JepaConfig::default();
318        assert_eq!(config.head_dim(), 32); // 256 / 8
319    }
320
321    #[test]
322    fn test_zero_embed_dim_rejected() {
323        let config = JepaConfig {
324            embed_dim: 0,
325            ..JepaConfig::default()
326        };
327        assert!(matches!(
328            config.validate(),
329            Err(ConfigError::ZeroEmbedDim(0))
330        ));
331    }
332
333    #[test]
334    fn test_head_dim_mismatch_rejected() {
335        let config = JepaConfig {
336            embed_dim: 255,
337            ..JepaConfig::default()
338        };
339        assert!(matches!(
340            config.validate(),
341            Err(ConfigError::HeadDimMismatch { .. })
342        ));
343    }
344
345    #[test]
346    fn test_invalid_momentum_rejected() {
347        let config = JepaConfig {
348            ema_momentum: 1.5,
349            ..JepaConfig::default()
350        };
351        assert!(matches!(
352            config.validate(),
353            Err(ConfigError::InvalidMomentum(_))
354        ));
355    }
356
357    #[test]
358    fn test_negative_momentum_rejected() {
359        let config = JepaConfig {
360            ema_momentum: -0.1,
361            ..JepaConfig::default()
362        };
363        assert!(matches!(
364            config.validate(),
365            Err(ConfigError::InvalidMomentum(_))
366        ));
367    }
368
369    #[test]
370    fn test_zero_patch_size_rejected() {
371        let config = JepaConfig {
372            patch_size: (0, 16),
373            ..JepaConfig::default()
374        };
375        assert!(matches!(
376            config.validate(),
377            Err(ConfigError::ZeroPatchSize(0, 16))
378        ));
379    }
380
381    #[test]
382    fn test_config_serialization_roundtrip() {
383        let config = JepaConfig::default();
384        let json = serde_json::to_string(&config).unwrap();
385        let deserialized: JepaConfig = serde_json::from_str(&json).unwrap();
386        assert_eq!(deserialized.embed_dim, config.embed_dim);
387        assert_eq!(deserialized.num_heads, config.num_heads);
388    }
389
390    // --- Preset tests ---
391
392    #[test]
393    fn test_vit_base_16_is_valid() {
394        let config = JepaConfig::vit_base_16();
395        assert!(config.validate().is_ok());
396        assert_eq!(config.embed_dim, 768);
397        assert_eq!(config.num_heads, 12);
398        assert_eq!(config.head_dim(), 64);
399    }
400
401    #[test]
402    fn test_vit_large_16_is_valid() {
403        let config = JepaConfig::vit_large_16();
404        assert!(config.validate().is_ok());
405        assert_eq!(config.embed_dim, 1024);
406        assert_eq!(config.num_heads, 16);
407        assert_eq!(config.head_dim(), 64);
408    }
409
410    #[test]
411    fn test_vit_huge_14_is_valid() {
412        let config = JepaConfig::vit_huge_14();
413        assert!(config.validate().is_ok());
414        assert_eq!(config.embed_dim, 1280);
415        assert_eq!(config.num_encoder_layers, 32);
416        assert_eq!(config.patch_size, (14, 14));
417    }
418
419    #[test]
420    fn test_vit_giant_14_is_valid() {
421        let config = JepaConfig::vit_giant_14();
422        assert!(config.validate().is_ok());
423        assert_eq!(config.embed_dim, 1408);
424        assert_eq!(config.num_encoder_layers, 40);
425    }
426
427    // --- Builder tests ---
428
429    #[test]
430    fn test_builder_default_is_valid() {
431        let config = JepaConfigBuilder::new().build().unwrap();
432        assert_eq!(config.embed_dim, 256);
433    }
434
435    #[test]
436    fn test_builder_custom_embed_dim() {
437        let config = JepaConfigBuilder::new()
438            .embed_dim(512)
439            .num_heads(8)
440            .build()
441            .unwrap();
442        assert_eq!(config.embed_dim, 512);
443        assert_eq!(config.head_dim(), 64);
444    }
445
446    #[test]
447    fn test_builder_from_preset() {
448        let config = JepaConfigBuilder::from_preset(JepaConfig::vit_huge_14())
449            .ema_momentum(0.999)
450            .build()
451            .unwrap();
452        assert_eq!(config.embed_dim, 1280);
453        assert!((config.ema_momentum - 0.999).abs() < 1e-10);
454    }
455
456    #[test]
457    fn test_builder_validates_on_build() {
458        let result = JepaConfigBuilder::new()
459            .embed_dim(255) // not divisible by 8 heads
460            .build();
461        assert!(result.is_err());
462    }
463
464    #[test]
465    fn test_builder_all_setters() {
466        let config = JepaConfigBuilder::new()
467            .embed_dim(384)
468            .predictor_embed_dim(192)
469            .num_encoder_layers(6)
470            .num_predictor_layers(3)
471            .num_heads(6)
472            .patch_size(8, 8)
473            .tubelet_size(4, 8, 8)
474            .ema_momentum(0.999)
475            .build()
476            .unwrap();
477        assert_eq!(config.embed_dim, 384);
478        assert_eq!(config.predictor_embed_dim, 192);
479        assert_eq!(config.num_encoder_layers, 6);
480        assert_eq!(config.num_predictor_layers, 3);
481        assert_eq!(config.num_heads, 6);
482        assert_eq!(config.patch_size, (8, 8));
483        assert_eq!(config.tubelet_size, (4, 8, 8));
484        assert!((config.ema_momentum - 0.999).abs() < 1e-10);
485    }
486}