1use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct JepaConfig {
37 pub embed_dim: usize,
39 pub predictor_embed_dim: usize,
41 pub num_encoder_layers: usize,
43 pub num_predictor_layers: usize,
45 pub num_heads: usize,
47 pub patch_size: (usize, usize),
49 pub tubelet_size: (usize, usize, usize),
51 pub ema_momentum: f64,
53}
54
55#[derive(Debug, thiserror::Error)]
57pub enum ConfigError {
58 #[error("embed_dim must be positive, got {0}")]
59 ZeroEmbedDim(usize),
60 #[error("predictor_embed_dim must be positive, got {0}")]
61 ZeroPredictorEmbedDim(usize),
62 #[error("num_encoder_layers must be positive, got {0}")]
63 ZeroEncoderLayers(usize),
64 #[error("num_predictor_layers must be positive, got {0}")]
65 ZeroPredictorLayers(usize),
66 #[error("num_heads must be positive, got {0}")]
67 ZeroHeads(usize),
68 #[error("embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")]
69 HeadDimMismatch { embed_dim: usize, num_heads: usize },
70 #[error("patch_size dimensions must be positive, got ({0}, {1})")]
71 ZeroPatchSize(usize, usize),
72 #[error("tubelet_size dimensions must be positive, got ({0}, {1}, {2})")]
73 ZeroTubeletSize(usize, usize, usize),
74 #[error("ema_momentum must be in [0.0, 1.0], got {0}")]
75 InvalidMomentum(f64),
76}
77
78impl JepaConfig {
79 pub fn validate(&self) -> Result<(), ConfigError> {
81 if self.embed_dim == 0 {
82 return Err(ConfigError::ZeroEmbedDim(self.embed_dim));
83 }
84 if self.predictor_embed_dim == 0 {
85 return Err(ConfigError::ZeroPredictorEmbedDim(self.predictor_embed_dim));
86 }
87 if self.num_encoder_layers == 0 {
88 return Err(ConfigError::ZeroEncoderLayers(self.num_encoder_layers));
89 }
90 if self.num_predictor_layers == 0 {
91 return Err(ConfigError::ZeroPredictorLayers(self.num_predictor_layers));
92 }
93 if self.num_heads == 0 {
94 return Err(ConfigError::ZeroHeads(self.num_heads));
95 }
96 if self.embed_dim % self.num_heads != 0 {
97 return Err(ConfigError::HeadDimMismatch {
98 embed_dim: self.embed_dim,
99 num_heads: self.num_heads,
100 });
101 }
102 if self.patch_size.0 == 0 || self.patch_size.1 == 0 {
103 return Err(ConfigError::ZeroPatchSize(
104 self.patch_size.0,
105 self.patch_size.1,
106 ));
107 }
108 if self.tubelet_size.0 == 0 || self.tubelet_size.1 == 0 || self.tubelet_size.2 == 0 {
109 return Err(ConfigError::ZeroTubeletSize(
110 self.tubelet_size.0,
111 self.tubelet_size.1,
112 self.tubelet_size.2,
113 ));
114 }
115 if !(0.0..=1.0).contains(&self.ema_momentum) {
116 return Err(ConfigError::InvalidMomentum(self.ema_momentum));
117 }
118 Ok(())
119 }
120
121 pub fn head_dim(&self) -> usize {
123 self.embed_dim / self.num_heads
124 }
125}
126
127impl JepaConfig {
128 pub fn vit_base_16() -> Self {
132 Self {
133 embed_dim: 768,
134 predictor_embed_dim: 384,
135 num_encoder_layers: 12,
136 num_predictor_layers: 6,
137 num_heads: 12,
138 patch_size: (16, 16),
139 tubelet_size: (2, 16, 16),
140 ema_momentum: 0.996,
141 }
142 }
143
144 pub fn vit_large_16() -> Self {
148 Self {
149 embed_dim: 1024,
150 predictor_embed_dim: 512,
151 num_encoder_layers: 24,
152 num_predictor_layers: 12,
153 num_heads: 16,
154 patch_size: (16, 16),
155 tubelet_size: (2, 16, 16),
156 ema_momentum: 0.996,
157 }
158 }
159
160 pub fn vit_huge_14() -> Self {
164 Self {
165 embed_dim: 1280,
166 predictor_embed_dim: 640,
167 num_encoder_layers: 32,
168 num_predictor_layers: 12,
169 num_heads: 16,
170 patch_size: (14, 14),
171 tubelet_size: (2, 14, 14),
172 ema_momentum: 0.996,
173 }
174 }
175
176 pub fn vit_giant_14() -> Self {
180 Self {
181 embed_dim: 1408,
182 predictor_embed_dim: 704,
183 num_encoder_layers: 40,
184 num_predictor_layers: 12,
185 num_heads: 16,
186 patch_size: (14, 14),
187 tubelet_size: (2, 14, 14),
188 ema_momentum: 0.996,
189 }
190 }
191}
192
193#[derive(Debug, Clone)]
210pub struct JepaConfigBuilder {
211 config: JepaConfig,
212}
213
214impl JepaConfigBuilder {
215 pub fn new() -> Self {
217 Self {
218 config: JepaConfig::default(),
219 }
220 }
221
222 pub fn from_preset(config: JepaConfig) -> Self {
224 Self { config }
225 }
226
227 pub fn embed_dim(mut self, dim: usize) -> Self {
229 self.config.embed_dim = dim;
230 self
231 }
232
233 pub fn predictor_embed_dim(mut self, dim: usize) -> Self {
235 self.config.predictor_embed_dim = dim;
236 self
237 }
238
239 pub fn num_encoder_layers(mut self, n: usize) -> Self {
241 self.config.num_encoder_layers = n;
242 self
243 }
244
245 pub fn num_predictor_layers(mut self, n: usize) -> Self {
247 self.config.num_predictor_layers = n;
248 self
249 }
250
251 pub fn num_heads(mut self, n: usize) -> Self {
253 self.config.num_heads = n;
254 self
255 }
256
257 pub fn patch_size(mut self, h: usize, w: usize) -> Self {
259 self.config.patch_size = (h, w);
260 self
261 }
262
263 pub fn tubelet_size(mut self, t: usize, h: usize, w: usize) -> Self {
265 self.config.tubelet_size = (t, h, w);
266 self
267 }
268
269 pub fn ema_momentum(mut self, m: f64) -> Self {
271 self.config.ema_momentum = m;
272 self
273 }
274
275 pub fn build(self) -> Result<JepaConfig, ConfigError> {
279 self.config.validate()?;
280 Ok(self.config)
281 }
282}
283
284impl Default for JepaConfigBuilder {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290impl Default for JepaConfig {
291 fn default() -> Self {
292 Self {
293 embed_dim: 256,
294 predictor_embed_dim: 128,
295 num_encoder_layers: 12,
296 num_predictor_layers: 6,
297 num_heads: 8,
298 patch_size: (16, 16),
299 tubelet_size: (2, 16, 16),
300 ema_momentum: 0.996,
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_default_config_is_valid() {
311 let config = JepaConfig::default();
312 assert!(config.validate().is_ok());
313 }
314
315 #[test]
316 fn test_head_dim() {
317 let config = JepaConfig::default();
318 assert_eq!(config.head_dim(), 32); }
320
321 #[test]
322 fn test_zero_embed_dim_rejected() {
323 let config = JepaConfig {
324 embed_dim: 0,
325 ..JepaConfig::default()
326 };
327 assert!(matches!(
328 config.validate(),
329 Err(ConfigError::ZeroEmbedDim(0))
330 ));
331 }
332
333 #[test]
334 fn test_head_dim_mismatch_rejected() {
335 let config = JepaConfig {
336 embed_dim: 255,
337 ..JepaConfig::default()
338 };
339 assert!(matches!(
340 config.validate(),
341 Err(ConfigError::HeadDimMismatch { .. })
342 ));
343 }
344
345 #[test]
346 fn test_invalid_momentum_rejected() {
347 let config = JepaConfig {
348 ema_momentum: 1.5,
349 ..JepaConfig::default()
350 };
351 assert!(matches!(
352 config.validate(),
353 Err(ConfigError::InvalidMomentum(_))
354 ));
355 }
356
357 #[test]
358 fn test_negative_momentum_rejected() {
359 let config = JepaConfig {
360 ema_momentum: -0.1,
361 ..JepaConfig::default()
362 };
363 assert!(matches!(
364 config.validate(),
365 Err(ConfigError::InvalidMomentum(_))
366 ));
367 }
368
369 #[test]
370 fn test_zero_patch_size_rejected() {
371 let config = JepaConfig {
372 patch_size: (0, 16),
373 ..JepaConfig::default()
374 };
375 assert!(matches!(
376 config.validate(),
377 Err(ConfigError::ZeroPatchSize(0, 16))
378 ));
379 }
380
381 #[test]
382 fn test_config_serialization_roundtrip() {
383 let config = JepaConfig::default();
384 let json = serde_json::to_string(&config).unwrap();
385 let deserialized: JepaConfig = serde_json::from_str(&json).unwrap();
386 assert_eq!(deserialized.embed_dim, config.embed_dim);
387 assert_eq!(deserialized.num_heads, config.num_heads);
388 }
389
390 #[test]
393 fn test_vit_base_16_is_valid() {
394 let config = JepaConfig::vit_base_16();
395 assert!(config.validate().is_ok());
396 assert_eq!(config.embed_dim, 768);
397 assert_eq!(config.num_heads, 12);
398 assert_eq!(config.head_dim(), 64);
399 }
400
401 #[test]
402 fn test_vit_large_16_is_valid() {
403 let config = JepaConfig::vit_large_16();
404 assert!(config.validate().is_ok());
405 assert_eq!(config.embed_dim, 1024);
406 assert_eq!(config.num_heads, 16);
407 assert_eq!(config.head_dim(), 64);
408 }
409
410 #[test]
411 fn test_vit_huge_14_is_valid() {
412 let config = JepaConfig::vit_huge_14();
413 assert!(config.validate().is_ok());
414 assert_eq!(config.embed_dim, 1280);
415 assert_eq!(config.num_encoder_layers, 32);
416 assert_eq!(config.patch_size, (14, 14));
417 }
418
419 #[test]
420 fn test_vit_giant_14_is_valid() {
421 let config = JepaConfig::vit_giant_14();
422 assert!(config.validate().is_ok());
423 assert_eq!(config.embed_dim, 1408);
424 assert_eq!(config.num_encoder_layers, 40);
425 }
426
427 #[test]
430 fn test_builder_default_is_valid() {
431 let config = JepaConfigBuilder::new().build().unwrap();
432 assert_eq!(config.embed_dim, 256);
433 }
434
435 #[test]
436 fn test_builder_custom_embed_dim() {
437 let config = JepaConfigBuilder::new()
438 .embed_dim(512)
439 .num_heads(8)
440 .build()
441 .unwrap();
442 assert_eq!(config.embed_dim, 512);
443 assert_eq!(config.head_dim(), 64);
444 }
445
446 #[test]
447 fn test_builder_from_preset() {
448 let config = JepaConfigBuilder::from_preset(JepaConfig::vit_huge_14())
449 .ema_momentum(0.999)
450 .build()
451 .unwrap();
452 assert_eq!(config.embed_dim, 1280);
453 assert!((config.ema_momentum - 0.999).abs() < 1e-10);
454 }
455
456 #[test]
457 fn test_builder_validates_on_build() {
458 let result = JepaConfigBuilder::new()
459 .embed_dim(255) .build();
461 assert!(result.is_err());
462 }
463
464 #[test]
465 fn test_builder_all_setters() {
466 let config = JepaConfigBuilder::new()
467 .embed_dim(384)
468 .predictor_embed_dim(192)
469 .num_encoder_layers(6)
470 .num_predictor_layers(3)
471 .num_heads(6)
472 .patch_size(8, 8)
473 .tubelet_size(4, 8, 8)
474 .ema_momentum(0.999)
475 .build()
476 .unwrap();
477 assert_eq!(config.embed_dim, 384);
478 assert_eq!(config.predictor_embed_dim, 192);
479 assert_eq!(config.num_encoder_layers, 6);
480 assert_eq!(config.num_predictor_layers, 3);
481 assert_eq!(config.num_heads, 6);
482 assert_eq!(config.patch_size, (8, 8));
483 assert_eq!(config.tubelet_size, (4, 8, 8));
484 assert!((config.ema_momentum - 0.999).abs() < 1e-10);
485 }
486}