oxicuda_ssl/augment/
multi_crop.rs1use crate::error::{SslError, SslResult};
10
11#[derive(Debug, Clone)]
13pub struct MultiCropConfig {
14 pub global_crop_size: usize,
16 pub local_crop_size: usize,
18 pub n_local_crops: usize,
20}
21
22impl Default for MultiCropConfig {
23 fn default() -> Self {
24 Self {
25 global_crop_size: 224,
26 local_crop_size: 96,
27 n_local_crops: 6,
28 }
29 }
30}
31
32impl MultiCropConfig {
33 pub fn new(
39 global_crop_size: usize,
40 local_crop_size: usize,
41 n_local_crops: usize,
42 ) -> SslResult<Self> {
43 if n_local_crops == 0 {
44 return Err(SslError::InvalidNumCrops);
45 }
46 if global_crop_size == 0 || local_crop_size == 0 {
47 return Err(SslError::EmptyInput);
48 }
49 Ok(Self {
50 global_crop_size,
51 local_crop_size,
52 n_local_crops,
53 })
54 }
55
56 #[must_use]
58 pub fn n_crops(&self) -> usize {
59 2 + self.n_local_crops
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq)]
65pub struct CropSpec {
66 pub size: usize,
68 pub is_global: bool,
70}
71
72pub fn multi_crop(cfg: &MultiCropConfig) -> SslResult<Vec<CropSpec>> {
81 let n = cfg.n_crops();
82 let mut crops = Vec::with_capacity(n);
83 crops.push(CropSpec {
84 size: cfg.global_crop_size,
85 is_global: true,
86 });
87 crops.push(CropSpec {
88 size: cfg.global_crop_size,
89 is_global: true,
90 });
91 for _ in 0..cfg.n_local_crops {
92 crops.push(CropSpec {
93 size: cfg.local_crop_size,
94 is_global: false,
95 });
96 }
97 Ok(crops)
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn default_config_dino_canonical() {
106 let cfg = MultiCropConfig::default();
107 assert_eq!(cfg.global_crop_size, 224);
108 assert_eq!(cfg.local_crop_size, 96);
109 assert_eq!(cfg.n_local_crops, 6);
110 assert_eq!(cfg.n_crops(), 8);
111 }
112
113 #[test]
114 fn rejects_zero_local_crops() {
115 assert!(MultiCropConfig::new(224, 96, 0).is_err());
116 }
117
118 #[test]
119 fn rejects_zero_crop_sizes() {
120 assert!(MultiCropConfig::new(0, 96, 6).is_err());
121 assert!(MultiCropConfig::new(224, 0, 6).is_err());
122 }
123
124 #[test]
125 fn multi_crop_returns_two_globals_then_n_local() {
126 let cfg = MultiCropConfig::new(160, 64, 4).expect("new should succeed");
127 let crops = multi_crop(&cfg).expect("multi_crop should succeed");
128 assert_eq!(crops.len(), 6);
129 assert_eq!(crops[0].size, 160);
130 assert!(crops[0].is_global);
131 assert_eq!(crops[1].size, 160);
132 assert!(crops[1].is_global);
133 for c in &crops[2..] {
134 assert_eq!(c.size, 64);
135 assert!(!c.is_global);
136 }
137 }
138
139 #[test]
140 fn multi_crop_default_yields_8_specs() {
141 let cfg = MultiCropConfig::default();
142 let crops = multi_crop(&cfg).expect("multi_crop should succeed");
143 assert_eq!(crops.len(), 8);
144 }
145}