differential_equations/ode/methods/runge_kutta/implicit/adaptive_step.rs
1//! Adaptive step size implicit Runge-Kutta methods for solving ordinary differential equations.
2
3/// Macro to create an adaptive implicit Runge-Kutta solver from a Butcher tableau.
4///
5/// This macro generates the necessary struct and trait implementations for an adaptive-step
6/// implicit Runge-Kutta method. It uses Newton's iteration to solve the
7/// implicit stage equations and estimates the error by comparing the result from the
8/// primary `b` weights with a secondary set of weights `b_hat`.
9///
10/// # Arguments
11///
12/// * `name`: Name of the solver struct to create
13/// * `a`: Matrix of coefficients for intermediate stages (can be non-zero on diagonal/upper triangle)
14/// * `b`: 2D array where the first row is the primary weights (`b`) and the second row is the secondary weights (`b_hat`) for error estimation.
15/// * `c`: Time offsets for each stage
16/// * `order`: Order of accuracy of the primary method (used for step size control)
17/// * `stages`: Number of stages in the method
18///
19/// # Note on Solver and Error Estimation
20/// - The implicit stage equations `k_i = f(t_n + c_i*h, y_n + h * sum(a_{ij}*k_j))` are solved
21/// using Newton's iteration. This requires the ODE system to provide its Jacobian.
22/// - Error estimation uses the difference between solutions computed with `b` and `b_hat`.
23/// The validity of `b_hat` as an error estimator depends on the specific method's tableau.
24/// For methods like Gauss-Legendre, this might not be the standard approach.
25///
26/// # Example (Illustrative - Requires a valid tableau with error estimator)
27/// ```rust
28/// // Assuming a hypothetical 2-stage, 2nd order implicit method with error estimator
29/// /*
30/// use differential_equations::adaptive_implicit_runge_kutta_method;
31/// adaptive_implicit_runge_kutta_method!(
32/// name: AdaptiveImplicitExample,
33/// a: [[0.5, 0.0], [0.5, 0.5]], // Example 'a' matrix
34/// b: [
35/// [0.5, 0.5], // Primary weights (e.g., order 2)
36/// [1.0, 0.0] // Secondary weights (e.g., order 1)
37/// ],
38/// c: [0.5, 1.0],
39/// order: 2,
40/// stages: 2
41/// );
42/// */
43/// ```
44#[macro_export]
45macro_rules! adaptive_implicit_runge_kutta_method {
46 (
47 $(#[$attr:meta])*
48 name: $name:ident,
49 a: $a:expr,
50 b: $b:expr,
51 c: $c:expr,
52 order: $order:expr,
53 stages: $stages:expr
54 $(,)? // Optional trailing comma
55 ) => {
56 $(#[$attr])*
57 #[doc = "\n\n"]
58 #[doc = "This adaptive implicit solver was automatically generated using the `adaptive_implicit_runge_kutta_method` macro."]
59 #[doc = " It uses Newton iteration and embedded error estimation (via b/b_hat vectors)."]
60 #[doc = " The ODE system itself must provide the Jacobian via the `ODE` trait if `use_analytical_jacobian` is true (default)."]
61 #[doc = " Otherwise, finite differences are used to approximate the Jacobian."]
62 pub struct $name<
63 T: $crate::traits::Real,
64 V: $crate::traits::State<T>,
65 D: $crate::traits::CallBackData,
66 > {
67 // Initial Step Size
68 pub h0: T,
69 // Current Step Size
70 h: T,
71
72 // Current State
73 t: T,
74 y: V,
75 dydt: V, // Derivative at t
76
77 // Previous State
78 t_prev: T,
79 y_prev: V,
80 dydt_prev: V, // Derivative at t_prev
81
82 // Stage derivatives (k_i)
83 k: [V; $stages],
84 // Temporary storage for stage values during iteration
85 y_stage: [V; $stages],
86 f_at_stages: [V; $stages], // Stores f(t_stage, y_stage) during Newton iteration
87
88 // Constants from Butcher tableau (fixed size arrays)
89 a: [[T; $stages]; $stages],
90 b_higher: [T; $stages], // Primary weights (b)
91 b_lower: [T; $stages], // Secondary weights (b_hat) for error estimation
92 c: [T; $stages],
93
94 // --- Adaptive Step Settings ---
95 pub rtol: T,
96 pub atol: T,
97 pub h_max: T,
98 pub h_min: T,
99 pub max_steps: usize,
100 pub max_rejects: usize,
101 pub safety_factor: T,
102 pub min_scale: T,
103 pub max_scale: T,
104
105 // --- Implicit Solver Settings ---
106 pub max_iter: usize, // Max iterations for Newton solver
107 pub tol: T, // Tolerance for Newton solver convergence
108 fd_epsilon_sqrt: T, // Stores sqrt(machine_epsilon) for FD
109
110 // Iteration tracking & Status
111 reject: bool,
112 n_stiff: usize,
113 steps: usize,
114 status: $crate::Status<T, V, D>,
115
116 // --- Jacobian and Newton Solver Data ---
117 jacobian_matrix: nalgebra::DMatrix<T>, // Jacobian of f: J(t,y)
118 newton_matrix: nalgebra::DMatrix<T>, // Matrix for Newton system (M)
119 rhs_newton: nalgebra::DVector<T>, // RHS vector for Newton system (-phi)
120 delta_k_vec: nalgebra::DVector<T>, // Solution of Newton system (delta_k)
121 }
122
123 impl<
124 T: $crate::traits::Real,
125 V: $crate::traits::State<T>,
126 D: $crate::traits::CallBackData,
127 > Default for $name<T, V, D> {
128 fn default() -> Self {
129 // Convert Butcher tableau values to type T
130 let a_t: [[T; $stages]; $stages] = $a.map(|row| row.map(|x| T::from_f64(x).unwrap()));
131 let b_higher_t: [T; $stages] = $b[0].map(|x| T::from_f64(x).unwrap());
132 let b_lower_t: [T; $stages] = $b[1].map(|x| T::from_f64(x).unwrap());
133 let c_t: [T; $stages] = $c.map(|x| T::from_f64(x).unwrap());
134
135 $name {
136 h0: T::zero(), // Indicate auto-calculation
137 h: T::zero(),
138 t: T::zero(),
139 y: V::zeros(),
140 dydt: V::zeros(),
141 t_prev: T::zero(),
142 y_prev: V::zeros(),
143 dydt_prev: V::zeros(),
144 k: [V::zeros(); $stages],
145 y_stage: [V::zeros(); $stages],
146 f_at_stages: [V::zeros(); $stages],
147 a: a_t,
148 b_higher: b_higher_t,
149 b_lower: b_lower_t,
150 c: c_t,
151 // Adaptive defaults
152 rtol: T::from_f64(1.0e-6).unwrap(),
153 atol: T::from_f64(1.0e-6).unwrap(),
154 h_max: T::infinity(),
155 h_min: T::zero(),
156 max_steps: 10000,
157 max_rejects: 100,
158 safety_factor: T::from_f64(0.9).unwrap(),
159 min_scale: T::from_f64(0.2).unwrap(),
160 max_scale: T::from_f64(10.0).unwrap(),
161 // Implicit defaults
162 max_iter: 50,
163 tol: T::from_f64(1e-8).unwrap(),
164 fd_epsilon_sqrt: T::zero(),
165 // Status
166 reject: false,
167 n_stiff: 0,
168 steps: 0,
169 status: $crate::Status::Uninitialized,
170 // Initialize nalgebra structures (empty, to be sized in init)
171 jacobian_matrix: nalgebra::DMatrix::zeros(0, 0),
172 newton_matrix: nalgebra::DMatrix::zeros(0, 0),
173 rhs_newton: nalgebra::DVector::zeros(0),
174 delta_k_vec: nalgebra::DVector::zeros(0),
175 }
176 }
177 }
178
179 impl<
180 T: $crate::traits::Real,
181 V: $crate::traits::State<T>,
182 D: $crate::traits::CallBackData,
183 > $crate::ode::ODENumericalMethod<T, V, D> for $name<T, V, D> {
184 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
185 where
186 F: $crate::ode::ODE<T, V, D>, // ODE trait now includes Jacobian
187 {
188 let mut evals = $crate::alias::Evals::new();
189
190 // Calculate initial derivative f(t0, y0)
191 let mut initial_dydt = V::zeros();
192 ode.diff(t0, y0, &mut initial_dydt);
193 evals.fcn += 1;
194
195 // If h0 is zero calculate h0 using initial derivative
196 if self.h0 == T::zero() {
197 self.h0 = $crate::ode::methods::h_init(ode, t0, tf, y0, $order, self.rtol, self.atol, self.h_min, self.h_max);
198 }
199
200 // Check bounds
201 self.h = $crate::utils::validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf)?;
202
203 // Initialize Statistics
204 self.reject = false;
205 self.n_stiff = 0;
206 self.steps = 0;
207
208 // Initialize State
209 self.t = t0;
210 self.y = *y0;
211 self.dydt = initial_dydt; // Store f(t0, y0)
212
213 // Initialize previous state (same as current initially)
214 self.t_prev = t0;
215 self.y_prev = *y0;
216 self.dydt_prev = initial_dydt;
217
218 // Initialize fd_epsilon_sqrt
219 self.fd_epsilon_sqrt = T::default_epsilon().sqrt();
220
221 // Initialize Status
222 self.status = $crate::Status::Initialized;
223
224 // Initialize Jacobian and Newton-related matrices/vectors with correct dimensions
225 let dim = y0.len();
226 self.jacobian_matrix = nalgebra::DMatrix::zeros(dim, dim);
227 let newton_system_size = $stages * dim;
228 self.newton_matrix = nalgebra::DMatrix::zeros(newton_system_size, newton_system_size);
229 self.rhs_newton = nalgebra::DVector::zeros(newton_system_size);
230 self.delta_k_vec = nalgebra::DVector::zeros(newton_system_size);
231 self.f_at_stages = [V::zeros(); $stages];
232
233 Ok(evals)
234 }
235
236 fn step<F>(&mut self, ode: &F) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
237 where
238 F: $crate::ode::ODE<T, V, D>, // ODE trait now includes Jacobian
239 {
240 let mut evals = $crate::alias::Evals::new();
241 let dim = self.y.len();
242
243 // Check step size validity
244 if self.h.abs() < self.h_min || self.h.abs() < T::default_epsilon() {
245 self.status = $crate::Status::Error($crate::Error::StepSize { t: self.t, y: self.y });
246 return Err($crate::Error::StepSize { t: self.t, y: self.y });
247 }
248
249 // Check max steps
250 if self.steps >= self.max_steps {
251 self.status = $crate::Status::Error($crate::Error::MaxSteps { t: self.t, y: self.y });
252 return Err($crate::Error::MaxSteps { t: self.t, y: self.y });
253 }
254 self.steps += 1;
255
256 // --- Newton Iteration for stage derivatives k_i ---
257 // Initial guess for k_i: k_i^(0) = f(t_n, y_n) (stored in self.dydt)
258 for i in 0..$stages {
259 self.k[i] = self.dydt;
260 }
261
262 // Calculate Jacobian J_n = df/dy(t_n, y_n) once per step attempt
263 ode.jacobian(self.t, &self.y, &mut self.jacobian_matrix);
264 evals.jac += 1;
265
266 let mut converged = false;
267 for _iter in 0..self.max_iter {
268 // 1. Compute residual phi(K_current) and store -phi in rhs_newton
269 for i in 0..$stages {
270 self.y_stage[i] = self.y; // y_n
271 for j in 0..$stages {
272 self.y_stage[i] += self.k[j] * (self.a[i][j] * self.h);
273 }
274
275 ode.diff(self.t + self.c[i] * self.h, &self.y_stage[i], &mut self.f_at_stages[i]);
276 evals.fcn += 1;
277
278 for row_idx in 0..dim {
279 self.rhs_newton[i * dim + row_idx] = self.f_at_stages[i].get(row_idx) - self.k[i].get(row_idx);
280 }
281 }
282
283 // 2. Form Newton matrix M
284 for i in 0..$stages { // block row index
285 for l in 0..$stages { // block column index
286 let scale_factor = -self.h * self.a[i][l];
287 for r in 0..dim { // row index within the block
288 for c_col in 0..dim { // column index within the block (renamed from c to avoid conflict)
289 // Direct assignment to the element in newton_matrix
290 self.newton_matrix[(i * dim + r, l * dim + c_col)] =
291 self.jacobian_matrix[(r, c_col)] * scale_factor;
292 }
293 }
294
295 if i == l { // If it's a diagonal block, add Identity
296 for d_idx in 0..dim { // index for the diagonal of the block
297 self.newton_matrix[(i * dim + d_idx, l * dim + d_idx)] += T::one();
298 }
299 }
300 }
301 }
302
303 // 3. Solve M * delta_k_vec = rhs_newton
304 let lu_decomp = nalgebra::LU::new(self.newton_matrix.clone());
305 if let Some(solution) = lu_decomp.solve(&self.rhs_newton) {
306 self.delta_k_vec.copy_from(&solution);
307 } else {
308 converged = false;
309 break;
310 }
311
312 // 4. Update K: self.k[i] += delta_k_vec_i
313 let mut norm_delta_k_sq = T::zero();
314 for i in 0..$stages {
315 for row_idx in 0..dim {
316 let delta_val = self.delta_k_vec[i * dim + row_idx];
317 let current_val = self.k[i].get(row_idx);
318 self.k[i].set(row_idx, current_val + delta_val);
319 norm_delta_k_sq += delta_val * delta_val;
320 }
321 }
322
323 // 5. Check convergence: ||delta_k_vec|| < self.tol
324 if norm_delta_k_sq < self.tol * self.tol {
325 converged = true;
326 break;
327 }
328 }
329
330 if !converged {
331 self.h *= T::from_f64(0.25).unwrap();
332 self.h = $crate::utils::constrain_step_size(self.h, self.h_min, self.h_max);
333 self.reject = true;
334 self.n_stiff += 1;
335
336 if self.n_stiff >= self.max_rejects {
337 self.status = $crate::Status::Error($crate::Error::Stiffness { t: self.t, y: self.y });
338 return Err($crate::Error::Stiffness { t: self.t, y: self.y });
339 }
340 return Ok(evals);
341 }
342
343 // --- Iteration converged, compute solutions and error ---
344 let mut delta_y_high = V::zeros();
345 for i in 0..$stages {
346 delta_y_high += self.k[i] * (self.b_higher[i] * self.h);
347 }
348 let y_high = self.y + delta_y_high;
349
350 let mut delta_y_low = V::zeros();
351 for i in 0..$stages {
352 delta_y_low += self.k[i] * (self.b_lower[i] * self.h);
353 }
354 let y_low = self.y + delta_y_low;
355
356 let err = y_high - y_low;
357
358 let mut err_norm = T::zero();
359 for n in 0..self.y.len() {
360 let scale = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
361 if scale > T::zero() {
362 err_norm = err_norm.max((err.get(n) / scale).abs());
363 }
364 }
365 err_norm = err_norm.max(T::default_epsilon() * T::from_f64(100.0).unwrap());
366
367 let order_inv = T::one() / T::from_usize($order).unwrap();
368 let mut scale = self.safety_factor * err_norm.powf(-order_inv);
369 scale = scale.max(self.min_scale).min(self.max_scale);
370 let h_new = self.h * scale;
371
372 if err_norm <= T::one() {
373 self.status = $crate::Status::Solving;
374
375 self.t_prev = self.t;
376 self.y_prev = self.y;
377 self.dydt_prev = self.dydt;
378
379 self.t += self.h;
380 self.y = y_high;
381
382 ode.diff(self.t, &self.y, &mut self.dydt);
383 evals.fcn += 1;
384
385 if self.reject {
386 self.n_stiff = 0;
387 self.reject = false;
388 }
389
390 self.h = $crate::utils::constrain_step_size(h_new, self.h_min, self.h_max);
391 } else {
392 self.status = $crate::Status::RejectedStep;
393 self.reject = true;
394 self.n_stiff += 1;
395
396 if self.n_stiff >= self.max_rejects {
397 self.status = $crate::Status::Error($crate::Error::Stiffness { t: self.t, y: self.y });
398 return Err($crate::Error::Stiffness { t: self.t, y: self.y });
399 }
400
401 self.h = $crate::utils::constrain_step_size(h_new, self.h_min, self.h_max);
402 return Ok(evals);
403 }
404
405 Ok(evals)
406 }
407
408 fn t(&self) -> T { self.t }
409 fn y(&self) -> &V { &self.y }
410 fn t_prev(&self) -> T { self.t_prev }
411 fn y_prev(&self) -> &V { &self.y_prev }
412 fn h(&self) -> T { self.h }
413 fn set_h(&mut self, h: T) { self.h = h; }
414 fn status(&self) -> &$crate::Status<T, V, D> { &self.status }
415 fn set_status(&mut self, status: $crate::Status<T, V, D>) { self.status = status; }
416 }
417
418 impl<
419 T: $crate::traits::Real,
420 V: $crate::traits::State<T>,
421 D: $crate::traits::CallBackData,
422 > $crate::interpolate::Interpolation<T, V> for $name<T, V, D> {
423 fn interpolate(&mut self, t_interp: T) -> Result<V, $crate::Error<T, V>> {
424 if self.t == self.t_prev {
425 if t_interp == self.t_prev {
426 return Ok(self.y_prev);
427 } else {
428 return Err($crate::Error::OutOfBounds { t_interp, t_prev: self.t_prev, t_curr: self.t });
429 }
430 }
431 if t_interp < self.t_prev || t_interp > self.t {
432 return Err($crate::Error::OutOfBounds {
433 t_interp,
434 t_prev: self.t_prev,
435 t_curr: self.t });
436 }
437
438 let y_interp = $crate::interpolate::cubic_hermite_interpolate(
439 self.t_prev, self.t,
440 &self.y_prev, &self.y,
441 &self.dydt_prev, &self.dydt,
442 t_interp
443 );
444
445 Ok(y_interp)
446 }
447 }
448
449// --- Builder Pattern Methods ---
450 impl<
451 T: $crate::traits::Real,
452 V: $crate::traits::State<T>,
453 D: $crate::traits::CallBackData,
454 > $name<T, V, D> {
455 pub fn new() -> Self {
456 Self::default()
457 }
458
459 pub fn h0(mut self, h0: T) -> Self { self.h0 = h0; self }
460 pub fn rtol(mut self, rtol: T) -> Self { self.rtol = rtol; self }
461 pub fn atol(mut self, atol: T) -> Self { self.atol = atol; self }
462 pub fn h_min(mut self, h_min: T) -> Self { self.h_min = h_min; self }
463 pub fn h_max(mut self, h_max: T) -> Self { self.h_max = h_max; self }
464 pub fn max_steps(mut self, max_steps: usize) -> Self { self.max_steps = max_steps; self }
465 pub fn max_rejects(mut self, max_rejects: usize) -> Self { self.max_rejects = max_rejects; self }
466 pub fn safety_factor(mut self, safety_factor: T) -> Self { self.safety_factor = safety_factor; self }
467 pub fn min_scale(mut self, min_scale: T) -> Self { self.min_scale = min_scale; self }
468 pub fn max_scale(mut self, max_scale: T) -> Self { self.max_scale = max_scale; self }
469 pub fn max_iter(mut self, iter: usize) -> Self { self.max_iter = iter; self }
470 pub fn tol(mut self, tol: T) -> Self { self.tol = tol; self }
471 }
472 };
473}
474
475const SQRT3: f64 = 1.732050808;
476const SQRT15: f64 = 3.872983346;
477
478adaptive_implicit_runge_kutta_method!(
479 /// Gauss-Legendre method of order 4.
480 ///
481 /// This is a 2-stage implicit Runge-Kutta method.
482 /// It is A-stable and self-adjoint.
483 /// The error estimation is based on the second 'b' row provided in the tableau,
484 /// which corresponds to simplifying order conditions rather than a standard
485 /// embedded lower-order method. Use with caution for adaptive stepping.
486 ///
487 /// Butcher Tableau:
488 /// ```text
489 /// c1 | a11 a12
490 /// c2 | a21 a22
491 /// -------------
492 /// | b1 b2 (Order 4)
493 /// | bh1 bh2 (Simplifying conditions)
494 ///
495 /// c1 = 1/2 - sqrt(3)/6, c2 = 1/2 + sqrt(3)/6
496 /// a11 = 1/4, a12 = 1/4 - sqrt(3)/6
497 /// a21 = 1/4 + sqrt(3)/6, a22 = 1/4
498 /// b1 = 1/2, b2 = 1/2
499 /// bh1 = 1/2 + sqrt(3)/2, bh2 = 1/2 - sqrt(3)/2
500 /// ```
501 name: GaussLegendre4,
502 a: [
503 [0.25, 0.25 - SQRT3 / 6.0],
504 [0.25 + SQRT3 / 6.0, 0.25]
505 ],
506 b: [
507 [0.5, 0.5],
508 [0.5 + SQRT3 / 2.0, 0.5 - SQRT3 / 2.0]
509 ],
510 c: [0.5 - SQRT3 / 6.0, 0.5 + SQRT3 / 6.0],
511 order: 4,
512 stages: 2
513);
514
515adaptive_implicit_runge_kutta_method!(
516 /// Gauss-Legendre method of order 6.
517 ///
518 /// This is a 3-stage implicit Runge-Kutta method.
519 /// It is A-stable and self-adjoint.
520 /// The error estimation is based on the second 'b' row provided in the tableau,
521 /// which corresponds to simplifying order conditions rather than a standard
522 /// embedded lower-order method. Use with caution for adaptive stepping.
523 ///
524 /// Butcher Tableau:
525 /// ```text
526 /// c1 | a11 a12 a13
527 /// c2 | a21 a22 a23
528 /// c3 | a31 a32 a33
529 /// -----------------
530 /// | b1 b2 b3 (Order 6)
531 /// | bh1 bh2 bh3 (Simplifying conditions)
532 ///
533 /// c1 = 1/2 - sqrt(15)/10, c2 = 1/2, c3 = 1/2 + sqrt(15)/10
534 /// a11 = 5/36, a12 = 2/9 - sqrt(15)/15, a13 = 5/36 - sqrt(15)/30
535 /// a21 = 5/36 + sqrt(15)/24, a22 = 2/9, a23 = 5/36 - sqrt(15)/24
536 /// a31 = 5/36 + sqrt(15)/30, a32 = 2/9 + sqrt(15)/15, a33 = 5/36
537 /// b1 = 5/18, b2 = 4/9, b3 = 5/18
538 /// bh1 = -5/6, bh2 = 8/3, bh3 = -5/6
539 /// ```
540 name: GaussLegendre6,
541 a: [
542 [5.0/36.0, 2.0/9.0 - SQRT15/15.0, 5.0/36.0 - SQRT15/30.0],
543 [5.0/36.0 + SQRT15/24.0, 2.0/9.0, 5.0/36.0 - SQRT15/24.0],
544 [5.0/36.0 + SQRT15/30.0, 2.0/9.0 + SQRT15/15.0, 5.0/36.0]
545 ],
546 b: [
547 [5.0/18.0, 4.0/9.0, 5.0/18.0],
548 [-5.0/6.0, 8.0/3.0, -5.0/6.0]
549 ],
550 c: [0.5 - SQRT15/10.0, 0.5, 0.5 + SQRT15/10.0],
551 order: 6,
552 stages: 3
553);