numra-sde 0.1.4

Stochastic differential equation solvers for Numra: Euler-Maruyama, Milstein, adaptive SRA1/SRA2, ensemble runner.
Documentation
//! SDE system trait and solver infrastructure.
//!
//! Author: Moussa Leblouba
//! Date: 3 February 2026
//! Modified: 2 May 2026

use numra_core::Scalar;

/// Type of noise in the SDE.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum NoiseType {
    /// Diagonal noise: each component has independent Wiener process
    #[default]
    Diagonal,
    /// Scalar noise: single Wiener process affects all components
    Scalar,
    /// General noise: full noise matrix (m Wiener processes, n state dims)
    General { n_wiener: usize },
}

/// Trait for stochastic differential equation systems.
///
/// Defines an SDE of the form:
/// ```text
/// dX(t) = f(t, X) dt + g(t, X) dW(t)
/// ```
pub trait SdeSystem<S: Scalar>: Sync {
    /// Dimension of the state space.
    fn dim(&self) -> usize;

    /// Evaluate the drift function f(t, x).
    ///
    /// # Arguments
    /// * `t` - Current time
    /// * `x` - Current state
    /// * `f` - Output buffer for drift (length = dim)
    fn drift(&self, t: S, x: &[S], f: &mut [S]);

    /// Evaluate the diffusion function g(t, x).
    ///
    /// For diagonal noise, `g` has length `dim`.
    /// For scalar noise, `g` has length `dim` (same noise scaled differently).
    /// For general noise, `g` has length `dim * n_wiener`.
    ///
    /// # Arguments
    /// * `t` - Current time
    /// * `x` - Current state
    /// * `g` - Output buffer for diffusion
    fn diffusion(&self, t: S, x: &[S], g: &mut [S]);

    /// Type of noise (default: diagonal).
    fn noise_type(&self) -> NoiseType {
        NoiseType::Diagonal
    }

    /// Number of Wiener processes.
    fn n_wiener(&self) -> usize {
        match self.noise_type() {
            NoiseType::Diagonal => self.dim(),
            NoiseType::Scalar => 1,
            NoiseType::General { n_wiener } => n_wiener,
        }
    }

    /// Derivative of diffusion w.r.t. state: ∂g/∂x * g
    ///
    /// Required for Milstein method. Default implementation uses finite differences.
    fn diffusion_derivative(&self, t: S, x: &[S], gdg: &mut [S]) {
        let dim = self.dim();
        let h_factor = S::EPSILON.sqrt();

        let mut g = vec![S::ZERO; dim];
        let mut g_plus = vec![S::ZERO; dim];
        let mut x_pert = x.to_vec();

        self.diffusion(t, x, &mut g);

        for i in 0..dim {
            let h = h_factor * (S::ONE + x[i].abs());
            x_pert[i] = x[i] + h;
            self.diffusion(t, &x_pert, &mut g_plus);
            x_pert[i] = x[i];

            // (∂g_i/∂x_i) * g_i for diagonal case
            gdg[i] = (g_plus[i] - g[i]) / h * g[i];
        }
    }
}

/// Options for SDE solvers.
///
/// **Divergence from `numra_ode::SolverOptions`** (per Foundation Spec §2.5):
/// SDE solvers carry stochastic noise, so step size also controls the Wiener
/// increment `δW ~ N(0, h)` — not just truncation accuracy. The fixed `dt`
/// field is distinguished from `rtol` / `atol` (used by adaptive SRA-family
/// methods) rather than collapsed onto a shared `h0` / `h_max`, because the
/// noise-discretisation interpretation matters at every step. `seed:
/// Option<u64>` is required for reproducibility — deterministic ODE
/// configuration has no analog. `save_trajectory: bool` toggles
/// trajectory-vs-final-only collection for Monte Carlo workloads where
/// intermediate states aren't kept. See
/// `docs/architecture/foundation-specification.md` §2.5.
#[derive(Clone, Debug)]
pub struct SdeOptions<S: Scalar> {
    /// Fixed time step (for non-adaptive methods)
    pub dt: S,
    /// Relative tolerance (for adaptive methods)
    pub rtol: S,
    /// Absolute tolerance (for adaptive methods)
    pub atol: S,
    /// Maximum time step
    pub dt_max: S,
    /// Minimum time step
    pub dt_min: S,
    /// Maximum number of steps
    pub max_steps: usize,
    /// Save solution at all steps (vs. just final)
    pub save_trajectory: bool,
    /// Random seed (None = use system entropy)
    pub seed: Option<u64>,
}

impl<S: Scalar> Default for SdeOptions<S> {
    fn default() -> Self {
        Self {
            dt: S::from_f64(0.01),
            rtol: S::from_f64(1e-3),
            atol: S::from_f64(1e-6),
            dt_max: S::INFINITY,
            dt_min: S::from_f64(1e-10),
            max_steps: 1_000_000,
            save_trajectory: true,
            seed: None,
        }
    }
}

impl<S: Scalar> SdeOptions<S> {
    /// Set fixed time step.
    pub fn dt(mut self, dt: S) -> Self {
        self.dt = dt;
        self
    }

    /// Set relative tolerance.
    pub fn rtol(mut self, rtol: S) -> Self {
        self.rtol = rtol;
        self
    }

    /// Set absolute tolerance.
    pub fn atol(mut self, atol: S) -> Self {
        self.atol = atol;
        self
    }

    /// Set maximum time step.
    pub fn dt_max(mut self, dt_max: S) -> Self {
        self.dt_max = dt_max;
        self
    }

    /// Set random seed for reproducibility.
    pub fn seed(mut self, seed: u64) -> Self {
        self.seed = Some(seed);
        self
    }

    /// Enable/disable trajectory saving.
    pub fn save_trajectory(mut self, save: bool) -> Self {
        self.save_trajectory = save;
        self
    }
}

/// Statistics from SDE solver.
#[derive(Clone, Debug, Default)]
pub struct SdeStats {
    /// Number of drift evaluations
    pub n_drift: usize,
    /// Number of diffusion evaluations
    pub n_diffusion: usize,
    /// Number of accepted steps
    pub n_accept: usize,
    /// Number of rejected steps (for adaptive methods)
    pub n_reject: usize,
}

/// Result of SDE integration.
#[derive(Clone, Debug)]
pub struct SdeResult<S: Scalar> {
    /// Time points
    pub t: Vec<S>,
    /// Solution at each time point (row-major: y[i*dim + j] = y_j(t_i))
    pub y: Vec<S>,
    /// Dimension of the system
    pub dim: usize,
    /// Solver statistics
    pub stats: SdeStats,
    /// Was integration successful?
    pub success: bool,
    /// Message (error description if failed)
    pub message: String,
}

impl<S: Scalar> SdeResult<S> {
    /// Create a new successful result.
    pub fn new(t: Vec<S>, y: Vec<S>, dim: usize, stats: SdeStats) -> Self {
        Self {
            t,
            y,
            dim,
            stats,
            success: true,
            message: String::new(),
        }
    }

    /// Create a failed result.
    pub fn failed(message: String, stats: SdeStats) -> Self {
        Self {
            t: Vec::new(),
            y: Vec::new(),
            dim: 0,
            stats,
            success: false,
            message,
        }
    }

    /// Number of time points.
    pub fn len(&self) -> usize {
        self.t.len()
    }

    /// Is result empty?
    pub fn is_empty(&self) -> bool {
        self.t.is_empty()
    }

    /// Get final time.
    pub fn t_final(&self) -> Option<S> {
        self.t.last().copied()
    }

    /// Get final state.
    pub fn y_final(&self) -> Option<Vec<S>> {
        if self.t.is_empty() {
            None
        } else {
            let start = (self.t.len() - 1) * self.dim;
            Some(self.y[start..start + self.dim].to_vec())
        }
    }

    /// Get state at index i.
    pub fn y_at(&self, i: usize) -> &[S] {
        let start = i * self.dim;
        &self.y[start..start + self.dim]
    }

    /// Iterate over (t, y) pairs.
    pub fn iter(&self) -> impl Iterator<Item = (S, &[S])> {
        self.t
            .iter()
            .enumerate()
            .map(move |(i, &t)| (t, self.y_at(i)))
    }
}

/// Trait for SDE solvers.
pub trait SdeSolver<S: Scalar> {
    /// Solve the SDE problem.
    ///
    /// # Arguments
    /// * `system` - The SDE system to solve
    /// * `t0` - Initial time
    /// * `tf` - Final time
    /// * `x0` - Initial state
    /// * `options` - Solver options
    /// * `seed` - Optional random seed (overrides options.seed)
    fn solve<Sys: SdeSystem<S>>(
        system: &Sys,
        t0: S,
        tf: S,
        x0: &[S],
        options: &SdeOptions<S>,
        seed: Option<u64>,
    ) -> Result<SdeResult<S>, String>;
}

#[cfg(test)]
mod tests {
    use super::*;

    struct TestSde;

    impl SdeSystem<f64> for TestSde {
        fn dim(&self) -> usize {
            1
        }
        fn drift(&self, _t: f64, x: &[f64], f: &mut [f64]) {
            f[0] = -x[0];
        }
        fn diffusion(&self, _t: f64, x: &[f64], g: &mut [f64]) {
            g[0] = 0.1 * x[0];
        }
    }

    #[test]
    fn test_sde_system_trait() {
        let sys = TestSde;
        assert_eq!(sys.dim(), 1);
        assert_eq!(sys.n_wiener(), 1);

        let mut f = [0.0];
        let mut g = [0.0];
        sys.drift(0.0, &[1.0], &mut f);
        sys.diffusion(0.0, &[1.0], &mut g);
        assert!((f[0] - (-1.0)).abs() < 1e-10);
        assert!((g[0] - 0.1).abs() < 1e-10);
    }

    #[test]
    fn test_sde_options() {
        let opts: SdeOptions<f64> = SdeOptions::default().dt(0.001).seed(42);
        assert!((opts.dt - 0.001).abs() < 1e-10);
        assert_eq!(opts.seed, Some(42));
    }

    /// Test SDE with linear-in-x diffusion `g(t, x) = α·x`. Then
    /// `(∂g_i/∂x_i) · g_i = α · α·x_i = α²·x_i`, which is the closed-form
    /// answer the trait default's forward-FD computation should approximate.
    struct LinearDiffusionSde {
        alpha: f64,
    }

    impl SdeSystem<f64> for LinearDiffusionSde {
        fn dim(&self) -> usize {
            2
        }
        fn drift(&self, _t: f64, _x: &[f64], f: &mut [f64]) {
            f[0] = 0.0;
            f[1] = 0.0;
        }
        fn diffusion(&self, _t: f64, x: &[f64], g: &mut [f64]) {
            g[0] = self.alpha * x[0];
            g[1] = self.alpha * x[1];
        }
        // Note: deliberately not overriding diffusion_derivative — exercises
        // the trait default that this test pins.
    }

    #[test]
    fn test_diffusion_derivative_default_large_x_no_scaling_bug() {
        // Pins F-FD-NOSCALE-BUG for the trait-default forward-FD branch: with
        // unscaled `h = 1e-8`, `x[i] + h == x[i]` in f64 for |x| > ~5e7, so
        // `g_plus == g` and the result was `0/h * g = 0` instead of `α²·x`.
        // With canonical `sqrt(EPSILON) * (1 + |x|)` the answer is recovered.
        // (Note site 4 is *forward* FD — different canonical step than the
        // central-FD sites in numra-optim and numra-dde.)
        let alpha = 0.5_f64;
        let sys = LinearDiffusionSde { alpha };
        let x = [1e8, 1e8];
        let mut gdg = [0.0; 2];
        sys.diffusion_derivative(0.0, &x, &mut gdg);

        // Expected: α² · x_i = 0.25 · 1e8 = 2.5e7
        let expected = alpha * alpha * 1e8;
        assert!(
            (gdg[0] - expected).abs() < 1e-3 * expected.abs(),
            "gdg[0] = {} should be ≈ {} (within 1e-3 relative); old unscaled formula returns 0",
            gdg[0],
            expected
        );
        assert!(
            (gdg[1] - expected).abs() < 1e-3 * expected.abs(),
            "gdg[1] = {} should be ≈ {} (within 1e-3 relative); old unscaled formula returns 0",
            gdg[1],
            expected
        );
    }
}