use crate::error::{SslError, SslResult};
use crate::handle::LcgRng;
use crate::head::predictor::PredictorHead;
#[derive(Debug, Clone, PartialEq)]
pub struct SimSiamConfig {
pub d_proj: usize,
pub d_pred: usize,
}
impl Default for SimSiamConfig {
fn default() -> Self {
Self {
d_proj: 128,
d_pred: 64,
}
}
}
pub fn simsiam_loss(p: &[f32], z: &[f32], n: usize, d: usize) -> SslResult<f32> {
validate_batch(p, z, n, d)?;
Ok(neg_cosine_mean(p, z, n, d))
}
pub fn simsiam_loss_batch(
p1: &[f32],
z2: &[f32],
p2: &[f32],
z1: &[f32],
n: usize,
d: usize,
) -> SslResult<f32> {
validate_batch(p1, z2, n, d)?;
validate_batch(p2, z1, n, d)?;
let d1 = neg_cosine_mean(p1, z2, n, d);
let d2 = neg_cosine_mean(p2, z1, n, d);
Ok((d1 + d2) * 0.5)
}
pub fn is_collapsed(z: &[f32], n: usize, d: usize, threshold: f32) -> SslResult<bool> {
if n == 0 || d == 0 {
return Err(SslError::EmptyInput);
}
if z.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: z.len(),
});
}
let mut normed = vec![0.0_f64; n * d];
for i in 0..n {
let row = &z[i * d..(i + 1) * d];
let norm = row
.iter()
.map(|&v| (v as f64) * (v as f64))
.sum::<f64>()
.sqrt()
.max(1e-12_f64);
for j in 0..d {
normed[i * d + j] = (row[j] as f64) / norm;
}
}
let mut mean_std = 0.0_f64;
let n_f = n as f64;
for j in 0..d {
let mut sum = 0.0_f64;
let mut sum_sq = 0.0_f64;
for i in 0..n {
let v = normed[i * d + j];
sum += v;
sum_sq += v * v;
}
let mean = sum / n_f;
let var = (sum_sq / n_f - mean * mean).max(0.0_f64);
mean_std += var.sqrt();
}
mean_std /= d as f64;
Ok(mean_std < threshold as f64)
}
#[derive(Debug, Clone)]
pub struct SimSiamPredictor {
pub predictor: PredictorHead,
}
impl SimSiamPredictor {
#[must_use]
pub fn new(predictor: PredictorHead) -> Self {
Self { predictor }
}
pub fn from_config(cfg: &SimSiamConfig, rng: &mut LcgRng) -> SslResult<Self> {
let predictor = PredictorHead::new(cfg.d_proj, cfg.d_pred, cfg.d_proj, rng)?;
Ok(Self { predictor })
}
pub fn forward(&self, z: &[f32]) -> SslResult<Vec<f32>> {
self.predictor.forward(z)
}
}
#[inline]
fn validate_batch(p: &[f32], z: &[f32], n: usize, d: usize) -> SslResult<()> {
if n == 0 || d == 0 {
return Err(SslError::EmptyInput);
}
if p.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: p.len(),
});
}
if z.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: z.len(),
});
}
Ok(())
}
#[inline]
fn neg_cosine_mean(p: &[f32], z: &[f32], n: usize, d: usize) -> f32 {
let mut total = 0.0_f64;
for i in 0..n {
let p_row = &p[i * d..(i + 1) * d];
let z_row = &z[i * d..(i + 1) * d];
let p_norm = p_row
.iter()
.map(|&v| (v as f64) * (v as f64))
.sum::<f64>()
.sqrt()
.max(1e-12_f64);
let z_norm = z_row
.iter()
.map(|&v| (v as f64) * (v as f64))
.sum::<f64>()
.sqrt()
.max(1e-12_f64);
let dot: f64 = p_row
.iter()
.zip(z_row.iter())
.map(|(&a, &b)| (a as f64) * (b as f64))
.sum();
let cos = dot / (p_norm * z_norm);
total -= cos; }
(total / n as f64) as f32
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
#[test]
fn simsiam_loss_aligned_gives_minus_one() {
let v = vec![1.0_f32, 0.0, 0.0, 0.0];
let l = simsiam_loss(&v, &v, 1, 4).expect("simsiam_loss should succeed");
assert!((l + 1.0).abs() < 1e-5, "loss = {l}");
}
#[test]
fn simsiam_loss_orthogonal_gives_zero() {
let p = vec![1.0_f32, 0.0];
let z = vec![0.0_f32, 1.0];
let l = simsiam_loss(&p, &z, 1, 2).expect("simsiam_loss should succeed");
assert!(l.abs() < 1e-5, "loss = {l}");
}
#[test]
fn simsiam_loss_antiparallel_gives_plus_one() {
let p = vec![1.0_f32, 0.0];
let z = vec![-1.0_f32, 0.0];
let l = simsiam_loss(&p, &z, 1, 2).expect("simsiam_loss should succeed");
assert!((l - 1.0).abs() < 1e-5, "loss = {l}");
}
#[test]
fn simsiam_loss_batch_symmetric() {
let p1 = vec![1.0_f32, 0.0]; let z2 = vec![1.0_f32, 0.0];
let p2 = vec![0.0_f32, 1.0]; let z1 = vec![1.0_f32, 0.0];
let sym = simsiam_loss_batch(&p1, &z2, &p2, &z1, 1, 2)
.expect("simsiam_loss_batch should succeed");
let expected = (-1.0_f32 + 0.0_f32) * 0.5;
assert!((sym - expected).abs() < 1e-5, "sym = {sym}");
}
#[test]
fn simsiam_loss_batch_equals_single_when_symmetric_inputs() {
let p: Vec<f32> = (0..12).map(|i| i as f32 * 0.1 + 0.5).collect();
let z: Vec<f32> = (0..12).map(|i| (12 - i) as f32 * 0.1 + 0.3).collect();
let single = simsiam_loss(&p, &z, 3, 4).expect("simsiam_loss should succeed");
let batch =
simsiam_loss_batch(&p, &z, &p, &z, 3, 4).expect("simsiam_loss_batch should succeed");
assert!(
(single - batch).abs() < 1e-5,
"single={single} batch={batch}"
);
}
#[test]
fn simsiam_predictor_forward_shape() {
let mut rng = LcgRng::new(42);
let cfg = SimSiamConfig {
d_proj: 16,
d_pred: 8,
};
let pred =
SimSiamPredictor::from_config(&cfg, &mut rng).expect("from_config should succeed");
let z = vec![0.5_f32; 16];
let p = pred.forward(&z).expect("forward should succeed");
assert_eq!(p.len(), 16, "output dim must equal d_proj");
}
#[test]
fn collapse_detection_constant_projections_collapsed() {
let n = 8;
let d = 4;
let z: Vec<f32> = (0..n * d)
.map(|idx| if idx % d == 0 { 1.0_f32 } else { 0.0_f32 })
.collect();
let collapsed = is_collapsed(&z, n, d, 0.1).expect("is_collapsed should succeed");
assert!(
collapsed,
"constant projections must be detected as collapsed"
);
}
#[test]
fn collapse_detection_diverse_projections_not_collapsed() {
let n = 4;
let d = 4;
let mut z = vec![0.0_f32; n * d];
for i in 0..n {
z[i * d + i] = 1.0;
}
let collapsed = is_collapsed(&z, n, d, 0.1).expect("is_collapsed should succeed");
assert!(
!collapsed,
"orthogonal projections must not be detected as collapsed"
);
}
#[test]
fn empty_input_returns_error() {
assert!(simsiam_loss(&[], &[], 0, 0).is_err());
assert!(simsiam_loss_batch(&[], &[], &[], &[], 0, 0).is_err());
assert!(is_collapsed(&[], 0, 0, 0.1).is_err());
}
#[test]
fn dimension_mismatch_returns_error() {
let p = vec![1.0_f32, 0.0, 0.0]; let z = vec![1.0_f32, 0.0]; let err = simsiam_loss(&p, &z, 1, 2);
assert!(
matches!(err, Err(SslError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {err:?}"
);
}
#[test]
fn single_sample_valid() {
let p = vec![0.6_f32, 0.8]; let z = vec![0.8_f32, 0.6];
let l = simsiam_loss(&p, &z, 1, 2).expect("simsiam_loss should succeed");
assert!(l.is_finite(), "loss must be finite for n=1");
}
}