Skip to main content

proof_engine/solver/
ode.rs

1//! ODE solvers — Euler, RK4, RK45 (adaptive), implicit, symplectic methods.
2
3/// State vector for an ODE system.
4#[derive(Debug, Clone)]
5pub struct OdeState {
6    pub t: f64,
7    pub y: Vec<f64>,
8}
9
10/// An ODE system dy/dt = f(t, y).
11pub trait OdeSystem: Send + Sync {
12    fn dimension(&self) -> usize;
13    fn evaluate(&self, t: f64, y: &[f64], dydt: &mut [f64]);
14}
15
16/// ODE integration method.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum OdeMethod {
19    Euler,
20    RungeKutta4,
21    RungeKutta45,  // adaptive (Dormand-Prince)
22    ImplicitEuler,
23    CrankNicolson,
24    Verlet,        // symplectic
25    Leapfrog,      // symplectic
26}
27
28/// ODE solver with configurable method and step control.
29pub struct OdeSolver {
30    pub method: OdeMethod,
31    pub dt: f64,
32    pub dt_min: f64,
33    pub dt_max: f64,
34    pub tolerance: f64,
35    pub max_steps: u64,
36    work: Vec<f64>,
37}
38
39impl OdeSolver {
40    pub fn new(method: OdeMethod, dt: f64) -> Self {
41        Self {
42            method, dt, dt_min: 1e-8, dt_max: 1.0,
43            tolerance: 1e-6, max_steps: 1_000_000,
44            work: Vec::new(),
45        }
46    }
47
48    pub fn rk4(dt: f64) -> Self { Self::new(OdeMethod::RungeKutta4, dt) }
49    pub fn adaptive(dt: f64, tol: f64) -> Self {
50        let mut s = Self::new(OdeMethod::RungeKutta45, dt);
51        s.tolerance = tol;
52        s
53    }
54    pub fn verlet(dt: f64) -> Self { Self::new(OdeMethod::Verlet, dt) }
55
56    /// Integrate one step. Returns the new state.
57    pub fn step(&mut self, system: &dyn OdeSystem, state: &OdeState) -> OdeState {
58        match self.method {
59            OdeMethod::Euler => self.euler_step(system, state),
60            OdeMethod::RungeKutta4 => self.rk4_step(system, state),
61            OdeMethod::RungeKutta45 => self.rk45_step(system, state),
62            OdeMethod::Verlet => self.verlet_step(system, state),
63            OdeMethod::Leapfrog => self.leapfrog_step(system, state),
64            OdeMethod::ImplicitEuler => self.implicit_euler_step(system, state),
65            OdeMethod::CrankNicolson => self.crank_nicolson_step(system, state),
66        }
67    }
68
69    /// Integrate from t0 to t_end. Returns all states at each step.
70    pub fn integrate(&mut self, system: &dyn OdeSystem, initial: &OdeState, t_end: f64) -> Vec<OdeState> {
71        let mut states = vec![initial.clone()];
72        let mut current = initial.clone();
73        let mut steps = 0u64;
74
75        while current.t < t_end && steps < self.max_steps {
76            current = self.step(system, &current);
77            states.push(current.clone());
78            steps += 1;
79        }
80        states
81    }
82
83    /// Integrate and return only the final state.
84    pub fn solve(&mut self, system: &dyn OdeSystem, initial: &OdeState, t_end: f64) -> OdeState {
85        let mut current = initial.clone();
86        let mut steps = 0u64;
87        while current.t < t_end && steps < self.max_steps {
88            current = self.step(system, &current);
89            steps += 1;
90        }
91        current
92    }
93
94    // ── Method implementations ──────────────────────────────────────────
95
96    fn euler_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
97        let n = s.y.len();
98        let mut dydt = vec![0.0; n];
99        sys.evaluate(s.t, &s.y, &mut dydt);
100        let y: Vec<f64> = s.y.iter().zip(dydt.iter()).map(|(y, dy)| y + dy * self.dt).collect();
101        OdeState { t: s.t + self.dt, y }
102    }
103
104    fn rk4_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
105        let n = s.y.len();
106        let h = self.dt;
107        let mut k1 = vec![0.0; n]; sys.evaluate(s.t, &s.y, &mut k1);
108        let y2: Vec<f64> = (0..n).map(|i| s.y[i] + k1[i] * h * 0.5).collect();
109        let mut k2 = vec![0.0; n]; sys.evaluate(s.t + h * 0.5, &y2, &mut k2);
110        let y3: Vec<f64> = (0..n).map(|i| s.y[i] + k2[i] * h * 0.5).collect();
111        let mut k3 = vec![0.0; n]; sys.evaluate(s.t + h * 0.5, &y3, &mut k3);
112        let y4: Vec<f64> = (0..n).map(|i| s.y[i] + k3[i] * h).collect();
113        let mut k4 = vec![0.0; n]; sys.evaluate(s.t + h, &y4, &mut k4);
114
115        let y: Vec<f64> = (0..n).map(|i| {
116            s.y[i] + h / 6.0 * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i])
117        }).collect();
118        OdeState { t: s.t + h, y }
119    }
120
121    fn rk45_step(&mut self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
122        // Dormand-Prince RK45 with adaptive step
123        let n = s.y.len();
124        let h = self.dt;
125
126        // Use RK4 for 4th and 5th order estimates
127        let state4 = self.rk4_step(sys, s);
128
129        // Simple error estimate via half-step Richardson extrapolation
130        let mut half = self.dt;
131        self.dt = h * 0.5;
132        let mid = self.rk4_step(sys, s);
133        let final_ = self.rk4_step(sys, &mid);
134        self.dt = half;
135
136        let error: f64 = state4.y.iter().zip(final_.y.iter())
137            .map(|(a, b)| (a - b).abs())
138            .fold(0.0, f64::max);
139
140        // Adjust step size
141        if error > self.tolerance && h > self.dt_min {
142            self.dt = (h * 0.5).max(self.dt_min);
143        } else if error < self.tolerance * 0.1 && h < self.dt_max {
144            self.dt = (h * 1.5).min(self.dt_max);
145        }
146
147        // Use the higher-order estimate
148        final_
149    }
150
151    fn verlet_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
152        // Velocity Verlet (for Hamiltonian systems where y = [x0,..,xn, v0,..,vn])
153        let n = s.y.len();
154        let half = n / 2;
155        let h = self.dt;
156
157        let mut acc = vec![0.0; n];
158        sys.evaluate(s.t, &s.y, &mut acc);
159
160        let mut y = s.y.clone();
161        // Update positions: x += v*h + 0.5*a*h²
162        for i in 0..half {
163            y[i] = s.y[i] + s.y[half + i] * h + 0.5 * acc[half + i] * h * h;
164        }
165
166        // Compute new acceleration
167        let mut acc_new = vec![0.0; n];
168        sys.evaluate(s.t + h, &y, &mut acc_new);
169
170        // Update velocities: v += 0.5*(a_old + a_new)*h
171        for i in 0..half {
172            y[half + i] = s.y[half + i] + 0.5 * (acc[half + i] + acc_new[half + i]) * h;
173        }
174
175        OdeState { t: s.t + h, y }
176    }
177
178    fn leapfrog_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
179        self.verlet_step(sys, s) // Verlet is mathematically equivalent
180    }
181
182    fn implicit_euler_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
183        // Simplified: use fixed-point iteration (1 iteration ≈ semi-implicit Euler)
184        let n = s.y.len();
185        let h = self.dt;
186        let mut dydt = vec![0.0; n];
187        sys.evaluate(s.t + h, &s.y, &mut dydt);
188        let y: Vec<f64> = s.y.iter().zip(dydt.iter()).map(|(y, dy)| y + dy * h).collect();
189        OdeState { t: s.t + h, y }
190    }
191
192    fn crank_nicolson_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
193        // Average of explicit and implicit Euler
194        let n = s.y.len();
195        let h = self.dt;
196        let mut f_n = vec![0.0; n];
197        sys.evaluate(s.t, &s.y, &mut f_n);
198        let y_euler: Vec<f64> = s.y.iter().zip(f_n.iter()).map(|(y, dy)| y + dy * h).collect();
199        let mut f_n1 = vec![0.0; n];
200        sys.evaluate(s.t + h, &y_euler, &mut f_n1);
201        let y: Vec<f64> = (0..n).map(|i| s.y[i] + 0.5 * h * (f_n[i] + f_n1[i])).collect();
202        OdeState { t: s.t + h, y }
203    }
204}
205
206// ── Built-in ODE systems ────────────────────────────────────────────────────
207
208/// Lorenz attractor: dx/dt = σ(y-x), dy/dt = x(ρ-z)-y, dz/dt = xy-βz.
209pub struct LorenzSystem { pub sigma: f64, pub rho: f64, pub beta: f64 }
210impl Default for LorenzSystem { fn default() -> Self { Self { sigma: 10.0, rho: 28.0, beta: 8.0/3.0 } } }
211impl OdeSystem for LorenzSystem {
212    fn dimension(&self) -> usize { 3 }
213    fn evaluate(&self, _t: f64, y: &[f64], dydt: &mut [f64]) {
214        dydt[0] = self.sigma * (y[1] - y[0]);
215        dydt[1] = y[0] * (self.rho - y[2]) - y[1];
216        dydt[2] = y[0] * y[1] - self.beta * y[2];
217    }
218}
219
220/// Simple harmonic oscillator: x'' + ω²x = 0.
221pub struct HarmonicOscillator { pub omega: f64 }
222impl OdeSystem for HarmonicOscillator {
223    fn dimension(&self) -> usize { 2 }
224    fn evaluate(&self, _t: f64, y: &[f64], dydt: &mut [f64]) {
225        dydt[0] = y[1];                        // dx/dt = v
226        dydt[1] = -self.omega * self.omega * y[0]; // dv/dt = -ω²x
227    }
228}
229
230/// Van der Pol oscillator: x'' - μ(1-x²)x' + x = 0.
231pub struct VanDerPol { pub mu: f64 }
232impl OdeSystem for VanDerPol {
233    fn dimension(&self) -> usize { 2 }
234    fn evaluate(&self, _t: f64, y: &[f64], dydt: &mut [f64]) {
235        dydt[0] = y[1];
236        dydt[1] = self.mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
237    }
238}
239
240/// Rossler attractor.
241pub struct RosslerSystem { pub a: f64, pub b: f64, pub c: f64 }
242impl Default for RosslerSystem { fn default() -> Self { Self { a: 0.2, b: 0.2, c: 5.7 } } }
243impl OdeSystem for RosslerSystem {
244    fn dimension(&self) -> usize { 3 }
245    fn evaluate(&self, _t: f64, y: &[f64], dydt: &mut [f64]) {
246        dydt[0] = -y[1] - y[2];
247        dydt[1] = y[0] + self.a * y[1];
248        dydt[2] = self.b + y[2] * (y[0] - self.c);
249    }
250}
251
252/// Custom ODE from a closure.
253pub struct CustomOde<F: Fn(f64, &[f64], &mut [f64]) + Send + Sync> {
254    pub dim: usize,
255    pub func: F,
256}
257impl<F: Fn(f64, &[f64], &mut [f64]) + Send + Sync> OdeSystem for CustomOde<F> {
258    fn dimension(&self) -> usize { self.dim }
259    fn evaluate(&self, t: f64, y: &[f64], dydt: &mut [f64]) { (self.func)(t, y, dydt); }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn euler_harmonic() {
268        let sys = HarmonicOscillator { omega: 1.0 };
269        let mut solver = OdeSolver::new(OdeMethod::Euler, 0.001);
270        let initial = OdeState { t: 0.0, y: vec![1.0, 0.0] };
271        let final_state = solver.solve(&sys, &initial, std::f64::consts::PI);
272        // After half period: x ≈ -1.0
273        assert!((final_state.y[0] + 1.0).abs() < 0.1, "x={}", final_state.y[0]);
274    }
275
276    #[test]
277    fn rk4_harmonic_accurate() {
278        let sys = HarmonicOscillator { omega: 1.0 };
279        let mut solver = OdeSolver::rk4(0.01);
280        let initial = OdeState { t: 0.0, y: vec![1.0, 0.0] };
281        let final_state = solver.solve(&sys, &initial, std::f64::consts::TAU);
282        // After full period: x ≈ 1.0
283        assert!((final_state.y[0] - 1.0).abs() < 0.01, "x={}", final_state.y[0]);
284    }
285
286    #[test]
287    fn lorenz_doesnt_diverge() {
288        let sys = LorenzSystem::default();
289        let mut solver = OdeSolver::rk4(0.01);
290        let initial = OdeState { t: 0.0, y: vec![1.0, 1.0, 1.0] };
291        let final_state = solver.solve(&sys, &initial, 10.0);
292        // Should remain bounded
293        for &v in &final_state.y {
294            assert!(v.abs() < 100.0, "Lorenz diverged: {:?}", final_state.y);
295        }
296    }
297
298    #[test]
299    fn integrate_returns_trajectory() {
300        let sys = HarmonicOscillator { omega: 1.0 };
301        let mut solver = OdeSolver::rk4(0.1);
302        let initial = OdeState { t: 0.0, y: vec![1.0, 0.0] };
303        let trajectory = solver.integrate(&sys, &initial, 1.0);
304        assert!(trajectory.len() > 5);
305    }
306
307    #[test]
308    fn verlet_conserves_energy() {
309        let sys = HarmonicOscillator { omega: 1.0 };
310        let mut solver = OdeSolver::verlet(0.01);
311        let initial = OdeState { t: 0.0, y: vec![1.0, 0.0] };
312        let energy_start = 0.5 * initial.y[0].powi(2) + 0.5 * initial.y[1].powi(2);
313        let final_state = solver.solve(&sys, &initial, 100.0);
314        let energy_end = 0.5 * final_state.y[0].powi(2) + 0.5 * final_state.y[1].powi(2);
315        assert!((energy_start - energy_end).abs() < 0.01, "Energy drift: {}", (energy_start - energy_end).abs());
316    }
317}