1#[derive(Debug, Clone)]
5pub struct OdeState {
6 pub t: f64,
7 pub y: Vec<f64>,
8}
9
10pub trait OdeSystem: Send + Sync {
12 fn dimension(&self) -> usize;
13 fn evaluate(&self, t: f64, y: &[f64], dydt: &mut [f64]);
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum OdeMethod {
19 Euler,
20 RungeKutta4,
21 RungeKutta45, ImplicitEuler,
23 CrankNicolson,
24 Verlet, Leapfrog, }
27
28pub struct OdeSolver {
30 pub method: OdeMethod,
31 pub dt: f64,
32 pub dt_min: f64,
33 pub dt_max: f64,
34 pub tolerance: f64,
35 pub max_steps: u64,
36 work: Vec<f64>,
37}
38
39impl OdeSolver {
40 pub fn new(method: OdeMethod, dt: f64) -> Self {
41 Self {
42 method, dt, dt_min: 1e-8, dt_max: 1.0,
43 tolerance: 1e-6, max_steps: 1_000_000,
44 work: Vec::new(),
45 }
46 }
47
48 pub fn rk4(dt: f64) -> Self { Self::new(OdeMethod::RungeKutta4, dt) }
49 pub fn adaptive(dt: f64, tol: f64) -> Self {
50 let mut s = Self::new(OdeMethod::RungeKutta45, dt);
51 s.tolerance = tol;
52 s
53 }
54 pub fn verlet(dt: f64) -> Self { Self::new(OdeMethod::Verlet, dt) }
55
56 pub fn step(&mut self, system: &dyn OdeSystem, state: &OdeState) -> OdeState {
58 match self.method {
59 OdeMethod::Euler => self.euler_step(system, state),
60 OdeMethod::RungeKutta4 => self.rk4_step(system, state),
61 OdeMethod::RungeKutta45 => self.rk45_step(system, state),
62 OdeMethod::Verlet => self.verlet_step(system, state),
63 OdeMethod::Leapfrog => self.leapfrog_step(system, state),
64 OdeMethod::ImplicitEuler => self.implicit_euler_step(system, state),
65 OdeMethod::CrankNicolson => self.crank_nicolson_step(system, state),
66 }
67 }
68
69 pub fn integrate(&mut self, system: &dyn OdeSystem, initial: &OdeState, t_end: f64) -> Vec<OdeState> {
71 let mut states = vec![initial.clone()];
72 let mut current = initial.clone();
73 let mut steps = 0u64;
74
75 while current.t < t_end && steps < self.max_steps {
76 current = self.step(system, ¤t);
77 states.push(current.clone());
78 steps += 1;
79 }
80 states
81 }
82
83 pub fn solve(&mut self, system: &dyn OdeSystem, initial: &OdeState, t_end: f64) -> OdeState {
85 let mut current = initial.clone();
86 let mut steps = 0u64;
87 while current.t < t_end && steps < self.max_steps {
88 current = self.step(system, ¤t);
89 steps += 1;
90 }
91 current
92 }
93
94 fn euler_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
97 let n = s.y.len();
98 let mut dydt = vec![0.0; n];
99 sys.evaluate(s.t, &s.y, &mut dydt);
100 let y: Vec<f64> = s.y.iter().zip(dydt.iter()).map(|(y, dy)| y + dy * self.dt).collect();
101 OdeState { t: s.t + self.dt, y }
102 }
103
104 fn rk4_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
105 let n = s.y.len();
106 let h = self.dt;
107 let mut k1 = vec![0.0; n]; sys.evaluate(s.t, &s.y, &mut k1);
108 let y2: Vec<f64> = (0..n).map(|i| s.y[i] + k1[i] * h * 0.5).collect();
109 let mut k2 = vec![0.0; n]; sys.evaluate(s.t + h * 0.5, &y2, &mut k2);
110 let y3: Vec<f64> = (0..n).map(|i| s.y[i] + k2[i] * h * 0.5).collect();
111 let mut k3 = vec![0.0; n]; sys.evaluate(s.t + h * 0.5, &y3, &mut k3);
112 let y4: Vec<f64> = (0..n).map(|i| s.y[i] + k3[i] * h).collect();
113 let mut k4 = vec![0.0; n]; sys.evaluate(s.t + h, &y4, &mut k4);
114
115 let y: Vec<f64> = (0..n).map(|i| {
116 s.y[i] + h / 6.0 * (k1[i] + 2.0 * k2[i] + 2.0 * k3[i] + k4[i])
117 }).collect();
118 OdeState { t: s.t + h, y }
119 }
120
121 fn rk45_step(&mut self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
122 let n = s.y.len();
124 let h = self.dt;
125
126 let state4 = self.rk4_step(sys, s);
128
129 let mut half = self.dt;
131 self.dt = h * 0.5;
132 let mid = self.rk4_step(sys, s);
133 let final_ = self.rk4_step(sys, &mid);
134 self.dt = half;
135
136 let error: f64 = state4.y.iter().zip(final_.y.iter())
137 .map(|(a, b)| (a - b).abs())
138 .fold(0.0, f64::max);
139
140 if error > self.tolerance && h > self.dt_min {
142 self.dt = (h * 0.5).max(self.dt_min);
143 } else if error < self.tolerance * 0.1 && h < self.dt_max {
144 self.dt = (h * 1.5).min(self.dt_max);
145 }
146
147 final_
149 }
150
151 fn verlet_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
152 let n = s.y.len();
154 let half = n / 2;
155 let h = self.dt;
156
157 let mut acc = vec![0.0; n];
158 sys.evaluate(s.t, &s.y, &mut acc);
159
160 let mut y = s.y.clone();
161 for i in 0..half {
163 y[i] = s.y[i] + s.y[half + i] * h + 0.5 * acc[half + i] * h * h;
164 }
165
166 let mut acc_new = vec![0.0; n];
168 sys.evaluate(s.t + h, &y, &mut acc_new);
169
170 for i in 0..half {
172 y[half + i] = s.y[half + i] + 0.5 * (acc[half + i] + acc_new[half + i]) * h;
173 }
174
175 OdeState { t: s.t + h, y }
176 }
177
178 fn leapfrog_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
179 self.verlet_step(sys, s) }
181
182 fn implicit_euler_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
183 let n = s.y.len();
185 let h = self.dt;
186 let mut dydt = vec![0.0; n];
187 sys.evaluate(s.t + h, &s.y, &mut dydt);
188 let y: Vec<f64> = s.y.iter().zip(dydt.iter()).map(|(y, dy)| y + dy * h).collect();
189 OdeState { t: s.t + h, y }
190 }
191
192 fn crank_nicolson_step(&self, sys: &dyn OdeSystem, s: &OdeState) -> OdeState {
193 let n = s.y.len();
195 let h = self.dt;
196 let mut f_n = vec![0.0; n];
197 sys.evaluate(s.t, &s.y, &mut f_n);
198 let y_euler: Vec<f64> = s.y.iter().zip(f_n.iter()).map(|(y, dy)| y + dy * h).collect();
199 let mut f_n1 = vec![0.0; n];
200 sys.evaluate(s.t + h, &y_euler, &mut f_n1);
201 let y: Vec<f64> = (0..n).map(|i| s.y[i] + 0.5 * h * (f_n[i] + f_n1[i])).collect();
202 OdeState { t: s.t + h, y }
203 }
204}
205
206pub struct LorenzSystem { pub sigma: f64, pub rho: f64, pub beta: f64 }
210impl Default for LorenzSystem { fn default() -> Self { Self { sigma: 10.0, rho: 28.0, beta: 8.0/3.0 } } }
211impl OdeSystem for LorenzSystem {
212 fn dimension(&self) -> usize { 3 }
213 fn evaluate(&self, _t: f64, y: &[f64], dydt: &mut [f64]) {
214 dydt[0] = self.sigma * (y[1] - y[0]);
215 dydt[1] = y[0] * (self.rho - y[2]) - y[1];
216 dydt[2] = y[0] * y[1] - self.beta * y[2];
217 }
218}
219
220pub struct HarmonicOscillator { pub omega: f64 }
222impl OdeSystem for HarmonicOscillator {
223 fn dimension(&self) -> usize { 2 }
224 fn evaluate(&self, _t: f64, y: &[f64], dydt: &mut [f64]) {
225 dydt[0] = y[1]; dydt[1] = -self.omega * self.omega * y[0]; }
228}
229
230pub struct VanDerPol { pub mu: f64 }
232impl OdeSystem for VanDerPol {
233 fn dimension(&self) -> usize { 2 }
234 fn evaluate(&self, _t: f64, y: &[f64], dydt: &mut [f64]) {
235 dydt[0] = y[1];
236 dydt[1] = self.mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
237 }
238}
239
240pub struct RosslerSystem { pub a: f64, pub b: f64, pub c: f64 }
242impl Default for RosslerSystem { fn default() -> Self { Self { a: 0.2, b: 0.2, c: 5.7 } } }
243impl OdeSystem for RosslerSystem {
244 fn dimension(&self) -> usize { 3 }
245 fn evaluate(&self, _t: f64, y: &[f64], dydt: &mut [f64]) {
246 dydt[0] = -y[1] - y[2];
247 dydt[1] = y[0] + self.a * y[1];
248 dydt[2] = self.b + y[2] * (y[0] - self.c);
249 }
250}
251
252pub struct CustomOde<F: Fn(f64, &[f64], &mut [f64]) + Send + Sync> {
254 pub dim: usize,
255 pub func: F,
256}
257impl<F: Fn(f64, &[f64], &mut [f64]) + Send + Sync> OdeSystem for CustomOde<F> {
258 fn dimension(&self) -> usize { self.dim }
259 fn evaluate(&self, t: f64, y: &[f64], dydt: &mut [f64]) { (self.func)(t, y, dydt); }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn euler_harmonic() {
268 let sys = HarmonicOscillator { omega: 1.0 };
269 let mut solver = OdeSolver::new(OdeMethod::Euler, 0.001);
270 let initial = OdeState { t: 0.0, y: vec![1.0, 0.0] };
271 let final_state = solver.solve(&sys, &initial, std::f64::consts::PI);
272 assert!((final_state.y[0] + 1.0).abs() < 0.1, "x={}", final_state.y[0]);
274 }
275
276 #[test]
277 fn rk4_harmonic_accurate() {
278 let sys = HarmonicOscillator { omega: 1.0 };
279 let mut solver = OdeSolver::rk4(0.01);
280 let initial = OdeState { t: 0.0, y: vec![1.0, 0.0] };
281 let final_state = solver.solve(&sys, &initial, std::f64::consts::TAU);
282 assert!((final_state.y[0] - 1.0).abs() < 0.01, "x={}", final_state.y[0]);
284 }
285
286 #[test]
287 fn lorenz_doesnt_diverge() {
288 let sys = LorenzSystem::default();
289 let mut solver = OdeSolver::rk4(0.01);
290 let initial = OdeState { t: 0.0, y: vec![1.0, 1.0, 1.0] };
291 let final_state = solver.solve(&sys, &initial, 10.0);
292 for &v in &final_state.y {
294 assert!(v.abs() < 100.0, "Lorenz diverged: {:?}", final_state.y);
295 }
296 }
297
298 #[test]
299 fn integrate_returns_trajectory() {
300 let sys = HarmonicOscillator { omega: 1.0 };
301 let mut solver = OdeSolver::rk4(0.1);
302 let initial = OdeState { t: 0.0, y: vec![1.0, 0.0] };
303 let trajectory = solver.integrate(&sys, &initial, 1.0);
304 assert!(trajectory.len() > 5);
305 }
306
307 #[test]
308 fn verlet_conserves_energy() {
309 let sys = HarmonicOscillator { omega: 1.0 };
310 let mut solver = OdeSolver::verlet(0.01);
311 let initial = OdeState { t: 0.0, y: vec![1.0, 0.0] };
312 let energy_start = 0.5 * initial.y[0].powi(2) + 0.5 * initial.y[1].powi(2);
313 let final_state = solver.solve(&sys, &initial, 100.0);
314 let energy_end = 0.5 * final_state.y[0].powi(2) + 0.5 * final_state.y[1].powi(2);
315 assert!((energy_start - energy_end).abs() < 0.01, "Energy drift: {}", (energy_start - energy_end).abs());
316 }
317}