differential_equations/methods/dirk/fixed/
ordinary.rs

1//! Fixed-step DIRK for ODEs
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    linalg::Matrix,
7    methods::{DiagonallyImplicitRungeKutta, Fixed, Ordinary},
8    ode::{ODE, OrdinaryNumericalMethod},
9    stats::Evals,
10    status::Status,
11    traits::{Real, State},
12    utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
16    OrdinaryNumericalMethod<T, Y> for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
17{
18    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19    where
20        F: ODE<T, Y>,
21    {
22        let mut evals = Evals::new();
23
24        // Validate step size bounds
25        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
26            Ok(h0) => self.h = h0,
27            Err(status) => return Err(status),
28        }
29
30        // Stats
31        self.stiffness_counter = 0;
32        self.newton_iterations = 0;
33        self.jacobian_evaluations = 0;
34        self.lu_decompositions = 0;
35
36        // State
37        self.t = t0;
38        self.y = *y0;
39        ode.diff(self.t, &self.y, &mut self.dydt);
40        evals.function += 1;
41
42        // Previous state
43        self.t_prev = self.t;
44        self.y_prev = self.y;
45        self.dydt_prev = self.dydt;
46
47        // Newton workspace
48        let dim = y0.len();
49        self.jacobian = Matrix::zeros(dim, dim);
50        self.z = *y0;
51        self.jacobian_age = 0;
52
53        // Status
54        self.status = Status::Initialized;
55
56        Ok(evals)
57    }
58
59    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
60    where
61        F: ODE<T, Y>,
62    {
63        let mut evals = Evals::new();
64
65        // Max steps guard
66        if self.steps >= self.max_steps {
67            self.status = Status::Error(Error::MaxSteps {
68                t: self.t,
69                y: self.y,
70            });
71            return Err(Error::MaxSteps {
72                t: self.t,
73                y: self.y,
74            });
75        }
76        self.steps += 1;
77
78        let dim = self.y.len();
79
80        // DIRK stage loop (sequential)
81        for stage in 0..self.stages {
82            // rhs = y_n + h Σ_{j<stage} a[stage][j] k[j]
83            let mut rhs = self.y;
84            for j in 0..stage {
85                rhs += self.k[j] * (self.a[stage][j] * self.h);
86            }
87
88            // Initial stage guess
89            self.z = self.y;
90
91            // Newton: solve z - rhs - h*a_ii f(t_i, z) = 0
92            let mut newton_converged = false;
93            let mut newton_iter = 0;
94            let mut increment_norm = T::infinity();
95
96            while !newton_converged && newton_iter < self.max_newton_iter {
97                newton_iter += 1;
98                self.newton_iterations += 1;
99                evals.newton += 1;
100
101                // Evaluate f at stage guess
102                let t_stage = self.t + self.c[stage] * self.h;
103                let mut f_stage = Y::zeros();
104                ode.diff(t_stage, &self.z, &mut f_stage);
105                evals.function += 1;
106
107                // Residual F(z)
108                let residual = self.z - rhs - f_stage * (self.a[stage][stage] * self.h);
109
110                // Max-norm and RHS
111                let mut residual_norm = T::zero();
112                self.rhs_newton = -residual;
113                for i in 0..dim {
114                    residual_norm = residual_norm.max(residual.get(i).abs());
115                }
116
117                // Converged by residual
118                if residual_norm < self.newton_tol {
119                    newton_converged = true;
120                    break;
121                }
122
123                // Converged by increment
124                if newton_iter > 1 && increment_norm < self.newton_tol {
125                    newton_converged = true;
126                    break;
127                }
128
129                // Refresh Jacobian if needed
130                if newton_iter == 1 || self.jacobian_age > 3 {
131                    ode.jacobian(t_stage, &self.z, &mut self.jacobian);
132                    evals.jacobian += 1;
133                    self.jacobian_age = 0;
134
135                    // Newton matrix: I - h*a_ii J
136                    self.jacobian
137                        .component_mul_mut(-self.h * self.a[stage][stage]);
138                    self.jacobian += Matrix::identity(dim);
139                }
140                self.jacobian_age += 1;
141
142                // Solve (I - h*a_ii J) Δz = -F(z) using in-place LU
143                self.delta_z = self.jacobian.lin_solve(self.rhs_newton).unwrap();
144                self.lu_decompositions += 1;
145
146                // Update z and increment norm
147                increment_norm = T::zero();
148                self.z += self.delta_z;
149                for row_idx in 0..dim {
150                    // Calculate infinity norm of increment
151                    increment_norm = increment_norm.max(self.delta_z.get(row_idx).abs());
152                }
153            }
154
155            // Newton failed for this stage
156            if !newton_converged {
157                self.status = Status::Error(Error::Stiffness {
158                    t: self.t,
159                    y: self.y,
160                });
161                return Err(Error::Stiffness {
162                    t: self.t,
163                    y: self.y,
164                });
165            }
166
167            // k_i from converged z
168            let t_stage = self.t + self.c[stage] * self.h;
169            ode.diff(t_stage, &self.z, &mut self.k[stage]);
170            evals.function += 1;
171        }
172
173        // y_{n+1} = y_n + h Σ b_i k_i
174        let mut y_new = self.y;
175        for i in 0..self.stages {
176            y_new += self.k[i] * (self.b[i] * self.h);
177        }
178
179        // Fixed step: always accept
180        self.status = Status::Solving;
181
182        // Advance state
183        self.t_prev = self.t;
184        self.y_prev = self.y;
185        self.dydt_prev = self.dydt;
186        self.h_prev = self.h;
187
188        self.t += self.h;
189        self.y = y_new;
190
191        // Next-step derivative
192        ode.diff(self.t, &self.y, &mut self.dydt);
193        evals.function += 1;
194
195        Ok(evals)
196    }
197
198    fn t(&self) -> T {
199        self.t
200    }
201    fn y(&self) -> &Y {
202        &self.y
203    }
204    fn t_prev(&self) -> T {
205        self.t_prev
206    }
207    fn y_prev(&self) -> &Y {
208        &self.y_prev
209    }
210    fn h(&self) -> T {
211        self.h
212    }
213    fn set_h(&mut self, h: T) {
214        self.h = h;
215    }
216    fn status(&self) -> &Status<T, Y> {
217        &self.status
218    }
219    fn set_status(&mut self, status: Status<T, Y>) {
220        self.status = status;
221    }
222}
223
224impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
225    for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
226{
227    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
228        // Check if t is within bounds
229        if t_interp < self.t_prev || t_interp > self.t {
230            return Err(Error::OutOfBounds {
231                t_interp,
232                t_prev: self.t_prev,
233                t_curr: self.t,
234            });
235        }
236
237        // Otherwise use cubic Hermite interpolation
238        let y_interp = cubic_hermite_interpolate(
239            self.t_prev,
240            self.t,
241            &self.y_prev,
242            &self.y,
243            &self.dydt_prev,
244            &self.dydt,
245            t_interp,
246        );
247
248        Ok(y_interp)
249    }
250}