differential_equations/methods/dirk/adaptive/
ordinary.rs

1//! Adaptive DIRK for ODEs
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    linalg::Matrix,
7    methods::h_init::InitialStepSize,
8    methods::{Adaptive, DiagonallyImplicitRungeKutta, Ordinary},
9    ode::{ODE, OrdinaryNumericalMethod},
10    stats::Evals,
11    status::Status,
12    traits::{Real, State},
13    utils::{constrain_step_size, validate_step_size_parameters},
14};
15
16impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
17    OrdinaryNumericalMethod<T, Y>
18    for DiagonallyImplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
19{
20    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
21    where
22        F: ODE<T, Y>,
23    {
24        let mut evals = Evals::new();
25
26        // Compute h0 if not set
27        if self.h0 == T::zero() {
28            // Implicit initial step size heuristic
29            self.h0 = InitialStepSize::<Ordinary>::compute(
30                ode, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max,
31                &mut evals,
32            );
33        }
34
35        // Validate step size bounds
36        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
37            Ok(h0) => self.h = h0,
38            Err(status) => return Err(status),
39        }
40
41        // Stats
42        self.stiffness_counter = 0;
43        self.newton_iterations = 0;
44        self.jacobian_evaluations = 0;
45        self.lu_decompositions = 0;
46
47        // State
48        self.t = t0;
49        self.y = *y0;
50        ode.diff(self.t, &self.y, &mut self.dydt);
51        evals.function += 1;
52
53        // Previous state
54        self.t_prev = self.t;
55        self.y_prev = self.y;
56        self.dydt_prev = self.dydt;
57
58        // Newton workspace
59        let dim = y0.len();
60        self.jacobian = Matrix::zeros(dim, dim);
61        self.z = *y0;
62        self.jacobian_age = 0;
63
64        // Status
65        self.status = Status::Initialized;
66
67        Ok(evals)
68    }
69
70    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
71    where
72        F: ODE<T, Y>,
73    {
74        let mut evals = Evals::new();
75
76        // Step size guard
77        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
78            self.status = Status::Error(Error::StepSize {
79                t: self.t,
80                y: self.y,
81            });
82            return Err(Error::StepSize {
83                t: self.t,
84                y: self.y,
85            });
86        }
87
88        // Max steps guard
89        if self.steps >= self.max_steps {
90            self.status = Status::Error(Error::MaxSteps {
91                t: self.t,
92                y: self.y,
93            });
94            return Err(Error::MaxSteps {
95                t: self.t,
96                y: self.y,
97            });
98        }
99        self.steps += 1;
100
101        let dim = self.y.len();
102
103        // DIRK stage loop (sequential)
104        for stage in 0..self.stages {
105            // rhs = y_n + h Σ_{j<stage} a[stage][j] k[j]
106            let mut rhs = self.y;
107            for j in 0..stage {
108                rhs += self.k[j] * (self.a[stage][j] * self.h);
109            }
110
111            // Initial stage guess
112            self.z = self.y;
113
114            // Newton: solve z - rhs - h*a_ii f(t_i, z) = 0
115            let mut newton_converged = false;
116            let mut newton_iter = 0;
117            let mut increment_norm = T::infinity();
118
119            while !newton_converged && newton_iter < self.max_newton_iter {
120                newton_iter += 1;
121                self.newton_iterations += 1;
122                evals.newton += 1;
123
124                // Evaluate f at stage guess
125                let t_stage = self.t + self.c[stage] * self.h;
126                let mut f_stage = Y::zeros();
127                ode.diff(t_stage, &self.z, &mut f_stage);
128                evals.function += 1;
129
130                // Residual F(z)
131                let residual = self.z - rhs - f_stage * (self.a[stage][stage] * self.h);
132
133                // Max-norm and RHS
134                let mut residual_norm = T::zero();
135                self.rhs_newton = -residual;
136                for i in 0..dim {
137                    residual_norm = residual_norm.max(residual.get(i).abs());
138                }
139
140                // Converged by residual
141                if residual_norm < self.newton_tol {
142                    newton_converged = true;
143                    break;
144                }
145
146                // Converged by increment
147                if newton_iter > 1 && increment_norm < self.newton_tol {
148                    newton_converged = true;
149                    break;
150                }
151
152                // Refresh Jacobian if needed
153                if newton_iter == 1 || self.jacobian_age > 3 {
154                    ode.jacobian(t_stage, &self.z, &mut self.jacobian);
155                    evals.jacobian += 1;
156                    self.jacobian_age = 0;
157
158                    // Newton matrix: I - h*a_ii J
159                    self.jacobian
160                        .component_mul_mut(-self.h * self.a[stage][stage]);
161                    self.jacobian += Matrix::identity(dim);
162                }
163                self.jacobian_age += 1;
164
165                // Solve (I - h*a_ii J) Δz = -F(z) using in-place LU
166                self.delta_z = self.jacobian.lin_solve(self.rhs_newton).unwrap();
167                self.lu_decompositions += 1;
168
169                // Update z and increment norm
170                increment_norm = T::zero();
171                self.z += self.delta_z;
172                for row_idx in 0..dim {
173                    // Calculate infinity norm of increment
174                    increment_norm = increment_norm.max(self.delta_z.get(row_idx).abs());
175                }
176            }
177
178            // Newton failed for this stage
179            if !newton_converged {
180                // Reduce h and retry later
181                self.h *= T::from_f64(0.25).unwrap();
182                self.h = constrain_step_size(self.h, self.h_min, self.h_max);
183                self.status = Status::RejectedStep;
184                self.stiffness_counter += 1;
185
186                if self.stiffness_counter >= self.max_rejects {
187                    self.status = Status::Error(Error::Stiffness {
188                        t: self.t,
189                        y: self.y,
190                    });
191                    return Err(Error::Stiffness {
192                        t: self.t,
193                        y: self.y,
194                    });
195                }
196                return Ok(evals);
197            }
198
199            // k_i from converged z
200            let t_stage = self.t + self.c[stage] * self.h;
201            ode.diff(t_stage, &self.z, &mut self.k[stage]);
202            evals.function += 1;
203        }
204
205        // y_{n+1} = y_n + h Σ b_i k_i
206        let mut y_new = self.y;
207        for i in 0..self.stages {
208            y_new += self.k[i] * (self.b[i] * self.h);
209        }
210
211        // Embedded error estimate (bh)
212        let mut err_norm = T::zero();
213        let bh = &self.bh.unwrap();
214
215        // Lower-order solution
216        let mut y_low = self.y;
217        for i in 0..self.stages {
218            y_low += self.k[i] * (bh[i] * self.h);
219        }
220
221        // err = y_high - y_low
222        let err = y_new - y_low;
223
224        // Weighted max-norm
225        for i in 0..self.y.len() {
226            let scale = self.atol[i] + self.rtol[i] * self.y.get(i).abs().max(y_new.get(i).abs());
227            if scale > T::zero() {
228                err_norm = err_norm.max((err.get(i) / scale).abs());
229            }
230        }
231
232        // Avoid vanishing error
233        err_norm = err_norm.max(T::default_epsilon() * T::from_f64(100.0).unwrap());
234
235        // Step scale factor
236        let order = T::from_usize(self.order).unwrap();
237        let error_exponent = T::one() / order;
238        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
239
240        // Clamp scale factor
241        scale = scale.max(self.min_scale).min(self.max_scale);
242
243        // Accept/reject
244        if err_norm <= T::one() {
245            // Accepted
246            self.status = Status::Solving;
247
248            // Log previous
249            self.t_prev = self.t;
250            self.y_prev = self.y;
251            self.dydt_prev = self.dydt;
252            self.h_prev = self.h;
253
254            // Advance state
255            self.t += self.h;
256            self.y = y_new;
257
258            // Next-step derivative
259            ode.diff(self.t, &self.y, &mut self.dydt);
260            evals.function += 1;
261
262            // If we were rejecting, limit growth
263            if let Status::RejectedStep = self.status {
264                self.stiffness_counter = 0;
265
266                // Avoid oscillations
267                scale = scale.min(T::one());
268            }
269        } else {
270            // Rejected
271            self.status = Status::RejectedStep;
272            self.stiffness_counter += 1;
273
274            // Too many rejections
275            if self.stiffness_counter >= self.max_rejects {
276                self.status = Status::Error(Error::Stiffness {
277                    t: self.t,
278                    y: self.y,
279                });
280                return Err(Error::Stiffness {
281                    t: self.t,
282                    y: self.y,
283                });
284            }
285        }
286
287        // Update h
288        self.h *= scale;
289
290        // Constrain h
291        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
292
293        Ok(evals)
294    }
295
296    fn t(&self) -> T {
297        self.t
298    }
299    fn y(&self) -> &Y {
300        &self.y
301    }
302    fn t_prev(&self) -> T {
303        self.t_prev
304    }
305    fn y_prev(&self) -> &Y {
306        &self.y_prev
307    }
308    fn h(&self) -> T {
309        self.h
310    }
311    fn set_h(&mut self, h: T) {
312        self.h = h;
313    }
314    fn status(&self) -> &Status<T, Y> {
315        &self.status
316    }
317    fn set_status(&mut self, status: Status<T, Y>) {
318        self.status = status;
319    }
320}
321
322impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
323    for DiagonallyImplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
324{
325    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
326        // Check if t is within bounds
327        if t_interp < self.t_prev || t_interp > self.t {
328            return Err(Error::OutOfBounds {
329                t_interp,
330                t_prev: self.t_prev,
331                t_curr: self.t,
332            });
333        }
334
335        // Use cubic Hermite interpolation
336        let y_interp = cubic_hermite_interpolate(
337            self.t_prev,
338            self.t,
339            &self.y_prev,
340            &self.y,
341            &self.dydt_prev,
342            &self.dydt,
343            t_interp,
344        );
345
346        Ok(y_interp)
347    }
348}