use crate::error::{SslError, SslResult};
use crate::handle::LcgRng;
#[derive(Debug, Clone)]
pub struct SimMimConfig {
pub mask_ratio: f32,
pub patch_size: usize,
}
impl Default for SimMimConfig {
fn default() -> Self {
Self {
mask_ratio: 0.6,
patch_size: 32,
}
}
}
impl SimMimConfig {
pub fn new(mask_ratio: f32, patch_size: usize) -> SslResult<Self> {
if !(mask_ratio.is_finite() && (0.0..1.0).contains(&mask_ratio)) {
return Err(SslError::InvalidMaskRatio { ratio: mask_ratio });
}
if patch_size == 0 {
return Err(SslError::InvalidParameter {
name: "patch_size".into(),
reason: "must be > 0".into(),
});
}
Ok(Self {
mask_ratio,
patch_size,
})
}
}
pub fn simmim_random_mask(
n_patches: usize,
mask_ratio: f32,
rng: &mut LcgRng,
) -> SslResult<Vec<bool>> {
if n_patches == 0 {
return Err(SslError::EmptyInput);
}
if !(mask_ratio.is_finite() && (0.0..1.0).contains(&mask_ratio)) {
return Err(SslError::InvalidMaskRatio { ratio: mask_ratio });
}
let n_mask = (n_patches as f32 * mask_ratio) as usize;
let mut indices: Vec<usize> = (0..n_patches).collect();
rng.shuffle(&mut indices);
let mut mask = vec![false; n_patches];
for &idx in indices.iter().take(n_mask) {
mask[idx] = true;
}
Ok(mask)
}
pub fn simmim_block_mask(
n_patches_h: usize,
n_patches_w: usize,
mask_ratio: f32,
rng: &mut LcgRng,
) -> SslResult<Vec<bool>> {
if n_patches_h == 0 || n_patches_w == 0 {
return Err(SslError::EmptyInput);
}
if !(mask_ratio.is_finite() && (0.0..1.0).contains(&mask_ratio)) {
return Err(SslError::InvalidMaskRatio { ratio: mask_ratio });
}
let total = n_patches_h * n_patches_w;
let target_masked = (total as f32 * mask_ratio) as usize;
let mut mask = vec![false; total];
let mut n_masked = 0usize;
let max_iters = (target_masked + 1).max(1) * 16 + 1;
let mut iters = 0usize;
while n_masked < target_masked && iters < max_iters {
iters += 1;
let bh = (rng.next_usize(3) + 2).min(n_patches_h);
let bw = (rng.next_usize(3) + 2).min(n_patches_w);
let r0 = if n_patches_h > bh {
rng.next_usize(n_patches_h - bh + 1)
} else {
0
};
let c0 = if n_patches_w > bw {
rng.next_usize(n_patches_w - bw + 1)
} else {
0
};
for r in r0..r0 + bh {
for c in c0..c0 + bw {
let idx = r * n_patches_w + c;
if !mask[idx] {
mask[idx] = true;
n_masked += 1;
}
}
}
}
Ok(mask)
}
pub fn simmim_l1_loss(
pred: &[f32],
target: &[f32],
mask: &[bool],
n_patches: usize,
patch_dim: usize,
) -> SslResult<f32> {
validate_inputs(pred, target, mask, n_patches, patch_dim)?;
let mut total = 0.0_f64;
let mut count = 0usize;
for (i, &masked) in mask.iter().enumerate() {
if masked {
count += 1;
let base = i * patch_dim;
for k in 0..patch_dim {
let diff = (pred[base + k] - target[base + k]) as f64;
total += diff.abs();
}
}
}
if count == 0 {
return Err(SslError::EmptyInput);
}
Ok((total / (count * patch_dim) as f64) as f32)
}
pub fn simmim_l2_loss(
pred: &[f32],
target: &[f32],
mask: &[bool],
n_patches: usize,
patch_dim: usize,
) -> SslResult<f32> {
validate_inputs(pred, target, mask, n_patches, patch_dim)?;
let mut total = 0.0_f64;
let mut count = 0usize;
for (i, &masked) in mask.iter().enumerate() {
if masked {
count += 1;
let base = i * patch_dim;
for k in 0..patch_dim {
let diff = (pred[base + k] - target[base + k]) as f64;
total += diff * diff;
}
}
}
if count == 0 {
return Err(SslError::EmptyInput);
}
Ok((total / (count * patch_dim) as f64) as f32)
}
pub fn simmim_reconstruction_loss(
pred: &[f32],
target: &[f32],
mask: &[bool],
n_patches: usize,
patch_dim: usize,
use_l1: bool,
) -> SslResult<f32> {
if use_l1 {
simmim_l1_loss(pred, target, mask, n_patches, patch_dim)
} else {
simmim_l2_loss(pred, target, mask, n_patches, patch_dim)
}
}
#[inline]
fn validate_inputs(
pred: &[f32],
target: &[f32],
mask: &[bool],
n_patches: usize,
patch_dim: usize,
) -> SslResult<()> {
if n_patches == 0 || patch_dim == 0 {
return Err(SslError::EmptyInput);
}
let expected = n_patches * patch_dim;
if pred.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: pred.len(),
});
}
if target.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: target.len(),
});
}
if mask.len() != n_patches {
return Err(SslError::DimensionMismatch {
expected: n_patches,
got: mask.len(),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simmim_l1_loss_all_zero_pred_nonzero_target() {
let n = 8;
let d = 4;
let pred = vec![0.0_f32; n * d];
let target = vec![1.0_f32; n * d];
let mask = vec![true; n];
let loss = simmim_l1_loss(&pred, &target, &mask, n, d).unwrap();
assert!(loss > 0.0, "loss should be > 0, got {loss}");
assert!((loss - 1.0).abs() < 1e-5, "expected 1.0, got {loss}");
}
#[test]
fn simmim_l1_loss_perfect_reconstruction_zero() {
let n = 10;
let d = 8;
let target: Vec<f32> = (0..n * d).map(|i| i as f32 * 0.1).collect();
let pred = target.clone();
let mask = vec![
true, false, true, false, true, false, true, false, true, false,
];
let loss = simmim_l1_loss(&pred, &target, &mask, n, d).unwrap();
assert!(loss.abs() < 1e-7, "perfect reconstruction: loss = {loss}");
}
#[test]
fn simmim_l1_vs_l2_ordering() {
let n = 20;
let d = 16;
let target = vec![0.0_f32; n * d];
let pred = vec![2.0_f32; n * d]; let mask = vec![true; n];
let l1 = simmim_l1_loss(&pred, &target, &mask, n, d).unwrap();
let l2 = simmim_l2_loss(&pred, &target, &mask, n, d).unwrap();
assert!(
l1 <= l2,
"expected L1 ≤ L2 when errors ≥ 1, got L1={l1} L2={l2}"
);
}
#[test]
fn simmim_l2_loss_vs_manual() {
let n = 4;
let d = 2;
let target = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let pred = vec![2.0_f32, 3.0, 3.0, 4.0, 6.0, 7.0, 7.0, 8.0];
let mask = vec![true, false, true, false];
let loss = simmim_l2_loss(&pred, &target, &mask, n, d).unwrap();
assert!((loss - 1.0).abs() < 1e-5, "expected 1.0, got {loss}");
}
#[test]
fn simmim_random_mask_ratio_approx() {
let mut rng = LcgRng::new(42);
let n = 200;
let ratio = 0.6_f32;
let mask = simmim_random_mask(n, ratio, &mut rng).unwrap();
let n_masked = mask.iter().filter(|&&v| v).count();
let realised = n_masked as f32 / n as f32;
assert!(
(realised - ratio).abs() < 0.05,
"realised ratio {realised} too far from target {ratio}"
);
}
#[test]
fn simmim_random_mask_length_correct() {
let mut rng = LcgRng::new(7);
let n = 196;
let mask = simmim_random_mask(n, 0.6, &mut rng).unwrap();
assert_eq!(mask.len(), n, "mask length mismatch");
}
#[test]
fn simmim_block_mask_ratio_approx() {
let mut rng = LcgRng::new(99);
let h = 14;
let w = 14;
let ratio = 0.5_f32;
let mask = simmim_block_mask(h, w, ratio, &mut rng).unwrap();
let total = h * w;
assert_eq!(mask.len(), total);
let n_masked = mask.iter().filter(|&&v| v).count();
let realised = n_masked as f32 / total as f32;
assert!(
(realised - ratio).abs() < 0.20,
"realised ratio {realised} too far from target {ratio} (tol 0.20)"
);
}
#[test]
fn simmim_reconstruction_loss_dispatch_l1() {
let n = 6;
let d = 4;
let pred: Vec<f32> = (0..n * d).map(|i| (i as f32) * 0.05).collect();
let target = vec![0.5_f32; n * d];
let mask = vec![true, false, true, false, true, false];
let loss = simmim_reconstruction_loss(&pred, &target, &mask, n, d, true).unwrap();
assert!(loss.is_finite(), "L1 dispatch returned non-finite: {loss}");
}
#[test]
fn simmim_reconstruction_loss_dispatch_l2() {
let n = 6;
let d = 4;
let pred: Vec<f32> = (0..n * d).map(|i| (i as f32) * 0.05).collect();
let target = vec![0.5_f32; n * d];
let mask = vec![true, false, true, false, true, false];
let loss = simmim_reconstruction_loss(&pred, &target, &mask, n, d, false).unwrap();
assert!(loss.is_finite(), "L2 dispatch returned non-finite: {loss}");
}
#[test]
fn simmim_loss_only_unmasked_ignored() {
let n = 6;
let d = 3;
let target = vec![1.0_f32; n * d];
let mask = vec![true, false, true, false, false, true];
let pred_base = vec![0.0_f32; n * d];
let loss_base = simmim_l1_loss(&pred_base, &target, &mask, n, d).unwrap();
let mut pred_mutated = pred_base.clone();
for &i in &[1_usize, 3, 4] {
for k in 0..d {
pred_mutated[i * d + k] = 999.0;
}
}
let loss_mutated = simmim_l1_loss(&pred_mutated, &target, &mask, n, d).unwrap();
assert!(
(loss_base - loss_mutated).abs() < 1e-6,
"unmasked patches affected loss: {loss_base} vs {loss_mutated}"
);
}
#[test]
fn empty_input_returns_error() {
assert_eq!(
simmim_l1_loss(&[], &[], &[], 0, 4),
Err(SslError::EmptyInput)
);
assert_eq!(
simmim_l1_loss(&[], &[], &[], 4, 0),
Err(SslError::EmptyInput)
);
let mut rng = LcgRng::new(0);
assert_eq!(
simmim_random_mask(0, 0.5, &mut rng),
Err(SslError::EmptyInput)
);
assert_eq!(
simmim_block_mask(0, 4, 0.5, &mut rng),
Err(SslError::EmptyInput)
);
}
#[test]
fn zero_mask_ratio_returns_error_or_zero_loss() {
let mut rng = LcgRng::new(3);
let n = 16;
let mask = simmim_random_mask(n, 0.0, &mut rng).unwrap();
assert!(mask.iter().all(|&v| !v));
let pred = vec![1.0_f32; n * 4];
let target = vec![0.0_f32; n * 4];
let result = simmim_l1_loss(&pred, &target, &mask, n, 4);
assert!(result.is_err(), "expected error for all-unmasked input");
}
}