use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::sinkhorn_scale";
pub const DIVISOR_FLOOR: u32 = 1;
#[must_use]
pub fn sinkhorn_scale(target: &str, divisor: &str, out: &str, count: u32) -> Program {
if count == 0 {
return crate::invalid_output_program(
OP_ID,
out,
DataType::U32,
format!("Fix: sinkhorn_scale requires count > 0, got {count}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let d_loaded = Expr::load(divisor, t.clone());
let d_safe = Expr::select(
Expr::eq(d_loaded.clone(), Expr::u32(0)),
Expr::u32(DIVISOR_FLOOR),
d_loaded,
);
let value = Expr::div(Expr::load(target, t.clone()), d_safe);
let body = vec![Node::if_then(
Expr::lt(t.clone(), Expr::u32(count)),
vec![Node::store(out, t, value)],
)];
Program::wrapped(
vec![
BufferDecl::storage(target, 0, BufferAccess::ReadOnly, DataType::U32).with_count(count),
BufferDecl::storage(divisor, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(count),
BufferDecl::storage(out, 2, BufferAccess::ReadWrite, DataType::U32).with_count(count),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
pub fn sinkhorn_iter_cpu(
k: &[f64],
a: &[f64],
b: &[f64],
u: &mut [f64],
v: &mut [f64],
m: u32,
n: u32,
) {
let mut kv = Vec::new();
let mut ktu = Vec::new();
sinkhorn_iter_cpu_into(k, a, b, u, v, m, n, &mut kv, &mut ktu);
}
#[allow(clippy::too_many_arguments)]
pub fn sinkhorn_iter_cpu_into(
k: &[f64],
a: &[f64],
b: &[f64],
u: &mut [f64],
v: &mut [f64],
m: u32,
n: u32,
kv: &mut Vec<f64>,
ktu: &mut Vec<f64>,
) {
let m = m as usize;
let n = n as usize;
kv.clear();
kv.resize(m, 0.0);
for i in 0..m {
for j in 0..n {
let k_ij = k.get(i * n + j).copied().unwrap_or(0.0);
let v_j = v.get(j).copied().unwrap_or(0.0);
kv[i] += k_ij * v_j;
}
}
for i in 0..m {
if let Some(u_i) = u.get_mut(i) {
*u_i = a.get(i).copied().unwrap_or(0.0) / kv[i].max(1e-30);
}
}
ktu.clear();
ktu.resize(n, 0.0);
for j in 0..n {
for i in 0..m {
let k_ij = k.get(i * n + j).copied().unwrap_or(0.0);
let u_i = u.get(i).copied().unwrap_or(0.0);
ktu[j] += k_ij * u_i;
}
}
for j in 0..n {
if let Some(v_j) = v.get_mut(j) {
*v_j = b.get(j).copied().unwrap_or(0.0) / ktu[j].max(1e-30);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-6;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPS * (1.0 + a.abs() + b.abs())
}
#[test]
fn cpu_uniform_marginals_converge_to_uniform_plan() {
let k = vec![1.0, 1.0, 1.0, 1.0];
let a = vec![0.5, 0.5];
let b = vec![0.5, 0.5];
let mut u = vec![1.0, 1.0];
let mut v = vec![1.0, 1.0];
sinkhorn_iter_cpu(&k, &a, &b, &mut u, &mut v, 2, 2);
assert!(approx_eq(u[0], 0.25));
assert!(approx_eq(u[1], 0.25));
assert!(approx_eq(v[0], 1.0));
assert!(approx_eq(v[1], 1.0));
}
#[test]
fn cpu_iterations_converge_to_balanced_plan() {
let k = vec![1.0, 0.5, 0.5, 1.0];
let a = vec![0.5, 0.5];
let b = vec![0.5, 0.5];
let mut u = vec![1.0, 1.0];
let mut v = vec![1.0, 1.0];
for _ in 0..50 {
sinkhorn_iter_cpu(&k, &a, &b, &mut u, &mut v, 2, 2);
}
let row_sum_0 = u[0] * (k[0] * v[0] + k[1] * v[1]);
let row_sum_1 = u[1] * (k[2] * v[0] + k[3] * v[1]);
let col_sum_0 = v[0] * (k[0] * u[0] + k[2] * u[1]);
let col_sum_1 = v[1] * (k[1] * u[0] + k[3] * u[1]);
assert!(approx_eq(row_sum_0, a[0]));
assert!(approx_eq(row_sum_1, a[1]));
assert!(approx_eq(col_sum_0, b[0]));
assert!(approx_eq(col_sum_1, b[1]));
}
#[test]
fn cpu_zero_in_divisor_handled() {
let k = vec![0.0, 0.0, 1.0, 1.0];
let a = vec![0.5, 0.5];
let b = vec![0.5, 0.5];
let mut u = vec![1.0, 1.0];
let mut v = vec![1.0, 1.0];
sinkhorn_iter_cpu(&k, &a, &b, &mut u, &mut v, 2, 2);
assert!(u[0].is_finite());
assert!(u[1].is_finite());
}
#[test]
fn cpu_into_reuses_sinkhorn_temporaries() {
let k = vec![1.0, 1.0, 1.0, 1.0];
let a = vec![0.5, 0.5];
let b = vec![0.5, 0.5];
let mut u = vec![1.0, 1.0];
let mut v = vec![1.0, 1.0];
let mut kv = Vec::new();
let mut ktu = Vec::new();
sinkhorn_iter_cpu_into(&k, &a, &b, &mut u, &mut v, 2, 2, &mut kv, &mut ktu);
let kv_ptr = kv.as_ptr();
let ktu_ptr = ktu.as_ptr();
sinkhorn_iter_cpu_into(&k, &a, &b, &mut u, &mut v, 2, 2, &mut kv, &mut ktu);
assert_eq!(kv.as_ptr(), kv_ptr);
assert_eq!(ktu.as_ptr(), ktu_ptr);
}
#[test]
fn cpu_short_inputs_update_available_lanes_only() {
let mut u = vec![1.0];
let mut v = vec![1.0];
let mut kv = Vec::new();
let mut ktu = Vec::new();
sinkhorn_iter_cpu_into(&[1.0], &[], &[], &mut u, &mut v, 2, 2, &mut kv, &mut ktu);
assert_eq!(u.len(), 1);
assert_eq!(v.len(), 1);
}
#[test]
fn ir_program_buffer_layout() {
let p = sinkhorn_scale("a", "kv", "u", 32);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["a", "kv", "u"]);
for buf in p.buffers.iter() {
assert_eq!(buf.count(), 32);
}
}
#[test]
fn zero_count_traps() {
let p = sinkhorn_scale("a", "kv", "u", 0);
assert!(p.stats().trap());
}
}