numra_sde/system.rs
1//! SDE system trait and solver infrastructure.
2//!
3//! Author: Moussa Leblouba
4//! Date: 3 February 2026
5//! Modified: 2 May 2026
6
7use numra_core::Scalar;
8
9/// Type of noise in the SDE.
10#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
11pub enum NoiseType {
12 /// Diagonal noise: each component has independent Wiener process
13 #[default]
14 Diagonal,
15 /// Scalar noise: single Wiener process affects all components
16 Scalar,
17 /// General noise: full noise matrix (m Wiener processes, n state dims)
18 General { n_wiener: usize },
19}
20
21/// Trait for stochastic differential equation systems.
22///
23/// Defines an SDE of the form:
24/// ```text
25/// dX(t) = f(t, X) dt + g(t, X) dW(t)
26/// ```
27pub trait SdeSystem<S: Scalar>: Sync {
28 /// Dimension of the state space.
29 fn dim(&self) -> usize;
30
31 /// Evaluate the drift function f(t, x).
32 ///
33 /// # Arguments
34 /// * `t` - Current time
35 /// * `x` - Current state
36 /// * `f` - Output buffer for drift (length = dim)
37 fn drift(&self, t: S, x: &[S], f: &mut [S]);
38
39 /// Evaluate the diffusion function g(t, x).
40 ///
41 /// For diagonal noise, `g` has length `dim`.
42 /// For scalar noise, `g` has length `dim` (same noise scaled differently).
43 /// For general noise, `g` has length `dim * n_wiener`.
44 ///
45 /// # Arguments
46 /// * `t` - Current time
47 /// * `x` - Current state
48 /// * `g` - Output buffer for diffusion
49 fn diffusion(&self, t: S, x: &[S], g: &mut [S]);
50
51 /// Type of noise (default: diagonal).
52 fn noise_type(&self) -> NoiseType {
53 NoiseType::Diagonal
54 }
55
56 /// Number of Wiener processes.
57 fn n_wiener(&self) -> usize {
58 match self.noise_type() {
59 NoiseType::Diagonal => self.dim(),
60 NoiseType::Scalar => 1,
61 NoiseType::General { n_wiener } => n_wiener,
62 }
63 }
64
65 /// Derivative of diffusion w.r.t. state: ∂g/∂x * g
66 ///
67 /// Required for Milstein method. Default implementation uses finite differences.
68 fn diffusion_derivative(&self, t: S, x: &[S], gdg: &mut [S]) {
69 let dim = self.dim();
70 let h_factor = S::EPSILON.sqrt();
71
72 let mut g = vec![S::ZERO; dim];
73 let mut g_plus = vec![S::ZERO; dim];
74 let mut x_pert = x.to_vec();
75
76 self.diffusion(t, x, &mut g);
77
78 for i in 0..dim {
79 let h = h_factor * (S::ONE + x[i].abs());
80 x_pert[i] = x[i] + h;
81 self.diffusion(t, &x_pert, &mut g_plus);
82 x_pert[i] = x[i];
83
84 // (∂g_i/∂x_i) * g_i for diagonal case
85 gdg[i] = (g_plus[i] - g[i]) / h * g[i];
86 }
87 }
88}
89
90/// Options for SDE solvers.
91///
92/// **Divergence from `numra_ode::SolverOptions`** (per Foundation Spec §2.5):
93/// SDE solvers carry stochastic noise, so step size also controls the Wiener
94/// increment `δW ~ N(0, h)` — not just truncation accuracy. The fixed `dt`
95/// field is distinguished from `rtol` / `atol` (used by adaptive SRA-family
96/// methods) rather than collapsed onto a shared `h0` / `h_max`, because the
97/// noise-discretisation interpretation matters at every step. `seed:
98/// Option<u64>` is required for reproducibility — deterministic ODE
99/// configuration has no analog. `save_trajectory: bool` toggles
100/// trajectory-vs-final-only collection for Monte Carlo workloads where
101/// intermediate states aren't kept. See
102/// `docs/architecture/foundation-specification.md` §2.5.
103#[derive(Clone, Debug)]
104pub struct SdeOptions<S: Scalar> {
105 /// Fixed time step (for non-adaptive methods)
106 pub dt: S,
107 /// Relative tolerance (for adaptive methods)
108 pub rtol: S,
109 /// Absolute tolerance (for adaptive methods)
110 pub atol: S,
111 /// Maximum time step
112 pub dt_max: S,
113 /// Minimum time step
114 pub dt_min: S,
115 /// Maximum number of steps
116 pub max_steps: usize,
117 /// Save solution at all steps (vs. just final)
118 pub save_trajectory: bool,
119 /// Random seed (None = use system entropy)
120 pub seed: Option<u64>,
121}
122
123impl<S: Scalar> Default for SdeOptions<S> {
124 fn default() -> Self {
125 Self {
126 dt: S::from_f64(0.01),
127 rtol: S::from_f64(1e-3),
128 atol: S::from_f64(1e-6),
129 dt_max: S::INFINITY,
130 dt_min: S::from_f64(1e-10),
131 max_steps: 1_000_000,
132 save_trajectory: true,
133 seed: None,
134 }
135 }
136}
137
138impl<S: Scalar> SdeOptions<S> {
139 /// Set fixed time step.
140 pub fn dt(mut self, dt: S) -> Self {
141 self.dt = dt;
142 self
143 }
144
145 /// Set relative tolerance.
146 pub fn rtol(mut self, rtol: S) -> Self {
147 self.rtol = rtol;
148 self
149 }
150
151 /// Set absolute tolerance.
152 pub fn atol(mut self, atol: S) -> Self {
153 self.atol = atol;
154 self
155 }
156
157 /// Set maximum time step.
158 pub fn dt_max(mut self, dt_max: S) -> Self {
159 self.dt_max = dt_max;
160 self
161 }
162
163 /// Set random seed for reproducibility.
164 pub fn seed(mut self, seed: u64) -> Self {
165 self.seed = Some(seed);
166 self
167 }
168
169 /// Enable/disable trajectory saving.
170 pub fn save_trajectory(mut self, save: bool) -> Self {
171 self.save_trajectory = save;
172 self
173 }
174}
175
176/// Statistics from SDE solver.
177#[derive(Clone, Debug, Default)]
178pub struct SdeStats {
179 /// Number of drift evaluations
180 pub n_drift: usize,
181 /// Number of diffusion evaluations
182 pub n_diffusion: usize,
183 /// Number of accepted steps
184 pub n_accept: usize,
185 /// Number of rejected steps (for adaptive methods)
186 pub n_reject: usize,
187}
188
189/// Result of SDE integration.
190#[derive(Clone, Debug)]
191pub struct SdeResult<S: Scalar> {
192 /// Time points
193 pub t: Vec<S>,
194 /// Solution at each time point (row-major: y[i*dim + j] = y_j(t_i))
195 pub y: Vec<S>,
196 /// Dimension of the system
197 pub dim: usize,
198 /// Solver statistics
199 pub stats: SdeStats,
200 /// Was integration successful?
201 pub success: bool,
202 /// Message (error description if failed)
203 pub message: String,
204}
205
206impl<S: Scalar> SdeResult<S> {
207 /// Create a new successful result.
208 pub fn new(t: Vec<S>, y: Vec<S>, dim: usize, stats: SdeStats) -> Self {
209 Self {
210 t,
211 y,
212 dim,
213 stats,
214 success: true,
215 message: String::new(),
216 }
217 }
218
219 /// Create a failed result.
220 pub fn failed(message: String, stats: SdeStats) -> Self {
221 Self {
222 t: Vec::new(),
223 y: Vec::new(),
224 dim: 0,
225 stats,
226 success: false,
227 message,
228 }
229 }
230
231 /// Number of time points.
232 pub fn len(&self) -> usize {
233 self.t.len()
234 }
235
236 /// Is result empty?
237 pub fn is_empty(&self) -> bool {
238 self.t.is_empty()
239 }
240
241 /// Get final time.
242 pub fn t_final(&self) -> Option<S> {
243 self.t.last().copied()
244 }
245
246 /// Get final state.
247 pub fn y_final(&self) -> Option<Vec<S>> {
248 if self.t.is_empty() {
249 None
250 } else {
251 let start = (self.t.len() - 1) * self.dim;
252 Some(self.y[start..start + self.dim].to_vec())
253 }
254 }
255
256 /// Get state at index i.
257 pub fn y_at(&self, i: usize) -> &[S] {
258 let start = i * self.dim;
259 &self.y[start..start + self.dim]
260 }
261
262 /// Iterate over (t, y) pairs.
263 pub fn iter(&self) -> impl Iterator<Item = (S, &[S])> {
264 self.t
265 .iter()
266 .enumerate()
267 .map(move |(i, &t)| (t, self.y_at(i)))
268 }
269}
270
271/// Trait for SDE solvers.
272pub trait SdeSolver<S: Scalar> {
273 /// Solve the SDE problem.
274 ///
275 /// # Arguments
276 /// * `system` - The SDE system to solve
277 /// * `t0` - Initial time
278 /// * `tf` - Final time
279 /// * `x0` - Initial state
280 /// * `options` - Solver options
281 /// * `seed` - Optional random seed (overrides options.seed)
282 fn solve<Sys: SdeSystem<S>>(
283 system: &Sys,
284 t0: S,
285 tf: S,
286 x0: &[S],
287 options: &SdeOptions<S>,
288 seed: Option<u64>,
289 ) -> Result<SdeResult<S>, String>;
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 struct TestSde;
297
298 impl SdeSystem<f64> for TestSde {
299 fn dim(&self) -> usize {
300 1
301 }
302 fn drift(&self, _t: f64, x: &[f64], f: &mut [f64]) {
303 f[0] = -x[0];
304 }
305 fn diffusion(&self, _t: f64, x: &[f64], g: &mut [f64]) {
306 g[0] = 0.1 * x[0];
307 }
308 }
309
310 #[test]
311 fn test_sde_system_trait() {
312 let sys = TestSde;
313 assert_eq!(sys.dim(), 1);
314 assert_eq!(sys.n_wiener(), 1);
315
316 let mut f = [0.0];
317 let mut g = [0.0];
318 sys.drift(0.0, &[1.0], &mut f);
319 sys.diffusion(0.0, &[1.0], &mut g);
320 assert!((f[0] - (-1.0)).abs() < 1e-10);
321 assert!((g[0] - 0.1).abs() < 1e-10);
322 }
323
324 #[test]
325 fn test_sde_options() {
326 let opts: SdeOptions<f64> = SdeOptions::default().dt(0.001).seed(42);
327 assert!((opts.dt - 0.001).abs() < 1e-10);
328 assert_eq!(opts.seed, Some(42));
329 }
330
331 /// Test SDE with linear-in-x diffusion `g(t, x) = α·x`. Then
332 /// `(∂g_i/∂x_i) · g_i = α · α·x_i = α²·x_i`, which is the closed-form
333 /// answer the trait default's forward-FD computation should approximate.
334 struct LinearDiffusionSde {
335 alpha: f64,
336 }
337
338 impl SdeSystem<f64> for LinearDiffusionSde {
339 fn dim(&self) -> usize {
340 2
341 }
342 fn drift(&self, _t: f64, _x: &[f64], f: &mut [f64]) {
343 f[0] = 0.0;
344 f[1] = 0.0;
345 }
346 fn diffusion(&self, _t: f64, x: &[f64], g: &mut [f64]) {
347 g[0] = self.alpha * x[0];
348 g[1] = self.alpha * x[1];
349 }
350 // Note: deliberately not overriding diffusion_derivative — exercises
351 // the trait default that this test pins.
352 }
353
354 #[test]
355 fn test_diffusion_derivative_default_large_x_no_scaling_bug() {
356 // Pins F-FD-NOSCALE-BUG for the trait-default forward-FD branch: with
357 // unscaled `h = 1e-8`, `x[i] + h == x[i]` in f64 for |x| > ~5e7, so
358 // `g_plus == g` and the result was `0/h * g = 0` instead of `α²·x`.
359 // With canonical `sqrt(EPSILON) * (1 + |x|)` the answer is recovered.
360 // (Note site 4 is *forward* FD — different canonical step than the
361 // central-FD sites in numra-optim and numra-dde.)
362 let alpha = 0.5_f64;
363 let sys = LinearDiffusionSde { alpha };
364 let x = [1e8, 1e8];
365 let mut gdg = [0.0; 2];
366 sys.diffusion_derivative(0.0, &x, &mut gdg);
367
368 // Expected: α² · x_i = 0.25 · 1e8 = 2.5e7
369 let expected = alpha * alpha * 1e8;
370 assert!(
371 (gdg[0] - expected).abs() < 1e-3 * expected.abs(),
372 "gdg[0] = {} should be ≈ {} (within 1e-3 relative); old unscaled formula returns 0",
373 gdg[0],
374 expected
375 );
376 assert!(
377 (gdg[1] - expected).abs() < 1e-3 * expected.abs(),
378 "gdg[1] = {} should be ≈ {} (within 1e-3 relative); old unscaled formula returns 0",
379 gdg[1],
380 expected
381 );
382 }
383}