differential_equations/methods/irk/adaptive/
ordinary.rs

1//! Adaptive IRK for ODEs
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    linalg::Matrix,
7    methods::h_init::InitialStepSize,
8    methods::{Adaptive, ImplicitRungeKutta, Ordinary},
9    ode::{ODE, OrdinaryNumericalMethod},
10    stats::Evals,
11    status::Status,
12    traits::{Real, State},
13    utils::{constrain_step_size, validate_step_size_parameters},
14};
15
16impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
17    OrdinaryNumericalMethod<T, Y> for ImplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
18{
19    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
20    where
21        F: ODE<T, Y>,
22    {
23        let mut evals = Evals::new();
24
25        // Compute h0 if not set
26        if self.h0 == T::zero() {
27            // Implicit initial step size heuristic
28            self.h0 = InitialStepSize::<Ordinary>::compute(
29                ode, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max,
30                &mut evals,
31            );
32        }
33
34        // Validate step size bounds
35        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
36            Ok(h0) => self.h = h0,
37            Err(status) => return Err(status),
38        }
39
40        // Stats
41        self.stiffness_counter = 0;
42        self.newton_iterations = 0;
43        self.jacobian_evaluations = 0;
44        self.lu_decompositions = 0;
45
46        // State
47        self.t = t0;
48        self.y = *y0;
49        ode.diff(self.t, &self.y, &mut self.dydt);
50        evals.function += 1;
51
52        // Previous state
53        self.t_prev = self.t;
54        self.y_prev = self.y;
55        self.dydt_prev = self.dydt;
56
57        // Linear algebra workspace
58        let dim = y0.len();
59        let newton_system_size = self.stages * dim;
60        self.stage_jacobians = core::array::from_fn(|_| Matrix::zeros(dim, dim));
61        self.newton_matrix = Matrix::zeros(newton_system_size, newton_system_size);
62        // Use State<T> storage for RHS and solution vectors
63        self.rhs_newton = vec![T::zero(); newton_system_size];
64        self.delta_k_vec = vec![T::zero(); newton_system_size];
65        self.jacobian_age = 0;
66
67        // Status
68        self.status = Status::Initialized;
69
70        Ok(evals)
71    }
72
73    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
74    where
75        F: ODE<T, Y>,
76    {
77        let mut evals = Evals::new();
78
79        // Step size guard
80        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
81            self.status = Status::Error(Error::StepSize {
82                t: self.t,
83                y: self.y,
84            });
85            return Err(Error::StepSize {
86                t: self.t,
87                y: self.y,
88            });
89        }
90
91        // Max steps guard
92        if self.steps >= self.max_steps {
93            self.status = Status::Error(Error::MaxSteps {
94                t: self.t,
95                y: self.y,
96            });
97            return Err(Error::MaxSteps {
98                t: self.t,
99                y: self.y,
100            });
101        }
102        self.steps += 1;
103
104        // Initial stage guesses: copy current state
105        let dim = self.y.len();
106        for i in 0..self.stages {
107            self.z[i] = self.y;
108        }
109
110        // Newton solve for F(z) = z - y_n - h*A*f(z) = 0
111        let mut newton_converged = false;
112        let mut newton_iter = 0;
113
114        // Track increment norm
115        let mut increment_norm = T::infinity();
116
117        while !newton_converged && newton_iter < self.max_newton_iter {
118            newton_iter += 1;
119            self.newton_iterations += 1;
120            evals.newton += 1;
121
122            // Evaluate f at stage guesses
123            for i in 0..self.stages {
124                ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
125            }
126            evals.function += self.stages;
127
128            // Residual and max-norm
129            let mut residual_norm = T::zero();
130            for i in 0..self.stages {
131                // Start with z_i - y_n
132                let mut residual = self.z[i] - self.y;
133
134                // Subtract h*sum(a_ij * f_j)
135                for j in 0..self.stages {
136                    residual = residual - self.k[j] * (self.a[i][j] * self.h);
137                }
138
139                // Infinity norm and RHS
140                for row_idx in 0..dim {
141                    let res_val = residual.get(row_idx);
142                    residual_norm = residual_norm.max(res_val.abs());
143                    // Store residual in Newton RHS (negative for solving delta_z)
144                    self.rhs_newton[i * dim + row_idx] = -res_val;
145                }
146            }
147
148            // Converged by residual
149            if residual_norm < self.newton_tol {
150                newton_converged = true;
151                break;
152            }
153
154            // Converged by increment
155            if newton_iter > 1 && increment_norm < self.newton_tol {
156                newton_converged = true;
157                break;
158            }
159
160            // Refresh Jacobians if needed
161            if newton_iter == 1 || self.jacobian_age > 3 {
162                // Stage Jacobians
163                for i in 0..self.stages {
164                    ode.jacobian(
165                        self.t + self.c[i] * self.h,
166                        &self.z[i],
167                        &mut self.stage_jacobians[i],
168                    );
169                    evals.jacobian += 1;
170                }
171
172                // Build Newton matrix: I - h*(A ⊗ J)
173                // Zero the block Newton matrix (ensure Full storage)
174                let nsys = self.stages * dim;
175                let mut nm = Matrix::zeros(nsys, nsys);
176                // Fill blocks
177                for i in 0..self.stages {
178                    for j in 0..self.stages {
179                        let scale = -self.h * self.a[i][j];
180                        for r in 0..dim {
181                            for c_col in 0..dim {
182                                nm[(i * dim + r, j * dim + c_col)] =
183                                    self.stage_jacobians[j][(r, c_col)] * scale;
184                            }
185                        }
186                    }
187                    // Add identity on block diagonal
188                    for d_idx in 0..dim {
189                        let idx = i * dim + d_idx;
190                        nm[(idx, idx)] += T::one();
191                    }
192                }
193                self.newton_matrix = nm;
194
195                self.jacobian_age = 0;
196            }
197            self.jacobian_age += 1;
198
199            // Solve (I - h*A⊗J) Δz = -F(z) using our LU in-place over a flat slice
200            let mut rhs = self.rhs_newton.clone();
201            self.newton_matrix.lin_solve_mut(&mut rhs[..]);
202            for i in 0..self.delta_k_vec.len() {
203                self.delta_k_vec[i] = rhs[i];
204            }
205            self.lu_decompositions += 1;
206
207            // Update z_i and increment norm
208            increment_norm = T::zero();
209            for i in 0..self.stages {
210                for row_idx in 0..dim {
211                    let delta_val = self.delta_k_vec[i * dim + row_idx];
212                    let current_val = self.z[i].get(row_idx);
213                    self.z[i].set(row_idx, current_val + delta_val);
214                    // Calculate infinity norm of increment
215                    increment_norm = increment_norm.max(delta_val.abs());
216                }
217            }
218
219            // Next loop will re-check
220        }
221
222        // Newton failed to converge
223        if !newton_converged {
224            // Reduce h and retry later
225            self.h *= T::from_f64(0.25).unwrap();
226            self.h = constrain_step_size(self.h, self.h_min, self.h_max);
227            self.status = Status::RejectedStep;
228            self.stiffness_counter += 1;
229
230            if self.stiffness_counter >= self.max_rejects {
231                self.status = Status::Error(Error::Stiffness {
232                    t: self.t,
233                    y: self.y,
234                });
235                return Err(Error::Stiffness {
236                    t: self.t,
237                    y: self.y,
238                });
239            }
240            return Ok(evals);
241        }
242
243        // Final stage derivatives
244        for i in 0..self.stages {
245            ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
246        }
247        evals.function += self.stages;
248
249        // y_{n+1} = y_n + h Σ b_i f_i
250        let mut y_new = self.y;
251        for i in 0..self.stages {
252            y_new += self.k[i] * (self.b[i] * self.h);
253        }
254
255        // Embedded error estimate (bh)
256        let mut err_norm = T::zero();
257        let bh = &self.bh.unwrap();
258
259        // Lower-order solution
260        let mut y_low = self.y;
261        for i in 0..self.stages {
262            y_low += self.k[i] * (bh[i] * self.h);
263        }
264
265        // err = y_high - y_low
266        let err = y_new - y_low;
267
268        // Weighted max-norm
269        for n in 0..self.y.len() {
270            let scale = self.atol[n] + self.rtol[n] * self.y.get(n).abs().max(y_new.get(n).abs());
271            if scale > T::zero() {
272                err_norm = err_norm.max((err.get(n) / scale).abs());
273            }
274        }
275
276        // Avoid vanishing error
277        err_norm = err_norm.max(T::default_epsilon() * T::from_f64(100.0).unwrap());
278
279        // Step scale factor
280        let order = T::from_usize(self.order).unwrap();
281        let error_exponent = T::one() / order;
282        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
283
284        // Clamp scale factor
285        scale = scale.max(self.min_scale).min(self.max_scale);
286
287        // Accept/reject
288        if err_norm <= T::one() {
289            // Accepted
290            self.status = Status::Solving;
291
292            // Log previous
293            self.t_prev = self.t;
294            self.y_prev = self.y;
295            self.dydt_prev = self.dydt;
296            self.h_prev = self.h;
297
298            // Advance state
299            self.t += self.h;
300            self.y = y_new;
301
302            // Next-step derivative
303            ode.diff(self.t, &self.y, &mut self.dydt);
304            evals.function += 1;
305
306            // If we were rejecting, limit growth
307            if let Status::RejectedStep = self.status {
308                self.stiffness_counter = 0;
309
310                // Avoid oscillations
311                scale = scale.min(T::one());
312            }
313        } else {
314            // Rejected
315            self.status = Status::RejectedStep;
316            self.stiffness_counter += 1;
317
318            // Too many rejections
319            if self.stiffness_counter >= self.max_rejects {
320                self.status = Status::Error(Error::Stiffness {
321                    t: self.t,
322                    y: self.y,
323                });
324                return Err(Error::Stiffness {
325                    t: self.t,
326                    y: self.y,
327                });
328            }
329        }
330
331        // Update h
332        self.h *= scale;
333
334        // Constrain h
335        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
336
337        Ok(evals)
338    }
339
340    fn t(&self) -> T {
341        self.t
342    }
343    fn y(&self) -> &Y {
344        &self.y
345    }
346    fn t_prev(&self) -> T {
347        self.t_prev
348    }
349    fn y_prev(&self) -> &Y {
350        &self.y_prev
351    }
352    fn h(&self) -> T {
353        self.h
354    }
355    fn set_h(&mut self, h: T) {
356        self.h = h;
357    }
358    fn status(&self) -> &Status<T, Y> {
359        &self.status
360    }
361    fn set_status(&mut self, status: Status<T, Y>) {
362        self.status = status;
363    }
364}
365
366impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
367    for ImplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
368{
369    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
370        // Check if t is within bounds
371        if t_interp < self.t_prev || t_interp > self.t {
372            return Err(Error::OutOfBounds {
373                t_interp,
374                t_prev: self.t_prev,
375                t_curr: self.t,
376            });
377        }
378
379        // Use cubic Hermite interpolation
380        let y_interp = cubic_hermite_interpolate(
381            self.t_prev,
382            self.t,
383            &self.y_prev,
384            &self.y,
385            &self.dydt_prev,
386            &self.dydt,
387            t_interp,
388        );
389
390        Ok(y_interp)
391    }
392}