Skip to main content

oxicuda_ssl/augment/
multi_crop.rs

1//! Multi-crop augmentation strategy used by SwAV and DINO.
2//!
3//! Each image generates 2 *global* crops (e.g. 224×224) and `n_local` *local*
4//! crops (e.g. 96×96) at different aspect ratios. Local crops are passed only
5//! to the student network; global crops are passed to both student and teacher.
6//! This module computes the deterministic crop sizes only — pixel-level
7//! cropping is the caller's responsibility.
8
9use crate::error::{SslError, SslResult};
10
11/// Multi-crop configuration.
12#[derive(Debug, Clone)]
13pub struct MultiCropConfig {
14    /// Target spatial size of each global crop (e.g. 224).
15    pub global_crop_size: usize,
16    /// Target spatial size of each local crop (e.g. 96).
17    pub local_crop_size: usize,
18    /// Number of additional local crops per image (default 6).
19    pub n_local_crops: usize,
20}
21
22impl Default for MultiCropConfig {
23    fn default() -> Self {
24        Self {
25            global_crop_size: 224,
26            local_crop_size: 96,
27            n_local_crops: 6,
28        }
29    }
30}
31
32impl MultiCropConfig {
33    /// Validated config.
34    ///
35    /// # Errors
36    /// - [`SslError::InvalidNumCrops`] if `n_local_crops < 1`.
37    /// - [`SslError::EmptyInput`] if either crop size is zero.
38    pub fn new(
39        global_crop_size: usize,
40        local_crop_size: usize,
41        n_local_crops: usize,
42    ) -> SslResult<Self> {
43        if n_local_crops == 0 {
44            return Err(SslError::InvalidNumCrops);
45        }
46        if global_crop_size == 0 || local_crop_size == 0 {
47            return Err(SslError::EmptyInput);
48        }
49        Ok(Self {
50            global_crop_size,
51            local_crop_size,
52            n_local_crops,
53        })
54    }
55
56    /// Total number of crops produced per image (`2 globals + n_local`).
57    #[must_use]
58    pub fn n_crops(&self) -> usize {
59        2 + self.n_local_crops
60    }
61}
62
63/// Per-crop spec returned by [`multi_crop`].
64#[derive(Debug, Clone, Copy, PartialEq)]
65pub struct CropSpec {
66    /// Crop size (assumed square).
67    pub size: usize,
68    /// True if this is a global crop, false if local.
69    pub is_global: bool,
70}
71
72/// Generate the deterministic list of crop specs given a configuration.
73///
74/// The pixel-level cropping (random window selection within the original image)
75/// is the caller's responsibility — this helper only enumerates the desired
76/// `(size, is_global)` tuples.
77///
78/// # Errors
79/// Propagates errors from [`MultiCropConfig`].
80pub fn multi_crop(cfg: &MultiCropConfig) -> SslResult<Vec<CropSpec>> {
81    let n = cfg.n_crops();
82    let mut crops = Vec::with_capacity(n);
83    crops.push(CropSpec {
84        size: cfg.global_crop_size,
85        is_global: true,
86    });
87    crops.push(CropSpec {
88        size: cfg.global_crop_size,
89        is_global: true,
90    });
91    for _ in 0..cfg.n_local_crops {
92        crops.push(CropSpec {
93            size: cfg.local_crop_size,
94            is_global: false,
95        });
96    }
97    Ok(crops)
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn default_config_dino_canonical() {
106        let cfg = MultiCropConfig::default();
107        assert_eq!(cfg.global_crop_size, 224);
108        assert_eq!(cfg.local_crop_size, 96);
109        assert_eq!(cfg.n_local_crops, 6);
110        assert_eq!(cfg.n_crops(), 8);
111    }
112
113    #[test]
114    fn rejects_zero_local_crops() {
115        assert!(MultiCropConfig::new(224, 96, 0).is_err());
116    }
117
118    #[test]
119    fn rejects_zero_crop_sizes() {
120        assert!(MultiCropConfig::new(0, 96, 6).is_err());
121        assert!(MultiCropConfig::new(224, 0, 6).is_err());
122    }
123
124    #[test]
125    fn multi_crop_returns_two_globals_then_n_local() {
126        let cfg = MultiCropConfig::new(160, 64, 4).expect("new should succeed");
127        let crops = multi_crop(&cfg).expect("multi_crop should succeed");
128        assert_eq!(crops.len(), 6);
129        assert_eq!(crops[0].size, 160);
130        assert!(crops[0].is_global);
131        assert_eq!(crops[1].size, 160);
132        assert!(crops[1].is_global);
133        for c in &crops[2..] {
134            assert_eq!(c.size, 64);
135            assert!(!c.is_global);
136        }
137    }
138
139    #[test]
140    fn multi_crop_default_yields_8_specs() {
141        let cfg = MultiCropConfig::default();
142        let crops = multi_crop(&cfg).expect("multi_crop should succeed");
143        assert_eq!(crops.len(), 8);
144    }
145}