use vyre_foundation::ir::Program;
use vyre_primitives::math::sinkhorn_iterate::sinkhorn_iterate;
pub const OP_ID: &str = "vyre-libs::self_substrate::sinkhorn_full_clustering";
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn sinkhorn_full_clustering_program(
k: &str,
k_t: &str,
a: &str,
b: &str,
u_curr: &str,
u_next: &str,
v: &str,
kv: &str,
ktu: &str,
changed: &str,
m: u32,
n: u32,
max_iterations: u32,
) -> Program {
use crate::observability::{bump, sinkhorn_full_clustering_calls};
bump(&sinkhorn_full_clustering_calls);
sinkhorn_iterate(
k,
k_t,
a,
b,
u_curr,
u_next,
v,
kv,
ktu,
changed,
m,
n,
max_iterations,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sinkhorn_clustering_program() {
let p = sinkhorn_full_clustering_program(
"k", "kt", "a", "b", "uc", "un", "v", "kv", "ktu", "c", 10, 20, 5,
);
assert_eq!(p.buffers().len(), 10);
assert!(p.buffers().iter().any(|b| b.name() == "uc"));
}
#[test]
fn test_multi_region_sinkhorn() {
let p1 = sinkhorn_full_clustering_program(
"k1", "kt1", "a1", "b1", "uc1", "un1", "v1", "kv1", "ktu1", "c1", 2, 2, 1,
);
let p2 = sinkhorn_full_clustering_program(
"k2", "kt2", "a2", "b2", "uc2", "un2", "v2", "kv2", "ktu2", "c2", 2, 2, 1,
);
let p3 = sinkhorn_full_clustering_program(
"k3", "kt3", "a3", "b3", "uc3", "un3", "v3", "kv3", "ktu3", "c3", 2, 2, 1,
);
let mut entry = p1.entry().to_vec();
entry.extend(p2.entry().to_vec());
entry.extend(p3.entry().to_vec());
let mut buffers = p1.buffers().to_vec();
buffers.extend(p2.buffers().to_vec());
buffers.extend(p3.buffers().to_vec());
let final_p = Program::wrapped(buffers, [256, 1, 1], entry);
let region_count = final_p
.entry()
.iter()
.filter(|n| matches!(n, vyre_foundation::ir::Node::Region { .. }))
.count();
assert!(region_count >= 3);
}
#[test]
fn test_end_to_end_sinkhorn_parity() {
let k = vec![65536, 65536, 65536, 65536];
let a = vec![32768, 32768];
let b = vec![32768, 32768];
let u_c = vec![65536, 65536];
let v_in = vec![65536, 65536];
let p = sinkhorn_full_clustering_program(
"k", "kt", "a", "b", "uc", "un", "v", "kv", "ktu", "c", 2, 2, 1,
);
use std::sync::Arc;
use vyre_reference::reference_eval;
use vyre_reference::value::Value;
let to_value = |data: &[u32]| {
let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
Value::Bytes(Arc::from(bytes))
};
let inputs = vec![
to_value(&u_c),
to_value(&[0_u32, 0]),
to_value(&[0]),
to_value(&k),
to_value(&k), to_value(&a),
to_value(&b),
to_value(&v_in),
to_value(&[0_u32, 0]),
to_value(&[0_u32, 0]),
];
let results = reference_eval(&p, &inputs).expect("Fix: interpreter failed");
let actual_bytes = results[0].to_bytes();
let actual_u: Vec<u32> = actual_bytes
.chunks_exact(4)
.map(|c| u32::from_le_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(actual_u[0], 32768);
}
}