use crate::error::{SeqError, SeqResult};
use crate::hmm::forward_backward::logsumexp;
#[derive(Debug, Clone)]
pub struct SinkhornConfig {
pub epsilon: f64,
pub max_iter: usize,
pub tol: f64,
}
impl Default for SinkhornConfig {
fn default() -> Self {
Self {
epsilon: 0.1,
max_iter: 1000,
tol: 1e-9,
}
}
}
impl SinkhornConfig {
pub fn validate(&self) -> SeqResult<()> {
if !(self.epsilon.is_finite() && self.epsilon > 0.0) {
return Err(SeqError::InvalidParameter {
name: "epsilon".into(),
value: self.epsilon,
});
}
if !(self.tol.is_finite() && self.tol >= 0.0) {
return Err(SeqError::InvalidParameter {
name: "tol".into(),
value: self.tol,
});
}
if self.max_iter == 0 {
return Err(SeqError::InvalidConfiguration(
"max_iter must be > 0".into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SinkhornResult {
pub n: usize,
pub m: usize,
pub plan: Vec<f64>,
pub f: Vec<f64>,
pub g: Vec<f64>,
pub transport_score: f64,
pub iterations: usize,
pub residual: f64,
}
impl SinkhornResult {
pub fn row(&self, i: usize) -> SeqResult<&[f64]> {
if i >= self.n {
return Err(SeqError::IndexOutOfBounds {
index: i,
len: self.n,
});
}
Ok(&self.plan[i * self.m..(i + 1) * self.m])
}
pub fn argmax_assignment(&self) -> Vec<usize> {
let mut out = vec![0usize; self.n];
for i in 0..self.n {
let base = i * self.m;
let mut best = 0usize;
let mut best_v = f64::NEG_INFINITY;
for j in 0..self.m {
let v = self.plan[base + j];
if v > best_v {
best_v = v;
best = j;
}
}
out[i] = best;
}
out
}
}
pub fn sinkhorn_normalize(
s: &[f64],
n: usize,
m: usize,
config: &SinkhornConfig,
) -> SeqResult<SinkhornResult> {
let a = vec![1.0 / n.max(1) as f64; n];
let b = vec![1.0 / m.max(1) as f64; m];
sinkhorn_normalize_with_margins(s, n, m, &a, &b, config)
}
pub fn sinkhorn_normalize_with_margins(
s: &[f64],
n: usize,
m: usize,
a: &[f64],
b: &[f64],
config: &SinkhornConfig,
) -> SeqResult<SinkhornResult> {
config.validate()?;
if n == 0 || m == 0 {
return Err(SeqError::EmptyInput);
}
if s.len() != n * m {
return Err(SeqError::ShapeMismatch {
expected: n * m,
got: s.len(),
});
}
if a.len() != n {
return Err(SeqError::LengthMismatch { a: a.len(), b: n });
}
if b.len() != m {
return Err(SeqError::LengthMismatch { a: b.len(), b: m });
}
for &v in s {
if !v.is_finite() {
return Err(SeqError::NumericalInstability(
"score matrix contains a non-finite entry".into(),
));
}
}
let mut mass_a = 0.0;
for (idx, &v) in a.iter().enumerate() {
if !(v.is_finite() && v > 0.0) {
return Err(SeqError::InvalidParameter {
name: format!("a[{idx}]"),
value: v,
});
}
mass_a += v;
}
let mut mass_b = 0.0;
for (idx, &v) in b.iter().enumerate() {
if !(v.is_finite() && v > 0.0) {
return Err(SeqError::InvalidParameter {
name: format!("b[{idx}]"),
value: v,
});
}
mass_b += v;
}
if (mass_a - mass_b).abs() > 1e-6 {
return Err(SeqError::NumericalInstability(format!(
"margin masses differ: Σa={mass_a}, Σb={mass_b}"
)));
}
let eps = config.epsilon;
let log_a: Vec<f64> = a.iter().map(|&x| x.ln()).collect();
let log_b: Vec<f64> = b.iter().map(|&x| x.ln()).collect();
let mut log_u = vec![0.0f64; n];
let mut log_v = vec![0.0f64; m];
let mut col_buf = vec![0.0f64; n]; let mut row_buf = vec![0.0f64; m];
let mut iterations = 0usize;
let mut residual = f64::INFINITY;
for sweep in 0..config.max_iter {
iterations = sweep + 1;
for i in 0..n {
let base = i * m;
for j in 0..m {
row_buf[j] = s[base + j] / eps + log_v[j];
}
let lse = logsumexp(&row_buf);
log_u[i] = log_a[i] - lse;
if !log_u[i].is_finite() {
return Err(SeqError::NotConverged { iter: iterations });
}
}
let mut max_delta = 0.0f64;
for j in 0..m {
for i in 0..n {
col_buf[i] = s[i * m + j] / eps + log_u[i];
}
let lse = logsumexp(&col_buf);
let new_v = log_b[j] - lse;
if !new_v.is_finite() {
return Err(SeqError::NotConverged { iter: iterations });
}
let delta = (new_v - log_v[j]).abs();
if delta > max_delta {
max_delta = delta;
}
log_v[j] = new_v;
}
residual = max_delta;
if max_delta <= config.tol {
break;
}
}
let mut plan = vec![0.0f64; n * m];
let mut transport_score = 0.0f64;
for i in 0..n {
let base = i * m;
for j in 0..m {
let log_p = s[base + j] / eps + log_u[i] + log_v[j];
let p = log_p.exp();
plan[base + j] = p;
transport_score += p * s[base + j];
}
}
let f: Vec<f64> = log_u.iter().map(|&x| eps * x).collect();
let g: Vec<f64> = log_v.iter().map(|&x| eps * x).collect();
Ok(SinkhornResult {
n,
m,
plan,
f,
g,
transport_score,
iterations,
residual,
})
}
#[derive(Debug, Clone)]
pub struct SinkhornCrf {
config: SinkhornConfig,
}
impl SinkhornCrf {
pub fn new(config: SinkhornConfig) -> SeqResult<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn config(&self) -> &SinkhornConfig {
&self.config
}
pub fn forward(&self, s: &[f64], n: usize, m: usize) -> SeqResult<SinkhornResult> {
sinkhorn_normalize(s, n, m, &self.config)
}
pub fn forward_with_margins(
&self,
s: &[f64],
n: usize,
m: usize,
a: &[f64],
b: &[f64],
) -> SeqResult<SinkhornResult> {
sinkhorn_normalize_with_margins(s, n, m, a, b, &self.config)
}
pub fn structured_loss(
&self,
s: &[f64],
n: usize,
m: usize,
gold: &[f64],
) -> SeqResult<(f64, SinkhornResult)> {
if gold.len() != n * m {
return Err(SeqError::ShapeMismatch {
expected: n * m,
got: gold.len(),
});
}
let pred = self.forward(s, n, m)?;
let mut loss = 0.0f64;
for k in 0..n * m {
loss += (gold[k] - pred.plan[k]) * s[k];
}
Ok((loss, pred))
}
pub fn structured_loss_grad(
&self,
s: &[f64],
n: usize,
m: usize,
gold: &[f64],
) -> SeqResult<(f64, Vec<f64>)> {
let (loss, pred) = self.structured_loss(s, n, m, gold)?;
let mut grad = vec![0.0f64; n * m];
for k in 0..n * m {
grad[k] = gold[k] - pred.plan[k];
}
Ok((loss, grad))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn plan_respects_uniform_margins() {
let n = 4;
let m = 4;
let mut s = vec![0.0; n * m];
for i in 0..n {
for j in 0..m {
s[i * m + j] = if i == j { 3.0 } else { 0.0 };
}
}
let cfg = SinkhornConfig {
epsilon: 0.05,
max_iter: 2000,
tol: 1e-12,
};
let res = sinkhorn_normalize(&s, n, m, &cfg).expect("sinkhorn");
for i in 0..n {
let row_sum: f64 = res.row(i).expect("row").iter().sum();
assert!(
(row_sum - 1.0 / n as f64).abs() < 1e-7,
"row {i} sum {row_sum}"
);
}
for j in 0..m {
let col_sum: f64 = (0..n).map(|i| res.plan[i * m + j]).sum();
assert!(
(col_sum - 1.0 / m as f64).abs() < 1e-7,
"col {j} sum {col_sum}"
);
}
}
#[test]
fn sharpens_to_permutation() {
let n = 5;
let mut s = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
s[i * n + j] = if i == j { 5.0 } else { 0.0 };
}
}
let cfg = SinkhornConfig {
epsilon: 0.02,
max_iter: 5000,
tol: 1e-12,
};
let res = sinkhorn_normalize(&s, n, n, &cfg).expect("sinkhorn");
assert_eq!(res.argmax_assignment(), vec![0, 1, 2, 3, 4]);
for i in 0..n {
let on_diag = n as f64 * res.plan[i * n + i];
assert!(on_diag > 0.95, "diag {i} = {on_diag}");
}
}
#[test]
fn converges_below_tolerance() {
let n = 3;
let m = 3;
let s = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
let cfg = SinkhornConfig {
epsilon: 0.2,
max_iter: 1000,
tol: 1e-10,
};
let res = sinkhorn_normalize(&s, n, m, &cfg).expect("sinkhorn");
assert!(res.residual <= 1e-10, "residual {}", res.residual);
assert!(res.iterations < cfg.max_iter);
}
#[test]
fn rectangular_with_margins() {
let n = 2;
let m = 3;
let s = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0];
let a = vec![0.6, 0.4];
let b = vec![0.3, 0.3, 0.4];
let cfg = SinkhornConfig {
epsilon: 0.1,
max_iter: 4000,
tol: 1e-12,
};
let res = sinkhorn_normalize_with_margins(&s, n, m, &a, &b, &cfg).expect("sinkhorn");
for i in 0..n {
let row_sum: f64 = res.row(i).expect("row").iter().sum();
assert!((row_sum - a[i]).abs() < 1e-7, "row {i} = {row_sum}");
}
for j in 0..m {
let col_sum: f64 = (0..n).map(|i| res.plan[i * m + j]).sum();
assert!((col_sum - b[j]).abs() < 1e-7, "col {j} = {col_sum}");
}
}
#[test]
fn structured_loss_and_grad() {
let n = 3;
let mut s = vec![0.0; n * n];
for i in 0..n {
s[i * n + i] = 2.0;
}
let cfg = SinkhornConfig {
epsilon: 0.1,
max_iter: 3000,
tol: 1e-12,
};
let layer = SinkhornCrf::new(cfg).expect("layer");
let pred = layer.forward(&s, n, n).expect("fwd");
let (loss0, grad0) = layer
.structured_loss_grad(&s, n, n, &pred.plan)
.expect("loss");
assert!(loss0.abs() < 1e-9, "loss0 = {loss0}");
for g in &grad0 {
assert!(g.abs() < 1e-9);
}
let mut gold = vec![0.0; n * n];
for i in 0..n {
gold[i * n + i] = 1.0 / n as f64;
}
let (loss1, _) = layer.structured_loss(&s, n, n, &gold).expect("loss1");
assert!(loss1 >= -1e-9, "loss1 = {loss1}");
}
#[test]
fn deterministic() {
let n = 4;
let m = 4;
let s: Vec<f64> = (0..n * m).map(|k| (k as f64 * 0.37).sin()).collect();
let cfg = SinkhornConfig::default();
let r1 = sinkhorn_normalize(&s, n, m, &cfg).expect("r1");
let r2 = sinkhorn_normalize(&s, n, m, &cfg).expect("r2");
assert_eq!(r1.plan, r2.plan);
assert_eq!(r1.f, r2.f);
assert_eq!(r1.g, r2.g);
}
#[test]
fn validation_errors() {
let cfg = SinkhornConfig::default();
assert!(sinkhorn_normalize(&[], 0, 3, &cfg).is_err());
assert!(sinkhorn_normalize(&[1.0, 2.0], 2, 2, &cfg).is_err());
assert!(sinkhorn_normalize(&[f64::NAN, 0.0, 0.0, 0.0], 2, 2, &cfg).is_err());
let bad = SinkhornConfig {
epsilon: 0.0,
..SinkhornConfig::default()
};
assert!(bad.validate().is_err());
let a = vec![0.5, 0.5];
let b = vec![1.0, 1.0];
assert!(sinkhorn_normalize_with_margins(&[0.0; 4], 2, 2, &a, &b, &cfg).is_err());
}
}