differential_equations/methods/irk/fixed/
ordinary.rs

1//! Fixed-step IRK for ODEs
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    linalg::Matrix,
7    methods::{Fixed, ImplicitRungeKutta, Ordinary},
8    ode::{ODE, OrdinaryNumericalMethod},
9    stats::Evals,
10    status::Status,
11    traits::{CallBackData, Real, State},
12    utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
16    OrdinaryNumericalMethod<T, Y, D> for ImplicitRungeKutta<Ordinary, Fixed, T, Y, D, O, S, I>
17{
18    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19    where
20        F: ODE<T, Y, D>,
21    {
22        let mut evals = Evals::new();
23
24        // Validate step size bounds
25        match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
26            // Set the fixed step size
27            Ok(h0) => self.h = h0,
28            Err(status) => return Err(status),
29        }
30
31        // Stats
32        self.stiffness_counter = 0;
33        self.newton_iterations = 0;
34        self.jacobian_evaluations = 0;
35        self.lu_decompositions = 0;
36
37        // State
38        self.t = t0;
39        self.y = *y0;
40        ode.diff(self.t, &self.y, &mut self.dydt);
41        evals.function += 1;
42
43        // Previous state
44        self.t_prev = self.t;
45        self.y_prev = self.y;
46        self.dydt_prev = self.dydt;
47
48        // Linear algebra workspace
49        let dim = y0.len();
50        let newton_system_size = self.stages * dim;
51        self.stage_jacobians = core::array::from_fn(|_| Matrix::zeros(dim));
52        self.newton_matrix = Matrix::zeros(newton_system_size);
53        self.rhs_newton = vec![T::zero(); newton_system_size];
54        self.delta_k_vec = vec![T::zero(); newton_system_size];
55        self.jacobian_age = 0;
56
57        // Status
58        self.status = Status::Initialized;
59
60        Ok(evals)
61    }
62
63    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
64    where
65        F: ODE<T, Y, D>,
66    {
67        let mut evals = Evals::new();
68
69        // Max steps guard
70        if self.steps >= self.max_steps {
71            self.status = Status::Error(Error::MaxSteps {
72                t: self.t,
73                y: self.y,
74            });
75            return Err(Error::MaxSteps {
76                t: self.t,
77                y: self.y,
78            });
79        }
80        self.steps += 1;
81
82        // Initial stage guesses: copy current state
83        let dim = self.y.len();
84        for i in 0..self.stages {
85            self.z[i] = self.y;
86        }
87
88        // Newton solve for F(z) = z - y_n - h*A*f(z) = 0
89        let mut newton_converged = false;
90        let mut newton_iter = 0;
91
92        // Track increment norm
93        let mut increment_norm = T::infinity();
94
95        while !newton_converged && newton_iter < self.max_newton_iter {
96            newton_iter += 1;
97            self.newton_iterations += 1;
98            evals.newton += 1;
99
100            // Evaluate f at stage guesses
101            for i in 0..self.stages {
102                ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
103            }
104            evals.function += self.stages;
105
106            // Residual and max-norm
107            let mut residual_norm = T::zero();
108            for i in 0..self.stages {
109                // Start with z_i - y_n
110                let mut residual = self.z[i] - self.y;
111
112                // Subtract h*sum(a_ij * f_j)
113                for j in 0..self.stages {
114                    residual = residual - self.k[j] * (self.a[i][j] * self.h);
115                }
116
117                // Infinity norm and RHS
118                for row_idx in 0..dim {
119                    let res_val = residual.get(row_idx);
120                    residual_norm = residual_norm.max(res_val.abs());
121                    // Store residual in Newton RHS (negative for solving delta_z)
122                    self.rhs_newton[i * dim + row_idx] = -res_val;
123                }
124            }
125
126            // Converged by residual
127            if residual_norm < self.newton_tol {
128                newton_converged = true;
129                break;
130            }
131
132            // Converged by increment
133            if newton_iter > 1 && increment_norm < self.newton_tol {
134                newton_converged = true;
135                break;
136            }
137
138            // Refresh Jacobians if needed
139            if newton_iter == 1 || self.jacobian_age > 3 {
140                // Stage Jacobians
141                for i in 0..self.stages {
142                    ode.jacobian(
143                        self.t + self.c[i] * self.h,
144                        &self.z[i],
145                        &mut self.stage_jacobians[i],
146                    );
147                    evals.jacobian += 1;
148                }
149
150                // Build Newton matrix: I - h*(A ⊗ J)
151                let nsys = self.stages * dim;
152                let mut nm = Matrix::zeros(nsys);
153                for i in 0..self.stages {
154                    for j in 0..self.stages {
155                        let scale_factor = -self.h * self.a[i][j];
156                        // Use J from stage j
157                        for r in 0..dim {
158                            for c_col in 0..dim {
159                                nm[(i * dim + r, j * dim + c_col)] =
160                                    self.stage_jacobians[j][(r, c_col)] * scale_factor;
161                            }
162                        }
163                    }
164
165                    // Add identity per block
166                    for d_idx in 0..dim {
167                        let idx = i * dim + d_idx;
168                        nm[(idx, idx)] = nm[(idx, idx)] + T::one();
169                    }
170                }
171                self.newton_matrix = nm;
172
173                self.jacobian_age = 0;
174            }
175            self.jacobian_age += 1;
176
177            // Solve (I - h*A⊗J) Δz = -F(z) using in-place LU on our matrix
178            let mut rhs = self.rhs_newton.clone();
179            self.newton_matrix.lin_solve_mut(&mut rhs[..]);
180            for i in 0..self.delta_k_vec.len() {
181                self.delta_k_vec[i] = rhs[i];
182            }
183            self.lu_decompositions += 1;
184
185            // Update z_i and increment norm
186            increment_norm = T::zero();
187            for i in 0..self.stages {
188                for row_idx in 0..dim {
189                    let delta_val = self.delta_k_vec[i * dim + row_idx];
190                    let current_val = self.z[i].get(row_idx);
191                    self.z[i].set(row_idx, current_val + delta_val);
192                    // Calculate infinity norm of increment
193                    increment_norm = increment_norm.max(delta_val.abs());
194                }
195            }
196
197            // Next loop will re-check
198        }
199
200        // Newton failed to converge
201        if !newton_converged {
202            self.status = Status::Error(Error::Stiffness {
203                t: self.t,
204                y: self.y,
205            });
206            return Err(Error::Stiffness {
207                t: self.t,
208                y: self.y,
209            });
210        }
211
212        // Final stage derivatives
213        for i in 0..self.stages {
214            ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
215        }
216        evals.function += self.stages;
217
218        // y_{n+1} = y_n + h Σ b_i f_i
219        let mut y_new = self.y;
220        for i in 0..self.stages {
221            y_new += self.k[i] * (self.b[i] * self.h);
222        }
223
224        // Fixed step: always accept
225        self.status = Status::Solving;
226
227        // Log previous
228        self.t_prev = self.t;
229        self.y_prev = self.y;
230        self.dydt_prev = self.dydt;
231        self.h_prev = self.h;
232
233        // Advance state
234        self.t += self.h;
235        self.y = y_new;
236
237        // Next-step derivative
238        ode.diff(self.t, &self.y, &mut self.dydt);
239        evals.function += 1;
240
241        Ok(evals)
242    }
243
244    fn t(&self) -> T {
245        self.t
246    }
247    fn y(&self) -> &Y {
248        &self.y
249    }
250    fn t_prev(&self) -> T {
251        self.t_prev
252    }
253    fn y_prev(&self) -> &Y {
254        &self.y_prev
255    }
256    fn h(&self) -> T {
257        self.h
258    }
259    fn set_h(&mut self, h: T) {
260        self.h = h;
261    }
262    fn status(&self) -> &Status<T, Y, D> {
263        &self.status
264    }
265    fn set_status(&mut self, status: Status<T, Y, D>) {
266        self.status = status;
267    }
268}
269
270impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
271    Interpolation<T, Y> for ImplicitRungeKutta<Ordinary, Fixed, T, Y, D, O, S, I>
272{
273    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
274        // Check if t is within bounds
275        if t_interp < self.t_prev || t_interp > self.t {
276            return Err(Error::OutOfBounds {
277                t_interp,
278                t_prev: self.t_prev,
279                t_curr: self.t,
280            });
281        }
282
283        // Use cubic Hermite interpolation
284        let y_interp = cubic_hermite_interpolate(
285            self.t_prev,
286            self.t,
287            &self.y_prev,
288            &self.y,
289            &self.dydt_prev,
290            &self.dydt,
291            t_interp,
292        );
293
294        Ok(y_interp)
295    }
296}