Skip to main content

jepa_core/
masking.rs

1//! Masking strategies for JEPA.
2//!
3//! Implements RFC-005 (Masking Strategies).
4//!
5//! Masking determines which input tokens are **context** (visible to the
6//! context encoder) and which are **targets** (predicted by the predictor,
7//! encoded by the target encoder). The masking strategy is arguably the
8//! single most important design decision in JEPA — it determines the
9//! *pretext task* and thus what the model learns.
10//!
11//! Three strategies are provided:
12//!
13//! | Strategy | Domain | Reference |
14//! |----------|--------|-----------|
15//! | [`BlockMasking`] | Images | Assran et al. (2023), I-JEPA |
16//! | [`SpatiotemporalMasking`] | Video | Bardes et al. (2024), V-JEPA |
17//! | [`MultiBlockMasking`] | Images / Video | Bardes et al. (2025), V-JEPA 2 |
18//!
19//! All strategies guarantee disjoint, non-empty context and target sets
20//! (see [`MaskSpec::validate`](crate::types::MaskSpec::validate)).
21
22use std::collections::HashSet;
23
24use rand::{Rng, RngExt as _};
25
26use crate::types::{InputShape, MaskSpec};
27
28/// Build a [`MaskSpec`] from a set of target indices, guaranteeing non-empty
29/// context and target partitions.
30///
31/// If `target_set` covers all tokens, one arbitrary target is moved to context.
32/// If `target_set` is empty, one random token is selected as a target.
33fn finalize_mask(mut target_set: HashSet<usize>, total: usize, rng: &mut impl Rng) -> MaskSpec {
34    // Ensure non-empty context
35    if target_set.len() >= total {
36        if let Some(&first) = target_set.iter().next() {
37            target_set.remove(&first);
38        }
39    }
40    // Ensure non-empty target
41    if target_set.is_empty() {
42        target_set.insert(rng.random_range(0..total));
43    }
44
45    let mut target_indices: Vec<usize> = target_set.into_iter().collect();
46    target_indices.sort_unstable();
47
48    let target_lookup: HashSet<usize> = target_indices.iter().copied().collect();
49    let context_indices: Vec<usize> = (0..total).filter(|i| !target_lookup.contains(i)).collect();
50
51    MaskSpec {
52        context_indices,
53        target_indices,
54        total_tokens: total,
55    }
56}
57
58/// Trait for masking strategies.
59///
60/// A masking strategy generates a [`MaskSpec`] that partitions input tokens
61/// into context (visible) and target (hidden) sets.
62///
63/// # Example
64///
65/// ```
66/// use jepa_core::masking::{MaskingStrategy, BlockMasking};
67/// use jepa_core::types::InputShape;
68/// use rand::SeedableRng;
69/// use rand_chacha::ChaCha8Rng;
70///
71/// let masking = BlockMasking {
72///     num_targets: 4,
73///     target_scale: (0.15, 0.2),
74///     target_aspect_ratio: (0.75, 1.5),
75/// };
76/// let shape = InputShape::Image { height: 14, width: 14 };
77/// let mut rng = ChaCha8Rng::seed_from_u64(42);
78/// let mask = masking.generate_mask(&shape, &mut rng);
79///
80/// assert!(mask.validate().is_ok());
81/// assert_eq!(mask.context_indices.len() + mask.target_indices.len(), 196);
82/// ```
83pub trait MaskingStrategy {
84    /// Generate a mask for a given input shape.
85    ///
86    /// # Arguments
87    /// * `shape` - The shape of the input (image grid or video grid)
88    /// * `rng` - Random number generator for stochastic masking
89    ///
90    /// # Returns
91    /// A [`MaskSpec`] with disjoint context and target index sets
92    fn generate_mask(&self, shape: &InputShape, rng: &mut impl Rng) -> MaskSpec;
93}
94
95/// Block masking for images (I-JEPA style).
96///
97/// Masks one or more contiguous rectangular blocks as targets,
98/// with the remaining patches as context. This forces the model
99/// to predict large semantic regions from partial observations.
100#[derive(Debug, Clone)]
101pub struct BlockMasking {
102    /// Number of target blocks to mask.
103    pub num_targets: usize,
104    /// Target block scale range as fraction of total patches: `(min, max)`.
105    pub target_scale: (f64, f64),
106    /// Target block aspect ratio range: `(min, max)`.
107    pub target_aspect_ratio: (f64, f64),
108}
109
110impl MaskingStrategy for BlockMasking {
111    fn generate_mask(&self, shape: &InputShape, rng: &mut impl Rng) -> MaskSpec {
112        let (height, width) = match shape {
113            InputShape::Image { height, width } => (*height, *width),
114            InputShape::Video {
115                height,
116                width,
117                frames: _,
118            } => (*height, *width),
119        };
120        let total = height * width;
121
122        let mut target_set = HashSet::new();
123
124        for _ in 0..self.num_targets {
125            // Sample scale and aspect ratio
126            let scale = self.target_scale.0
127                + rng.random::<f64>() * (self.target_scale.1 - self.target_scale.0);
128            let aspect = self.target_aspect_ratio.0
129                + rng.random::<f64>() * (self.target_aspect_ratio.1 - self.target_aspect_ratio.0);
130
131            // Compute block dimensions
132            let num_patches = (total as f64 * scale / self.num_targets as f64).round() as usize;
133            let block_h = ((num_patches as f64 * aspect).sqrt()).round() as usize;
134            let block_w = if block_h > 0 {
135                (num_patches / block_h).max(1)
136            } else {
137                1
138            };
139
140            let block_h = block_h.clamp(1, height);
141            let block_w = block_w.clamp(1, width);
142
143            // Random top-left corner
144            let top = rng.random_range(0..=(height - block_h));
145            let left = rng.random_range(0..=(width - block_w));
146
147            for r in top..(top + block_h) {
148                for c in left..(left + block_w) {
149                    target_set.insert(r * width + c);
150                }
151            }
152        }
153
154        finalize_mask(target_set, total, rng)
155    }
156}
157
158/// Spatiotemporal masking for video (V-JEPA style).
159///
160/// Masks contiguous 3D regions in space and time, forcing the model
161/// to predict temporal dynamics and spatial structure jointly.
162#[derive(Debug, Clone)]
163pub struct SpatiotemporalMasking {
164    /// Number of target tubes to mask.
165    pub num_targets: usize,
166    /// Temporal extent range of each tube in frames: `(min, max)`.
167    pub temporal_extent: (usize, usize),
168    /// Spatial scale of each tube as fraction of frame area: `(min, max)`.
169    pub spatial_scale: (f64, f64),
170}
171
172impl MaskingStrategy for SpatiotemporalMasking {
173    fn generate_mask(&self, shape: &InputShape, rng: &mut impl Rng) -> MaskSpec {
174        let (frames, height, width) = match shape {
175            InputShape::Video {
176                frames,
177                height,
178                width,
179            } => (*frames, *height, *width),
180            InputShape::Image { height, width } => (1, *height, *width),
181        };
182        let total = frames * height * width;
183        let frame_area = height * width;
184
185        let mut target_set = HashSet::new();
186
187        for _ in 0..self.num_targets {
188            // Sample temporal extent
189            let t_extent = rng.random_range(self.temporal_extent.0..=self.temporal_extent.1);
190            let t_extent = t_extent.clamp(1, frames);
191
192            // Sample spatial block
193            let scale = self.spatial_scale.0
194                + rng.random::<f64>() * (self.spatial_scale.1 - self.spatial_scale.0);
195            let num_spatial = (frame_area as f64 * scale).round() as usize;
196            let block_side = (num_spatial as f64).sqrt().round() as usize;
197            let block_h = block_side.clamp(1, height);
198            let block_w = block_side.clamp(1, width);
199
200            let t_start = rng.random_range(0..=(frames - t_extent));
201            let top = rng.random_range(0..=(height - block_h));
202            let left = rng.random_range(0..=(width - block_w));
203
204            for t in t_start..(t_start + t_extent) {
205                for r in top..(top + block_h) {
206                    for c in left..(left + block_w) {
207                        target_set.insert(t * frame_area + r * width + c);
208                    }
209                }
210            }
211        }
212
213        finalize_mask(target_set, total, rng)
214    }
215}
216
217/// Multi-block masking (V-JEPA 2 style).
218///
219/// Masks multiple blocks with specific constraints on total coverage ratio.
220#[derive(Debug, Clone)]
221pub struct MultiBlockMasking {
222    /// Target masking ratio (fraction of tokens masked).
223    pub mask_ratio: f64,
224    /// Number of mask blocks.
225    pub num_blocks: usize,
226}
227
228impl MaskingStrategy for MultiBlockMasking {
229    fn generate_mask(&self, shape: &InputShape, rng: &mut impl Rng) -> MaskSpec {
230        let (height, width) = match shape {
231            InputShape::Image { height, width } => (*height, *width),
232            InputShape::Video {
233                height,
234                width,
235                frames: _,
236            } => (*height, *width),
237        };
238        let total = shape.total_tokens();
239        let target_count = ((total as f64) * self.mask_ratio).round() as usize;
240        let per_block = (target_count / self.num_blocks).max(1);
241
242        let mut target_set = HashSet::new();
243
244        for _ in 0..self.num_blocks {
245            let block_side = (per_block as f64).sqrt().round() as usize;
246            let block_h = block_side.clamp(1, height);
247            let block_w = block_side.clamp(1, width);
248
249            let top = rng.random_range(0..=(height - block_h));
250            let left = rng.random_range(0..=(width - block_w));
251
252            for r in top..(top + block_h) {
253                for c in left..(left + block_w) {
254                    target_set.insert(r * width + c);
255                }
256            }
257        }
258
259        finalize_mask(target_set, total, rng)
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use proptest::prelude::*;
267    use rand::SeedableRng;
268    use rand_chacha::ChaCha8Rng;
269
270    fn rng(seed: u64) -> ChaCha8Rng {
271        ChaCha8Rng::seed_from_u64(seed)
272    }
273
274    #[test]
275    fn test_block_masking_partitions_all_patches() {
276        let masking = BlockMasking {
277            num_targets: 4,
278            target_scale: (0.15, 0.2),
279            target_aspect_ratio: (0.75, 1.5),
280        };
281        let shape = InputShape::Image {
282            height: 14,
283            width: 14,
284        };
285        let mask = masking.generate_mask(&shape, &mut rng(42));
286
287        // Context + target should cover all tokens (no overlap ensured by construction)
288        assert!(mask.validate().is_ok());
289        assert_eq!(mask.context_indices.len() + mask.target_indices.len(), 196);
290    }
291
292    #[test]
293    fn test_block_masking_non_empty_partitions() {
294        let masking = BlockMasking {
295            num_targets: 4,
296            target_scale: (0.15, 0.2),
297            target_aspect_ratio: (0.75, 1.5),
298        };
299        let shape = InputShape::Image {
300            height: 14,
301            width: 14,
302        };
303        let mask = masking.generate_mask(&shape, &mut rng(42));
304        assert!(!mask.context_indices.is_empty());
305        assert!(!mask.target_indices.is_empty());
306    }
307
308    #[test]
309    fn test_block_masking_no_overlap() {
310        let masking = BlockMasking {
311            num_targets: 4,
312            target_scale: (0.15, 0.2),
313            target_aspect_ratio: (0.75, 1.5),
314        };
315        let shape = InputShape::Image {
316            height: 14,
317            width: 14,
318        };
319        let mask = masking.generate_mask(&shape, &mut rng(42));
320        let context_set: std::collections::HashSet<_> = mask.context_indices.iter().collect();
321        for t in &mask.target_indices {
322            assert!(!context_set.contains(t), "overlap at index {t}");
323        }
324    }
325
326    #[test]
327    fn test_masking_reproducible_with_same_seed() {
328        let masking = BlockMasking {
329            num_targets: 4,
330            target_scale: (0.15, 0.2),
331            target_aspect_ratio: (0.75, 1.5),
332        };
333        let shape = InputShape::Image {
334            height: 14,
335            width: 14,
336        };
337        let mask1 = masking.generate_mask(&shape, &mut rng(42));
338        let mask2 = masking.generate_mask(&shape, &mut rng(42));
339        assert_eq!(mask1.context_indices, mask2.context_indices);
340        assert_eq!(mask1.target_indices, mask2.target_indices);
341    }
342
343    #[test]
344    fn test_masking_different_with_different_seeds() {
345        let masking = BlockMasking {
346            num_targets: 4,
347            target_scale: (0.15, 0.2),
348            target_aspect_ratio: (0.75, 1.5),
349        };
350        let shape = InputShape::Image {
351            height: 14,
352            width: 14,
353        };
354        let mask1 = masking.generate_mask(&shape, &mut rng(42));
355        let mask2 = masking.generate_mask(&shape, &mut rng(43));
356        assert_ne!(mask1.target_indices, mask2.target_indices);
357    }
358
359    #[test]
360    fn test_spatiotemporal_masking_valid() {
361        let masking = SpatiotemporalMasking {
362            num_targets: 2,
363            temporal_extent: (2, 4),
364            spatial_scale: (0.1, 0.2),
365        };
366        let shape = InputShape::Video {
367            frames: 8,
368            height: 14,
369            width: 14,
370        };
371        let mask = masking.generate_mask(&shape, &mut rng(42));
372        assert!(mask.validate().is_ok());
373        assert!(!mask.context_indices.is_empty());
374        assert!(!mask.target_indices.is_empty());
375    }
376
377    #[test]
378    fn test_multi_block_masking_valid() {
379        let masking = MultiBlockMasking {
380            mask_ratio: 0.5,
381            num_blocks: 4,
382        };
383        let shape = InputShape::Image {
384            height: 14,
385            width: 14,
386        };
387        let mask = masking.generate_mask(&shape, &mut rng(42));
388        assert!(mask.validate().is_ok());
389        assert!(!mask.context_indices.is_empty());
390        assert!(!mask.target_indices.is_empty());
391    }
392
393    // --- Edge-case tests ---
394
395    #[test]
396    fn test_block_masking_minimum_grid_2x2() {
397        // Smallest grid where block masking can produce both context and target
398        let masking = BlockMasking {
399            num_targets: 1,
400            target_scale: (0.25, 0.5),
401            target_aspect_ratio: (1.0, 1.0),
402        };
403        let shape = InputShape::Image {
404            height: 2,
405            width: 2,
406        };
407        let mask = masking.generate_mask(&shape, &mut rng(42));
408        assert!(mask.validate().is_ok());
409        assert!(!mask.context_indices.is_empty());
410        assert!(!mask.target_indices.is_empty());
411        assert_eq!(mask.context_indices.len() + mask.target_indices.len(), 4);
412    }
413
414    #[test]
415    fn test_block_masking_maximum_coverage() {
416        // Many targets with high scale — tests the non-empty context guarantee
417        let masking = BlockMasking {
418            num_targets: 10,
419            target_scale: (0.8, 0.99),
420            target_aspect_ratio: (0.5, 2.0),
421        };
422        let shape = InputShape::Image {
423            height: 4,
424            width: 4,
425        };
426        let mask = masking.generate_mask(&shape, &mut rng(42));
427        assert!(mask.validate().is_ok());
428        assert!(
429            !mask.context_indices.is_empty(),
430            "must always have at least one context token"
431        );
432    }
433
434    #[test]
435    fn test_multi_block_masking_very_high_ratio() {
436        // mask_ratio near 1.0 — tests the non-empty context guarantee
437        let masking = MultiBlockMasking {
438            mask_ratio: 0.99,
439            num_blocks: 8,
440        };
441        let shape = InputShape::Image {
442            height: 4,
443            width: 4,
444        };
445        let mask = masking.generate_mask(&shape, &mut rng(42));
446        assert!(mask.validate().is_ok());
447        assert!(!mask.context_indices.is_empty());
448        assert!(!mask.target_indices.is_empty());
449    }
450
451    #[test]
452    fn test_spatiotemporal_masking_single_frame() {
453        // Video with 1 frame — degenerates to image-like behavior
454        let masking = SpatiotemporalMasking {
455            num_targets: 1,
456            temporal_extent: (1, 1),
457            spatial_scale: (0.1, 0.2),
458        };
459        let shape = InputShape::Video {
460            frames: 1,
461            height: 8,
462            width: 8,
463        };
464        let mask = masking.generate_mask(&shape, &mut rng(42));
465        assert!(mask.validate().is_ok());
466        assert_eq!(mask.context_indices.len() + mask.target_indices.len(), 64);
467    }
468
469    #[test]
470    fn test_spatiotemporal_masking_on_image_shape() {
471        // Image shape passed to spatiotemporal masking (1 frame fallback)
472        let masking = SpatiotemporalMasking {
473            num_targets: 2,
474            temporal_extent: (1, 1),
475            spatial_scale: (0.1, 0.2),
476        };
477        let shape = InputShape::Image {
478            height: 8,
479            width: 8,
480        };
481        let mask = masking.generate_mask(&shape, &mut rng(42));
482        assert!(mask.validate().is_ok());
483        assert_eq!(mask.context_indices.len() + mask.target_indices.len(), 64);
484    }
485
486    // --- Property-based tests ---
487
488    proptest! {
489        #[test]
490        fn prop_block_mask_always_valid(
491            seed in 0u64..100000,
492            grid_h in 4usize..20,
493            grid_w in 4usize..20,
494            num_targets in 1usize..6,
495        ) {
496            let masking = BlockMasking {
497                num_targets,
498                target_scale: (0.1, 0.3),
499                target_aspect_ratio: (0.75, 1.5),
500            };
501            let shape = InputShape::Image { height: grid_h, width: grid_w };
502            let mask = masking.generate_mask(&shape, &mut rng(seed));
503
504            // Mask should always be valid
505            prop_assert!(mask.validate().is_ok());
506
507            // Context + target = total, no overlap
508            let total = grid_h * grid_w;
509            prop_assert_eq!(mask.context_indices.len() + mask.target_indices.len(), total);
510            prop_assert!(!mask.context_indices.is_empty());
511            prop_assert!(!mask.target_indices.is_empty());
512
513            // No duplicates in context
514            let mut ctx = mask.context_indices.clone();
515            ctx.sort_unstable();
516            ctx.dedup();
517            prop_assert_eq!(ctx.len(), mask.context_indices.len());
518
519            // No duplicates in target
520            let mut tgt = mask.target_indices.clone();
521            tgt.sort_unstable();
522            tgt.dedup();
523            prop_assert_eq!(tgt.len(), mask.target_indices.len());
524
525            // All indices in bounds
526            for &i in &mask.context_indices {
527                prop_assert!(i < total);
528            }
529            for &i in &mask.target_indices {
530                prop_assert!(i < total);
531            }
532        }
533
534        #[test]
535        fn prop_spatiotemporal_mask_always_valid(
536            seed in 0u64..100000,
537            frames in 4usize..12,
538            grid_h in 4usize..12,
539            grid_w in 4usize..12,
540        ) {
541            let masking = SpatiotemporalMasking {
542                num_targets: 2,
543                temporal_extent: (2, 3),
544                spatial_scale: (0.05, 0.15),
545            };
546            let shape = InputShape::Video { frames, height: grid_h, width: grid_w };
547            let mask = masking.generate_mask(&shape, &mut rng(seed));
548
549            prop_assert!(mask.validate().is_ok());
550
551            let total = frames * grid_h * grid_w;
552            prop_assert_eq!(mask.context_indices.len() + mask.target_indices.len(), total);
553            prop_assert!(!mask.context_indices.is_empty());
554            prop_assert!(!mask.target_indices.is_empty());
555        }
556
557        #[test]
558        fn prop_multi_block_mask_always_valid(
559            seed in 0u64..100000,
560            grid_h in 4usize..16,
561            grid_w in 4usize..16,
562            mask_ratio in 0.1f64..0.8,
563            num_blocks in 1usize..6,
564        ) {
565            let masking = MultiBlockMasking { mask_ratio, num_blocks };
566            let shape = InputShape::Image { height: grid_h, width: grid_w };
567            let mask = masking.generate_mask(&shape, &mut rng(seed));
568
569            prop_assert!(mask.validate().is_ok());
570            prop_assert!(!mask.context_indices.is_empty());
571            prop_assert!(!mask.target_indices.is_empty());
572        }
573
574        #[test]
575        fn prop_masking_is_deterministic(seed in 0u64..100000) {
576            let masking = BlockMasking {
577                num_targets: 4,
578                target_scale: (0.15, 0.2),
579                target_aspect_ratio: (0.75, 1.5),
580            };
581            let shape = InputShape::Image { height: 14, width: 14 };
582            let mask1 = masking.generate_mask(&shape, &mut rng(seed));
583            let mask2 = masking.generate_mask(&shape, &mut rng(seed));
584            prop_assert_eq!(mask1.context_indices, mask2.context_indices);
585            prop_assert_eq!(mask1.target_indices, mask2.target_indices);
586        }
587    }
588}