use crate::error::{SslError, SslResult};
#[derive(Debug, Clone)]
pub struct SwavConfig {
pub temperature: f32,
pub epsilon: f32,
pub sinkhorn_iters: usize,
}
impl Default for SwavConfig {
fn default() -> Self {
Self {
temperature: 0.1,
epsilon: 0.05,
sinkhorn_iters: 3,
}
}
}
impl SwavConfig {
pub fn new(temperature: f32, epsilon: f32, sinkhorn_iters: usize) -> SslResult<Self> {
if !(temperature.is_finite() && temperature > 0.0) {
return Err(SslError::InvalidTemperature { temp: temperature });
}
if !(epsilon.is_finite() && epsilon > 0.0) {
return Err(SslError::InvalidLossWeight { weight: epsilon });
}
Ok(Self {
temperature,
epsilon,
sinkhorn_iters,
})
}
}
pub fn sinkhorn_knopp(q: &mut [f32], n: usize, k: usize, iters: usize) -> SslResult<()> {
if n == 0 {
return Err(SslError::EmptyInput);
}
if k < 2 {
return Err(SslError::NumPrototypesTooSmall);
}
if q.len() != n * k {
return Err(SslError::DimensionMismatch {
expected: n * k,
got: q.len(),
});
}
let total: f32 = q.iter().sum();
if total <= 0.0 {
return Err(SslError::SinkhornDiverged { iters: 0 });
}
let inv_total = 1.0 / total;
for v in q.iter_mut() {
*v *= inv_total;
}
for it in 0..iters {
let mut col_sum = vec![0.0_f32; k];
for i in 0..n {
for j in 0..k {
col_sum[j] += q[i * k + j];
}
}
for c in &col_sum {
if !(c.is_finite() && *c > 0.0) {
return Err(SslError::SinkhornDiverged { iters: it });
}
}
for i in 0..n {
for j in 0..k {
q[i * k + j] /= col_sum[j] * k as f32;
}
}
for i in 0..n {
let mut row_sum = 0.0_f32;
for j in 0..k {
row_sum += q[i * k + j];
}
if !(row_sum.is_finite() && row_sum > 0.0) {
return Err(SslError::SinkhornDiverged { iters: it });
}
for j in 0..k {
q[i * k + j] /= row_sum * n as f32;
}
}
}
for i in 0..n {
let mut row_sum = 0.0_f32;
for j in 0..k {
row_sum += q[i * k + j];
}
if row_sum > 0.0 {
for j in 0..k {
q[i * k + j] /= row_sum;
}
}
}
Ok(())
}
fn row_softmax(scores: &[f32], n: usize, k: usize, t: f32) -> Vec<f32> {
let mut out = Vec::with_capacity(n * k);
for i in 0..n {
let row = &scores[i * k..(i + 1) * k];
let mut max_v = f32::NEG_INFINITY;
for &v in row {
if v / t > max_v {
max_v = v / t;
}
}
let mut s = 0.0_f64;
let mut tmp = Vec::with_capacity(k);
for &v in row {
let e = ((v / t - max_v) as f64).exp();
tmp.push(e);
s += e;
}
let inv = 1.0_f64 / s.max(1e-30);
for v in &tmp {
out.push((*v * inv) as f32);
}
}
out
}
fn cross_entropy(q: &[f32], p: &[f32], n: usize, k: usize) -> f32 {
let mut total = 0.0_f64;
for i in 0..n {
for j in 0..k {
let pj = p[i * k + j].max(1e-12);
total += -(q[i * k + j] as f64) * (pj as f64).ln();
}
}
(total / n as f64) as f32
}
pub fn swav_loss(
z_a: &[f32],
z_b: &[f32],
protos: &[f32],
n: usize,
d: usize,
k: usize,
cfg: &SwavConfig,
) -> SslResult<f32> {
if n == 0 || d == 0 {
return Err(SslError::EmptyInput);
}
if k < 2 {
return Err(SslError::NumPrototypesTooSmall);
}
if z_a.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: z_a.len(),
});
}
if z_b.len() != n * d {
return Err(SslError::DimensionMismatch {
expected: n * d,
got: z_b.len(),
});
}
if protos.len() != k * d {
return Err(SslError::DimensionMismatch {
expected: k * d,
got: protos.len(),
});
}
let scores_a = matmul_zct(z_a, protos, n, d, k);
let scores_b = matmul_zct(z_b, protos, n, d, k);
let mut q_a: Vec<f32> = scores_a.iter().map(|s| (s / cfg.epsilon).exp()).collect();
let mut q_b: Vec<f32> = scores_b.iter().map(|s| (s / cfg.epsilon).exp()).collect();
sinkhorn_knopp(&mut q_a, n, k, cfg.sinkhorn_iters)?;
sinkhorn_knopp(&mut q_b, n, k, cfg.sinkhorn_iters)?;
let p_a = row_softmax(&scores_a, n, k, cfg.temperature);
let p_b = row_softmax(&scores_b, n, k, cfg.temperature);
let l1 = cross_entropy(&q_b, &p_a, n, k);
let l2 = cross_entropy(&q_a, &p_b, n, k);
Ok(0.5 * (l1 + l2))
}
fn matmul_zct(z: &[f32], protos: &[f32], n: usize, d: usize, k: usize) -> Vec<f32> {
let mut out = vec![0.0_f32; n * k];
for i in 0..n {
for j in 0..k {
let mut acc = 0.0_f32;
let row = &z[i * d..(i + 1) * d];
let proto = &protos[j * d..(j + 1) * d];
for (a, b) in row.iter().zip(proto.iter()) {
acc += a * b;
}
out[i * k + j] = acc;
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn swav_default_config() {
let cfg = SwavConfig::default();
assert!((cfg.temperature - 0.1).abs() < 1e-7);
assert!((cfg.epsilon - 0.05).abs() < 1e-7);
assert_eq!(cfg.sinkhorn_iters, 3);
}
#[test]
fn swav_rejects_invalid_temperature() {
assert!(SwavConfig::new(0.0, 0.05, 3).is_err());
assert!(SwavConfig::new(-1.0, 0.05, 3).is_err());
}
#[test]
fn swav_rejects_invalid_epsilon() {
assert!(SwavConfig::new(0.1, 0.0, 3).is_err());
}
#[test]
fn sinkhorn_normalises_uniform_input() {
let n = 4;
let k = 3;
let mut q = vec![1.0_f32; n * k];
sinkhorn_knopp(&mut q, n, k, 5).expect("sinkhorn_knopp should succeed");
for v in &q {
assert!((*v - 1.0 / k as f32).abs() < 1e-4, "v = {v}");
}
}
#[test]
fn sinkhorn_rows_sum_to_one() {
let n = 4;
let k = 3;
let mut q = vec![0.0_f32; n * k];
for (i, v) in q.iter_mut().enumerate() {
*v = (i as f32 + 1.0) * 0.1;
}
sinkhorn_knopp(&mut q, n, k, 5).expect("sinkhorn_knopp should succeed");
for i in 0..n {
let s: f32 = q[i * k..(i + 1) * k].iter().sum();
assert!((s - 1.0).abs() < 1e-4, "row sum = {s}");
}
}
#[test]
fn sinkhorn_rejects_zero_total() {
let n = 2;
let k = 2;
let mut q = vec![0.0_f32; n * k];
assert!(sinkhorn_knopp(&mut q, n, k, 3).is_err());
}
#[test]
fn sinkhorn_rejects_too_few_prototypes() {
let mut q = vec![1.0_f32];
assert!(sinkhorn_knopp(&mut q, 1, 1, 3).is_err());
}
#[test]
fn swav_loss_finite_on_random_data() {
let n = 8;
let d = 4;
let k = 6;
let z_a: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.013).sin()).collect();
let z_b: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.027).cos()).collect();
let protos: Vec<f32> = (0..k * d).map(|i| (i as f32 * 0.041).sin()).collect();
let cfg = SwavConfig::default();
let l = swav_loss(&z_a, &z_b, &protos, n, d, k, &cfg).expect("swav_loss should succeed");
assert!(l.is_finite() && l > 0.0);
}
#[test]
fn swav_rejects_dim_mismatch_protos() {
let n = 2;
let d = 4;
let k = 3;
let z = vec![1.0_f32; n * d];
let protos = vec![1.0_f32; k * 2]; let cfg = SwavConfig::default();
assert!(swav_loss(&z, &z, &protos, n, d, k, &cfg).is_err());
}
}