flowmatch 0.1.6

Flow matching primitives (ndarray-first; backend-agnostic) with semidiscrete FM and RFM experiments.
Documentation
//! Riemannian Flow Matching (RFM) training loop.
//!
//! Based on Chen & Lipman (2023), "Riemannian Flow Matching on General Geometries".
//!
//! The core objective is to regress a time-dependent vector field `v_t(x)` against
//! the conditional vector field `u_t(x | x_0, x_1)` generated by a geodesic path.
//!
//! Path: `x_t = exp_{x_0}(t * log_{x_0}(x_1))`
//! Target: `u_t(x_t) = dx_t/dt` (tangent vector to the geodesic at `x_t`)
//! Loss: `|| v_t(x_t) - u_t(x_t) ||_{x_t}^2` (Riemannian metric norm)

use crate::sd_fm::SdFmTrainConfig;
use crate::{Error, Result};
use ndarray::{Array1, ArrayView1, ArrayView2};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use skel::Manifold;

/// A vector field on a manifold: `v(x, t) in T_x M`.
pub trait ManifoldVectorField {
    /// Evaluate the vector field at point `x` and time `t`.
    /// Returns a vector in the tangent space `T_x M`.
    fn eval(&self, x: &ArrayView1<f32>, t: f32) -> Array1<f32>;

    /// Update parameters (SGD step) to minimize `|| v(x,t) - target ||_x^2`.
    /// `target` is in `T_x M`.
    fn sgd_step(&mut self, x: &ArrayView1<f32>, t: f32, target: &ArrayView1<f32>, lr: f32);
}

/// Train a Riemannian Flow Matching model.
///
/// - `manifold`: The manifold geometry (defines exp/log maps).
/// - `field`: The learnable vector field.
/// - `x1_samples`: Target data samples (on the manifold).
/// - `cfg`: Training configuration.
///
/// Note: This implementation assumes a "simple" base distribution `p(x_0)` that we can sample from.
/// For Poincaré ball, this is typically a wrapped normal or uniform on the ball.
/// Here we accept a `sample_x0` closure to abstract the base distribution.
pub fn train_riemannian_fm<M, F, S>(
    manifold: &M,
    field: &mut F,
    x1_samples: &ArrayView2<f32>,
    cfg: &SdFmTrainConfig,
    sample_x0: S,
) -> Result<()>
where
    M: Manifold,
    F: ManifoldVectorField,
    S: Fn(&mut ChaCha8Rng) -> Array1<f32>,
{
    let n_data = x1_samples.nrows();
    if n_data == 0 {
        return Err(Error::Domain("x1_samples must be non-empty"));
    }

    let mut rng = ChaCha8Rng::seed_from_u64(cfg.seed);

    for _step in 0..cfg.steps {
        for _ in 0..cfg.batch_size {
            // 1. Sample t, x0, x1
            let t = cfg.t_schedule.sample_t(&mut rng);
            let x0_f32 = sample_x0(&mut rng);

            // Pick random target x1 (use gen_range for uniform integer sampling
            // without f32 precision limits).
            let idx = rng.random_range(0..n_data);
            let x1_f32 = x1_samples.row(idx);

            // Convert to f64 for manifold ops (skel uses f64)
            let x0 = x0_f32.mapv(|v| v as f64);
            let x1 = x1_f32.mapv(|v| v as f64);

            // 2. Compute geodesic path and target velocity
            // v_init = log_{x0}(x1)
            let v_init = manifold.log_map(&x0.view(), &x1.view());

            // x_t = exp_{x0}(t * v_init)
            let t_v_init = &v_init * (t as f64);
            let xt = manifold.exp_map(&x0.view(), &t_v_init.view());

            // Target velocity u_t at x_t.
            // u_t = P_{x0 -> xt}(v_init)
            let ut = manifold.parallel_transport(&x0.view(), &xt.view(), &v_init.view());

            // 3. Regression step
            let xt_f32 = xt.mapv(|v| v as f32);
            let ut_f32 = ut.mapv(|v| v as f32);

            field.sgd_step(&xt_f32.view(), t, &ut_f32.view(), cfg.lr);
        }
    }

    Ok(())
}