Skip to main content

flowmatch/
lib.rs

1#![warn(missing_docs)]
2//! # flowmatch
3//!
4//! Flow matching as a library primitive.
5//!
6//! This crate is intentionally small:
7//!
8//! - it implements **training loops** and **sampling** for flow-matching style models,
9//! - it depends on `wass` for OT-shaped coupling primitives (e.g. semidiscrete assignments),
10//! - it does not provide a CLI or experiment runner (that belongs in L6 / apps).
11//!
12//! ## Public invariants (must not change)
13//!
14//! - **Determinism knobs are explicit**: training/sampling functions take `seed` (or configs do).
15//! - **No hidden normalization**: if inputs are normalized, it is stated in the doc comment.
16//! - **Backend-agnostic by default**: this crate uses `ndarray` and simple SGD; no GPU framework types
17//!   leak through the public API in the default feature set.
18//!   - Optional training backends (e.g. `burn`) are **feature-gated**.
19//!
20//! ## How this maps to “Flow Matching” (papers)
21//!
22//! The core training objective used here is the standard *conditional flow matching* regression
23//! (sample \(t\), sample a point on a path \(x_t\), regress a vector field \(v_\theta(x_t,t;\cdot)\)
24//! toward a target velocity \(u_t\)). Concretely:
25//!
26//! - `sd_fm::train_sd_fm_semidiscrete_linear` uses a **linear interpolation path**
27//!   \(x_t = (1-t)x_0 + t y_j\) and target \(u_t = y_j - x_0\).
28//! - A semidiscrete “pick \(j\)” step is provided by `wass::semidiscrete` (potentials + hard assignment),
29//!   which acts like a simple coupling / conditioning mechanism.
30//!
31//! ## References (conceptual anchors; not “implemented fully”)
32//!
33//! - Lipman et al., *Flow Matching for Generative Modeling* (arXiv:2210.02747):
34//!   the canonical FM objective and linear-path baselines.
35//! - Lipman et al., *Flow Matching Guide and Code* (arXiv:2412.06264):
36//!   a comprehensive reference covering the full design space.
37//! - Li et al., *Flow Matching Meets Biology and Life Science: A Survey* (arXiv:2507.17731, 2025):
38//!   a taxonomy of variants (CFM/RFM, non-Euclidean, discrete) and a map of applications/tooling.
39//! - Gat et al., *Discrete Flow Matching* (NeurIPS 2024):
40//!   extending the FM paradigm to discrete data (language, graphs).
41//! - Chen & Lipman, *Riemannian Flow Matching on General Geometries* (arXiv:2302.03660):
42//!   the foundation for FM on manifolds (like the Poincaré ball in `hyperball`).
43//!
44//! Related variants that are **not** implemented here (yet):
45//!
46//! - Dao et al., *Flow Matching in Latent Space* (arXiv:2307.08698) — latent FM + guidance details
47//! - Klein et al., *Equivariant Flow Matching* (NeurIPS 2023) — symmetry constraints
48//! - Zaghen et al., *Towards Variational Flow Matching on General Geometries* (arXiv:2502.12981, 2025) —
49//!   variational objectives with Riemannian Gaussians (RG-VFM).
50//!
51//! **Applications & Extensions**:
52//!
53//! - Qin et al., *DeFoG: Discrete Flow Matching for Graph Generation* (arXiv:2410.04263, 2025).
54//! - FlowMM: Generating Materials with Riemannian Flow Matching (2024/2025).
55//!
56//! ## What can change later
57//!
58//! - The parameterization of vector fields (linear vs MLP vs backend-specific).
59//! - ODE integrators (Euler → Heun/RK).
60//! - Adding optional Tweedie correction utilities (diffusion-specific).
61//!
62//! ## Module map
63//!
64//! - `sd_fm`: semidiscrete conditional FM training and sampling
65//! - `rfm`: rectified-flow coupling helpers (minibatch OT pairing)
66//! - `linear`: simple linear vector-field parameterizations
67//! - `ode`: fixed-step ODE integrators (`Euler`, `Heun`)
68//! - `metrics`: evaluation metrics (JS divergence, entropic OT cost)
69//! - `discrete_ctmc`: CTMC generator scaffolding for discrete FM
70//! - `simplex`: simplex utilities for discrete FM variants
71//! - `riemannian`: Riemannian FM training (feature-gated: `riemannian`)
72//! - `riemannian_ode`: manifold ODE integrators (feature-gated: `riemannian`)
73//! - `burn_euclidean`: Burn-backed Euclidean FM training (feature-gated: `burn`)
74//! - `burn_sd_fm`: Burn-backed SD-FM/RFM training (feature-gated: `burn`)
75
76pub mod discrete_ctmc;
77pub mod linear;
78pub mod metrics;
79pub mod ode;
80pub mod rfm;
81#[cfg(feature = "riemannian")]
82pub mod riemannian;
83#[cfg(feature = "riemannian")]
84pub mod riemannian_ode;
85pub mod sd_fm;
86pub mod simplex;
87
88#[cfg(feature = "burn")]
89pub mod burn_euclidean;
90
91#[cfg(feature = "burn")]
92pub mod burn_sd_fm;
93
94/// flowmatch error variants.
95#[derive(Debug, thiserror::Error)]
96pub enum Error {
97    /// Array shape or dimension mismatch.
98    #[error("shape mismatch: {0}")]
99    Shape(&'static str),
100    /// Value outside the valid domain (e.g., `t` not in `[0, 1]`).
101    #[error("domain error: {0}")]
102    Domain(&'static str),
103}
104
105/// Convenience alias for `std::result::Result<T, flowmatch::Error>`.
106pub type Result<T> = std::result::Result<T, Error>;