differential_equations/ode/methods/runge_kutta/implicit/
adaptive_step.rs

1//! Adaptive step size implicit Runge-Kutta methods for solving ordinary differential equations.
2
3/// Macro to create an adaptive implicit Runge-Kutta solver from a Butcher tableau.
4///
5/// This macro generates the necessary struct and trait implementations for an adaptive-step
6/// implicit Runge-Kutta method. It uses Newton's iteration to solve the
7/// implicit stage equations and estimates the error by comparing the result from the
8/// primary `b` weights with a secondary set of weights `b_hat`.
9///
10/// # Arguments
11///
12/// * `name`: Name of the solver struct to create
13/// * `a`: Matrix of coefficients for intermediate stages (can be non-zero on diagonal/upper triangle)
14/// * `b`: 2D array where the first row is the primary weights (`b`) and the second row is the secondary weights (`b_hat`) for error estimation.
15/// * `c`: Time offsets for each stage
16/// * `order`: Order of accuracy of the primary method (used for step size control)
17/// * `stages`: Number of stages in the method
18///
19/// # Note on Solver and Error Estimation
20/// - The implicit stage equations `k_i = f(t_n + c_i*h, y_n + h * sum(a_{ij}*k_j))` are solved
21///   using Newton's iteration. This requires the ODE system to provide its Jacobian.
22/// - Error estimation uses the difference between solutions computed with `b` and `b_hat`.
23///   The validity of `b_hat` as an error estimator depends on the specific method's tableau.
24///   For methods like Gauss-Legendre, this might not be the standard approach.
25///
26/// # Example (Illustrative - Requires a valid tableau with error estimator)
27/// ```rust
28/// // Assuming a hypothetical 2-stage, 2nd order implicit method with error estimator
29/// /*
30/// use differential_equations::adaptive_implicit_runge_kutta_method;
31/// adaptive_implicit_runge_kutta_method!(
32///     name: AdaptiveImplicitExample,
33///     a: [[0.5, 0.0], [0.5, 0.5]], // Example 'a' matrix
34///     b: [
35///         [0.5, 0.5], // Primary weights (e.g., order 2)
36///         [1.0, 0.0]  // Secondary weights (e.g., order 1)
37///     ],
38///     c: [0.5, 1.0],
39///     order: 2,
40///     stages: 2
41/// );
42/// */
43/// ```
44#[macro_export]
45macro_rules! adaptive_implicit_runge_kutta_method {
46    (
47        $(#[$attr:meta])*
48        name: $name:ident,
49        a: $a:expr,
50        b: $b:expr,
51        c: $c:expr,
52        order: $order:expr,
53        stages: $stages:expr
54        $(,)? // Optional trailing comma
55    ) => {
56        $(#[$attr])*
57        #[doc = "\n\n"]
58        #[doc = "This adaptive implicit solver was automatically generated using the `adaptive_implicit_runge_kutta_method` macro."]
59        #[doc = " It uses Newton iteration and embedded error estimation (via b/b_hat vectors)."]
60        #[doc = " The ODE system itself must provide the Jacobian via the `ODE` trait if `use_analytical_jacobian` is true (default)."]
61        #[doc = " Otherwise, finite differences are used to approximate the Jacobian."]
62        pub struct $name<
63            T: $crate::traits::Real,
64            V: $crate::traits::State<T>,
65            D: $crate::traits::CallBackData,
66        > {
67            // Initial Step Size
68            pub h0: T,
69            // Current Step Size
70            h: T,
71
72            // Current State
73            t: T,
74            y: V,
75            dydt: V, // Derivative at t
76
77            // Previous State
78            t_prev: T,
79            y_prev: V,
80            dydt_prev: V, // Derivative at t_prev
81
82            // Stage derivatives (k_i)
83            k: [V; $stages],
84            // Temporary storage for stage values during iteration
85            y_stage: [V; $stages],
86            f_at_stages: [V; $stages], // Stores f(t_stage, y_stage) during Newton iteration
87
88            // Constants from Butcher tableau (fixed size arrays)
89            a: [[T; $stages]; $stages],
90            b_higher: [T; $stages], // Primary weights (b)
91            b_lower: [T; $stages],  // Secondary weights (b_hat) for error estimation
92            c: [T; $stages],
93
94            // --- Adaptive Step Settings ---
95            pub rtol: T,
96            pub atol: T,
97            pub h_max: T,
98            pub h_min: T,
99            pub max_steps: usize,
100            pub max_rejects: usize,
101            pub safety_factor: T,
102            pub min_scale: T,
103            pub max_scale: T,
104
105            // --- Implicit Solver Settings ---
106            pub max_iter: usize, // Max iterations for Newton solver
107            pub tol: T,          // Tolerance for Newton solver convergence
108            fd_epsilon_sqrt: T, // Stores sqrt(machine_epsilon) for FD
109
110            // Iteration tracking & Status
111            reject: bool,
112            n_stiff: usize,
113            steps: usize,
114            status: $crate::Status<T, V, D>,
115
116            // --- Jacobian and Newton Solver Data ---
117            jacobian_matrix: nalgebra::DMatrix<T>, // Jacobian of f: J(t,y)
118            newton_matrix: nalgebra::DMatrix<T>,   // Matrix for Newton system (M)
119            rhs_newton: nalgebra::DVector<T>,      // RHS vector for Newton system (-phi)
120            delta_k_vec: nalgebra::DVector<T>,     // Solution of Newton system (delta_k)
121        }
122
123        impl<
124            T: $crate::traits::Real,
125            V: $crate::traits::State<T>,
126            D: $crate::traits::CallBackData,
127        > Default for $name<T, V, D> {
128            fn default() -> Self {
129                // Convert Butcher tableau values to type T
130                let a_t: [[T; $stages]; $stages] = $a.map(|row| row.map(|x| T::from_f64(x).unwrap()));
131                let b_higher_t: [T; $stages] = $b[0].map(|x| T::from_f64(x).unwrap());
132                let b_lower_t: [T; $stages] = $b[1].map(|x| T::from_f64(x).unwrap());
133                let c_t: [T; $stages] = $c.map(|x| T::from_f64(x).unwrap());
134
135                $name {
136                    h0: T::zero(), // Indicate auto-calculation
137                    h: T::zero(),
138                    t: T::zero(),
139                    y: V::zeros(),
140                    dydt: V::zeros(),
141                    t_prev: T::zero(),
142                    y_prev: V::zeros(),
143                    dydt_prev: V::zeros(),
144                    k: [V::zeros(); $stages],
145                    y_stage: [V::zeros(); $stages],
146                    f_at_stages: [V::zeros(); $stages],
147                    a: a_t,
148                    b_higher: b_higher_t,
149                    b_lower: b_lower_t,
150                    c: c_t,
151                    // Adaptive defaults
152                    rtol: T::from_f64(1.0e-6).unwrap(),
153                    atol: T::from_f64(1.0e-6).unwrap(),
154                    h_max: T::infinity(),
155                    h_min: T::zero(),
156                    max_steps: 10000,
157                    max_rejects: 100,
158                    safety_factor: T::from_f64(0.9).unwrap(),
159                    min_scale: T::from_f64(0.2).unwrap(),
160                    max_scale: T::from_f64(10.0).unwrap(),
161                    // Implicit defaults
162                    max_iter: 50,
163                    tol: T::from_f64(1e-8).unwrap(),
164                    fd_epsilon_sqrt: T::zero(),
165                    // Status
166                    reject: false,
167                    n_stiff: 0,
168                    steps: 0,
169                    status: $crate::Status::Uninitialized,
170                    // Initialize nalgebra structures (empty, to be sized in init)
171                    jacobian_matrix: nalgebra::DMatrix::zeros(0, 0),
172                    newton_matrix: nalgebra::DMatrix::zeros(0, 0),
173                    rhs_newton: nalgebra::DVector::zeros(0),
174                    delta_k_vec: nalgebra::DVector::zeros(0),
175                }
176            }
177        }
178
179        impl<
180            T: $crate::traits::Real,
181            V: $crate::traits::State<T>,
182            D: $crate::traits::CallBackData,
183        > $crate::ode::ODENumericalMethod<T, V, D> for $name<T, V, D> {
184            fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
185            where
186                F: $crate::ode::ODE<T, V, D>, // ODE trait now includes Jacobian
187            {
188                let mut evals = $crate::alias::Evals::new();
189
190                // Calculate initial derivative f(t0, y0)
191                let mut initial_dydt = V::zeros();
192                ode.diff(t0, y0, &mut initial_dydt);
193                evals.fcn += 1;
194
195                // If h0 is zero calculate h0 using initial derivative
196                if self.h0 == T::zero() {
197                    self.h0 = $crate::ode::methods::h_init(ode, t0, tf, y0, $order, self.rtol, self.atol, self.h_min, self.h_max);
198                }
199
200                // Check bounds
201                self.h = $crate::utils::validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf)?;
202
203                // Initialize Statistics
204                self.reject = false;
205                self.n_stiff = 0;
206                self.steps = 0;
207
208                // Initialize State
209                self.t = t0;
210                self.y = *y0;
211                self.dydt = initial_dydt; // Store f(t0, y0)
212
213                // Initialize previous state (same as current initially)
214                self.t_prev = t0;
215                self.y_prev = *y0;
216                self.dydt_prev = initial_dydt;
217
218                // Initialize fd_epsilon_sqrt
219                self.fd_epsilon_sqrt = T::default_epsilon().sqrt();
220
221                // Initialize Status
222                self.status = $crate::Status::Initialized;
223
224                // Initialize Jacobian and Newton-related matrices/vectors with correct dimensions
225                let dim = y0.len();
226                self.jacobian_matrix = nalgebra::DMatrix::zeros(dim, dim);
227                let newton_system_size = $stages * dim;
228                self.newton_matrix = nalgebra::DMatrix::zeros(newton_system_size, newton_system_size);
229                self.rhs_newton = nalgebra::DVector::zeros(newton_system_size);
230                self.delta_k_vec = nalgebra::DVector::zeros(newton_system_size);
231                self.f_at_stages = [V::zeros(); $stages];
232
233                Ok(evals)
234            }
235
236            fn step<F>(&mut self, ode: &F) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
237            where
238                F: $crate::ode::ODE<T, V, D>, // ODE trait now includes Jacobian
239            {
240                let mut evals = $crate::alias::Evals::new();
241                let dim = self.y.len();
242
243                // Check step size validity
244                if self.h.abs() < self.h_min || self.h.abs() < T::default_epsilon() {
245                    self.status = $crate::Status::Error($crate::Error::StepSize { t: self.t, y: self.y });
246                    return Err($crate::Error::StepSize { t: self.t, y: self.y });
247                }
248
249                // Check max steps
250                if self.steps >= self.max_steps {
251                    self.status = $crate::Status::Error($crate::Error::MaxSteps { t: self.t, y: self.y });
252                    return Err($crate::Error::MaxSteps { t: self.t, y: self.y });
253                }
254                self.steps += 1;
255
256                // --- Newton Iteration for stage derivatives k_i ---
257                // Initial guess for k_i: k_i^(0) = f(t_n, y_n) (stored in self.dydt)
258                for i in 0..$stages {
259                    self.k[i] = self.dydt;
260                }
261
262                // Calculate Jacobian J_n = df/dy(t_n, y_n) once per step attempt
263                ode.jacobian(self.t, &self.y, &mut self.jacobian_matrix);
264                evals.jac += 1;
265
266                let mut converged = false;
267                for _iter in 0..self.max_iter {
268                    // 1. Compute residual phi(K_current) and store -phi in rhs_newton
269                    for i in 0..$stages {
270                        self.y_stage[i] = self.y; // y_n
271                        for j in 0..$stages {
272                            self.y_stage[i] += self.k[j] * (self.a[i][j] * self.h);
273                        }
274
275                        ode.diff(self.t + self.c[i] * self.h, &self.y_stage[i], &mut self.f_at_stages[i]);
276                        evals.fcn += 1;
277
278                        for row_idx in 0..dim {
279                            self.rhs_newton[i * dim + row_idx] = self.f_at_stages[i].get(row_idx) - self.k[i].get(row_idx);
280                        }
281                    }
282
283                    // 2. Form Newton matrix M
284                    for i in 0..$stages { // block row index
285                        for l in 0..$stages { // block column index
286                            let scale_factor = -self.h * self.a[i][l];
287                            for r in 0..dim { // row index within the block
288                                for c_col in 0..dim { // column index within the block (renamed from c to avoid conflict)
289                                    // Direct assignment to the element in newton_matrix
290                                    self.newton_matrix[(i * dim + r, l * dim + c_col)] = 
291                                        self.jacobian_matrix[(r, c_col)] * scale_factor;
292                                }
293                            }
294
295                            if i == l { // If it's a diagonal block, add Identity
296                                for d_idx in 0..dim { // index for the diagonal of the block
297                                    self.newton_matrix[(i * dim + d_idx, l * dim + d_idx)] += T::one();
298                                }
299                            }
300                        }
301                    }
302
303                    // 3. Solve M * delta_k_vec = rhs_newton
304                    let lu_decomp = nalgebra::LU::new(self.newton_matrix.clone());
305                    if let Some(solution) = lu_decomp.solve(&self.rhs_newton) {
306                        self.delta_k_vec.copy_from(&solution);
307                    } else {
308                        converged = false;
309                        break;
310                    }
311
312                    // 4. Update K: self.k[i] += delta_k_vec_i
313                    let mut norm_delta_k_sq = T::zero();
314                    for i in 0..$stages {
315                        for row_idx in 0..dim {
316                            let delta_val = self.delta_k_vec[i * dim + row_idx];
317                            let current_val = self.k[i].get(row_idx);
318                            self.k[i].set(row_idx, current_val + delta_val);
319                            norm_delta_k_sq += delta_val * delta_val;
320                        }
321                    }
322
323                    // 5. Check convergence: ||delta_k_vec|| < self.tol
324                    if norm_delta_k_sq < self.tol * self.tol {
325                        converged = true;
326                        break;
327                    }
328                }
329
330                if !converged {
331                    self.h *= T::from_f64(0.25).unwrap();
332                    self.h = $crate::utils::constrain_step_size(self.h, self.h_min, self.h_max);
333                    self.reject = true;
334                    self.n_stiff += 1;
335
336                    if self.n_stiff >= self.max_rejects {
337                        self.status = $crate::Status::Error($crate::Error::Stiffness { t: self.t, y: self.y });
338                        return Err($crate::Error::Stiffness { t: self.t, y: self.y });
339                    }
340                    return Ok(evals);
341                }
342
343                // --- Iteration converged, compute solutions and error ---
344                let mut delta_y_high = V::zeros();
345                for i in 0..$stages {
346                    delta_y_high += self.k[i] * (self.b_higher[i] * self.h);
347                }
348                let y_high = self.y + delta_y_high;
349
350                let mut delta_y_low = V::zeros();
351                for i in 0..$stages {
352                    delta_y_low += self.k[i] * (self.b_lower[i] * self.h);
353                }
354                let y_low = self.y + delta_y_low;
355
356                let err = y_high - y_low;
357
358                let mut err_norm = T::zero();
359                for n in 0..self.y.len() {
360                    let scale = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
361                    if scale > T::zero() {
362                        err_norm = err_norm.max((err.get(n) / scale).abs());
363                    }
364                }
365                err_norm = err_norm.max(T::default_epsilon() * T::from_f64(100.0).unwrap());
366
367                let order_inv = T::one() / T::from_usize($order).unwrap();
368                let mut scale = self.safety_factor * err_norm.powf(-order_inv);
369                scale = scale.max(self.min_scale).min(self.max_scale);
370                let h_new = self.h * scale;
371
372                if err_norm <= T::one() {
373                    self.status = $crate::Status::Solving;
374
375                    self.t_prev = self.t;
376                    self.y_prev = self.y;
377                    self.dydt_prev = self.dydt;
378
379                    self.t += self.h;
380                    self.y = y_high;
381
382                    ode.diff(self.t, &self.y, &mut self.dydt);
383                    evals.fcn += 1;
384
385                    if self.reject {
386                        self.n_stiff = 0;
387                        self.reject = false;
388                    }
389
390                    self.h = $crate::utils::constrain_step_size(h_new, self.h_min, self.h_max);
391                } else {
392                    self.status = $crate::Status::RejectedStep;
393                    self.reject = true;
394                    self.n_stiff += 1;
395
396                    if self.n_stiff >= self.max_rejects {
397                        self.status = $crate::Status::Error($crate::Error::Stiffness { t: self.t, y: self.y });
398                        return Err($crate::Error::Stiffness { t: self.t, y: self.y });
399                    }
400
401                    self.h = $crate::utils::constrain_step_size(h_new, self.h_min, self.h_max);
402                    return Ok(evals);
403                }
404
405                Ok(evals)
406            }
407
408            fn t(&self) -> T { self.t }
409            fn y(&self) -> &V { &self.y }
410            fn t_prev(&self) -> T { self.t_prev }
411            fn y_prev(&self) -> &V { &self.y_prev }
412            fn h(&self) -> T { self.h }
413            fn set_h(&mut self, h: T) { self.h = h; }
414            fn status(&self) -> &$crate::Status<T, V, D> { &self.status }
415            fn set_status(&mut self, status: $crate::Status<T, V, D>) { self.status = status; }
416        }
417
418        impl<
419            T: $crate::traits::Real,
420            V: $crate::traits::State<T>,
421            D: $crate::traits::CallBackData,
422        > $crate::interpolate::Interpolation<T, V> for $name<T, V, D> {
423            fn interpolate(&mut self, t_interp: T) -> Result<V, $crate::Error<T, V>> {
424                if self.t == self.t_prev {
425                    if t_interp == self.t_prev {
426                        return Ok(self.y_prev);
427                    } else {
428                        return Err($crate::Error::OutOfBounds { t_interp, t_prev: self.t_prev, t_curr: self.t });
429                    }
430                }
431                if t_interp < self.t_prev || t_interp > self.t {
432                    return Err($crate::Error::OutOfBounds {
433                        t_interp,
434                        t_prev: self.t_prev,
435                        t_curr: self.t });
436                }
437
438                let y_interp = $crate::interpolate::cubic_hermite_interpolate(
439                    self.t_prev, self.t,
440                    &self.y_prev, &self.y,
441                    &self.dydt_prev, &self.dydt,
442                    t_interp
443                );
444
445                Ok(y_interp)
446            }
447        }
448
449// --- Builder Pattern Methods ---
450        impl<
451            T: $crate::traits::Real,
452            V: $crate::traits::State<T>,
453            D: $crate::traits::CallBackData,
454        > $name<T, V, D> {
455            pub fn new() -> Self {
456                Self::default()
457            }
458
459            pub fn h0(mut self, h0: T) -> Self { self.h0 = h0; self }
460            pub fn rtol(mut self, rtol: T) -> Self { self.rtol = rtol; self }
461            pub fn atol(mut self, atol: T) -> Self { self.atol = atol; self }
462            pub fn h_min(mut self, h_min: T) -> Self { self.h_min = h_min; self }
463            pub fn h_max(mut self, h_max: T) -> Self { self.h_max = h_max; self }
464            pub fn max_steps(mut self, max_steps: usize) -> Self { self.max_steps = max_steps; self }
465            pub fn max_rejects(mut self, max_rejects: usize) -> Self { self.max_rejects = max_rejects; self }
466            pub fn safety_factor(mut self, safety_factor: T) -> Self { self.safety_factor = safety_factor; self }
467            pub fn min_scale(mut self, min_scale: T) -> Self { self.min_scale = min_scale; self }
468            pub fn max_scale(mut self, max_scale: T) -> Self { self.max_scale = max_scale; self }
469            pub fn max_iter(mut self, iter: usize) -> Self { self.max_iter = iter; self }
470            pub fn tol(mut self, tol: T) -> Self { self.tol = tol; self }
471        }
472    };
473}
474
475const SQRT3: f64 = 1.732050808;
476const SQRT15: f64 = 3.872983346;
477
478adaptive_implicit_runge_kutta_method!(
479    /// Gauss-Legendre method of order 4.
480    ///
481    /// This is a 2-stage implicit Runge-Kutta method.
482    /// It is A-stable and self-adjoint.
483    /// The error estimation is based on the second 'b' row provided in the tableau,
484    /// which corresponds to simplifying order conditions rather than a standard
485    /// embedded lower-order method. Use with caution for adaptive stepping.
486    ///
487    /// Butcher Tableau:
488    /// ```text
489    /// c1 | a11 a12
490    /// c2 | a21 a22
491    /// -------------
492    ///    | b1  b2    (Order 4)
493    ///    | bh1 bh2   (Simplifying conditions)
494    ///
495    /// c1 = 1/2 - sqrt(3)/6, c2 = 1/2 + sqrt(3)/6
496    /// a11 = 1/4, a12 = 1/4 - sqrt(3)/6
497    /// a21 = 1/4 + sqrt(3)/6, a22 = 1/4
498    /// b1 = 1/2, b2 = 1/2
499    /// bh1 = 1/2 + sqrt(3)/2, bh2 = 1/2 - sqrt(3)/2
500    /// ```
501    name: GaussLegendre4,
502    a: [
503        [0.25, 0.25 - SQRT3 / 6.0],
504        [0.25 + SQRT3 / 6.0, 0.25]
505    ],
506    b: [
507        [0.5, 0.5],
508        [0.5 + SQRT3 / 2.0, 0.5 - SQRT3 / 2.0]
509    ],
510    c: [0.5 - SQRT3 / 6.0, 0.5 + SQRT3 / 6.0],
511    order: 4,
512    stages: 2
513);
514
515adaptive_implicit_runge_kutta_method!(
516    /// Gauss-Legendre method of order 6.
517    ///
518    /// This is a 3-stage implicit Runge-Kutta method.
519    /// It is A-stable and self-adjoint.
520    /// The error estimation is based on the second 'b' row provided in the tableau,
521    /// which corresponds to simplifying order conditions rather than a standard
522    /// embedded lower-order method. Use with caution for adaptive stepping.
523    ///
524    /// Butcher Tableau:
525    /// ```text
526    /// c1 | a11 a12 a13
527    /// c2 | a21 a22 a23
528    /// c3 | a31 a32 a33
529    /// -----------------
530    ///    | b1  b2  b3   (Order 6)
531    ///    | bh1 bh2 bh3  (Simplifying conditions)
532    ///
533    /// c1 = 1/2 - sqrt(15)/10, c2 = 1/2, c3 = 1/2 + sqrt(15)/10
534    /// a11 = 5/36, a12 = 2/9 - sqrt(15)/15, a13 = 5/36 - sqrt(15)/30
535    /// a21 = 5/36 + sqrt(15)/24, a22 = 2/9, a23 = 5/36 - sqrt(15)/24
536    /// a31 = 5/36 + sqrt(15)/30, a32 = 2/9 + sqrt(15)/15, a33 = 5/36
537    /// b1 = 5/18, b2 = 4/9, b3 = 5/18
538    /// bh1 = -5/6, bh2 = 8/3, bh3 = -5/6
539    /// ```
540    name: GaussLegendre6,
541    a: [
542        [5.0/36.0, 2.0/9.0 - SQRT15/15.0, 5.0/36.0 - SQRT15/30.0],
543        [5.0/36.0 + SQRT15/24.0, 2.0/9.0, 5.0/36.0 - SQRT15/24.0],
544        [5.0/36.0 + SQRT15/30.0, 2.0/9.0 + SQRT15/15.0, 5.0/36.0]
545    ],
546    b: [
547        [5.0/18.0, 4.0/9.0, 5.0/18.0],
548        [-5.0/6.0, 8.0/3.0, -5.0/6.0]
549    ],
550    c: [0.5 - SQRT15/10.0, 0.5, 0.5 + SQRT15/10.0],
551    order: 6,
552    stages: 3
553);