differential_equations/ode/methods/runge_kutta/implicit/fixed_step.rs
1//! Fixed-step implicit Runge-Kutta methods for solving ordinary differential equations.
2
3/// Macro to create a fixed-step implicit Runge-Kutta solver from a Butcher tableau.
4///
5/// This macro generates the necessary struct and trait implementations for a fixed-step
6/// implicit Runge-Kutta method. It uses a simple fixed-point iteration to solve the
7/// implicit stage equations.
8///
9/// # Arguments
10///
11/// * `name`: Name of the solver struct to create
12/// * `a`: Matrix of coefficients for intermediate stages (can be non-zero on diagonal/upper triangle)
13/// * `b`: Weights for final summation
14/// * `c`: Time offsets for each stage
15/// * `order`: Order of accuracy of the method
16/// * `stages`: Number of stages in the method
17///
18/// # Note on Solver
19/// The implicit stage equations `k_i = f(t_n + c_i*h, y_n + h * sum(a_{ij}*k_j))` are solved
20/// using fixed-point iteration. This is simple but may fail to converge for stiff problems
21/// unless `h` is sufficiently small (`h * L < 1`, where `L` is the Lipschitz constant).
22/// More robust solvers (like Newton's method) require Jacobians and linear algebra.
23///
24/// # Example
25/// ```
26/// use differential_equations::implicit_runge_kutta_method;
27///
28/// // Define Implicit Euler method
29/// implicit_runge_kutta_method!(
30/// /// Implicit Euler (Backward Euler) Method (1st Order)
31/// name: ImplicitEulerExample,
32/// a: [[1.0]],
33/// b: [1.0],
34/// c: [1.0],
35/// order: 1,
36/// stages: 1
37/// );
38/// ```
39#[macro_export]
40macro_rules! implicit_runge_kutta_method {
41 (
42 $(#[$attr:meta])*
43 name: $name:ident,
44 a: $a:expr,
45 b: $b:expr,
46 c: $c:expr,
47 order: $order:expr,
48 stages: $stages:expr
49 $(,)? // Optional trailing comma
50 ) => {
51
52 $(#[$attr])*
53 #[doc = "\n\n"]
54 #[doc = "This fixed-step implicit solver was automatically generated using the `implicit_runge_kutta_method` macro."]
55 #[doc = " It uses fixed-point iteration to solve the stage equations."]
56 pub struct $name<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> {
57 // Step Size
58 pub h: T,
59
60 // Current State
61 t: T,
62 y: V,
63
64 // Previous State
65 t_prev: T,
66 y_prev: V,
67 dydt_prev: V, // Derivative at t_prev
68
69 // Stage derivatives (k_i)
70 k: [V; $stages],
71 // Temporary storage for stage values during iteration
72 y_stage: [V; $stages],
73 k_new: [V; $stages],
74
75 // Constants from Butcher tableau (fixed size arrays)
76 a: [[T; $stages]; $stages],
77 b: [T; $stages],
78 c: [T; $stages],
79
80 // --- Solver Settings ---
81 pub max_iter: usize, // Max iterations for fixed-point solver
82 pub tol: T, // Tolerance for fixed-point solver convergence
83
84 // Status & Counters
85 status: $crate::Status<T, V, D>,
86 steps: usize,
87 }
88
89 impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> Default for $name<T, V, D> {
90 fn default() -> Self {
91 // Convert Butcher tableau values to type T
92 let a_t: [[T; $stages]; $stages] = $a.map(|row| row.map(|x| T::from_f64(x).unwrap()));
93 let b_t: [T; $stages] = $b.map(|x| T::from_f64(x).unwrap());
94 let c_t: [T; $stages] = $c.map(|x| T::from_f64(x).unwrap());
95
96 $name {
97 h: T::from_f64(0.01).unwrap(), // Default fixed step size
98 t: T::zero(),
99 y: V::zeros(),
100 t_prev: T::zero(),
101 y_prev: V::zeros(),
102 dydt_prev: V::zeros(),
103 k: [V::zeros(); $stages],
104 y_stage: [V::zeros(); $stages],
105 k_new: [V::zeros(); $stages],
106 a: a_t,
107 b: b_t,
108 c: c_t,
109 max_iter: 50, // Default max iterations
110 tol: T::from_f64(1e-8).unwrap(), // Default tolerance
111 status: $crate::Status::Uninitialized,
112 steps: 0,
113 }
114 }
115 }
116
117 impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $crate::ode::ODENumericalMethod<T, V, D> for $name<T, V, D> {
118 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
119 where
120 F: $crate::ode::ODE<T, V, D>
121 {
122 let mut evals = $crate::alias::Evals::new();
123
124 if self.h == T::zero() {
125 return Err($crate::Error::BadInput {
126 msg: concat!(stringify!($name), " requires a non-zero fixed step size 'h' to be set.").to_string(),
127 });
128 }
129 // Basic validation
130 self.h = $crate::utils::validate_step_size_parameters::<T, V, D>(self.h, T::zero(), T::infinity(), t0, tf)?;
131
132 // Initialize State
133 self.t = t0;
134 self.y = *y0;
135 self.t_prev = t0;
136 self.y_prev = *y0;
137
138 // Calculate initial derivative f(t0, y0) for interpolation
139 ode.diff(t0, y0, &mut self.dydt_prev);
140 evals.fcn += 1;
141
142 // Reset counters
143 self.steps = 0;
144
145 self.status = $crate::Status::Initialized;
146 Ok(evals)
147 }
148
149 fn step<F>(&mut self, ode: &F) -> Result<$crate::alias::Evals, $crate::Error<T, V>>
150 where
151 F: $crate::ode::ODE<T, V, D>
152 {
153 let mut evals = $crate::alias::Evals::new();
154
155 // --- Fixed-Point Iteration for stage derivatives k_i ---
156 // Initial guess: k_i^{(0)} = f(t_n, y_n) (stored in self.dydt_prev)
157 for i in 0..$stages {
158 self.k[i] = self.dydt_prev;
159 }
160
161 let mut converged = false;
162 for _iter in 0..self.max_iter {
163 let mut max_diff_sq = T::zero();
164
165 // Calculate next iteration k_i^{(m+1)} based on k_j^{(m)}
166 for i in 0..$stages {
167 // Calculate stage value y_stage = y_n + h * sum(a_ij * k_j^{(m)})
168 self.y_stage[i] = self.y;
169 for j in 0..$stages {
170 // Use current k values from this iteration
171 self.y_stage[i] += self.k[j] * (self.a[i][j] * self.h);
172 }
173
174 // Evaluate f at stage time and value: f(t_n + c_i*h, y_stage)
175 ode.diff(self.t + self.c[i] * self.h, &self.y_stage[i], &mut self.k_new[i]);
176 evals.fcn += 1;
177 }
178
179 // Check convergence: max ||k_new_i - k_i|| < tol
180 for i in 0..$stages {
181 let diff = self.k_new[i] - self.k[i];
182 let mut error_norm_sq = T::zero();
183 for idx in 0..diff.len() {
184 error_norm_sq += diff.get(idx) * diff.get(idx);
185 }
186 max_diff_sq = max_diff_sq.max(error_norm_sq);
187
188 // Update k for next iteration
189 self.k[i] = self.k_new[i];
190 }
191
192
193 if max_diff_sq.sqrt() < self.tol {
194 converged = true;
195 break;
196 }
197 } // End fixed-point iteration loop
198
199 if !converged {
200 self.status = $crate::Status::Error($crate::Error::StepSize { t: self.t, y: self.y });
201 return Err($crate::Error::StepSize { t: self.t, y: self.y });
202 }
203
204 // --- Iteration converged, compute final update ---
205 self.steps += 1;
206
207 // Store previous state
208 self.t_prev = self.t;
209 self.y_prev = self.y;
210 // Note: self.dydt_prev remains f(t_prev, y_prev)
211
212 // Compute the final update y_{n+1} = y_n + h * sum(b_i * k_i)
213 let mut delta_y = V::zeros();
214 for i in 0..$stages {
215 delta_y += self.k[i] * (self.b[i] * self.h);
216 }
217
218 // Update state
219 self.y += delta_y;
220 self.t += self.h;
221
222 // Calculate derivative at the new point for the *next* step's prediction
223 // and for interpolation purposes.
224 ode.diff(self.t, &self.y, &mut self.dydt_prev); // Store f(t_new, y_new) in dydt_prev for next step
225 evals.fcn += 1; // Count this evaluation
226
227 self.status = $crate::Status::Solving;
228 Ok(evals) // Return evals for this step
229 }
230
231 // --- Standard trait methods ---
232 fn t(&self) -> T { self.t }
233 fn y(&self) -> &V { &self.y }
234 fn t_prev(&self) -> T { self.t_prev }
235 fn y_prev(&self) -> &V { &self.y_prev }
236 fn h(&self) -> T { self.h }
237 fn set_h(&mut self, h: T) { self.h = h; }
238 fn status(&self) -> &$crate::Status<T, V, D> { &self.status }
239 fn set_status(&mut self, status: $crate::Status<T, V, D>) { self.status = status; }
240 }
241
242 impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $crate::interpolate::Interpolation<T, V> for $name<T, V, D> {
243 fn interpolate(&mut self, t_interp: T) -> Result<V, $crate::Error<T, V>> {
244 if self.t == self.t_prev { // Handle case before first step
245 if t_interp == self.t_prev {
246 return Ok(self.y_prev);
247 } else {
248 return Err($crate::Error::OutOfBounds { t_interp, t_prev: self.t_prev, t_curr: self.t });
249 }
250 }
251
252 // Check if t is within the bounds of the current step
253 if t_interp < self.t_prev || t_interp > self.t {
254 return Err($crate::Error::OutOfBounds {
255 t_interp,
256 t_prev: self.t_prev,
257 t_curr: self.t });
258 }
259
260 // Use cubic Hermite interpolation between (t_prev, y_prev, dydt_prev) and (t, y, k[0])
261 let y_interp = $crate::interpolate::cubic_hermite_interpolate(
262 self.t_prev, self.t,
263 &self.y_prev, &self.y,
264 &self.dydt_prev, &self.k[0],
265 t_interp
266 );
267
268 Ok(y_interp)
269 }
270 }
271
272 impl<T: $crate::traits::Real, V: $crate::traits::State<T>, D: $crate::traits::CallBackData> $name<T, V, D> {
273 /// Create a new solver instance with default settings.
274 pub fn new(h: T) -> Self {
275 $name {
276 h,
277 ..Default::default()
278 }
279 }
280
281 /// Set the fixed step size `h`.
282 pub fn h(mut self, h: T) -> Self {
283 self.h = h;
284 self
285 }
286
287 /// Set the maximum number of fixed-point iterations per step.
288 pub fn max_iter(mut self, iter: usize) -> Self {
289 self.max_iter = iter;
290 self
291 }
292
293 /// Set the tolerance for fixed-point iteration convergence.
294 pub fn tol(mut self, tol: T) -> Self {
295 self.tol = tol;
296 self
297 }
298 }
299 };
300}
301
302implicit_runge_kutta_method!(
303 /// Implicit Euler (Backward Euler) Method (1st Order)
304 ///
305 /// Solves `y_{n+1} = y_n + h * f(t_{n+1}, y_{n+1})`.
306 /// The Butcher Tableau is:
307 /// ```text
308 /// 1 | 1
309 /// -----
310 /// | 1
311 /// ```
312 name: BackwardEuler,
313 a: [[1.0]],
314 b: [1.0],
315 c: [1.0],
316 order: 1,
317 stages: 1
318);
319
320implicit_runge_kutta_method!(
321 /// Crank-Nicolson Method (Trapezoidal Rule) (2nd Order)
322 ///
323 /// Solves `y_{n+1} = y_n + 0.5*h * (f(t_n, y_n) + f(t_{n+1}, y_{n+1}))`.
324 /// This is often implemented as a 2-stage implicit method.
325 /// Stage 1: `k1 = f(t_n, y_n)` (explicit)
326 /// Stage 2: `k2 = f(t_{n+1}, y_n + 0.5*h*k1 + 0.5*h*k2)` (implicit)
327 /// Update: `y_{n+1} = y_n + 0.5*h*k1 + 0.5*h*k2`
328 /// The Butcher Tableau is:
329 /// ```text
330 /// 0 | 0 0
331 /// 1 | 1/2 1/2
332 /// --------------
333 /// | 1/2 1/2
334 /// ```
335 /// Note: The fixed-point solver in this macro solves for *all* stages simultaneously.
336 /// For Crank-Nicolson, k1 is explicit, but the solver treats it implicitly.
337 /// This works but is less efficient than a specialized implementation.
338 name: CrankNicolson,
339 a: [[0.0, 0.0],
340 [0.5, 0.5]],
341 b: [0.5, 0.5],
342 c: [0.0, 1.0],
343 order: 2,
344 stages: 2
345);