flowmatch 0.1.6

Flow matching primitives (ndarray-first; backend-agnostic) with semidiscrete FM and RFM experiments.
Documentation
#![allow(dead_code)]
//! Simplex-based helpers (for “discrete FM on the simplex” style methods).
//!
//! This module is intentionally minimal: it provides **explicit** utilities for
//! (1) validating simplex constraints and (2) sampling basic relaxations.
//!
//! Public invariant: we do **not** silently normalize in methods named like “validate”.
//! If we normalize, the function name says so (`normalize_*`).
//!
//! # Simplex Flow Matching: Related Work
//!
//! Linear interpolation on the simplex (`simplex_lerp`) has known pathologies: Stark et al.
//! (2024) show it produces discontinuous training targets. Several alternatives exist:
//!
//! - Davis et al. (2024, NeurIPS), “Fisher Flow Matching” -- equips the simplex with the
//!   Fisher-Rao metric from information geometry; geodesics under this metric respect the
//!   simplex boundary and avoid the linear interpolation discontinuity
//! - Cheng et al. (2024, NeurIPS), “Categorical Flow Matching on Statistical Manifolds”
//!   -- complementary Fisher-information-metric approach with “Fisher-efficient” updates
//! - Stark et al. (2024), “Dirichlet Flow Matching with Applications to DNA Sequence
//!   Design” -- uses Dirichlet interpolation paths instead of linear; `sample_dirichlet`
//!   below partially supports this approach
//! - Tang et al. (2025), “Gumbel-Softmax Flow and Score Matching” -- Gumbel-Softmax
//!   reparameterization for scaling simplex FM to higher dimensions

use crate::{Error, Result};

/// Check whether `p` lies on the probability simplex (within `tol`).
pub fn validate_simplex(p: &[f32], tol: f32) -> Result<()> {
    if p.is_empty() {
        return Err(Error::Domain("simplex vector must be non-empty"));
    }
    if !tol.is_finite() || tol < 0.0 {
        return Err(Error::Domain("tol must be finite and >= 0"));
    }
    if p.iter().any(|&x| !x.is_finite()) {
        return Err(Error::Domain("simplex vector contains non-finite values"));
    }
    if p.iter().any(|&x| x < -tol) {
        return Err(Error::Domain("simplex vector has negative entries"));
    }
    let s: f32 = p.iter().sum();
    if (s - 1.0).abs() > tol {
        return Err(Error::Domain(
            "simplex vector does not sum to 1 (within tol)",
        ));
    }
    Ok(())
}

/// Explicit normalization to the simplex via `p_i / sum(p)`, with checks.
pub fn normalize_simplex(p: &[f32]) -> Result<Vec<f32>> {
    if p.is_empty() {
        return Err(Error::Domain("simplex vector must be non-empty"));
    }
    if p.iter().any(|&x| !x.is_finite()) {
        return Err(Error::Domain("vector contains non-finite values"));
    }
    if p.iter().any(|&x| x < 0.0) {
        return Err(Error::Domain("vector must be nonnegative to normalize"));
    }
    let s: f32 = p.iter().sum();
    if s <= 0.0 {
        return Err(Error::Domain("vector must have positive total mass"));
    }
    Ok(p.iter().map(|&x| x / s).collect())
}

/// Linear interpolant on the simplex: `p(t) = (1-t)p0 + t p1`.
///
/// This does not enforce normalization; if `p0` and `p1` are on the simplex,
/// then `p(t)` is on the simplex for `t in [0,1]`.
pub fn simplex_lerp(p0: &[f32], p1: &[f32], t: f32) -> Result<Vec<f32>> {
    if p0.len() != p1.len() {
        return Err(Error::Shape("p0 and p1 must have same length"));
    }
    if !(0.0..=1.0).contains(&t) || !t.is_finite() {
        return Err(Error::Domain("t must be in [0,1] and finite"));
    }
    let mut out = Vec::with_capacity(p0.len());
    for i in 0..p0.len() {
        out.push((1.0 - t) * p0[i] + t * p1[i]);
    }
    Ok(out)
}

/// Sample a Dirichlet distribution with parameters `alpha` using Gamma draws.
///
/// Returns a simplex vector of the same length as `alpha`.
pub fn sample_dirichlet(alpha: &[f32], rng: &mut impl rand::Rng) -> Result<Vec<f32>> {
    if alpha.is_empty() {
        return Err(Error::Domain("alpha must be non-empty"));
    }
    if alpha.iter().any(|&a| !a.is_finite() || a <= 0.0) {
        return Err(Error::Domain("Dirichlet alpha must be positive and finite"));
    }

    use rand_distr::{Distribution, Gamma};
    let mut xs = vec![0.0f32; alpha.len()];
    let mut s: f64 = 0.0;
    for (i, &a) in alpha.iter().enumerate() {
        let g = Gamma::new(a as f64, 1.0).map_err(|_| Error::Domain("invalid Gamma params"))?;
        let x: f64 = g.sample(rng);
        xs[i] = x as f32;
        s += x;
    }
    if s <= 0.0 {
        return Err(Error::Domain("Dirichlet sampling produced zero total mass"));
    }
    for v in &mut xs {
        *v = (*v as f64 / s) as f32;
    }
    Ok(xs)
}

#[cfg(test)]
mod tests {
    use super::*;
    use rand::SeedableRng;
    use rand_chacha::ChaCha8Rng;

    #[test]
    fn dirichlet_samples_are_on_simplex() {
        let alpha = vec![0.7f32, 1.3, 2.1, 0.5];
        let mut rng = ChaCha8Rng::seed_from_u64(123);
        let p = sample_dirichlet(&alpha, &mut rng).unwrap();
        validate_simplex(&p, 1e-5).unwrap();
        assert!(p.iter().all(|&x| x >= 0.0));
    }

    #[test]
    fn validate_simplex_accepts_valid_point() {
        let p = vec![0.2f32, 0.3, 0.5];
        validate_simplex(&p, 1e-6).unwrap();
    }

    #[test]
    fn validate_simplex_rejects_wrong_sum() {
        let p = vec![0.2f32, 0.3, 0.4]; // sums to 0.9
        assert!(validate_simplex(&p, 1e-6).is_err());
    }

    #[test]
    fn validate_simplex_rejects_negative() {
        let p = vec![-0.1f32, 0.6, 0.5];
        assert!(validate_simplex(&p, 1e-6).is_err());
    }

    #[test]
    fn validate_simplex_rejects_empty() {
        let p: Vec<f32> = vec![];
        assert!(validate_simplex(&p, 1e-6).is_err());
    }

    #[test]
    fn normalize_simplex_produces_valid_simplex() {
        let p = vec![1.0f32, 2.0, 3.0, 4.0];
        let n = normalize_simplex(&p).unwrap();
        validate_simplex(&n, 1e-6).unwrap();
        // Check proportions are preserved.
        assert!((n[0] / n[1] - 0.5).abs() < 1e-6);
        assert!((n[2] / n[3] - 0.75).abs() < 1e-6);
    }

    #[test]
    fn normalize_simplex_rejects_negative() {
        let p = vec![1.0f32, -1.0, 2.0];
        assert!(normalize_simplex(&p).is_err());
    }

    #[test]
    fn normalize_simplex_rejects_all_zeros() {
        let p = vec![0.0f32, 0.0, 0.0];
        assert!(normalize_simplex(&p).is_err());
    }

    #[test]
    fn simplex_lerp_at_boundaries() {
        let p0 = vec![1.0f32, 0.0, 0.0];
        let p1 = vec![0.0f32, 0.0, 1.0];

        // t=0 -> p0
        let at0 = simplex_lerp(&p0, &p1, 0.0).unwrap();
        for (a, b) in at0.iter().zip(p0.iter()) {
            assert!((a - b).abs() < 1e-7);
        }

        // t=1 -> p1
        let at1 = simplex_lerp(&p0, &p1, 1.0).unwrap();
        for (a, b) in at1.iter().zip(p1.iter()) {
            assert!((a - b).abs() < 1e-7);
        }

        // t=0.5 -> midpoint
        let mid = simplex_lerp(&p0, &p1, 0.5).unwrap();
        assert!((mid[0] - 0.5).abs() < 1e-7);
        assert!((mid[1] - 0.0).abs() < 1e-7);
        assert!((mid[2] - 0.5).abs() < 1e-7);
    }

    #[test]
    fn simplex_lerp_stays_on_simplex() {
        // If p0 and p1 are on the simplex, p(t) is on the simplex for t in [0,1].
        let p0 = vec![0.3f32, 0.5, 0.2];
        let p1 = vec![0.1f32, 0.1, 0.8];
        for i in 0..=10 {
            let t = i as f32 / 10.0;
            let pt = simplex_lerp(&p0, &p1, t).unwrap();
            validate_simplex(&pt, 1e-5).unwrap();
        }
    }

    #[test]
    fn simplex_lerp_rejects_mismatched_lengths() {
        let p0 = vec![0.5f32, 0.5];
        let p1 = vec![0.3f32, 0.3, 0.4];
        assert!(simplex_lerp(&p0, &p1, 0.5).is_err());
    }

    #[test]
    fn simplex_lerp_rejects_out_of_range_t() {
        let p0 = vec![0.5f32, 0.5];
        let p1 = vec![0.3f32, 0.7];
        assert!(simplex_lerp(&p0, &p1, -0.1).is_err());
        assert!(simplex_lerp(&p0, &p1, 1.1).is_err());
    }
}