use ndarray::{Array1, Array2};
use wass::sinkhorn_log_with_convergence;
fn rgb_distance(a: &[f32; 3], b: &[f32; 3]) -> f32 {
let dr = a[0] - b[0];
let dg = a[1] - b[1];
let db = a[2] - b[2];
(dr * dr + dg * dg + db * db).sqrt()
}
fn print_palette(name: &str, palette: &[[f32; 3]]) {
println!("{name}:");
for (i, c) in palette.iter().enumerate() {
println!(" [{i}] R={:5.1} G={:5.1} B={:5.1}", c[0], c[1], c[2]);
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let warm: Vec<[f32; 3]> = vec![
[220.0, 50.0, 30.0], [240.0, 100.0, 20.0], [250.0, 150.0, 30.0], [255.0, 200.0, 50.0], [255.0, 230.0, 80.0], [200.0, 40.0, 60.0], [180.0, 70.0, 20.0], [230.0, 180.0, 60.0], ];
let cool: Vec<[f32; 3]> = vec![
[30.0, 60.0, 200.0], [20.0, 100.0, 220.0], [50.0, 150.0, 210.0], [60.0, 200.0, 200.0], [80.0, 220.0, 180.0], [70.0, 40.0, 180.0], [100.0, 60.0, 160.0], [40.0, 170.0, 190.0], ];
let n = warm.len();
assert_eq!(n, cool.len());
print_palette("Warm palette (source)", &warm);
println!();
print_palette("Cool palette (target)", &cool);
println!();
let mut cost = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
cost[[i, j]] = rgb_distance(&warm[i], &cool[j]);
}
}
let a = Array1::from_elem(n, 1.0 / n as f32);
let b = a.clone();
let reg = 5.0; let max_iter = 500;
let tol = 1e-6;
let (plan, distance, iters) = sinkhorn_log_with_convergence(&a, &b, &cost, reg, max_iter, tol)?;
println!("Sinkhorn distance: {distance:.2} (converged in {iters} iterations, reg={reg})");
println!();
println!("Transport plan (rows=warm, cols=cool), entries * {n}:");
print!("{:>8}", "");
for j in 0..n {
print!(" cool{j}");
}
println!();
for i in 0..n {
print!("warm{i} ",);
for j in 0..n {
print!(" {:.3}", plan[[i, j]] * n as f32);
}
println!();
}
println!();
println!("Transferred palette (warm -> cool via transport plan):");
for i in 0..n {
let row_sum: f32 = plan.row(i).sum();
let mut mapped = [0.0f32; 3];
for j in 0..n {
let weight = plan[[i, j]] / row_sum;
mapped[0] += weight * cool[j][0];
mapped[1] += weight * cool[j][1];
mapped[2] += weight * cool[j][2];
}
println!(
" warm[{i}] ({:5.1},{:5.1},{:5.1}) -> ({:5.1},{:5.1},{:5.1})",
warm[i][0], warm[i][1], warm[i][2], mapped[0], mapped[1], mapped[2],
);
}
Ok(())
}