1use std::collections::HashSet;
23
24use rand::{Rng, RngExt as _};
25
26use crate::types::{InputShape, MaskSpec};
27
28fn finalize_mask(mut target_set: HashSet<usize>, total: usize, rng: &mut impl Rng) -> MaskSpec {
34 if target_set.len() >= total {
36 if let Some(&first) = target_set.iter().next() {
37 target_set.remove(&first);
38 }
39 }
40 if target_set.is_empty() {
42 target_set.insert(rng.random_range(0..total));
43 }
44
45 let mut target_indices: Vec<usize> = target_set.into_iter().collect();
46 target_indices.sort_unstable();
47
48 let target_lookup: HashSet<usize> = target_indices.iter().copied().collect();
49 let context_indices: Vec<usize> = (0..total).filter(|i| !target_lookup.contains(i)).collect();
50
51 MaskSpec {
52 context_indices,
53 target_indices,
54 total_tokens: total,
55 }
56}
57
58pub trait MaskingStrategy {
84 fn generate_mask(&self, shape: &InputShape, rng: &mut impl Rng) -> MaskSpec;
93}
94
95#[derive(Debug, Clone)]
101pub struct BlockMasking {
102 pub num_targets: usize,
104 pub target_scale: (f64, f64),
106 pub target_aspect_ratio: (f64, f64),
108}
109
110impl MaskingStrategy for BlockMasking {
111 fn generate_mask(&self, shape: &InputShape, rng: &mut impl Rng) -> MaskSpec {
112 let (height, width) = match shape {
113 InputShape::Image { height, width } => (*height, *width),
114 InputShape::Video {
115 height,
116 width,
117 frames: _,
118 } => (*height, *width),
119 };
120 let total = height * width;
121
122 let mut target_set = HashSet::new();
123
124 for _ in 0..self.num_targets {
125 let scale = self.target_scale.0
127 + rng.random::<f64>() * (self.target_scale.1 - self.target_scale.0);
128 let aspect = self.target_aspect_ratio.0
129 + rng.random::<f64>() * (self.target_aspect_ratio.1 - self.target_aspect_ratio.0);
130
131 let num_patches = (total as f64 * scale / self.num_targets as f64).round() as usize;
133 let block_h = ((num_patches as f64 * aspect).sqrt()).round() as usize;
134 let block_w = if block_h > 0 {
135 (num_patches / block_h).max(1)
136 } else {
137 1
138 };
139
140 let block_h = block_h.clamp(1, height);
141 let block_w = block_w.clamp(1, width);
142
143 let top = rng.random_range(0..=(height - block_h));
145 let left = rng.random_range(0..=(width - block_w));
146
147 for r in top..(top + block_h) {
148 for c in left..(left + block_w) {
149 target_set.insert(r * width + c);
150 }
151 }
152 }
153
154 finalize_mask(target_set, total, rng)
155 }
156}
157
158#[derive(Debug, Clone)]
163pub struct SpatiotemporalMasking {
164 pub num_targets: usize,
166 pub temporal_extent: (usize, usize),
168 pub spatial_scale: (f64, f64),
170}
171
172impl MaskingStrategy for SpatiotemporalMasking {
173 fn generate_mask(&self, shape: &InputShape, rng: &mut impl Rng) -> MaskSpec {
174 let (frames, height, width) = match shape {
175 InputShape::Video {
176 frames,
177 height,
178 width,
179 } => (*frames, *height, *width),
180 InputShape::Image { height, width } => (1, *height, *width),
181 };
182 let total = frames * height * width;
183 let frame_area = height * width;
184
185 let mut target_set = HashSet::new();
186
187 for _ in 0..self.num_targets {
188 let t_extent = rng.random_range(self.temporal_extent.0..=self.temporal_extent.1);
190 let t_extent = t_extent.clamp(1, frames);
191
192 let scale = self.spatial_scale.0
194 + rng.random::<f64>() * (self.spatial_scale.1 - self.spatial_scale.0);
195 let num_spatial = (frame_area as f64 * scale).round() as usize;
196 let block_side = (num_spatial as f64).sqrt().round() as usize;
197 let block_h = block_side.clamp(1, height);
198 let block_w = block_side.clamp(1, width);
199
200 let t_start = rng.random_range(0..=(frames - t_extent));
201 let top = rng.random_range(0..=(height - block_h));
202 let left = rng.random_range(0..=(width - block_w));
203
204 for t in t_start..(t_start + t_extent) {
205 for r in top..(top + block_h) {
206 for c in left..(left + block_w) {
207 target_set.insert(t * frame_area + r * width + c);
208 }
209 }
210 }
211 }
212
213 finalize_mask(target_set, total, rng)
214 }
215}
216
217#[derive(Debug, Clone)]
221pub struct MultiBlockMasking {
222 pub mask_ratio: f64,
224 pub num_blocks: usize,
226}
227
228impl MaskingStrategy for MultiBlockMasking {
229 fn generate_mask(&self, shape: &InputShape, rng: &mut impl Rng) -> MaskSpec {
230 let (height, width) = match shape {
231 InputShape::Image { height, width } => (*height, *width),
232 InputShape::Video {
233 height,
234 width,
235 frames: _,
236 } => (*height, *width),
237 };
238 let total = shape.total_tokens();
239 let target_count = ((total as f64) * self.mask_ratio).round() as usize;
240 let per_block = (target_count / self.num_blocks).max(1);
241
242 let mut target_set = HashSet::new();
243
244 for _ in 0..self.num_blocks {
245 let block_side = (per_block as f64).sqrt().round() as usize;
246 let block_h = block_side.clamp(1, height);
247 let block_w = block_side.clamp(1, width);
248
249 let top = rng.random_range(0..=(height - block_h));
250 let left = rng.random_range(0..=(width - block_w));
251
252 for r in top..(top + block_h) {
253 for c in left..(left + block_w) {
254 target_set.insert(r * width + c);
255 }
256 }
257 }
258
259 finalize_mask(target_set, total, rng)
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use proptest::prelude::*;
267 use rand::SeedableRng;
268 use rand_chacha::ChaCha8Rng;
269
270 fn rng(seed: u64) -> ChaCha8Rng {
271 ChaCha8Rng::seed_from_u64(seed)
272 }
273
274 #[test]
275 fn test_block_masking_partitions_all_patches() {
276 let masking = BlockMasking {
277 num_targets: 4,
278 target_scale: (0.15, 0.2),
279 target_aspect_ratio: (0.75, 1.5),
280 };
281 let shape = InputShape::Image {
282 height: 14,
283 width: 14,
284 };
285 let mask = masking.generate_mask(&shape, &mut rng(42));
286
287 assert!(mask.validate().is_ok());
289 assert_eq!(mask.context_indices.len() + mask.target_indices.len(), 196);
290 }
291
292 #[test]
293 fn test_block_masking_non_empty_partitions() {
294 let masking = BlockMasking {
295 num_targets: 4,
296 target_scale: (0.15, 0.2),
297 target_aspect_ratio: (0.75, 1.5),
298 };
299 let shape = InputShape::Image {
300 height: 14,
301 width: 14,
302 };
303 let mask = masking.generate_mask(&shape, &mut rng(42));
304 assert!(!mask.context_indices.is_empty());
305 assert!(!mask.target_indices.is_empty());
306 }
307
308 #[test]
309 fn test_block_masking_no_overlap() {
310 let masking = BlockMasking {
311 num_targets: 4,
312 target_scale: (0.15, 0.2),
313 target_aspect_ratio: (0.75, 1.5),
314 };
315 let shape = InputShape::Image {
316 height: 14,
317 width: 14,
318 };
319 let mask = masking.generate_mask(&shape, &mut rng(42));
320 let context_set: std::collections::HashSet<_> = mask.context_indices.iter().collect();
321 for t in &mask.target_indices {
322 assert!(!context_set.contains(t), "overlap at index {t}");
323 }
324 }
325
326 #[test]
327 fn test_masking_reproducible_with_same_seed() {
328 let masking = BlockMasking {
329 num_targets: 4,
330 target_scale: (0.15, 0.2),
331 target_aspect_ratio: (0.75, 1.5),
332 };
333 let shape = InputShape::Image {
334 height: 14,
335 width: 14,
336 };
337 let mask1 = masking.generate_mask(&shape, &mut rng(42));
338 let mask2 = masking.generate_mask(&shape, &mut rng(42));
339 assert_eq!(mask1.context_indices, mask2.context_indices);
340 assert_eq!(mask1.target_indices, mask2.target_indices);
341 }
342
343 #[test]
344 fn test_masking_different_with_different_seeds() {
345 let masking = BlockMasking {
346 num_targets: 4,
347 target_scale: (0.15, 0.2),
348 target_aspect_ratio: (0.75, 1.5),
349 };
350 let shape = InputShape::Image {
351 height: 14,
352 width: 14,
353 };
354 let mask1 = masking.generate_mask(&shape, &mut rng(42));
355 let mask2 = masking.generate_mask(&shape, &mut rng(43));
356 assert_ne!(mask1.target_indices, mask2.target_indices);
357 }
358
359 #[test]
360 fn test_spatiotemporal_masking_valid() {
361 let masking = SpatiotemporalMasking {
362 num_targets: 2,
363 temporal_extent: (2, 4),
364 spatial_scale: (0.1, 0.2),
365 };
366 let shape = InputShape::Video {
367 frames: 8,
368 height: 14,
369 width: 14,
370 };
371 let mask = masking.generate_mask(&shape, &mut rng(42));
372 assert!(mask.validate().is_ok());
373 assert!(!mask.context_indices.is_empty());
374 assert!(!mask.target_indices.is_empty());
375 }
376
377 #[test]
378 fn test_multi_block_masking_valid() {
379 let masking = MultiBlockMasking {
380 mask_ratio: 0.5,
381 num_blocks: 4,
382 };
383 let shape = InputShape::Image {
384 height: 14,
385 width: 14,
386 };
387 let mask = masking.generate_mask(&shape, &mut rng(42));
388 assert!(mask.validate().is_ok());
389 assert!(!mask.context_indices.is_empty());
390 assert!(!mask.target_indices.is_empty());
391 }
392
393 #[test]
396 fn test_block_masking_minimum_grid_2x2() {
397 let masking = BlockMasking {
399 num_targets: 1,
400 target_scale: (0.25, 0.5),
401 target_aspect_ratio: (1.0, 1.0),
402 };
403 let shape = InputShape::Image {
404 height: 2,
405 width: 2,
406 };
407 let mask = masking.generate_mask(&shape, &mut rng(42));
408 assert!(mask.validate().is_ok());
409 assert!(!mask.context_indices.is_empty());
410 assert!(!mask.target_indices.is_empty());
411 assert_eq!(mask.context_indices.len() + mask.target_indices.len(), 4);
412 }
413
414 #[test]
415 fn test_block_masking_maximum_coverage() {
416 let masking = BlockMasking {
418 num_targets: 10,
419 target_scale: (0.8, 0.99),
420 target_aspect_ratio: (0.5, 2.0),
421 };
422 let shape = InputShape::Image {
423 height: 4,
424 width: 4,
425 };
426 let mask = masking.generate_mask(&shape, &mut rng(42));
427 assert!(mask.validate().is_ok());
428 assert!(
429 !mask.context_indices.is_empty(),
430 "must always have at least one context token"
431 );
432 }
433
434 #[test]
435 fn test_multi_block_masking_very_high_ratio() {
436 let masking = MultiBlockMasking {
438 mask_ratio: 0.99,
439 num_blocks: 8,
440 };
441 let shape = InputShape::Image {
442 height: 4,
443 width: 4,
444 };
445 let mask = masking.generate_mask(&shape, &mut rng(42));
446 assert!(mask.validate().is_ok());
447 assert!(!mask.context_indices.is_empty());
448 assert!(!mask.target_indices.is_empty());
449 }
450
451 #[test]
452 fn test_spatiotemporal_masking_single_frame() {
453 let masking = SpatiotemporalMasking {
455 num_targets: 1,
456 temporal_extent: (1, 1),
457 spatial_scale: (0.1, 0.2),
458 };
459 let shape = InputShape::Video {
460 frames: 1,
461 height: 8,
462 width: 8,
463 };
464 let mask = masking.generate_mask(&shape, &mut rng(42));
465 assert!(mask.validate().is_ok());
466 assert_eq!(mask.context_indices.len() + mask.target_indices.len(), 64);
467 }
468
469 #[test]
470 fn test_spatiotemporal_masking_on_image_shape() {
471 let masking = SpatiotemporalMasking {
473 num_targets: 2,
474 temporal_extent: (1, 1),
475 spatial_scale: (0.1, 0.2),
476 };
477 let shape = InputShape::Image {
478 height: 8,
479 width: 8,
480 };
481 let mask = masking.generate_mask(&shape, &mut rng(42));
482 assert!(mask.validate().is_ok());
483 assert_eq!(mask.context_indices.len() + mask.target_indices.len(), 64);
484 }
485
486 proptest! {
489 #[test]
490 fn prop_block_mask_always_valid(
491 seed in 0u64..100000,
492 grid_h in 4usize..20,
493 grid_w in 4usize..20,
494 num_targets in 1usize..6,
495 ) {
496 let masking = BlockMasking {
497 num_targets,
498 target_scale: (0.1, 0.3),
499 target_aspect_ratio: (0.75, 1.5),
500 };
501 let shape = InputShape::Image { height: grid_h, width: grid_w };
502 let mask = masking.generate_mask(&shape, &mut rng(seed));
503
504 prop_assert!(mask.validate().is_ok());
506
507 let total = grid_h * grid_w;
509 prop_assert_eq!(mask.context_indices.len() + mask.target_indices.len(), total);
510 prop_assert!(!mask.context_indices.is_empty());
511 prop_assert!(!mask.target_indices.is_empty());
512
513 let mut ctx = mask.context_indices.clone();
515 ctx.sort_unstable();
516 ctx.dedup();
517 prop_assert_eq!(ctx.len(), mask.context_indices.len());
518
519 let mut tgt = mask.target_indices.clone();
521 tgt.sort_unstable();
522 tgt.dedup();
523 prop_assert_eq!(tgt.len(), mask.target_indices.len());
524
525 for &i in &mask.context_indices {
527 prop_assert!(i < total);
528 }
529 for &i in &mask.target_indices {
530 prop_assert!(i < total);
531 }
532 }
533
534 #[test]
535 fn prop_spatiotemporal_mask_always_valid(
536 seed in 0u64..100000,
537 frames in 4usize..12,
538 grid_h in 4usize..12,
539 grid_w in 4usize..12,
540 ) {
541 let masking = SpatiotemporalMasking {
542 num_targets: 2,
543 temporal_extent: (2, 3),
544 spatial_scale: (0.05, 0.15),
545 };
546 let shape = InputShape::Video { frames, height: grid_h, width: grid_w };
547 let mask = masking.generate_mask(&shape, &mut rng(seed));
548
549 prop_assert!(mask.validate().is_ok());
550
551 let total = frames * grid_h * grid_w;
552 prop_assert_eq!(mask.context_indices.len() + mask.target_indices.len(), total);
553 prop_assert!(!mask.context_indices.is_empty());
554 prop_assert!(!mask.target_indices.is_empty());
555 }
556
557 #[test]
558 fn prop_multi_block_mask_always_valid(
559 seed in 0u64..100000,
560 grid_h in 4usize..16,
561 grid_w in 4usize..16,
562 mask_ratio in 0.1f64..0.8,
563 num_blocks in 1usize..6,
564 ) {
565 let masking = MultiBlockMasking { mask_ratio, num_blocks };
566 let shape = InputShape::Image { height: grid_h, width: grid_w };
567 let mask = masking.generate_mask(&shape, &mut rng(seed));
568
569 prop_assert!(mask.validate().is_ok());
570 prop_assert!(!mask.context_indices.is_empty());
571 prop_assert!(!mask.target_indices.is_empty());
572 }
573
574 #[test]
575 fn prop_masking_is_deterministic(seed in 0u64..100000) {
576 let masking = BlockMasking {
577 num_targets: 4,
578 target_scale: (0.15, 0.2),
579 target_aspect_ratio: (0.75, 1.5),
580 };
581 let shape = InputShape::Image { height: 14, width: 14 };
582 let mask1 = masking.generate_mask(&shape, &mut rng(seed));
583 let mask2 = masking.generate_mask(&shape, &mut rng(seed));
584 prop_assert_eq!(mask1.context_indices, mask2.context_indices);
585 prop_assert_eq!(mask1.target_indices, mask2.target_indices);
586 }
587 }
588}