use flowmatch::linear::LinearCondField;
use flowmatch::ode::{integrate_fixed, OdeMethod};
use flowmatch::ot_cfm::{ot_cfm_training_step, OtCfmConfig};
use ndarray::{Array1, Array2};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let batch_size = 32;
let dim = 2;
let steps = 200;
let lr = 0.01_f32;
let centers = [[2.0_f32, 2.0], [-2.0, 2.0], [-2.0, -2.0], [2.0, -2.0]];
let ot_cfg = OtCfmConfig {
reg: 1.0,
max_sinkhorn_iter: 100,
sinkhorn_tol: 1e-2,
};
let mut field = LinearCondField::new_zeros(dim);
let mut rng = ChaCha8Rng::seed_from_u64(42);
println!("OT-CFM training: {dim}D, batch={batch_size}, steps={steps}");
println!();
for step in 0..steps {
let mut x0 = Array2::<f32>::zeros((batch_size, dim));
for i in 0..batch_size {
for j in 0..dim {
x0[[i, j]] = StandardNormal.sample(&mut rng);
}
}
let mut x1 = Array2::<f32>::zeros((batch_size, dim));
for i in 0..batch_size {
let c = ¢ers[i % 4];
for j in 0..dim {
x1[[i, j]] = c[j]
+ 0.3 * {
let v: f64 = StandardNormal.sample(&mut rng);
v as f32
};
}
}
let t: Vec<f32> = (0..batch_size)
.map(|_| rng.random_range(0.0_f32..1.0))
.collect();
let targets = ot_cfm_training_step(&x0.view(), &x1.view(), &t, &ot_cfg)?;
for (i, &t_i) in t.iter().enumerate().take(batch_size) {
let xt_row = targets.x_t.row(i);
let j = targets.coupling[i];
let y_row = x1.row(j);
let u_row = targets.u_t.row(i);
field.sgd_step(&xt_row, t_i, &y_row, &u_row, lr / batch_size as f32)?;
}
if step % 50 == 0 || step == steps - 1 {
let avg_dist = evaluate_sampling(&field, ¢ers, &mut rng)?;
println!("Step {step:>3}: avg dist to center = {avg_dist:.4}");
}
}
Ok(())
}
fn evaluate_sampling(
field: &LinearCondField,
centers: &[[f32; 2]; 4],
rng: &mut ChaCha8Rng,
) -> Result<f32, Box<dyn std::error::Error>> {
let n_eval = 32;
let ode_steps = 20;
let dt = 1.0 / ode_steps as f32;
let dim = 2;
let mut total_dist = 0.0_f32;
for i in 0..n_eval {
let mut x = Array1::<f32>::zeros(dim);
for j in 0..dim {
x[j] = StandardNormal.sample(rng);
}
let c_idx = i % 4;
let y = Array1::from_vec(centers[c_idx].to_vec());
let result = integrate_fixed(OdeMethod::Heun, &x, 0.0, dt, ode_steps, |x_arr, t_val| {
field.eval(x_arr, t_val, &y.view())
})?;
let dist: f32 = result
.iter()
.zip(centers[c_idx].iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f32>()
.sqrt();
total_dist += dist;
}
Ok(total_dist / n_eval as f32)
}