Skip to main content

differential_equations/methods/irk/adaptive/
ordinary.rs

1//! Adaptive IRK for ODEs
2use crate::{
3    error::Error,
4    interpolate::{Interpolation, cubic_hermite_interpolate},
5    linalg::Matrix,
6    methods::h_init::InitialStepSize,
7    methods::{Adaptive, ImplicitRungeKutta, Ordinary},
8    ode::{ODE, OrdinaryNumericalMethod},
9    stats::Evals,
10    status::Status,
11    traits::{Real, State},
12    utils::{constrain_step_size, validate_step_size_parameters},
13};
14
15impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
16    OrdinaryNumericalMethod<T, Y> for ImplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
17{
18    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19    where
20        F: ODE<T, Y> + ?Sized,
21    {
22        let mut evals = Evals::new();
23
24        // Compute h0 if not set
25        if self.h0 == T::zero() {
26            // Implicit initial step size heuristic
27            self.h0 = InitialStepSize::<Ordinary>::compute(
28                ode, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max,
29                &mut evals,
30            );
31        }
32
33        // Validate step size bounds
34        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
35            Ok(h0) => self.h = (self.filter)(h0),
36            Err(status) => return Err(status),
37        }
38
39        // Stats
40        self.stiffness_counter = 0;
41        self.newton_iterations = 0;
42        self.jacobian_evaluations = 0;
43        self.lu_decompositions = 0;
44
45        // State
46        self.t = t0;
47        self.y = y0.clone();
48        self.dydt = y0.zeros_like();
49        self.y_prev = y0.clone();
50        self.dydt_prev = y0.zeros_like();
51        self.k = core::array::from_fn(|_| y0.zeros_like());
52        self.z = core::array::from_fn(|_| y0.zeros_like());
53        ode.diff(self.t, &self.y, &mut self.dydt);
54        evals.function += 1;
55
56        // Previous state
57        self.t_prev = self.t;
58        self.y_prev = self.y.clone();
59        self.dydt_prev = self.dydt.clone();
60
61        // Linear algebra workspace
62        let dim = y0.len();
63        let newton_system_size = self.stages * dim;
64        self.stage_jacobians = core::array::from_fn(|_| Matrix::zeros(dim, dim));
65        self.newton_matrix = Matrix::zeros(newton_system_size, newton_system_size);
66        // Use State<T> storage for RHS and solution vectors
67        self.rhs_newton = vec![T::zero(); newton_system_size];
68        self.delta_k_vec = vec![T::zero(); newton_system_size];
69        self.jacobian_age = 0;
70
71        // Status
72        self.status = Status::Initialized;
73
74        Ok(evals)
75    }
76
77    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
78    where
79        F: ODE<T, Y> + ?Sized,
80    {
81        let mut evals = Evals::new();
82
83        // Step size guard
84        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
85            self.status = Status::Error(Error::StepSize {
86                t: self.t,
87                y: self.y.clone(),
88            });
89            return Err(Error::StepSize {
90                t: self.t,
91                y: self.y.clone(),
92            });
93        }
94
95        // Max steps guard
96        if self.steps >= self.max_steps {
97            self.status = Status::Error(Error::MaxSteps {
98                t: self.t,
99                y: self.y.clone(),
100            });
101            return Err(Error::MaxSteps {
102                t: self.t,
103                y: self.y.clone(),
104            });
105        }
106        self.steps += 1;
107
108        // Initial stage guesses: copy current state
109        let dim = self.y.len();
110        for i in 0..self.stages {
111            self.z[i] = self.y.clone();
112        }
113
114        // Newton solve for F(z) = z - y_n - h*A*f(z) = 0
115        let mut newton_converged = false;
116        let mut newton_iter = 0;
117
118        // Track increment norm
119        let mut increment_norm = T::infinity();
120
121        while !newton_converged && newton_iter < self.max_newton_iter {
122            newton_iter += 1;
123            self.newton_iterations += 1;
124            evals.newton += 1;
125
126            // Evaluate f at stage guesses
127            for i in 0..self.stages {
128                ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
129            }
130            evals.function += self.stages;
131
132            // Residual and max-norm
133            let mut residual_norm = T::zero();
134            for i in 0..self.stages {
135                // Start with z_i - y_n
136                let mut residual = self.z[i].minus(&self.y);
137
138                // Subtract h*sum(a_ij * f_j)
139                for j in 0..self.stages {
140                    residual.add_scaled(-(self.a[i][j] * self.h), &self.k[j]);
141                }
142
143                // Infinity norm and RHS
144                for row_idx in 0..dim {
145                    let res_val = residual.get_component(row_idx);
146                    residual_norm = residual_norm.max(res_val.abs());
147                    // Store residual in Newton RHS (negative for solving delta_z)
148                    self.rhs_newton[i * dim + row_idx] = -res_val;
149                }
150            }
151
152            // Converged by residual
153            if residual_norm < self.newton_tol {
154                newton_converged = true;
155                break;
156            }
157
158            // Converged by increment
159            if newton_iter > 1 && increment_norm < self.newton_tol {
160                newton_converged = true;
161                break;
162            }
163
164            // Refresh Jacobians if needed
165            if newton_iter == 1 || self.jacobian_age > 3 {
166                // Stage Jacobians
167                for i in 0..self.stages {
168                    ode.jacobian(
169                        self.t + self.c[i] * self.h,
170                        &self.z[i],
171                        &mut self.stage_jacobians[i],
172                    );
173                    evals.jacobian += 1;
174                }
175
176                // Build Newton matrix: I - h*(A ⊗ J)
177                // Zero the block Newton matrix (ensure Full storage)
178                let nsys = self.stages * dim;
179                let mut nm = Matrix::zeros(nsys, nsys);
180                // Fill blocks
181                for i in 0..self.stages {
182                    for j in 0..self.stages {
183                        let scale = -self.h * self.a[i][j];
184                        for r in 0..dim {
185                            for c_col in 0..dim {
186                                nm[(i * dim + r, j * dim + c_col)] =
187                                    self.stage_jacobians[j][(r, c_col)] * scale;
188                            }
189                        }
190                    }
191                    // Add identity on block diagonal
192                    for d_idx in 0..dim {
193                        let idx = i * dim + d_idx;
194                        nm[(idx, idx)] += T::one();
195                    }
196                }
197                self.newton_matrix = nm;
198
199                self.jacobian_age = 0;
200            }
201            self.jacobian_age += 1;
202
203            // Solve (I - h*A⊗J) Δz = -F(z) using our LU in-place over a flat slice
204            for i in 0..self.delta_k_vec.len() {
205                self.delta_k_vec[i] = self.rhs_newton[i];
206            }
207            self.newton_matrix
208                .lin_solve_mut(&mut self.delta_k_vec[..])
209                .map_err(|e| crate::error::Error::LinearAlgebra {
210                    t: self.t,
211                    y: self.y.clone(),
212                    msg: e.to_string(),
213                })?;
214            self.lu_decompositions += 1;
215
216            // Update z_i and increment norm
217            increment_norm = T::zero();
218            for i in 0..self.stages {
219                for row_idx in 0..dim {
220                    let delta_val = self.delta_k_vec[i * dim + row_idx];
221                    let current_z = self.z[i].get_component(row_idx);
222                    self.z[i].set_component(row_idx, current_z + delta_val);
223                    // Calculate infinity norm of increment
224                    increment_norm = increment_norm.max(delta_val.abs());
225                }
226            }
227
228            // Next loop will re-check
229        }
230
231        // Newton failed to converge
232        if !newton_converged {
233            // Reduce h and retry later
234            self.h *= T::from_f64(0.25).unwrap();
235            self.h = constrain_step_size(self.h, self.h_min, self.h_max);
236            self.h = (self.filter)(self.h);
237            self.status = Status::RejectedStep;
238            self.stiffness_counter += 1;
239
240            if self.stiffness_counter >= self.max_rejects {
241                self.status = Status::Error(Error::Stiffness {
242                    t: self.t,
243                    y: self.y.clone(),
244                });
245                return Err(Error::Stiffness {
246                    t: self.t,
247                    y: self.y.clone(),
248                });
249            }
250            return Ok(evals);
251        }
252
253        // Final stage derivatives
254        for i in 0..self.stages {
255            ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
256        }
257        evals.function += self.stages;
258
259        // y_{n+1} = y_n + h Σ b_i f_i
260        let mut y_new = self.y.clone();
261        for i in 0..self.stages {
262            y_new.add_scaled(self.b[i] * self.h, &self.k[i]);
263        }
264
265        // Embedded error estimate (bh)
266        let mut err_norm = T::zero();
267        let bh = &self.bh.unwrap();
268
269        // Lower-order solution
270        let mut y_low = self.y.clone();
271        for i in 0..self.stages {
272            y_low.add_scaled(bh[i] * self.h, &self.k[i]);
273        }
274
275        // Weighted max-norm
276        let dim = self.y.len();
277        for n in 0..dim {
278            let scale = self.atol[n]
279                + self.rtol[n]
280                    * self
281                        .y
282                        .get_component(n)
283                        .abs()
284                        .max(y_new.get_component(n).abs());
285            if scale > T::zero() {
286                err_norm =
287                    err_norm.max(((y_new.get_component(n) - y_low.get_component(n)) / scale).abs());
288            }
289        }
290
291        // Avoid vanishing error
292        err_norm = err_norm.max(T::default_epsilon() * T::from_f64(100.0).unwrap());
293
294        // Step scale factor
295        let order = T::from_usize(self.order).unwrap();
296        let error_exponent = T::one() / order;
297        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
298
299        // Clamp scale factor
300        scale = scale.max(self.min_scale).min(self.max_scale);
301
302        // Accept/reject
303        if err_norm <= T::one() {
304            // Accepted
305            self.status = Status::Solving;
306
307            // Log previous
308            self.t_prev = self.t;
309            self.y_prev = self.y.clone();
310            self.dydt_prev = self.dydt.clone();
311            self.h_prev = self.h;
312
313            // Advance state
314            self.t += self.h;
315            self.y = y_new;
316
317            // Next-step derivative
318            ode.diff(self.t, &self.y, &mut self.dydt);
319            evals.function += 1;
320
321            // If we were rejecting, limit growth
322            if let Status::RejectedStep = self.status {
323                self.stiffness_counter = 0;
324
325                // Avoid oscillations
326                scale = scale.min(T::one());
327            }
328        } else {
329            // Rejected
330            self.status = Status::RejectedStep;
331            self.stiffness_counter += 1;
332
333            // Too many rejections
334            if self.stiffness_counter >= self.max_rejects {
335                self.status = Status::Error(Error::Stiffness {
336                    t: self.t,
337                    y: self.y.clone(),
338                });
339                return Err(Error::Stiffness {
340                    t: self.t,
341                    y: self.y.clone(),
342                });
343            }
344        }
345
346        // Update h
347        self.h *= scale;
348
349        // Constrain h
350        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
351
352        // Apply step size filter
353        self.h = (self.filter)(self.h);
354
355        Ok(evals)
356    }
357
358    fn t(&self) -> T {
359        self.t
360    }
361    fn y(&self) -> &Y {
362        &self.y
363    }
364    fn t_prev(&self) -> T {
365        self.t_prev
366    }
367    fn y_prev(&self) -> &Y {
368        &self.y_prev
369    }
370    fn h(&self) -> T {
371        self.h
372    }
373    fn set_h(&mut self, h: T) {
374        self.h = (self.filter)(h);
375    }
376    fn status(&self) -> &Status<T, Y> {
377        &self.status
378    }
379    fn set_status(&mut self, status: Status<T, Y>) {
380        self.status = status;
381    }
382}
383
384impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
385    for ImplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
386{
387    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
388        // Check if t is within bounds
389        if t_interp < self.t_prev || t_interp > self.t {
390            return Err(Error::OutOfBounds {
391                t_interp,
392                t_prev: self.t_prev,
393                t_curr: self.t,
394            });
395        }
396
397        // Use cubic Hermite interpolation
398        let y_interp = cubic_hermite_interpolate(
399            self.t_prev,
400            self.t,
401            &self.y_prev,
402            &self.y,
403            &self.dydt_prev,
404            &self.dydt,
405            t_interp,
406        );
407
408        Ok(y_interp)
409    }
410}