Skip to main content

numerics_ode/
system.rs

1//! System-of-ODEs solvers.
2//!
3//! Extends all four methods to systems of ODEs of the form:
4//! ```text
5//! dy⃗/dx = f(x, &y⃗),   y⃗(x₀) = y⃗₀
6//! ```
7//!
8//! Each module function returns `(Vec<f64>, Vec<Vec<f64>>)` — the `x` grid and
9//! the state vector at each grid point.
10
11/// Euler method for systems.
12pub fn euler(
13    f: &dyn Fn(f64, &[f64]) -> Vec<f64>,
14    x0: f64,
15    y0: &[f64],
16    x_end: f64,
17    n: usize,
18) -> (Vec<f64>, Vec<Vec<f64>>) {
19    assert!(n >= 1);
20    let m = y0.len();
21    let h = (x_end - x0) / n as f64;
22    let mut xs = Vec::with_capacity(n + 1);
23    let mut ys: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
24    xs.push(x0);
25    ys.push(y0.to_vec());
26    let mut x = x0;
27    let mut y = y0.to_vec();
28    for _ in 0..n {
29        let dy = f(x, &y);
30        for j in 0..m {
31            y[j] += h * dy[j];
32        }
33        x += h;
34        xs.push(x);
35        ys.push(y.clone());
36    }
37    (xs, ys)
38}
39
40/// RK4 method for systems.
41pub fn rk4(
42    f: &dyn Fn(f64, &[f64]) -> Vec<f64>,
43    x0: f64,
44    y0: &[f64],
45    x_end: f64,
46    n: usize,
47) -> (Vec<f64>, Vec<Vec<f64>>) {
48    assert!(n >= 1);
49    let m = y0.len();
50    let h = (x_end - x0) / n as f64;
51    let mut xs = Vec::with_capacity(n + 1);
52    let mut ys: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
53    xs.push(x0);
54    ys.push(y0.to_vec());
55    let mut x = x0;
56    let mut y = y0.to_vec();
57    for _ in 0..n {
58        let k1 = f(x, &y);
59        let mut y_tmp = vec![0.0; m];
60        for j in 0..m {
61            y_tmp[j] = y[j] + h * k1[j] / 2.0;
62        }
63        let k2 = f(x + h / 2.0, &y_tmp);
64        for j in 0..m {
65            y_tmp[j] = y[j] + h * k2[j] / 2.0;
66        }
67        let k3 = f(x + h / 2.0, &y_tmp);
68        for j in 0..m {
69            y_tmp[j] = y[j] + h * k3[j];
70        }
71        let k4 = f(x + h, &y_tmp);
72        for j in 0..m {
73            y[j] += (h / 6.0) * (k1[j] + 2.0 * k2[j] + 2.0 * k3[j] + k4[j]);
74        }
75        x += h;
76        xs.push(x);
77        ys.push(y.clone());
78    }
79    (xs, ys)
80}
81
82/// Adams-Bashforth 2-step for systems.
83pub fn adams_bashforth(
84    f: &dyn Fn(f64, &[f64]) -> Vec<f64>,
85    x0: f64,
86    y0: &[f64],
87    x_end: f64,
88    n: usize,
89) -> (Vec<f64>, Vec<Vec<f64>>) {
90    assert!(n >= 1);
91    let m = y0.len();
92    let h = (x_end - x0) / n as f64;
93    let mut xs = Vec::with_capacity(n + 1);
94    let mut ys: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
95    xs.push(x0);
96    ys.push(y0.to_vec());
97
98    if n == 1 {
99        let dy = f(x0, y0);
100        let mut y1 = y0.to_vec();
101        for j in 0..m {
102            y1[j] += h * dy[j];
103        }
104        xs.push(x0 + h);
105        ys.push(y1);
106        return (xs, ys);
107    }
108
109    // Bootstrap with RK4
110    let k1 = f(x0, y0);
111    let mut y_tmp = vec![0.0; m];
112    for j in 0..m {
113        y_tmp[j] = y0[j] + h * k1[j] / 2.0;
114    }
115    let k2 = f(x0 + h / 2.0, &y_tmp);
116    for j in 0..m {
117        y_tmp[j] = y0[j] + h * k2[j] / 2.0;
118    }
119    let k3 = f(x0 + h / 2.0, &y_tmp);
120    for j in 0..m {
121        y_tmp[j] = y0[j] + h * k3[j];
122    }
123    let k4 = f(x0 + h, &y_tmp);
124    let mut y1 = y0.to_vec();
125    for j in 0..m {
126        y1[j] += (h / 6.0) * (k1[j] + 2.0 * k2[j] + 2.0 * k3[j] + k4[j]);
127    }
128    let x1 = x0 + h;
129    xs.push(x1);
130    ys.push(y1.clone());
131
132    let mut f_prev = k1;
133    let mut x_curr = x1;
134    let mut y_curr = y1;
135
136    for _ in 1..n {
137        let f_curr = f(x_curr, &y_curr);
138        let mut y_next = vec![0.0; m];
139        for j in 0..m {
140            y_next[j] = y_curr[j] + (h / 2.0) * (3.0 * f_curr[j] - f_prev[j]);
141        }
142        x_curr += h;
143        xs.push(x_curr);
144        ys.push(y_next.clone());
145        f_prev = f_curr;
146        y_curr = y_next;
147    }
148
149    (xs, ys)
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    /// System: dy0/dx = y1, dy1/dx = -y0  →  (cos x, -sin x)
157    fn harmonic(_x: f64, y: &[f64]) -> Vec<f64> {
158        vec![y[1], -y[0]]
159    }
160
161    #[test]
162    fn euler_harmonic() {
163        let y0 = vec![1.0, 0.0];
164        let (_, ys) = euler(&harmonic, 0.0, &y0, std::f64::consts::PI / 2.0, 10_000);
165        let last = ys.last().unwrap();
166        assert!((last[0] - 0.0).abs() < 1e-3, "y0 = {}", last[0]);
167        assert!((last[1] - (-1.0)).abs() < 1e-3, "y1 = {}", last[1]);
168    }
169
170    #[test]
171    fn rk4_harmonic() {
172        let y0 = vec![1.0, 0.0];
173        let (_, ys) = rk4(&harmonic, 0.0, &y0, std::f64::consts::PI / 2.0, 100);
174        let last = ys.last().unwrap();
175        assert!((last[0]).abs() < 1e-9, "y0 = {}", last[0]);
176        assert!((last[1] + 1.0).abs() < 1e-9, "y1 = {}", last[1]);
177    }
178
179    #[test]
180    fn ab2_harmonic() {
181        let y0 = vec![1.0, 0.0];
182        let (_, ys) = adams_bashforth(&harmonic, 0.0, &y0, std::f64::consts::PI / 2.0, 200);
183        let last = ys.last().unwrap();
184        assert!((last[0]).abs() < 1e-4, "y0 = {}", last[0]);
185        assert!((last[1] + 1.0).abs() < 1e-4, "y1 = {}", last[1]);
186    }
187
188    /// Lotka-Volterra predator-prey (conservation of energy)
189    fn lotka_volterra(_x: f64, y: &[f64]) -> Vec<f64> {
190        let a = 1.0;
191        let b = 1.0;
192        let c = 1.0;
193        let d = 1.0;
194        vec![
195            a * y[0] - b * y[0] * y[1],
196            -c * y[1] + d * y[0] * y[1],
197        ]
198    }
199
200    #[test]
201    fn rk4_lotka_volterra_periodicity() {
202        let y0 = vec![2.0, 2.0];
203        let (xs, ys) = rk4(&lotka_volterra, 0.0, &y0, 10.0, 10_000);
204        // Population should stay positive
205        for y in &ys {
206            assert!(y[0] > 0.0, "prey went negative: {}", y[0]);
207            assert!(y[1] > 0.0, "predator went negative: {}", y[1]);
208        }
209        let _ = xs; // use xs
210    }
211
212    /// Decoupled system: dy0/dx = y0, dy1/dx = -y1
213    #[test]
214    fn euler_decoupled_system() {
215        let f = |_x: f64, y: &[f64]| vec![y[0], -y[1]];
216        let y0 = vec![1.0, 1.0];
217        let (_, ys) = euler(&f, 0.0, &y0, 1.0, 10_000);
218        let last = ys.last().unwrap();
219        assert!((last[0] - 1.0_f64.exp()).abs() < 1e-3);
220        assert!((last[1] - (-1.0_f64).exp()).abs() < 1e-3);
221    }
222
223    #[test]
224    fn rk4_decoupled_system() {
225        let f = |_x: f64, y: &[f64]| vec![y[0], -y[1]];
226        let y0 = vec![1.0, 1.0];
227        let (_, ys) = rk4(&f, 0.0, &y0, 1.0, 100);
228        let last = ys.last().unwrap();
229        assert!((last[0] - 1.0_f64.exp()).abs() < 1e-9);
230        assert!((last[1] - (-1.0_f64).exp()).abs() < 1e-9);
231    }
232
233    /// 3D system: rigid body rotation
234    fn rigid_body(_x: f64, y: &[f64]) -> Vec<f64> {
235        let i1 = 1.0;
236        let i2 = 2.0;
237        let i3 = 3.0;
238        vec![
239            (i2 - i3) / (i1) * y[1] * y[2],
240            (i3 - i1) / (i2) * y[0] * y[2],
241            (i1 - i2) / (i3) * y[0] * y[1],
242        ]
243    }
244
245    #[test]
246    fn rk4_rigid_body_energy_conservation() {
247        let y0 = vec![1.0, 1.0, 1.0];
248        let (xs, ys) = rk4(&rigid_body, 0.0, &y0, 10.0, 10_000);
249        // Energy E = 0.5*(I1*y0² + I2*y1² + I3*y2²) should be conserved
250        let e0 = 0.5 * (1.0 * y0[0].powi(2) + 2.0 * y0[1].powi(2) + 3.0 * y0[2].powi(2));
251        for y in &ys {
252            let e = 0.5 * (1.0 * y[0].powi(2) + 2.0 * y[1].powi(2) + 3.0 * y[2].powi(2));
253            assert!((e - e0).abs() / e0 < 1e-6, "energy drift: {e} vs {e0}");
254        }
255        let _ = xs;
256    }
257
258    #[test]
259    fn constant_system() {
260        let f = |_x: f64, y: &[f64]| vec![0.0; y.len()];
261        let y0 = vec![1.0, 2.0, 3.0];
262        let (_, ys) = rk4(&f, 0.0, &y0, 1.0, 10);
263        let last = ys.last().unwrap();
264        assert!((last[0] - 1.0).abs() < 1e-12);
265        assert!((last[1] - 2.0).abs() < 1e-12);
266        assert!((last[2] - 3.0).abs() < 1e-12);
267    }
268
269    #[test]
270    fn linear_system() {
271        let f = |_x: f64, _y: &[f64]| vec![1.0, 2.0];
272        let y0 = vec![0.0, 0.0];
273        let (_, ys) = rk4(&f, 0.0, &y0, 1.0, 100);
274        let last = ys.last().unwrap();
275        assert!((last[0] - 1.0).abs() < 1e-10);
276        assert!((last[1] - 2.0).abs() < 1e-10);
277    }
278
279    #[test]
280    fn output_length_matches() {
281        let y0 = vec![1.0, 0.0];
282        let (xs, ys) = rk4(&harmonic, 0.0, &y0, 1.0, 50);
283        assert_eq!(xs.len(), 51);
284        assert_eq!(ys.len(), 51);
285        assert_eq!(ys[0].len(), 2);
286    }
287}