Skip to main content

pathwise_core/
state.rs

1// pathwise-core/src/state.rs
2use rand::Rng;
3use std::ops::{Add, Mul};
4
5/// Both Brownian increments for one time step.
6/// `dw` = dW = z1 * sqrt(dt)
7/// `dz` = integral_0^dt W(s) ds = (dt/2)*dw - sqrt(dt^3/12)*z2
8///
9/// Euler and Milstein ignore `dz`. SRI uses both.
10/// Derivation: dZ = dt*dW - I_{(0,1)}, conditional mean of I_{(0,1)} given dW
11/// is (dt/2)*dW, conditional variance is dt^3/12. Negative sign is correct.
12/// Verified: E[dZ]=0, Var[dZ]=dt^3/3, Cov(dW,dZ)=dt^2/2.
13#[derive(Clone, Debug)]
14pub struct Increment<B: Clone> {
15    pub dw: B,
16    pub dz: B,
17}
18
19/// Algebraic requirements for SDE state types.
20/// `f64` and `nalgebra::SVector<f64, N>` both satisfy this automatically.
21pub trait State:
22    Clone + Send + Sync + 'static + Add<Output = Self> + Mul<f64, Output = Self>
23{
24    fn zero() -> Self;
25}
26
27impl State for f64 {
28    fn zero() -> Self {
29        0.0
30    }
31}
32
33impl<const N: usize> State for nalgebra::SVector<f64, N> {
34    fn zero() -> Self {
35        nalgebra::SVector::zeros()
36    }
37}
38
39/// Types that can sample a Brownian increment for a given dt.
40///
41/// # Object safety
42/// This trait is not object-safe because `sample` is generic over `R: Rng`.
43/// All uses are monomorphic — `dyn NoiseIncrement` will not compile.
44pub trait NoiseIncrement: Clone + Send + Sync + 'static {
45    fn sample<R: Rng>(rng: &mut R, dt: f64) -> Increment<Self>;
46}
47
48impl NoiseIncrement for f64 {
49    fn sample<R: Rng>(rng: &mut R, dt: f64) -> Increment<Self> {
50        use rand_distr::{Distribution, Normal};
51        let normal = Normal::new(0.0_f64, 1.0).unwrap();
52        let z1 = normal.sample(rng);
53        let z2 = normal.sample(rng);
54        let dw = z1 * dt.sqrt();
55        let dz = (dt / 2.0) * dw - (dt.powi(3) / 12.0).sqrt() * z2;
56        Increment { dw, dz }
57    }
58}
59
60impl<const M: usize> NoiseIncrement for nalgebra::SVector<f64, M> {
61    fn sample<R: Rng>(rng: &mut R, dt: f64) -> Increment<Self> {
62        use rand_distr::{Distribution, Normal};
63        let normal = Normal::new(0.0_f64, 1.0).unwrap();
64        let mut dw = nalgebra::SVector::<f64, M>::zeros();
65        let mut dz = nalgebra::SVector::<f64, M>::zeros();
66        for i in 0..M {
67            let z1 = normal.sample(rng);
68            let z2 = normal.sample(rng);
69            dw[i] = z1 * dt.sqrt();
70            dz[i] = (dt / 2.0) * dw[i] - (dt.powi(3) / 12.0).sqrt() * z2;
71        }
72        Increment { dw, dz }
73    }
74}
75
76/// Diffusion coefficient interface.
77/// `apply(x, t, dw)` returns the diffusion contribution `g(x,t) * dw` directly.
78///
79/// Blanket impl for scalar closures: `f(x,t) -> f64` satisfies `Diffusion<f64, f64>`
80/// by computing `f(x,t) * dw`.
81///
82/// Blanket impl for nD diagonal closures: `f(x,t) -> SVector<N>` satisfies
83/// `Diffusion<SVector<N>, SVector<N>>` by element-wise (Hadamard) product.
84///
85/// Full-matrix processes (Heston, CorrOU) provide concrete struct impls.
86///
87/// # Calling convention
88/// Scalar diffusions (`B = f64`) receive `x` by value because `f64: Copy`.
89/// Vector diffusions (`B = SVector<f64, N>`) receive `x` by reference.
90/// Concrete struct impls (e.g. `HestonDiffusion`) use whichever is appropriate for their state type.
91pub trait Diffusion<S: State, B: NoiseIncrement>: Send + Sync {
92    fn apply(&self, x: &S, t: f64, dw: &B) -> S;
93}
94
95// Scalar blanket impl: closure returns g(x,t); apply multiplies by dw.
96impl<F: Fn(f64, f64) -> f64 + Send + Sync> Diffusion<f64, f64> for F {
97    fn apply(&self, x: &f64, t: f64, dw: &f64) -> f64 {
98        self(*x, t) * dw
99    }
100}
101
102// nD diagonal blanket impl: closure returns sigma vector; apply component-multiplies with dw.
103impl<const N: usize, F> Diffusion<nalgebra::SVector<f64, N>, nalgebra::SVector<f64, N>> for F
104where
105    F: Fn(&nalgebra::SVector<f64, N>, f64) -> nalgebra::SVector<f64, N> + Send + Sync,
106{
107    fn apply(
108        &self,
109        x: &nalgebra::SVector<f64, N>,
110        t: f64,
111        dw: &nalgebra::SVector<f64, N>,
112    ) -> nalgebra::SVector<f64, N> {
113        self(x, t).component_mul(dw)
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use rand::SeedableRng;
121
122    #[test]
123    fn increment_f64_variance() {
124        // Var[dW] ≈ dt; Var[dZ] ≈ dt^3/3; Cov(dW,dZ) ≈ dt^2/2
125        let dt = 0.01_f64;
126        let n = 100_000_usize;
127        let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
128        let mut dws = vec![0.0_f64; n];
129        let mut dzs = vec![0.0_f64; n];
130        for i in 0..n {
131            let inc = <f64 as NoiseIncrement>::sample(&mut rng, dt);
132            dws[i] = inc.dw;
133            dzs[i] = inc.dz;
134        }
135        let var_dw: f64 = dws.iter().map(|x| x * x).sum::<f64>() / n as f64;
136        let var_dz: f64 = dzs.iter().map(|x| x * x).sum::<f64>() / n as f64;
137        let cov: f64 = dws.iter().zip(&dzs).map(|(w, z)| w * z).sum::<f64>() / n as f64;
138        // Var[dW] = dt = 0.01
139        assert!(
140            (var_dw - dt).abs() / dt < 0.02,
141            "Var[dW]={:.6} expected {:.6}",
142            var_dw,
143            dt
144        );
145        // Var[dZ] = dt^3/3
146        let expected_var_dz = dt.powi(3) / 3.0;
147        assert!(
148            (var_dz - expected_var_dz).abs() / expected_var_dz < 0.03,
149            "Var[dZ]={:.8} expected {:.8}",
150            var_dz,
151            expected_var_dz
152        );
153        // Cov(dW,dZ) = dt^2/2
154        let expected_cov = dt.powi(2) / 2.0;
155        assert!(
156            (cov - expected_cov).abs() / expected_cov < 0.02,
157            "Cov(dW,dZ)={:.8} expected {:.8}",
158            cov,
159            expected_cov
160        );
161    }
162
163    #[test]
164    fn increment_svector_has_correct_length() {
165        let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
166        let inc = <nalgebra::SVector<f64, 3> as NoiseIncrement>::sample(&mut rng, 0.01);
167        assert_eq!(inc.dw.len(), 3);
168        assert_eq!(inc.dz.len(), 3);
169    }
170}