use crate::error::{SslError, SslResult};
#[derive(Debug, Clone)]
pub struct BarlowTwinsConfig {
pub lambda: f32,
}
impl Default for BarlowTwinsConfig {
fn default() -> Self {
Self { lambda: 0.005 }
}
}
impl BarlowTwinsConfig {
pub fn new(lambda: f32) -> SslResult<Self> {
if !(lambda.is_finite() && lambda >= 0.0) {
return Err(SslError::InvalidLossWeight { weight: lambda });
}
Ok(Self { lambda })
}
}
fn standardise(z: &[f32], n: usize, d: usize) -> Vec<f32> {
let mut out = z.to_vec();
if n < 2 {
return out;
}
let inv_n = 1.0_f32 / n as f32;
for j in 0..d {
let mut mean = 0.0_f32;
for i in 0..n {
mean += out[i * d + j];
}
mean *= inv_n;
let mut var = 0.0_f32;
for i in 0..n {
let v = out[i * d + j] - mean;
var += v * v;
}
let std = (var * inv_n + 1e-5).sqrt();
let inv_std = 1.0 / std;
for i in 0..n {
out[i * d + j] = (out[i * d + j] - mean) * inv_std;
}
}
out
}
pub fn barlow_twins_loss(
z_a: &[f32],
z_b: &[f32],
n: usize,
d: usize,
cfg: &BarlowTwinsConfig,
) -> SslResult<f32> {
if n == 0 || d == 0 {
return Err(SslError::EmptyInput);
}
if n < 2 {
return Err(SslError::BatchTooSmall);
}
if z_a.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: z_a.len(),
});
}
if z_b.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: z_b.len(),
});
}
let za = standardise(z_a, n, d);
let zb = standardise(z_b, n, d);
let inv_n = 1.0_f32 / n as f32;
let mut diag_sum = 0.0_f64;
let mut off_sum = 0.0_f64;
for i in 0..d {
for j in 0..d {
let mut c = 0.0_f32;
for k in 0..n {
c += za[k * d + i] * zb[k * d + j];
}
c *= inv_n;
if i == j {
let r = 1.0 - c;
diag_sum += (r * r) as f64;
} else {
off_sum += (c * c) as f64;
}
}
}
Ok(diag_sum as f32 + cfg.lambda * off_sum as f32)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn barlow_default_lambda() {
let cfg = BarlowTwinsConfig::default();
assert!((cfg.lambda - 0.005).abs() < 1e-7);
}
#[test]
fn barlow_identical_inputs_low_diag_loss() {
let n = 16;
let d = 4;
let mut z = vec![0.0_f32; n * d];
for (i, v) in z.iter_mut().enumerate() {
*v = ((i as f32) * 0.31415).sin();
}
let cfg = BarlowTwinsConfig::default();
let l = barlow_twins_loss(&z, &z, n, d, &cfg).expect("barlow_twins_loss should succeed");
assert!(l < 1.0, "l = {l}");
}
#[test]
fn barlow_uncorrelated_inputs_high_diag_loss() {
let n = 256;
let d = 4;
let mut rng_a = 13u64;
let mut rng_b = 17u64;
let mut z_a = vec![0.0_f32; n * d];
let mut z_b = vec![0.0_f32; n * d];
for v in z_a.iter_mut() {
rng_a = rng_a
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
*v = ((rng_a >> 33) as f32 / (u32::MAX as f32 + 1.0)) - 0.5;
}
for v in z_b.iter_mut() {
rng_b = rng_b
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
*v = ((rng_b >> 33) as f32 / (u32::MAX as f32 + 1.0)) - 0.5;
}
let cfg = BarlowTwinsConfig::default();
let l =
barlow_twins_loss(&z_a, &z_b, n, d, &cfg).expect("barlow_twins_loss should succeed");
assert!(l > 2.0, "l = {l}");
}
#[test]
fn barlow_rejects_n_lt_2() {
let z = vec![1.0_f32, 2.0];
let cfg = BarlowTwinsConfig::default();
assert!(barlow_twins_loss(&z, &z, 1, 2, &cfg).is_err());
}
#[test]
fn barlow_rejects_dim_mismatch() {
let a = vec![1.0_f32; 8];
let b = vec![1.0_f32; 6];
let cfg = BarlowTwinsConfig::default();
assert!(barlow_twins_loss(&a, &b, 2, 4, &cfg).is_err());
}
#[test]
fn barlow_rejects_negative_lambda() {
assert!(BarlowTwinsConfig::new(-0.1).is_err());
}
#[test]
fn barlow_lambda_zero_is_diagonal_only() {
let n = 8;
let d = 3;
let mut z = vec![0.0_f32; n * d];
for (i, v) in z.iter_mut().enumerate() {
*v = (i as f32) * 0.1;
}
let cfg_zero = BarlowTwinsConfig::new(0.0).expect("new should succeed");
let cfg_default = BarlowTwinsConfig::default();
let l_zero =
barlow_twins_loss(&z, &z, n, d, &cfg_zero).expect("barlow_twins_loss should succeed");
let l_default = barlow_twins_loss(&z, &z, n, d, &cfg_default)
.expect("barlow_twins_loss should succeed");
assert!(l_default >= l_zero - 1e-4);
}
}