1#[derive(Debug, Clone)]
8pub struct AdaptiveResult {
9 pub xs: Vec<f64>,
11 pub ys: Vec<f64>,
13 pub rejected_steps: usize,
15 pub accepted_steps: usize,
17}
18
19#[allow(clippy::too_many_arguments)]
37pub fn solve_adaptive(
38 f: &dyn Fn(f64, f64) -> f64,
39 x0: f64,
40 y0: f64,
41 x_end: f64,
42 h_init: f64,
43 tol: f64,
44 h_min: f64,
45 h_max: f64,
46 max_steps: usize,
47) -> AdaptiveResult {
48 const A: [[f64; 5]; 5] = [
50 [1.0 / 5.0, 0.0, 0.0, 0.0, 0.0],
51 [3.0 / 40.0, 9.0 / 40.0, 0.0, 0.0, 0.0],
52 [44.0 / 45.0, -56.0 / 15.0, 32.0 / 9.0, 0.0, 0.0],
53 [19372.0 / 6561.0, -25360.0 / 2187.0, 64448.0 / 6561.0, -212.0 / 729.0, 0.0],
54 [9017.0 / 3168.0, -355.0 / 33.0, 46732.0 / 5247.0, 49.0 / 176.0, -5103.0 / 18656.0],
55 ];
56 const B5: [f64; 6] = [35.0 / 384.0, 0.0, 500.0 / 1113.0, 125.0 / 192.0, -2187.0 / 6784.0, 11.0 / 84.0];
57 const E: [f64; 7] = [
58 71.0 / 57600.0,
59 0.0,
60 -71.0 / 16695.0,
61 71.0 / 1920.0,
62 -17253.0 / 339200.0,
63 22.0 / 525.0,
64 -1.0 / 40.0,
65 ];
66
67 let mut xs = vec![x0];
68 let mut ys = vec![y0];
69 let mut x = x0;
70 let mut y = y0;
71 let sign = if x_end >= x0 { 1.0 } else { -1.0 };
73 let mut h = sign * h_init.abs();
74 let h_min_s = sign * h_min.abs();
75 let h_max_s = sign * h_max.abs();
76 let mut accepted = 0usize;
77 let mut rejected = 0usize;
78 let mut total = 0usize;
79
80 loop {
81 if total >= max_steps {
82 break;
83 }
84 let remaining = x_end - x;
86 if sign * (remaining - h) < 0.0 {
87 h = remaining;
88 }
89 if h.abs() < h_min_s.abs() {
90 h = h_min_s;
91 }
92
93 let k1 = f(x, y);
94 let k2 = f(x + h * A[0][0], y + h * A[0][0] * k1);
95 let k3 = f(
96 x + h * (A[1][0] + A[1][1]),
97 y + h * (A[1][0] * k1 + A[1][1] * k2),
98 );
99 let k4 = f(
100 x + h * (A[2][0] + A[2][1] + A[2][2]),
101 y + h * (A[2][0] * k1 + A[2][1] * k2 + A[2][2] * k3),
102 );
103 let k5 = f(
104 x + h * (A[3][0] + A[3][1] + A[3][2] + A[3][3]),
105 y + h * (A[3][0] * k1 + A[3][1] * k2 + A[3][2] * k3 + A[3][3] * k4),
106 );
107 let k6 = f(
108 x + h * (A[4][0] + A[4][1] + A[4][2] + A[4][3] + A[4][4]),
109 y + h * (A[4][0] * k1 + A[4][1] * k2 + A[4][2] * k3 + A[4][3] * k4 + A[4][4] * k5),
110 );
111
112 let y_new = y + h * (B5[0] * k1 + B5[1] * k2 + B5[2] * k3 + B5[3] * k4 + B5[4] * k5 + B5[5] * k6);
113 let k7 = f(x + h, y_new);
114 let err = h * (E[0] * k1 + E[1] * k2 + E[2] * k3 + E[3] * k4 + E[4] * k5 + E[5] * k6 + E[6] * k7);
115 let err_norm = err.abs();
116 total += 1;
117
118 if err_norm <= tol || h.abs() <= h_min_s.abs() {
119 x += h;
120 y = y_new;
121 xs.push(x);
122 ys.push(y);
123 accepted += 1;
124
125 if err_norm > 0.0 {
126 let scale = 0.9 * (tol / err_norm).powf(0.2);
127 h *= scale.min(5.0);
128 } else {
129 h *= 2.0;
130 }
131 if h.abs() > h_max_s.abs() {
132 h = h_max_s;
133 }
134
135 if (x_end - x).abs() < 1e-14 * x_end.abs().max(1.0) {
136 break;
137 }
138 } else {
139 rejected += 1;
140 let scale = 0.9 * (tol / err_norm).powf(0.2);
141 h *= scale.max(0.1);
142 if h.abs() < h_min_s.abs() {
143 h = h_min_s;
144 }
145 }
146 }
147
148 AdaptiveResult {
149 xs,
150 ys,
151 rejected_steps: rejected,
152 accepted_steps: accepted,
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn exponential_growth() {
162 let f = |_x: f64, y: f64| y;
163 let res = solve_adaptive(&f, 0.0, 1.0, 1.0, 0.1, 1e-8, 1e-10, 0.5, 10_000);
164 let exact = 1.0_f64.exp();
165 let last = *res.ys.last().unwrap();
166 assert!((last - exact).abs() < 1e-6, "got {last}, exact {exact}");
167 }
168
169 #[test]
170 fn constant_rhs() {
171 let f = |_x: f64, _y: f64| 0.0;
172 let res = solve_adaptive(&f, 0.0, 42.0, 1.0, 0.1, 1e-6, 1e-10, 0.5, 1000);
173 let last = *res.ys.last().unwrap();
174 assert!((last - 42.0).abs() < 1e-12);
175 }
176
177 #[test]
178 fn linear_rhs() {
179 let f = |_x: f64, _y: f64| 5.0;
180 let res = solve_adaptive(&f, 0.0, 0.0, 2.0, 0.1, 1e-8, 1e-10, 0.5, 1000);
181 let last = *res.ys.last().unwrap();
182 assert!((last - 10.0).abs() < 1e-8);
183 }
184
185 #[test]
186 fn sinusoidal_rhs() {
187 let f = |x: f64, _y: f64| x.cos();
188 let res = solve_adaptive(&f, 0.0, 0.0, std::f64::consts::PI / 2.0, 0.1, 1e-10, 1e-12, 0.2, 10_000);
189 let last = *res.ys.last().unwrap();
190 assert!((last - 1.0).abs() < 1e-8);
191 }
192
193 #[test]
194 fn exponential_decay() {
195 let f = |_x: f64, y: f64| -y;
196 let res = solve_adaptive(&f, 0.0, 1.0, 2.0, 0.1, 1e-8, 1e-10, 0.5, 10_000);
197 let last = *res.ys.last().unwrap();
198 assert!((last - (-2.0_f64).exp()).abs() < 1e-6);
199 }
200
201 #[test]
202 fn adaptive_reduces_steps_for_easy_problems() {
203 let f = |_x: f64, _y: f64| 0.0;
204 let res = solve_adaptive(&f, 0.0, 1.0, 10.0, 0.1, 1e-6, 1e-10, 1.0, 10_000);
205 assert!(res.accepted_steps < 20, "took {} steps for trivial problem", res.accepted_steps);
206 }
207
208 #[test]
209 fn quadratic_rhs() {
210 let f = |x: f64, _y: f64| 2.0 * x;
211 let res = solve_adaptive(&f, 0.0, 0.0, 3.0, 0.1, 1e-10, 1e-12, 0.5, 10_000);
212 let last = *res.ys.last().unwrap();
213 assert!((last - 9.0).abs() < 1e-8);
214 }
215
216 #[test]
217 fn result_contains_initial_point() {
218 let f = |_x: f64, y: f64| y;
219 let res = solve_adaptive(&f, 0.0, 1.0, 1.0, 0.1, 1e-6, 1e-10, 0.5, 1000);
220 assert_eq!(res.xs[0], 0.0);
221 assert_eq!(res.ys[0], 1.0);
222 }
223
224 #[test]
225 fn tol_affects_accuracy() {
226 let f = |_x: f64, y: f64| y;
227 let exact = 1.0_f64.exp();
228 let res_loose = solve_adaptive(&f, 0.0, 1.0, 1.0, 0.1, 1e-4, 1e-10, 0.5, 10_000);
229 let res_tight = solve_adaptive(&f, 0.0, 1.0, 1.0, 0.1, 1e-12, 1e-14, 0.5, 10_000);
230 let e_loose = (res_loose.ys.last().unwrap() - exact).abs();
231 let e_tight = (res_tight.ys.last().unwrap() - exact).abs();
232 assert!(e_tight < e_loose, "tight tol should be more accurate: {e_tight} vs {e_loose}");
233 }
234
235 #[test]
236 fn stiff_problem() {
237 let f = |_x: f64, y: f64| -50.0 * y;
238 let res = solve_adaptive(&f, 0.0, 1.0, 0.1, 0.001, 1e-6, 1e-8, 0.01, 100_000);
239 let last = *res.ys.last().unwrap();
240 let exact = (-5.0_f64).exp();
241 assert!((last - exact).abs() < 1e-4, "got {last}, exact {exact}");
242 }
243
244 #[test]
245 fn negative_direction() {
246 let f = |_x: f64, y: f64| y;
247 let res = solve_adaptive(&f, 1.0, 1.0_f64.exp(), 0.0, 0.1, 1e-8, 1e-12, 0.5, 10_000);
248 let last = *res.ys.last().unwrap();
249 assert!((last - 1.0).abs() < 1e-5, "got {last}");
250 }
251}