differential_equations/methods/dirk/fixed/
ordinary.rs

1//! Fixed-step DIRK for ODEs
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    linalg::Matrix,
7    methods::{DiagonallyImplicitRungeKutta, Fixed, Ordinary},
8    ode::{ODE, OrdinaryNumericalMethod},
9    stats::Evals,
10    status::Status,
11    traits::{CallBackData, Real, State},
12    utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
16    OrdinaryNumericalMethod<T, Y, D>
17    for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, D, O, S, I>
18{
19    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
20    where
21        F: ODE<T, Y, D>,
22    {
23        let mut evals = Evals::new();
24
25        // Validate step size bounds
26        match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
27            Ok(h0) => self.h = h0,
28            Err(status) => return Err(status),
29        }
30
31        // Stats
32        self.stiffness_counter = 0;
33        self.newton_iterations = 0;
34        self.jacobian_evaluations = 0;
35        self.lu_decompositions = 0;
36
37        // State
38        self.t = t0;
39        self.y = *y0;
40        ode.diff(self.t, &self.y, &mut self.dydt);
41        evals.function += 1;
42
43        // Previous state
44        self.t_prev = self.t;
45        self.y_prev = self.y;
46        self.dydt_prev = self.dydt;
47
48        // Newton workspace
49        let dim = y0.len();
50        self.jacobian = Matrix::zeros(dim);
51        self.z = *y0;
52        self.jacobian_age = 0;
53
54        // Status
55        self.status = Status::Initialized;
56
57        Ok(evals)
58    }
59
60    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
61    where
62        F: ODE<T, Y, D>,
63    {
64        let mut evals = Evals::new();
65
66        // Max steps guard
67        if self.steps >= self.max_steps {
68            self.status = Status::Error(Error::MaxSteps {
69                t: self.t,
70                y: self.y,
71            });
72            return Err(Error::MaxSteps {
73                t: self.t,
74                y: self.y,
75            });
76        }
77        self.steps += 1;
78
79        let dim = self.y.len();
80
81        // DIRK stage loop (sequential)
82        for stage in 0..self.stages {
83            // rhs = y_n + h Σ_{j<stage} a[stage][j] k[j]
84            let mut rhs = self.y;
85            for j in 0..stage {
86                rhs += self.k[j] * (self.a[stage][j] * self.h);
87            }
88
89            // Initial stage guess
90            self.z = self.y;
91
92            // Newton: solve z - rhs - h*a_ii f(t_i, z) = 0
93            let mut newton_converged = false;
94            let mut newton_iter = 0;
95            let mut increment_norm = T::infinity();
96
97            while !newton_converged && newton_iter < self.max_newton_iter {
98                newton_iter += 1;
99                self.newton_iterations += 1;
100                evals.newton += 1;
101
102                // Evaluate f at stage guess
103                let t_stage = self.t + self.c[stage] * self.h;
104                let mut f_stage = Y::zeros();
105                ode.diff(t_stage, &self.z, &mut f_stage);
106                evals.function += 1;
107
108                // Residual F(z)
109                let residual = self.z - rhs - f_stage * (self.a[stage][stage] * self.h);
110
111                // Max-norm and RHS
112                let mut residual_norm = T::zero();
113                self.rhs_newton = -residual;
114                for i in 0..dim {
115                    residual_norm = residual_norm.max(residual.get(i).abs());
116                }
117
118                // Converged by residual
119                if residual_norm < self.newton_tol {
120                    newton_converged = true;
121                    break;
122                }
123
124                // Converged by increment
125                if newton_iter > 1 && increment_norm < self.newton_tol {
126                    newton_converged = true;
127                    break;
128                }
129
130                // Refresh Jacobian if needed
131                if newton_iter == 1 || self.jacobian_age > 3 {
132                    ode.jacobian(t_stage, &self.z, &mut self.jacobian);
133                    evals.jacobian += 1;
134                    self.jacobian_age = 0;
135
136                    // Newton matrix: I - h*a_ii J
137                    self.jacobian
138                        .component_mul_mut(-self.h * self.a[stage][stage]);
139                    self.jacobian += Matrix::identity(dim);
140                }
141                self.jacobian_age += 1;
142
143                // Solve (I - h*a_ii J) Δz = -F(z) using in-place LU
144                self.delta_z = self.jacobian.lin_solve(self.rhs_newton);
145                self.lu_decompositions += 1;
146
147                // Update z and increment norm
148                increment_norm = T::zero();
149                self.z = self.z + self.delta_z;
150                for row_idx in 0..dim {
151                    // Calculate infinity norm of increment
152                    increment_norm = increment_norm.max(self.delta_z.get(row_idx).abs());
153                }
154            }
155
156            // Newton failed for this stage
157            if !newton_converged {
158                self.status = Status::Error(Error::Stiffness {
159                    t: self.t,
160                    y: self.y,
161                });
162                return Err(Error::Stiffness {
163                    t: self.t,
164                    y: self.y,
165                });
166            }
167
168            // k_i from converged z
169            let t_stage = self.t + self.c[stage] * self.h;
170            ode.diff(t_stage, &self.z, &mut self.k[stage]);
171            evals.function += 1;
172        }
173
174        // y_{n+1} = y_n + h Σ b_i k_i
175        let mut y_new = self.y;
176        for i in 0..self.stages {
177            y_new += self.k[i] * (self.b[i] * self.h);
178        }
179
180        // Fixed step: always accept
181        self.status = Status::Solving;
182
183        // Advance state
184        self.t_prev = self.t;
185        self.y_prev = self.y;
186        self.dydt_prev = self.dydt;
187        self.h_prev = self.h;
188
189        self.t += self.h;
190        self.y = y_new;
191
192        // Next-step derivative
193        ode.diff(self.t, &self.y, &mut self.dydt);
194        evals.function += 1;
195
196        Ok(evals)
197    }
198
199    fn t(&self) -> T {
200        self.t
201    }
202    fn y(&self) -> &Y {
203        &self.y
204    }
205    fn t_prev(&self) -> T {
206        self.t_prev
207    }
208    fn y_prev(&self) -> &Y {
209        &self.y_prev
210    }
211    fn h(&self) -> T {
212        self.h
213    }
214    fn set_h(&mut self, h: T) {
215        self.h = h;
216    }
217    fn status(&self) -> &Status<T, Y, D> {
218        &self.status
219    }
220    fn set_status(&mut self, status: Status<T, Y, D>) {
221        self.status = status;
222    }
223}
224
225impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
226    Interpolation<T, Y> for DiagonallyImplicitRungeKutta<Ordinary, Fixed, T, Y, D, O, S, I>
227{
228    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
229        // Check if t is within bounds
230        if t_interp < self.t_prev || t_interp > self.t {
231            return Err(Error::OutOfBounds {
232                t_interp,
233                t_prev: self.t_prev,
234                t_curr: self.t,
235            });
236        }
237
238        // Otherwise use cubic Hermite interpolation
239        let y_interp = cubic_hermite_interpolate(
240            self.t_prev,
241            self.t,
242            &self.y_prev,
243            &self.y,
244            &self.dydt_prev,
245            &self.dydt,
246            t_interp,
247        );
248
249        Ok(y_interp)
250    }
251}