differential_equations/ode/methods/runge_kutta/implicit/
fixed_step.rs

1//! Fixed-step implicit Runge-Kutta methods for solving ordinary differential equations.
2
3/// Macro to create a fixed-step implicit Runge-Kutta solver from a Butcher tableau.
4///
5/// This macro generates the necessary struct and trait implementations for a fixed-step
6/// implicit Runge-Kutta method. It uses a simple fixed-point iteration to solve the
7/// implicit stage equations.
8///
9/// # Arguments
10///
11/// * `name`: Name of the solver struct to create
12/// * `a`: Matrix of coefficients for intermediate stages (can be non-zero on diagonal/upper triangle)
13/// * `b`: Weights for final summation
14/// * `c`: Time offsets for each stage
15/// * `order`: Order of accuracy of the method
16/// * `stages`: Number of stages in the method
17///
18/// # Note on Solver
19/// The implicit stage equations `k_i = f(t_n + c_i*h, y_n + h * sum(a_{ij}*k_j))` are solved
20/// using fixed-point iteration. This is simple but may fail to converge for stiff problems
21/// unless `h` is sufficiently small (`h * L < 1`, where `L` is the Lipschitz constant).
22/// More robust solvers (like Newton's method) require Jacobians and linear algebra.
23///
24/// # Example
25/// ```
26/// use differential_equations::implicit_runge_kutta_method;
27///
28/// // Define Implicit Euler method
29/// implicit_runge_kutta_method!(
30///     /// Implicit Euler (Backward Euler) Method (1st Order)
31///     name: ImplicitEulerExample,
32///     a: [[1.0]],
33///     b: [1.0],
34///     c: [1.0],
35///     order: 1,
36///     stages: 1
37/// );
38/// ```
39#[macro_export]
40macro_rules! implicit_runge_kutta_method {
41    (
42        $(#[$attr:meta])*
43        name: $name:ident,
44        a: $a:expr,
45        b: $b:expr,
46        c: $c:expr,
47        order: $order:expr,
48        stages: $stages:expr
49        $(,)? // Optional trailing comma
50    ) => {
51
52        $(#[$attr])*
53        #[doc = "\n\n"]
54        #[doc = "This fixed-step implicit solver was automatically generated using the `implicit_runge_kutta_method` macro."]
55        #[doc = " It uses fixed-point iteration to solve the stage equations."]
56        pub struct $name<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> {
57            // Step Size
58            pub h: T,
59
60            // Current State
61            t: T,
62            y: V,
63
64            // Previous State
65            t_prev: T,
66            y_prev: V,
67            dydt_prev: V, // Derivative at t_prev
68
69            // Stage derivatives (k_i)
70            k: [V; $stages],
71            // Temporary storage for stage values during iteration
72            y_stage: [V; $stages],
73            k_new: [V; $stages],
74
75            // Constants from Butcher tableau (fixed size arrays)
76            a: [[T; $stages]; $stages],
77            b: [T; $stages],
78            c: [T; $stages],
79
80            // --- Solver Settings ---
81            pub max_iter: usize, // Max iterations for fixed-point solver
82            pub tol: T,          // Tolerance for fixed-point solver convergence
83
84            // Status & Counters
85            status: $crate::Status<T, V, D>,
86            steps: usize,
87        }
88
89        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> Default for $name<T, V, D> {
90            fn default() -> Self {
91                // Convert Butcher tableau values to type T
92                let a_t: [[T; $stages]; $stages] = $a.map(|row| row.map(|x| T::from_f64(x).unwrap()));
93                let b_t: [T; $stages] = $b.map(|x| T::from_f64(x).unwrap());
94                let c_t: [T; $stages] = $c.map(|x| T::from_f64(x).unwrap());
95
96                $name {
97                    h: T::from_f64(0.01).unwrap(), // Default fixed step size
98                    t: T::zero(),
99                    y: V::zeros(),
100                    t_prev: T::zero(),
101                    y_prev: V::zeros(),
102                    dydt_prev: V::zeros(),
103                    k: [V::zeros(); $stages],
104                    y_stage: [V::zeros(); $stages],
105                    k_new: [V::zeros(); $stages],
106                    a: a_t,
107                    b: b_t,
108                    c: c_t,
109                    max_iter: 50, // Default max iterations
110                    tol: T::from_f64(1e-8).unwrap(), // Default tolerance
111                    status: $crate::Status::Uninitialized,
112                    steps: 0,
113                }
114            }
115        }
116
117        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $crate::ode::ODENumericalMethod<T, V, D> for $name<T, V, D> {
118            fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
119            where
120                F: $crate::ode::ODE<T, V, D>
121            {
122                let mut evals = $crate::alias::Evals::new();
123
124                 if self.h == T::zero() {
125                    return Err($crate::Error::BadInput {
126                        msg: concat!(stringify!($name), " requires a non-zero fixed step size 'h' to be set.").to_string(),
127                    });
128                }
129                // Basic validation
130                self.h = $crate::utils::validate_step_size_parameters::<T, V, D>(self.h, T::zero(), T::infinity(), t0, tf)?;
131
132                // Initialize State
133                self.t = t0;
134                self.y = *y0;
135                self.t_prev = t0;
136                self.y_prev = *y0;
137
138                // Calculate initial derivative f(t0, y0) for interpolation
139                ode.diff(t0, y0, &mut self.dydt_prev);
140                evals.fcn += 1;
141
142                // Reset counters
143                self.steps = 0;
144
145                self.status = $crate::Status::Initialized;
146                Ok(evals)
147            }
148
149            fn step<F>(&mut self, ode: &F) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
150            where
151                F: $crate::ode::ODE<T, V, D>
152            {
153                let mut evals = $crate::alias::Evals::new();
154
155                // --- Fixed-Point Iteration for stage derivatives k_i ---
156                // Initial guess: k_i^{(0)} = f(t_n, y_n) (stored in self.dydt_prev)
157                for i in 0..$stages {
158                    self.k[i] = self.dydt_prev;
159                }
160
161                let mut converged = false;
162                for _iter in 0..self.max_iter {
163                    let mut max_diff_sq = T::zero();
164
165                    // Calculate next iteration k_i^{(m+1)} based on k_j^{(m)}
166                    for i in 0..$stages {
167                        // Calculate stage value y_stage = y_n + h * sum(a_ij * k_j^{(m)})
168                        self.y_stage[i] = self.y;
169                        for j in 0..$stages {
170                            // Use current k values from this iteration
171                            self.y_stage[i] += self.k[j] * (self.a[i][j] * self.h);
172                        }
173
174                        // Evaluate f at stage time and value: f(t_n + c_i*h, y_stage)
175                        ode.diff(self.t + self.c[i] * self.h, &self.y_stage[i], &mut self.k_new[i]);
176                        evals.fcn += 1;
177                    }
178
179                    // Check convergence: max ||k_new_i - k_i|| < tol
180                    for i in 0..$stages {
181                        let diff = self.k_new[i] - self.k[i];
182                        let mut error_norm_sq = T::zero();
183                        for idx in 0..diff.len() {
184                            error_norm_sq += diff.get(idx) * diff.get(idx);
185                        }
186                        max_diff_sq = max_diff_sq.max(error_norm_sq);
187
188                        // Update k for next iteration
189                        self.k[i] = self.k_new[i];
190                    }
191
192
193                    if max_diff_sq.sqrt() < self.tol {
194                        converged = true;
195                        break;
196                    }
197                } // End fixed-point iteration loop
198
199                if !converged {
200                    self.status = $crate::Status::Error($crate::Error::StepSize { t: self.t, y: self.y });
201                    return Err($crate::Error::StepSize { t: self.t, y: self.y });
202                }
203
204                // --- Iteration converged, compute final update ---
205                self.steps += 1;
206
207                // Store previous state
208                self.t_prev = self.t;
209                self.y_prev = self.y;
210                // Note: self.dydt_prev remains f(t_prev, y_prev)
211
212                // Compute the final update y_{n+1} = y_n + h * sum(b_i * k_i)
213                let mut delta_y = V::zeros();
214                for i in 0..$stages {
215                    delta_y += self.k[i] * (self.b[i] * self.h);
216                }
217
218                // Update state
219                self.y += delta_y;
220                self.t += self.h;
221
222                // Calculate derivative at the new point for the *next* step's prediction
223                // and for interpolation purposes.
224                ode.diff(self.t, &self.y, &mut self.dydt_prev); // Store f(t_new, y_new) in dydt_prev for next step
225                evals.fcn += 1; // Count this evaluation
226
227                self.status = $crate::Status::Solving;
228                Ok(evals) // Return evals for this step
229            }
230
231            // --- Standard trait methods ---
232            fn t(&self) -> T { self.t }
233            fn y(&self) -> &V { &self.y }
234            fn t_prev(&self) -> T { self.t_prev }
235            fn y_prev(&self) -> &V { &self.y_prev }
236            fn h(&self) -> T { self.h }
237            fn set_h(&mut self, h: T) { self.h = h; }
238            fn status(&self) -> &$crate::Status<T, V, D> { &self.status }
239            fn set_status(&mut self, status: $crate::Status<T, V, D>) { self.status = status; }
240        }
241
242        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $crate::interpolate::Interpolation<T, V> for $name<T, V, D> {
243            fn interpolate(&mut self, t_interp: T) -> Result<V, $crate::Error<T, V>> {
244                 if self.t == self.t_prev { // Handle case before first step
245                     if t_interp == self.t_prev {
246                         return Ok(self.y_prev);
247                     } else {
248                         return Err($crate::Error::OutOfBounds { t_interp, t_prev: self.t_prev, t_curr: self.t });
249                     }
250                 }
251
252                // Check if t is within the bounds of the current step
253                if t_interp < self.t_prev || t_interp > self.t {
254                    return Err($crate::Error::OutOfBounds {
255                        t_interp,
256                        t_prev: self.t_prev,
257                        t_curr: self.t });
258                }
259
260                // Use cubic Hermite interpolation between (t_prev, y_prev, dydt_prev) and (t, y, k[0])
261                let y_interp = $crate::interpolate::cubic_hermite_interpolate(
262                    self.t_prev, self.t,
263                    &self.y_prev, &self.y,
264                    &self.dydt_prev, &self.k[0],
265                    t_interp
266                );
267
268                Ok(y_interp)
269            }
270        }
271
272        impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $name<T, V, D> {
273            /// Create a new solver instance with default settings.
274            pub fn new(h: T) -> Self {
275                $name {
276                    h,
277                    ..Default::default()
278                }
279            }
280
281            /// Set the fixed step size `h`.
282            pub fn h(mut self, h: T) -> Self {
283                self.h = h;
284                self
285            }
286
287            /// Set the maximum number of fixed-point iterations per step.
288            pub fn max_iter(mut self, iter: usize) -> Self {
289                self.max_iter = iter;
290                self
291            }
292
293            /// Set the tolerance for fixed-point iteration convergence.
294            pub fn tol(mut self, tol: T) -> Self {
295                self.tol = tol;
296                self
297            }
298        }
299    };
300}
301
302implicit_runge_kutta_method!(
303    /// Implicit Euler (Backward Euler) Method (1st Order)
304    ///
305    /// Solves `y_{n+1} = y_n + h * f(t_{n+1}, y_{n+1})`.
306    /// The Butcher Tableau is:
307    /// ```text
308    /// 1 | 1
309    /// -----
310    ///   | 1
311    /// ```
312    name: BackwardEuler,
313    a: [[1.0]],
314    b: [1.0],
315    c: [1.0],
316    order: 1,
317    stages: 1
318);
319
320implicit_runge_kutta_method!(
321    /// Crank-Nicolson Method (Trapezoidal Rule) (2nd Order)
322    ///
323    /// Solves `y_{n+1} = y_n + 0.5*h * (f(t_n, y_n) + f(t_{n+1}, y_{n+1}))`.
324    /// This is often implemented as a 2-stage implicit method.
325    /// Stage 1: `k1 = f(t_n, y_n)` (explicit)
326    /// Stage 2: `k2 = f(t_{n+1}, y_n + 0.5*h*k1 + 0.5*h*k2)` (implicit)
327    /// Update: `y_{n+1} = y_n + 0.5*h*k1 + 0.5*h*k2`
328    /// The Butcher Tableau is:
329    /// ```text
330    /// 0   | 0   0
331    /// 1   | 1/2 1/2
332    /// --------------
333    ///     | 1/2 1/2
334    /// ```
335    /// Note: The fixed-point solver in this macro solves for *all* stages simultaneously.
336    /// For Crank-Nicolson, k1 is explicit, but the solver treats it implicitly.
337    /// This works but is less efficient than a specialized implementation.
338    name: CrankNicolson,
339    a: [[0.0, 0.0],
340        [0.5, 0.5]],
341    b: [0.5, 0.5],
342    c: [0.0, 1.0],
343    order: 2,
344    stages: 2
345);