1pub fn euler(
13 f: &dyn Fn(f64, &[f64]) -> Vec<f64>,
14 x0: f64,
15 y0: &[f64],
16 x_end: f64,
17 n: usize,
18) -> (Vec<f64>, Vec<Vec<f64>>) {
19 assert!(n >= 1);
20 let m = y0.len();
21 let h = (x_end - x0) / n as f64;
22 let mut xs = Vec::with_capacity(n + 1);
23 let mut ys: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
24 xs.push(x0);
25 ys.push(y0.to_vec());
26 let mut x = x0;
27 let mut y = y0.to_vec();
28 for _ in 0..n {
29 let dy = f(x, &y);
30 for j in 0..m {
31 y[j] += h * dy[j];
32 }
33 x += h;
34 xs.push(x);
35 ys.push(y.clone());
36 }
37 (xs, ys)
38}
39
40pub fn rk4(
42 f: &dyn Fn(f64, &[f64]) -> Vec<f64>,
43 x0: f64,
44 y0: &[f64],
45 x_end: f64,
46 n: usize,
47) -> (Vec<f64>, Vec<Vec<f64>>) {
48 assert!(n >= 1);
49 let m = y0.len();
50 let h = (x_end - x0) / n as f64;
51 let mut xs = Vec::with_capacity(n + 1);
52 let mut ys: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
53 xs.push(x0);
54 ys.push(y0.to_vec());
55 let mut x = x0;
56 let mut y = y0.to_vec();
57 for _ in 0..n {
58 let k1 = f(x, &y);
59 let mut y_tmp = vec![0.0; m];
60 for j in 0..m {
61 y_tmp[j] = y[j] + h * k1[j] / 2.0;
62 }
63 let k2 = f(x + h / 2.0, &y_tmp);
64 for j in 0..m {
65 y_tmp[j] = y[j] + h * k2[j] / 2.0;
66 }
67 let k3 = f(x + h / 2.0, &y_tmp);
68 for j in 0..m {
69 y_tmp[j] = y[j] + h * k3[j];
70 }
71 let k4 = f(x + h, &y_tmp);
72 for j in 0..m {
73 y[j] += (h / 6.0) * (k1[j] + 2.0 * k2[j] + 2.0 * k3[j] + k4[j]);
74 }
75 x += h;
76 xs.push(x);
77 ys.push(y.clone());
78 }
79 (xs, ys)
80}
81
82pub fn adams_bashforth(
84 f: &dyn Fn(f64, &[f64]) -> Vec<f64>,
85 x0: f64,
86 y0: &[f64],
87 x_end: f64,
88 n: usize,
89) -> (Vec<f64>, Vec<Vec<f64>>) {
90 assert!(n >= 1);
91 let m = y0.len();
92 let h = (x_end - x0) / n as f64;
93 let mut xs = Vec::with_capacity(n + 1);
94 let mut ys: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
95 xs.push(x0);
96 ys.push(y0.to_vec());
97
98 if n == 1 {
99 let dy = f(x0, y0);
100 let mut y1 = y0.to_vec();
101 for j in 0..m {
102 y1[j] += h * dy[j];
103 }
104 xs.push(x0 + h);
105 ys.push(y1);
106 return (xs, ys);
107 }
108
109 let k1 = f(x0, y0);
111 let mut y_tmp = vec![0.0; m];
112 for j in 0..m {
113 y_tmp[j] = y0[j] + h * k1[j] / 2.0;
114 }
115 let k2 = f(x0 + h / 2.0, &y_tmp);
116 for j in 0..m {
117 y_tmp[j] = y0[j] + h * k2[j] / 2.0;
118 }
119 let k3 = f(x0 + h / 2.0, &y_tmp);
120 for j in 0..m {
121 y_tmp[j] = y0[j] + h * k3[j];
122 }
123 let k4 = f(x0 + h, &y_tmp);
124 let mut y1 = y0.to_vec();
125 for j in 0..m {
126 y1[j] += (h / 6.0) * (k1[j] + 2.0 * k2[j] + 2.0 * k3[j] + k4[j]);
127 }
128 let x1 = x0 + h;
129 xs.push(x1);
130 ys.push(y1.clone());
131
132 let mut f_prev = k1;
133 let mut x_curr = x1;
134 let mut y_curr = y1;
135
136 for _ in 1..n {
137 let f_curr = f(x_curr, &y_curr);
138 let mut y_next = vec![0.0; m];
139 for j in 0..m {
140 y_next[j] = y_curr[j] + (h / 2.0) * (3.0 * f_curr[j] - f_prev[j]);
141 }
142 x_curr += h;
143 xs.push(x_curr);
144 ys.push(y_next.clone());
145 f_prev = f_curr;
146 y_curr = y_next;
147 }
148
149 (xs, ys)
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 fn harmonic(_x: f64, y: &[f64]) -> Vec<f64> {
158 vec![y[1], -y[0]]
159 }
160
161 #[test]
162 fn euler_harmonic() {
163 let y0 = vec![1.0, 0.0];
164 let (_, ys) = euler(&harmonic, 0.0, &y0, std::f64::consts::PI / 2.0, 10_000);
165 let last = ys.last().unwrap();
166 assert!((last[0] - 0.0).abs() < 1e-3, "y0 = {}", last[0]);
167 assert!((last[1] - (-1.0)).abs() < 1e-3, "y1 = {}", last[1]);
168 }
169
170 #[test]
171 fn rk4_harmonic() {
172 let y0 = vec![1.0, 0.0];
173 let (_, ys) = rk4(&harmonic, 0.0, &y0, std::f64::consts::PI / 2.0, 100);
174 let last = ys.last().unwrap();
175 assert!((last[0]).abs() < 1e-9, "y0 = {}", last[0]);
176 assert!((last[1] + 1.0).abs() < 1e-9, "y1 = {}", last[1]);
177 }
178
179 #[test]
180 fn ab2_harmonic() {
181 let y0 = vec![1.0, 0.0];
182 let (_, ys) = adams_bashforth(&harmonic, 0.0, &y0, std::f64::consts::PI / 2.0, 200);
183 let last = ys.last().unwrap();
184 assert!((last[0]).abs() < 1e-4, "y0 = {}", last[0]);
185 assert!((last[1] + 1.0).abs() < 1e-4, "y1 = {}", last[1]);
186 }
187
188 fn lotka_volterra(_x: f64, y: &[f64]) -> Vec<f64> {
190 let a = 1.0;
191 let b = 1.0;
192 let c = 1.0;
193 let d = 1.0;
194 vec![
195 a * y[0] - b * y[0] * y[1],
196 -c * y[1] + d * y[0] * y[1],
197 ]
198 }
199
200 #[test]
201 fn rk4_lotka_volterra_periodicity() {
202 let y0 = vec![2.0, 2.0];
203 let (xs, ys) = rk4(&lotka_volterra, 0.0, &y0, 10.0, 10_000);
204 for y in &ys {
206 assert!(y[0] > 0.0, "prey went negative: {}", y[0]);
207 assert!(y[1] > 0.0, "predator went negative: {}", y[1]);
208 }
209 let _ = xs; }
211
212 #[test]
214 fn euler_decoupled_system() {
215 let f = |_x: f64, y: &[f64]| vec![y[0], -y[1]];
216 let y0 = vec![1.0, 1.0];
217 let (_, ys) = euler(&f, 0.0, &y0, 1.0, 10_000);
218 let last = ys.last().unwrap();
219 assert!((last[0] - 1.0_f64.exp()).abs() < 1e-3);
220 assert!((last[1] - (-1.0_f64).exp()).abs() < 1e-3);
221 }
222
223 #[test]
224 fn rk4_decoupled_system() {
225 let f = |_x: f64, y: &[f64]| vec![y[0], -y[1]];
226 let y0 = vec![1.0, 1.0];
227 let (_, ys) = rk4(&f, 0.0, &y0, 1.0, 100);
228 let last = ys.last().unwrap();
229 assert!((last[0] - 1.0_f64.exp()).abs() < 1e-9);
230 assert!((last[1] - (-1.0_f64).exp()).abs() < 1e-9);
231 }
232
233 fn rigid_body(_x: f64, y: &[f64]) -> Vec<f64> {
235 let i1 = 1.0;
236 let i2 = 2.0;
237 let i3 = 3.0;
238 vec![
239 (i2 - i3) / (i1) * y[1] * y[2],
240 (i3 - i1) / (i2) * y[0] * y[2],
241 (i1 - i2) / (i3) * y[0] * y[1],
242 ]
243 }
244
245 #[test]
246 fn rk4_rigid_body_energy_conservation() {
247 let y0 = vec![1.0, 1.0, 1.0];
248 let (xs, ys) = rk4(&rigid_body, 0.0, &y0, 10.0, 10_000);
249 let e0 = 0.5 * (1.0 * y0[0].powi(2) + 2.0 * y0[1].powi(2) + 3.0 * y0[2].powi(2));
251 for y in &ys {
252 let e = 0.5 * (1.0 * y[0].powi(2) + 2.0 * y[1].powi(2) + 3.0 * y[2].powi(2));
253 assert!((e - e0).abs() / e0 < 1e-6, "energy drift: {e} vs {e0}");
254 }
255 let _ = xs;
256 }
257
258 #[test]
259 fn constant_system() {
260 let f = |_x: f64, y: &[f64]| vec![0.0; y.len()];
261 let y0 = vec![1.0, 2.0, 3.0];
262 let (_, ys) = rk4(&f, 0.0, &y0, 1.0, 10);
263 let last = ys.last().unwrap();
264 assert!((last[0] - 1.0).abs() < 1e-12);
265 assert!((last[1] - 2.0).abs() < 1e-12);
266 assert!((last[2] - 3.0).abs() < 1e-12);
267 }
268
269 #[test]
270 fn linear_system() {
271 let f = |_x: f64, _y: &[f64]| vec![1.0, 2.0];
272 let y0 = vec![0.0, 0.0];
273 let (_, ys) = rk4(&f, 0.0, &y0, 1.0, 100);
274 let last = ys.last().unwrap();
275 assert!((last[0] - 1.0).abs() < 1e-10);
276 assert!((last[1] - 2.0).abs() < 1e-10);
277 }
278
279 #[test]
280 fn output_length_matches() {
281 let y0 = vec![1.0, 0.0];
282 let (xs, ys) = rk4(&harmonic, 0.0, &y0, 1.0, 50);
283 assert_eq!(xs.len(), 51);
284 assert_eq!(ys.len(), 51);
285 assert_eq!(ys[0].len(), 2);
286 }
287}