flowmatch
Flow matching in Rust. Train generative models that learn to transport noise to data via ODE vector fields.
Problem
You have a set of target points -- protein backbone angles, earthquake epicenters, token embeddings -- and want to train a vector field that transforms Gaussian noise into samples from the same distribution. Flow matching [1] does this by regressing a conditional vector field along straight (or geodesic) interpolation paths, then sampling via ODE integration.
This library provides the training loop, OT-based coupling, ODE integration, and evaluation metrics. It works on flat spaces and on Riemannian manifolds.
Examples
Transport noise to discrete targets (simplest case). Semidiscrete FM pairs Gaussian noise with fixed target points via optimal transport, trains a linear conditional field, and integrates an ODE to produce samples:
n=16 d=8
pot_cfg: steps=2000 batch=1024 seed=7
fm_cfg: steps=800 batch=256 lr=0.008 seed=9 euler_steps=40
sample_mse_to_assigned_y = 0.0367
Straighter trajectories via minibatch OT. Rectified flow matching [7] uses Sinkhorn coupling within each minibatch so that noise-to-data paths cross less, reducing integration error:
sample_mse_to_assigned_y = 0.0684
Protein torsion angles on a torus. Backbone phi/psi angles live on S1 x S1. This example trains on real angles from PDB 1BPI (BPTI), then measures sample quality by JS divergence between generated and observed Ramachandran histograms:
PDB 1BPI φ/ψ (n=56) as a torus via R^4 embedding
Ramachandran histogram JS divergence (lower is better):
- baseline (Gaussian decode): 0.6391
- trained (RFM+minibatch OT): 0.4105
- ratio trained/baseline: 0.642
Earthquake locations on a sphere. USGS M6+ earthquake epicenters (2024) mapped to S^2. Evaluation uses entropic OT cost between generated and observed locations:
USGS earthquakes (n=50), embedding=R^3 with S^2 projection
OT cost (lower is better):
- baseline (near-noise): 0.6496
- trained (RFM+minibatch OT): 0.3129
- ratio trained/baseline: 0.482
Some generated samples (lat, lon):
0: lat= 12.63°, lon= -104.96°
1: lat= 58.20°, lon= 169.16°
2: lat= -13.11°, lon= -167.62°
3: lat= -35.47°, lon= -79.28°
Geodesics on the Poincare ball. Riemannian ODE integration on hyperbolic space, using the skel::Manifold trait implemented by hyperball:
All examples
| Example | What it shows |
|---|---|
sd_fm_semidiscrete_linear |
Gaussian noise to discrete targets via semidiscrete OT |
rfm_minibatch_ot_linear |
Minibatch Sinkhorn coupling for straighter trajectories |
rfm_minibatch_outlier_partial |
Outlier forcing problem and partial pairing fix |
rfm_protein_torsions_1bpi |
Real protein phi/psi angles on the torus, JS divergence metric |
rfm_usgs_earthquakes_sphere |
Real earthquake locations on S^2, OT cost metric |
rfm_textish_tokens |
Token embeddings with TF-IDF weights |
rfm_torsions_nfe_curve |
Sample quality vs. ODE steps (torsion data) |
rfm_usgs_nfe_curve |
Sample quality vs. ODE steps (earthquake data) |
rfm_usgs_solver_nfe_tradeoff |
Euler vs. Heun under equal compute budgets |
ode_comparison |
Euler vs Heun on a 2D circular ODE (radius preservation) |
rfm_poincare_geodesic_ode |
Riemannian ODE on Poincare ball (--features riemannian) |
discrete_ctmc_path_evolution |
CTMC path evolution with time-dependent generators |
rfm_conditional_2d |
2D conditional flow matching visualization |
rfm_two_moons |
Two-moons distribution transport |
burn_sd_fm_semidiscrete_linear |
Semidiscrete FM with Burn backend (--features burn) |
burn_rfm_minibatch_ot_linear |
RFM with Burn backend (--features burn) |
mmd_flow_eval |
MMD (kernel two-sample test) as a flow quality metric |
riemannian_fm_poincare |
Riemannian FM on the Poincare disk (--features riemannian) |
profile_breakdown_* |
Where training time goes (Sinkhorn vs SGD) |
Requires --features sheaf-evals:
| Example | What it shows |
|---|---|
rfm_usgs_earthquakes_cluster_mass |
Do generated samples preserve cluster structure? |
rfm_usgs_knn_leiden |
kNN graph + Leiden community detection on generated data |
rfm_usgs_full_pipeline_report |
Full pipeline with all metrics and timings |
What it provides
Training: Semidiscrete FM, rectified flow matching with minibatch OT coupling, time schedules (uniform, U-shaped, logit-normal).
Sampling: Fixed-step ODE integrators (Euler, Heun) for Euclidean and Riemannian manifolds.
Coupling: Sinkhorn OT pairing, greedy matching, partial/selective pairing for outlier handling.
Discrete FM: CTMC generator scaffolding with cosine-squared schedule [3], conditional probability paths, conditional rate matrices.
Evaluation: JS divergence on histograms, entropic OT cost.
Dependencies
wass-- optimal transport (Sinkhorn, coupling)skel-- manifold trait (exp/log/transport)logp-- information theory (JS divergence)hyperball-- hyperbolic geometry (dev-dependency for Riemannian tests)
Status
MSRV: 1.80.
Tests
References
- Lipman et al., Flow Matching for Generative Modeling (2022)
- Lipman et al., Flow Matching Guide and Code (2024) -- comprehensive tutorial
- Gat et al., Discrete Flow Matching (NeurIPS 2024) -- CTMC-based discrete FM
- Chen & Lipman, Riemannian Flow Matching on General Geometries (2023)
- de Kruiff et al., Pullback Flow Matching on Data Manifolds (2024) -- FM on implicit manifolds without closed-form exp/log maps
- Sherry & Smets, Flow Matching on Lie Groups (2025) -- specialization to SO(3) and SE(3)
- Liu et al., Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow (2022) -- rectified flow
License
MIT OR Apache-2.0