Skip to main content

differential_equations/methods/dirk/fixed/
ordinary.rs

1//! Fixed-step DIRK for ODEs
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    linalg::Matrix,
7    methods::{DiagonallyImplicitRungeKutta, Fixed, Ordinary},
8    ode::{ODE, OrdinaryNumericalMethod},
9    stats::Evals,
10    status::Status,
11    traits::{Real, State},
12    utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
16    OrdinaryNumericalMethod<T, Y> for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
17{
18    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19    where
20        F: ODE<T, Y> + ?Sized,
21    {
22        let mut evals = Evals::new();
23
24        // Validate step size bounds
25        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
26            Ok(h0) => self.h = h0,
27            Err(status) => return Err(status),
28        }
29
30        // Stats
31        self.stiffness_counter = 0;
32        self.newton_iterations = 0;
33        self.jacobian_evaluations = 0;
34        self.lu_decompositions = 0;
35
36        // State
37        self.t = t0;
38        self.y = y0.clone();
39        self.dydt = y0.zeros_like();
40        self.y_prev = y0.clone();
41        self.dydt_prev = y0.zeros_like();
42        self.k = core::array::from_fn(|_| y0.zeros_like());
43        self.z = y0.clone();
44        self.rhs_newton = y0.zeros_like();
45        self.delta_z = y0.zeros_like();
46        ode.diff(self.t, &self.y, &mut self.dydt);
47        evals.function += 1;
48
49        // Previous state
50        self.t_prev = self.t;
51        self.y_prev = self.y.clone();
52        self.dydt_prev = self.dydt.clone();
53
54        // Newton workspace
55        let dim = y0.len();
56        self.jacobian = Matrix::zeros(dim, dim);
57        self.z = y0.clone();
58        self.jacobian_age = 0;
59
60        // Status
61        self.status = Status::Initialized;
62
63        Ok(evals)
64    }
65
66    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
67    where
68        F: ODE<T, Y> + ?Sized,
69    {
70        let mut evals = Evals::new();
71
72        // Max steps guard
73        if self.steps >= self.max_steps {
74            self.status = Status::Error(Error::MaxSteps {
75                t: self.t,
76                y: self.y.clone(),
77            });
78            return Err(Error::MaxSteps {
79                t: self.t,
80                y: self.y.clone(),
81            });
82        }
83        self.steps += 1;
84
85        let dim = self.y.len();
86
87        // DIRK stage loop (sequential)
88        for stage in 0..self.stages {
89            // rhs = y_n + h Σ_{j<stage} a[stage][j] k[j]
90            let mut rhs = self.y.clone();
91            for j in 0..stage {
92                rhs.add_scaled(self.a[stage][j] * self.h, &self.k[j]);
93            }
94
95            // Initial stage guess
96            self.z = self.y.clone();
97
98            // Newton: solve z - rhs - h*a_ii f(t_i, z) = 0
99            let mut newton_converged = false;
100            let mut newton_iter = 0;
101            let mut increment_norm = T::infinity();
102
103            while !newton_converged && newton_iter < self.max_newton_iter {
104                newton_iter += 1;
105                self.newton_iterations += 1;
106                evals.newton += 1;
107
108                // Evaluate f at stage guess
109                let t_stage = self.t + self.c[stage] * self.h;
110                let mut f_stage = self.y.zeros_like();
111                ode.diff(t_stage, &self.z, &mut f_stage);
112                evals.function += 1;
113
114                // Residual F(z)
115                let residual = self.z.plus_linear_combination(&[
116                    (&rhs, -T::one()),
117                    (&f_stage, -(self.a[stage][stage] * self.h)),
118                ]);
119
120                // Max-norm and RHS
121                self.rhs_newton = residual.scaled(-T::one());
122                let residual_norm = residual.max_norm();
123
124                // Converged by residual
125                if residual_norm < self.newton_tol {
126                    newton_converged = true;
127                    break;
128                }
129
130                // Converged by increment
131                if newton_iter > 1 && increment_norm < self.newton_tol {
132                    newton_converged = true;
133                    break;
134                }
135
136                // Refresh Jacobian if needed
137                if newton_iter == 1 || self.jacobian_age > 3 {
138                    ode.jacobian(t_stage, &self.z, &mut self.jacobian);
139                    evals.jacobian += 1;
140                    self.jacobian_age = 0;
141
142                    // Newton matrix: I - h*a_ii J
143                    self.jacobian
144                        .component_mul_mut(-self.h * self.a[stage][stage]);
145                    self.jacobian += Matrix::identity(dim);
146                }
147                self.jacobian_age += 1;
148
149                // Solve (I - h*a_ii J) Δz = -F(z) using in-place LU
150                match self.jacobian.lin_solve(self.rhs_newton.clone()) {
151                    Ok(dz) => self.delta_z = dz,
152                    Err(e) => {
153                        let mapped_err = Error::LinearAlgebra {
154                            t: self.t,
155                            y: self.y.clone(),
156                            msg: e.to_string(),
157                        };
158                        self.status = Status::Error(mapped_err.clone());
159                        return Err(mapped_err);
160                    }
161                }
162                evals.solves += 1;
163
164                // Update z and increment norm
165                self.z.add_scaled(T::one(), &self.delta_z);
166                increment_norm = self.delta_z.max_norm();
167            }
168
169            // Newton failed for this stage
170            if !newton_converged {
171                self.status = Status::Error(Error::Stiffness {
172                    t: self.t,
173                    y: self.y.clone(),
174                });
175                return Err(Error::Stiffness {
176                    t: self.t,
177                    y: self.y.clone(),
178                });
179            }
180
181            // k_i from converged z
182            let t_stage = self.t + self.c[stage] * self.h;
183            ode.diff(t_stage, &self.z, &mut self.k[stage]);
184            evals.function += 1;
185        }
186
187        // y_{n+1} = y_n + h Σ b_i k_i
188        let mut y_new = self.y.clone();
189        for i in 0..self.stages {
190            y_new.add_scaled(self.b[i] * self.h, &self.k[i]);
191        }
192
193        // Fixed step: always accept
194        self.status = Status::Solving;
195
196        // Advance state
197        self.t_prev = self.t;
198        self.y_prev = self.y.clone();
199        self.dydt_prev = self.dydt.clone();
200        self.h_prev = self.h;
201
202        self.t += self.h;
203        self.y = y_new;
204
205        // Next-step derivative
206        ode.diff(self.t, &self.y, &mut self.dydt);
207        evals.function += 1;
208
209        Ok(evals)
210    }
211
212    fn t(&self) -> T {
213        self.t
214    }
215    fn y(&self) -> &Y {
216        &self.y
217    }
218    fn t_prev(&self) -> T {
219        self.t_prev
220    }
221    fn y_prev(&self) -> &Y {
222        &self.y_prev
223    }
224    fn h(&self) -> T {
225        self.h
226    }
227    fn set_h(&mut self, h: T) {
228        self.h = h;
229    }
230    fn status(&self) -> &Status<T, Y> {
231        &self.status
232    }
233    fn set_status(&mut self, status: Status<T, Y>) {
234        self.status = status;
235    }
236}
237
238impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
239    for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
240{
241    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
242        // Check if t is within bounds
243        if t_interp < self.t_prev || t_interp > self.t {
244            return Err(Error::OutOfBounds {
245                t_interp,
246                t_prev: self.t_prev,
247                t_curr: self.t,
248            });
249        }
250
251        // Otherwise use cubic Hermite interpolation
252        let y_interp = cubic_hermite_interpolate(
253            self.t_prev,
254            self.t,
255            &self.y_prev,
256            &self.y,
257            &self.dydt_prev,
258            &self.dydt,
259            t_interp,
260        );
261
262        Ok(y_interp)
263    }
264}