use crate::error::{SslError, SslResult};
#[derive(Debug, Clone)]
pub struct MultiCropConfig {
pub global_crop_size: usize,
pub local_crop_size: usize,
pub n_local_crops: usize,
}
impl Default for MultiCropConfig {
fn default() -> Self {
Self {
global_crop_size: 224,
local_crop_size: 96,
n_local_crops: 6,
}
}
}
impl MultiCropConfig {
pub fn new(
global_crop_size: usize,
local_crop_size: usize,
n_local_crops: usize,
) -> SslResult<Self> {
if n_local_crops == 0 {
return Err(SslError::InvalidNumCrops);
}
if global_crop_size == 0 || local_crop_size == 0 {
return Err(SslError::EmptyInput);
}
Ok(Self {
global_crop_size,
local_crop_size,
n_local_crops,
})
}
#[must_use]
pub fn n_crops(&self) -> usize {
2 + self.n_local_crops
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct CropSpec {
pub size: usize,
pub is_global: bool,
}
pub fn multi_crop(cfg: &MultiCropConfig) -> SslResult<Vec<CropSpec>> {
let n = cfg.n_crops();
let mut crops = Vec::with_capacity(n);
crops.push(CropSpec {
size: cfg.global_crop_size,
is_global: true,
});
crops.push(CropSpec {
size: cfg.global_crop_size,
is_global: true,
});
for _ in 0..cfg.n_local_crops {
crops.push(CropSpec {
size: cfg.local_crop_size,
is_global: false,
});
}
Ok(crops)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_dino_canonical() {
let cfg = MultiCropConfig::default();
assert_eq!(cfg.global_crop_size, 224);
assert_eq!(cfg.local_crop_size, 96);
assert_eq!(cfg.n_local_crops, 6);
assert_eq!(cfg.n_crops(), 8);
}
#[test]
fn rejects_zero_local_crops() {
assert!(MultiCropConfig::new(224, 96, 0).is_err());
}
#[test]
fn rejects_zero_crop_sizes() {
assert!(MultiCropConfig::new(0, 96, 6).is_err());
assert!(MultiCropConfig::new(224, 0, 6).is_err());
}
#[test]
fn multi_crop_returns_two_globals_then_n_local() {
let cfg = MultiCropConfig::new(160, 64, 4).unwrap();
let crops = multi_crop(&cfg).unwrap();
assert_eq!(crops.len(), 6);
assert_eq!(crops[0].size, 160);
assert!(crops[0].is_global);
assert_eq!(crops[1].size, 160);
assert!(crops[1].is_global);
for c in &crops[2..] {
assert_eq!(c.size, 64);
assert!(!c.is_global);
}
}
#[test]
fn multi_crop_default_yields_8_specs() {
let cfg = MultiCropConfig::default();
let crops = multi_crop(&cfg).unwrap();
assert_eq!(crops.len(), 8);
}
}