Skip to main content

proof_engine/stochastic/
sde.rs

1//! Stochastic differential equations: generic SDE solver with Euler-Maruyama
2//! and Milstein methods, plus preset SDEs for common processes.
3
4use super::brownian::Rng;
5use glam::Vec2;
6
7// ---------------------------------------------------------------------------
8// SDE
9// ---------------------------------------------------------------------------
10
11/// A stochastic differential equation dX = a(t,X)dt + b(t,X)dW.
12pub struct SDE {
13    /// Drift function a(t, x).
14    pub drift: Box<dyn Fn(f64, f64) -> f64>,
15    /// Diffusion function b(t, x).
16    pub diffusion: Box<dyn Fn(f64, f64) -> f64>,
17}
18
19impl SDE {
20    pub fn new(
21        drift: Box<dyn Fn(f64, f64) -> f64>,
22        diffusion: Box<dyn Fn(f64, f64) -> f64>,
23    ) -> Self {
24        Self { drift, diffusion }
25    }
26}
27
28// ---------------------------------------------------------------------------
29// Solvers
30// ---------------------------------------------------------------------------
31
32/// Euler-Maruyama method for solving an SDE.
33///
34/// X_{n+1} = X_n + a(t_n, X_n) * dt + b(t_n, X_n) * dW_n
35pub fn euler_maruyama(sde: &SDE, x0: f64, dt: f64, steps: usize, rng: &mut Rng) -> Vec<f64> {
36    let mut path = Vec::with_capacity(steps + 1);
37    path.push(x0);
38    let mut x = x0;
39    let mut t = 0.0;
40    let sqrt_dt = dt.sqrt();
41
42    for _ in 0..steps {
43        let dw = rng.normal() * sqrt_dt;
44        let a = (sde.drift)(t, x);
45        let b = (sde.diffusion)(t, x);
46        x += a * dt + b * dw;
47        t += dt;
48        path.push(x);
49    }
50    path
51}
52
53/// Milstein method for solving an SDE.
54///
55/// X_{n+1} = X_n + a*dt + b*dW + 0.5*b*b'*(dW^2 - dt)
56/// where b' = db/dx (diffusion_derivative).
57pub fn milstein(
58    sde: &SDE,
59    diffusion_derivative: &dyn Fn(f64, f64) -> f64,
60    x0: f64,
61    dt: f64,
62    steps: usize,
63    rng: &mut Rng,
64) -> Vec<f64> {
65    let mut path = Vec::with_capacity(steps + 1);
66    path.push(x0);
67    let mut x = x0;
68    let mut t = 0.0;
69    let sqrt_dt = dt.sqrt();
70
71    for _ in 0..steps {
72        let dw = rng.normal() * sqrt_dt;
73        let a = (sde.drift)(t, x);
74        let b = (sde.diffusion)(t, x);
75        let b_prime = diffusion_derivative(t, x);
76        x += a * dt + b * dw + 0.5 * b * b_prime * (dw * dw - dt);
77        t += dt;
78        path.push(x);
79    }
80    path
81}
82
83/// Heun's method (improved Euler / predictor-corrector) for SDEs.
84pub fn heun(sde: &SDE, x0: f64, dt: f64, steps: usize, rng: &mut Rng) -> Vec<f64> {
85    let mut path = Vec::with_capacity(steps + 1);
86    path.push(x0);
87    let mut x = x0;
88    let mut t = 0.0;
89    let sqrt_dt = dt.sqrt();
90
91    for _ in 0..steps {
92        let dw = rng.normal() * sqrt_dt;
93        let a1 = (sde.drift)(t, x);
94        let b1 = (sde.diffusion)(t, x);
95
96        let x_tilde = x + a1 * dt + b1 * dw;
97        let t_next = t + dt;
98
99        let a2 = (sde.drift)(t_next, x_tilde);
100        let b2 = (sde.diffusion)(t_next, x_tilde);
101
102        x += 0.5 * (a1 + a2) * dt + 0.5 * (b1 + b2) * dw;
103        t = t_next;
104        path.push(x);
105    }
106    path
107}
108
109// ---------------------------------------------------------------------------
110// Error measures
111// ---------------------------------------------------------------------------
112
113/// Strong error: max |exact(t_i) - numerical(t_i)|.
114pub fn strong_error(exact: &[f64], numerical: &[f64]) -> f64 {
115    exact
116        .iter()
117        .zip(numerical.iter())
118        .map(|(e, n)| (e - n).abs())
119        .fold(0.0, f64::max)
120}
121
122/// Weak error: |E[exact(T)] - E[numerical(T)]|.
123pub fn weak_error(exact_mean: f64, numerical_mean: f64) -> f64 {
124    (exact_mean - numerical_mean).abs()
125}
126
127/// Root mean square error between two paths.
128pub fn rmse(exact: &[f64], numerical: &[f64]) -> f64 {
129    let n = exact.len().min(numerical.len());
130    if n == 0 {
131        return 0.0;
132    }
133    let sum: f64 = exact.iter().zip(numerical.iter()).map(|(e, n)| (e - n).powi(2)).sum();
134    (sum / n as f64).sqrt()
135}
136
137// ---------------------------------------------------------------------------
138// Preset SDEs
139// ---------------------------------------------------------------------------
140
141/// Geometric Brownian Motion: dS = mu*S*dt + sigma*S*dW.
142pub fn sde_gbm(mu: f64, sigma: f64) -> SDE {
143    SDE {
144        drift: Box::new(move |_t, x| mu * x),
145        diffusion: Box::new(move |_t, x| sigma * x),
146    }
147}
148
149/// GBM diffusion derivative: d(sigma*x)/dx = sigma.
150pub fn sde_gbm_diffusion_deriv(sigma: f64) -> Box<dyn Fn(f64, f64) -> f64> {
151    Box::new(move |_t, _x| sigma)
152}
153
154/// Ornstein-Uhlenbeck: dX = theta*(mu - X)*dt + sigma*dW.
155pub fn sde_ou(theta: f64, mu: f64, sigma: f64) -> SDE {
156    SDE {
157        drift: Box::new(move |_t, x| theta * (mu - x)),
158        diffusion: Box::new(move |_t, _x| sigma),
159    }
160}
161
162/// OU diffusion derivative: d(sigma)/dx = 0.
163pub fn sde_ou_diffusion_deriv() -> Box<dyn Fn(f64, f64) -> f64> {
164    Box::new(|_t, _x| 0.0)
165}
166
167/// Cox-Ingersoll-Ross: dX = kappa*(theta - X)*dt + sigma*sqrt(X)*dW.
168pub fn sde_cir(kappa: f64, theta: f64, sigma: f64) -> SDE {
169    SDE {
170        drift: Box::new(move |_t, x| kappa * (theta - x)),
171        diffusion: Box::new(move |_t, x| sigma * x.max(0.0).sqrt()),
172    }
173}
174
175/// CIR diffusion derivative: d(sigma*sqrt(x))/dx = sigma/(2*sqrt(x)).
176pub fn sde_cir_diffusion_deriv(sigma: f64) -> Box<dyn Fn(f64, f64) -> f64> {
177    Box::new(move |_t, x| {
178        let sx = x.max(1e-15).sqrt();
179        sigma / (2.0 * sx)
180    })
181}
182
183/// Constant Elasticity of Variance (CEV): dS = mu*S*dt + sigma*S^gamma*dW.
184pub fn sde_cev(mu: f64, sigma: f64, gamma: f64) -> SDE {
185    SDE {
186        drift: Box::new(move |_t, x| mu * x),
187        diffusion: Box::new(move |_t, x| sigma * x.abs().powf(gamma)),
188    }
189}
190
191/// Langevin equation: dV = -gamma*V*dt + sigma*dW (velocity process).
192pub fn sde_langevin(gamma: f64, sigma: f64) -> SDE {
193    SDE {
194        drift: Box::new(move |_t, v| -gamma * v),
195        diffusion: Box::new(move |_t, _v| sigma),
196    }
197}
198
199// ---------------------------------------------------------------------------
200// SDERenderer
201// ---------------------------------------------------------------------------
202
203/// Render SDE solution paths with drift/diffusion visualization.
204pub struct SDERenderer {
205    pub path_character: char,
206    pub path_color: [f32; 4],
207    pub drift_color: [f32; 4],
208    pub x_scale: f32,
209    pub y_scale: f32,
210}
211
212impl SDERenderer {
213    pub fn new() -> Self {
214        Self {
215            path_character: '·',
216            path_color: [0.2, 0.8, 1.0, 1.0],
217            drift_color: [1.0, 0.5, 0.2, 0.5],
218            x_scale: 0.1,
219            y_scale: 1.0,
220        }
221    }
222
223    /// Render a single solution path.
224    pub fn render_path(&self, path: &[f64]) -> Vec<(Vec2, char, [f32; 4])> {
225        path.iter()
226            .enumerate()
227            .map(|(i, &val)| {
228                (
229                    Vec2::new(i as f32 * self.x_scale, val as f32 * self.y_scale),
230                    self.path_character,
231                    self.path_color,
232                )
233            })
234            .collect()
235    }
236
237    /// Render multiple sample paths.
238    pub fn render_paths(&self, paths: &[Vec<f64>]) -> Vec<(Vec2, char, [f32; 4])> {
239        let n = paths.len().max(1);
240        let mut glyphs = Vec::new();
241        for (pi, path) in paths.iter().enumerate() {
242            let alpha = 0.2 + 0.6 * (pi as f32 / n as f32);
243            let color = [self.path_color[0], self.path_color[1], self.path_color[2], alpha];
244            for (i, &val) in path.iter().enumerate() {
245                glyphs.push((
246                    Vec2::new(i as f32 * self.x_scale, val as f32 * self.y_scale),
247                    self.path_character,
248                    color,
249                ));
250            }
251        }
252        glyphs
253    }
254
255    /// Render drift field as arrows at sampled points.
256    pub fn render_drift_field(
257        &self,
258        sde: &SDE,
259        t_range: (f64, f64),
260        x_range: (f64, f64),
261        grid: (usize, usize),
262    ) -> Vec<(Vec2, char, [f32; 4])> {
263        let mut glyphs = Vec::new();
264        let (t_steps, x_steps) = grid;
265
266        for ti in 0..t_steps {
267            for xi in 0..x_steps {
268                let t = t_range.0 + (t_range.1 - t_range.0) * ti as f64 / t_steps as f64;
269                let x = x_range.0 + (x_range.1 - x_range.0) * xi as f64 / x_steps as f64;
270                let drift = (sde.drift)(t, x);
271
272                let ch = if drift > 0.1 {
273                    '↑'
274                } else if drift < -0.1 {
275                    '↓'
276                } else {
277                    '·'
278                };
279
280                let pos = Vec2::new(t as f32 * self.x_scale, x as f32 * self.y_scale);
281                glyphs.push((pos, ch, self.drift_color));
282            }
283        }
284        glyphs
285    }
286}
287
288impl Default for SDERenderer {
289    fn default() -> Self {
290        Self::new()
291    }
292}
293
294// ---------------------------------------------------------------------------
295// Tests
296// ---------------------------------------------------------------------------
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_euler_maruyama_gbm() {
304        // Compare EM solution of GBM SDE against closed-form
305        let mu = 0.05;
306        let sigma = 0.2;
307        let s0 = 100.0;
308        let dt = 0.001;
309        let steps = 1000;
310        let trials = 5000;
311
312        let mut rng = Rng::new(42);
313        let sde = sde_gbm(mu, sigma);
314
315        let endpoints: Vec<f64> = (0..trials)
316            .map(|_| {
317                let path = euler_maruyama(&sde, s0, dt, steps, &mut rng);
318                *path.last().unwrap()
319            })
320            .collect();
321
322        let empirical_mean = endpoints.iter().sum::<f64>() / trials as f64;
323        let expected_mean = s0 * (mu * 1.0).exp();
324
325        assert!(
326            (empirical_mean - expected_mean).abs() / expected_mean < 0.1,
327            "EM GBM mean {} should be ~{}", empirical_mean, expected_mean
328        );
329    }
330
331    #[test]
332    fn test_milstein_gbm() {
333        let mu = 0.05;
334        let sigma = 0.2;
335        let s0 = 100.0;
336        let dt = 0.01;
337        let steps = 100;
338
339        let sde = sde_gbm(mu, sigma);
340        let deriv = sde_gbm_diffusion_deriv(sigma);
341        let mut rng = Rng::new(42);
342
343        let path = milstein(&sde, &*deriv, s0, dt, steps, &mut rng);
344        assert_eq!(path.len(), steps + 1);
345        assert!((path[0] - s0).abs() < 1e-10);
346        // GBM should stay positive with high probability
347        assert!(path.iter().all(|&x| x > 0.0));
348    }
349
350    #[test]
351    fn test_euler_maruyama_ou() {
352        let theta = 2.0;
353        let mu = 5.0;
354        let sigma = 1.0;
355        let sde = sde_ou(theta, mu, sigma);
356        let mut rng = Rng::new(42);
357
358        let path = euler_maruyama(&sde, 0.0, 0.01, 10_000, &mut rng);
359        // Tail should be near mu
360        let tail_mean: f64 = path[5000..].iter().sum::<f64>() / 5000.0;
361        assert!(
362            (tail_mean - mu).abs() < 1.0,
363            "OU tail mean {} should be near {}", tail_mean, mu
364        );
365    }
366
367    #[test]
368    fn test_cir_non_negative() {
369        let sde = sde_cir(1.0, 0.05, 0.1);
370        let mut rng = Rng::new(42);
371        let path = euler_maruyama(&sde, 0.05, 0.001, 10_000, &mut rng);
372        // CIR with Feller condition 2*kappa*theta > sigma^2 should stay positive
373        // 2*1*0.05 = 0.1 > 0.01 = sigma^2, so it should
374        let min_val = path.iter().cloned().fold(f64::INFINITY, f64::min);
375        // EM can go slightly negative; just check it's not too negative
376        assert!(
377            min_val > -0.01,
378            "CIR should stay roughly non-negative, min = {}", min_val
379        );
380    }
381
382    #[test]
383    fn test_strong_error() {
384        let exact = vec![0.0, 1.0, 2.0, 3.0];
385        let numerical = vec![0.0, 1.1, 1.8, 3.2];
386        let err = strong_error(&exact, &numerical);
387        assert!((err - 0.2).abs() < 1e-10);
388    }
389
390    #[test]
391    fn test_weak_error() {
392        assert!((weak_error(5.0, 4.8) - 0.2).abs() < 1e-10);
393    }
394
395    #[test]
396    fn test_rmse() {
397        let exact = vec![0.0, 1.0, 2.0];
398        let numerical = vec![0.0, 1.0, 2.0];
399        assert!(rmse(&exact, &numerical) < 1e-10);
400    }
401
402    #[test]
403    fn test_milstein_better_than_euler() {
404        // Milstein should have better strong convergence than Euler for GBM
405        let mu = 0.1;
406        let sigma = 0.3;
407        let s0 = 1.0;
408        let dt = 0.01;
409        let steps = 100;
410        let sde = sde_gbm(mu, sigma);
411        let deriv = sde_gbm_diffusion_deriv(sigma);
412
413        // Use same noise for both — we just check they produce valid paths
414        let mut rng1 = Rng::new(42);
415        let em_path = euler_maruyama(&sde, s0, dt, steps, &mut rng1);
416        let mut rng2 = Rng::new(42);
417        let sde2 = sde_gbm(mu, sigma);
418        let mil_path = milstein(&sde2, &*deriv, s0, dt, steps, &mut rng2);
419
420        assert_eq!(em_path.len(), steps + 1);
421        assert_eq!(mil_path.len(), steps + 1);
422    }
423
424    #[test]
425    fn test_heun_method() {
426        let sde = sde_ou(1.0, 0.0, 1.0);
427        let mut rng = Rng::new(42);
428        let path = heun(&sde, 5.0, 0.01, 1000, &mut rng);
429        assert_eq!(path.len(), 1001);
430        // Should mean-revert toward 0
431        let tail_mean: f64 = path[500..].iter().sum::<f64>() / 500.0;
432        assert!(tail_mean.abs() < 2.0);
433    }
434
435    #[test]
436    fn test_sde_renderer() {
437        let sde = sde_ou(1.0, 0.0, 1.0);
438        let mut rng = Rng::new(42);
439        let path = euler_maruyama(&sde, 0.0, 0.01, 100, &mut rng);
440        let renderer = SDERenderer::new();
441        let glyphs = renderer.render_path(&path);
442        assert_eq!(glyphs.len(), 101);
443    }
444
445    #[test]
446    fn test_drift_field_render() {
447        let sde = sde_ou(1.0, 0.0, 1.0);
448        let renderer = SDERenderer::new();
449        let glyphs = renderer.render_drift_field(&sde, (0.0, 1.0), (-2.0, 2.0), (5, 5));
450        assert_eq!(glyphs.len(), 25);
451    }
452}