Skip to main content

differential_equations/methods/dirk/adaptive/
ordinary.rs

1//! Adaptive DIRK for ODEs
2use crate::{
3    error::Error,
4    interpolate::{Interpolation, cubic_hermite_interpolate},
5    linalg::Matrix,
6    methods::h_init::InitialStepSize,
7    methods::{Adaptive, DiagonallyImplicitRungeKutta, Ordinary},
8    ode::{ODE, OrdinaryNumericalMethod},
9    stats::Evals,
10    status::Status,
11    traits::{Real, State},
12    utils::{constrain_step_size, validate_step_size_parameters},
13};
14
15impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
16    OrdinaryNumericalMethod<T, Y>
17    for DiagonallyImplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
18{
19    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
20    where
21        F: ODE<T, Y> + ?Sized,
22    {
23        let mut evals = Evals::new();
24
25        // Compute h0 if not set
26        if self.h0 == T::zero() {
27            // Implicit initial step size heuristic
28            self.h0 = InitialStepSize::<Ordinary>::compute(
29                ode, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max,
30                &mut evals,
31            );
32        }
33
34        // Validate step size bounds
35        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
36            Ok(h0) => self.h = (self.filter)(h0),
37            Err(status) => return Err(status),
38        }
39
40        // Stats
41        self.stiffness_counter = 0;
42        self.newton_iterations = 0;
43        self.jacobian_evaluations = 0;
44        self.lu_decompositions = 0;
45
46        // State
47        self.t = t0;
48        self.y = y0.clone();
49        self.dydt = y0.zeros_like();
50        self.y_prev = y0.clone();
51        self.dydt_prev = y0.zeros_like();
52        self.k = core::array::from_fn(|_| y0.zeros_like());
53        self.z = y0.clone();
54        self.rhs_newton = y0.zeros_like();
55        self.delta_z = y0.zeros_like();
56        ode.diff(self.t, &self.y, &mut self.dydt);
57        evals.function += 1;
58
59        // Previous state
60        self.t_prev = self.t;
61        self.y_prev = self.y.clone();
62        self.dydt_prev = self.dydt.clone();
63
64        // Newton workspace
65        let dim = y0.len();
66        self.jacobian = Matrix::zeros(dim, dim);
67        self.z = y0.clone();
68        self.jacobian_age = 0;
69
70        // Status
71        self.status = Status::Initialized;
72
73        Ok(evals)
74    }
75
76    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
77    where
78        F: ODE<T, Y> + ?Sized,
79    {
80        let mut evals = Evals::new();
81
82        // Step size guard
83        if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
84            self.status = Status::Error(Error::StepSize {
85                t: self.t,
86                y: self.y.clone(),
87            });
88            return Err(Error::StepSize {
89                t: self.t,
90                y: self.y.clone(),
91            });
92        }
93
94        // Max steps guard
95        if self.steps >= self.max_steps {
96            self.status = Status::Error(Error::MaxSteps {
97                t: self.t,
98                y: self.y.clone(),
99            });
100            return Err(Error::MaxSteps {
101                t: self.t,
102                y: self.y.clone(),
103            });
104        }
105        self.steps += 1;
106
107        let dim = self.y.len();
108
109        // DIRK stage loop (sequential)
110        for stage in 0..self.stages {
111            // rhs = y_n + h Σ_{j<stage} a[stage][j] k[j]
112            let mut rhs = self.y.clone();
113            for j in 0..stage {
114                rhs.add_scaled(self.a[stage][j] * self.h, &self.k[j]);
115            }
116
117            // Initial stage guess
118            self.z = self.y.clone();
119
120            // Newton: solve z - rhs - h*a_ii f(t_i, z) = 0
121            let mut newton_converged = false;
122            let mut newton_iter = 0;
123            let mut increment_norm = T::infinity();
124
125            while !newton_converged && newton_iter < self.max_newton_iter {
126                newton_iter += 1;
127                self.newton_iterations += 1;
128                evals.newton += 1;
129
130                // Evaluate f at stage guess
131                let t_stage = self.t + self.c[stage] * self.h;
132                let mut f_stage = self.y.zeros_like();
133                ode.diff(t_stage, &self.z, &mut f_stage);
134                evals.function += 1;
135
136                // Residual F(z)
137                let residual = self.z.plus_linear_combination(&[
138                    (&rhs, -T::one()),
139                    (&f_stage, -(self.a[stage][stage] * self.h)),
140                ]);
141
142                // Max-norm and RHS
143                self.rhs_newton = residual.scaled(-T::one());
144                let residual_norm = residual.max_norm();
145
146                // Converged by residual
147                if residual_norm < self.newton_tol {
148                    newton_converged = true;
149                    break;
150                }
151
152                // Converged by increment
153                if newton_iter > 1 && increment_norm < self.newton_tol {
154                    newton_converged = true;
155                    break;
156                }
157
158                // Refresh Jacobian if needed
159                if newton_iter == 1 || self.jacobian_age > 3 {
160                    ode.jacobian(t_stage, &self.z, &mut self.jacobian);
161                    evals.jacobian += 1;
162                    self.jacobian_age = 0;
163
164                    // Newton matrix: I - h*a_ii J
165                    self.jacobian
166                        .component_mul_mut(-self.h * self.a[stage][stage]);
167                    self.jacobian += Matrix::identity(dim);
168                }
169                self.jacobian_age += 1;
170
171                // Solve (I - h*a_ii J) Δz = -F(z) using in-place LU
172                match self.jacobian.lin_solve(self.rhs_newton.clone()) {
173                    Ok(dz) => self.delta_z = dz,
174                    Err(e) => {
175                        let mapped_err = Error::LinearAlgebra {
176                            t: self.t,
177                            y: self.y.clone(),
178                            msg: e.to_string(),
179                        };
180                        self.status = Status::Error(mapped_err.clone());
181                        return Err(mapped_err);
182                    }
183                }
184                evals.solves += 1;
185
186                // Update z and increment norm
187                self.z.add_scaled(T::one(), &self.delta_z);
188                increment_norm = self.delta_z.max_norm();
189            }
190
191            // Newton failed for this stage
192            if !newton_converged {
193                // Reduce h and retry later
194                self.h *= T::from_f64(0.25).unwrap();
195                self.h = constrain_step_size(self.h, self.h_min, self.h_max);
196                self.h = (self.filter)(self.h);
197                self.status = Status::RejectedStep;
198                self.stiffness_counter += 1;
199
200                if self.stiffness_counter >= self.max_rejects {
201                    self.status = Status::Error(Error::Stiffness {
202                        t: self.t,
203                        y: self.y.clone(),
204                    });
205                    return Err(Error::Stiffness {
206                        t: self.t,
207                        y: self.y.clone(),
208                    });
209                }
210                return Ok(evals);
211            }
212
213            // k_i from converged z
214            let t_stage = self.t + self.c[stage] * self.h;
215            ode.diff(t_stage, &self.z, &mut self.k[stage]);
216            evals.function += 1;
217        }
218
219        // y_{n+1} = y_n + h Σ b_i k_i
220        let mut y_new = self.y.clone();
221        for i in 0..self.stages {
222            y_new.add_scaled(self.b[i] * self.h, &self.k[i]);
223        }
224
225        // Embedded error estimate (bh)
226        let mut err_norm = T::zero();
227        let bh = &self.bh.unwrap();
228
229        // Lower-order solution
230        let mut y_low = self.y.clone();
231        for i in 0..self.stages {
232            y_low.add_scaled(bh[i] * self.h, &self.k[i]);
233        }
234
235        // Weighted max-norm
236        let dim = self.y.len();
237        for i in 0..dim {
238            let scale = self.atol[i]
239                + self.rtol[i]
240                    * self
241                        .y
242                        .get_component(i)
243                        .abs()
244                        .max(y_new.get_component(i).abs());
245            if scale > T::zero() {
246                err_norm =
247                    err_norm.max(((y_new.get_component(i) - y_low.get_component(i)) / scale).abs());
248            }
249        }
250
251        // Avoid vanishing error
252        err_norm = err_norm.max(T::default_epsilon() * T::from_f64(100.0).unwrap());
253
254        // Step scale factor
255        let order = T::from_usize(self.order).unwrap();
256        let error_exponent = T::one() / order;
257        let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
258
259        // Clamp scale factor
260        scale = scale.max(self.min_scale).min(self.max_scale);
261
262        // Accept/reject
263        if err_norm <= T::one() {
264            // Accepted
265            self.status = Status::Solving;
266
267            // Log previous
268            self.t_prev = self.t;
269            self.y_prev = self.y.clone();
270            self.dydt_prev = self.dydt.clone();
271            self.h_prev = self.h;
272
273            // Advance state
274            self.t += self.h;
275            self.y = y_new;
276
277            // Next-step derivative
278            ode.diff(self.t, &self.y, &mut self.dydt);
279            evals.function += 1;
280
281            // If we were rejecting, limit growth
282            if let Status::RejectedStep = self.status {
283                self.stiffness_counter = 0;
284
285                // Avoid oscillations
286                scale = scale.min(T::one());
287            }
288        } else {
289            // Rejected
290            self.status = Status::RejectedStep;
291            self.stiffness_counter += 1;
292
293            // Too many rejections
294            if self.stiffness_counter >= self.max_rejects {
295                self.status = Status::Error(Error::Stiffness {
296                    t: self.t,
297                    y: self.y.clone(),
298                });
299                return Err(Error::Stiffness {
300                    t: self.t,
301                    y: self.y.clone(),
302                });
303            }
304        }
305
306        // Update h
307        self.h *= scale;
308
309        // Constrain h
310        self.h = constrain_step_size(self.h, self.h_min, self.h_max);
311
312        // Apply step size filter
313        self.h = (self.filter)(self.h);
314
315        Ok(evals)
316    }
317
318    fn t(&self) -> T {
319        self.t
320    }
321    fn y(&self) -> &Y {
322        &self.y
323    }
324    fn t_prev(&self) -> T {
325        self.t_prev
326    }
327    fn y_prev(&self) -> &Y {
328        &self.y_prev
329    }
330    fn h(&self) -> T {
331        self.h
332    }
333    fn set_h(&mut self, h: T) {
334        self.h = (self.filter)(h);
335    }
336    fn status(&self) -> &Status<T, Y> {
337        &self.status
338    }
339    fn set_status(&mut self, status: Status<T, Y>) {
340        self.status = status;
341    }
342}
343
344impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
345    for DiagonallyImplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
346{
347    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
348        // Check if t is within bounds
349        if t_interp < self.t_prev || t_interp > self.t {
350            return Err(Error::OutOfBounds {
351                t_interp,
352                t_prev: self.t_prev,
353                t_curr: self.t,
354            });
355        }
356
357        // Use cubic Hermite interpolation
358        let y_interp = cubic_hermite_interpolate(
359            self.t_prev,
360            self.t,
361            &self.y_prev,
362            &self.y,
363            &self.dydt_prev,
364            &self.dydt,
365            t_interp,
366        );
367
368        Ok(y_interp)
369    }
370}