Skip to main content

differential_equations/methods/irk/fixed/
ordinary.rs

1//! Fixed-step IRK for ODEs
2
3use crate::{
4    error::Error,
5    interpolate::{Interpolation, cubic_hermite_interpolate},
6    linalg::Matrix,
7    methods::{Fixed, ImplicitRungeKutta, Ordinary},
8    ode::{ODE, OrdinaryNumericalMethod},
9    stats::Evals,
10    status::Status,
11    traits::{Real, State},
12    utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
16    OrdinaryNumericalMethod<T, Y> for ImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
17{
18    fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19    where
20        F: ODE<T, Y> + ?Sized,
21    {
22        let mut evals = Evals::new();
23
24        // Validate step size bounds
25        match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
26            // Set the fixed step size
27            Ok(h0) => self.h = h0,
28            Err(status) => return Err(status),
29        }
30
31        // Stats
32        self.stiffness_counter = 0;
33        self.newton_iterations = 0;
34        self.jacobian_evaluations = 0;
35        self.lu_decompositions = 0;
36
37        // State
38        self.t = t0;
39        self.y = y0.clone();
40        self.dydt = y0.zeros_like();
41        self.y_prev = y0.clone();
42        self.dydt_prev = y0.zeros_like();
43        self.k = core::array::from_fn(|_| y0.zeros_like());
44        self.z = core::array::from_fn(|_| y0.zeros_like());
45        ode.diff(self.t, &self.y, &mut self.dydt);
46        evals.function += 1;
47
48        // Previous state
49        self.t_prev = self.t;
50        self.y_prev = self.y.clone();
51        self.dydt_prev = self.dydt.clone();
52
53        // Linear algebra workspace
54        let dim = y0.len();
55        let newton_system_size = self.stages * dim;
56        self.stage_jacobians = core::array::from_fn(|_| Matrix::zeros(dim, dim));
57        self.newton_matrix = Matrix::zeros(newton_system_size, newton_system_size);
58        self.rhs_newton = vec![T::zero(); newton_system_size];
59        self.delta_k_vec = vec![T::zero(); newton_system_size];
60        self.jacobian_age = 0;
61
62        // Status
63        self.status = Status::Initialized;
64
65        Ok(evals)
66    }
67
68    fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
69    where
70        F: ODE<T, Y> + ?Sized,
71    {
72        let mut evals = Evals::new();
73
74        // Max steps guard
75        if self.steps >= self.max_steps {
76            self.status = Status::Error(Error::MaxSteps {
77                t: self.t,
78                y: self.y.clone(),
79            });
80            return Err(Error::MaxSteps {
81                t: self.t,
82                y: self.y.clone(),
83            });
84        }
85        self.steps += 1;
86
87        // Initial stage guesses: copy current state
88        let dim = self.y.len();
89        for i in 0..self.stages {
90            self.z[i] = self.y.clone();
91        }
92
93        // Newton solve for F(z) = z - y_n - h*A*f(z) = 0
94        let mut newton_converged = false;
95        let mut newton_iter = 0;
96
97        // Track increment norm
98        let mut increment_norm = T::infinity();
99
100        while !newton_converged && newton_iter < self.max_newton_iter {
101            newton_iter += 1;
102            self.newton_iterations += 1;
103            evals.newton += 1;
104
105            // Evaluate f at stage guesses
106            for i in 0..self.stages {
107                ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
108            }
109            evals.function += self.stages;
110
111            // Residual and max-norm
112            let mut residual_norm = T::zero();
113            for i in 0..self.stages {
114                // Start with z_i - y_n
115                let mut residual = self.z[i].minus(&self.y);
116
117                // Subtract h*sum(a_ij * f_j)
118                for j in 0..self.stages {
119                    residual.add_scaled(-(self.a[i][j] * self.h), &self.k[j]);
120                }
121
122                // Infinity norm and RHS
123                for row_idx in 0..dim {
124                    let res_val = residual.get_component(row_idx);
125                    residual_norm = residual_norm.max(res_val.abs());
126                    // Store residual in Newton RHS (negative for solving delta_z)
127                    self.rhs_newton[i * dim + row_idx] = -res_val;
128                }
129            }
130
131            // Converged by residual
132            if residual_norm < self.newton_tol {
133                newton_converged = true;
134                break;
135            }
136
137            // Converged by increment
138            if newton_iter > 1 && increment_norm < self.newton_tol {
139                newton_converged = true;
140                break;
141            }
142
143            // Refresh Jacobians if needed
144            if newton_iter == 1 || self.jacobian_age > 3 {
145                // Stage Jacobians
146                for i in 0..self.stages {
147                    ode.jacobian(
148                        self.t + self.c[i] * self.h,
149                        &self.z[i],
150                        &mut self.stage_jacobians[i],
151                    );
152                    evals.jacobian += 1;
153                }
154
155                // Build Newton matrix: I - h*(A ⊗ J)
156                let nsys = self.stages * dim;
157                let mut nm = Matrix::zeros(nsys, nsys);
158                for i in 0..self.stages {
159                    for j in 0..self.stages {
160                        let scale_factor = -self.h * self.a[i][j];
161                        // Use J from stage j
162                        for r in 0..dim {
163                            for c_col in 0..dim {
164                                nm[(i * dim + r, j * dim + c_col)] =
165                                    self.stage_jacobians[j][(r, c_col)] * scale_factor;
166                            }
167                        }
168                    }
169
170                    // Add identity per block
171                    for d_idx in 0..dim {
172                        let idx = i * dim + d_idx;
173                        nm[(idx, idx)] += T::one();
174                    }
175                }
176                self.newton_matrix = nm;
177
178                self.jacobian_age = 0;
179            }
180            self.jacobian_age += 1;
181
182            // Solve (I - h*A⊗J) Δz = -F(z) using in-place LU on our matrix
183            let mut rhs = self.rhs_newton.clone();
184            self.newton_matrix
185                .lin_solve_mut(&mut rhs[..])
186                .map_err(|e| crate::error::Error::LinearAlgebra {
187                    t: self.t,
188                    y: self.y.clone(),
189                    msg: e.to_string(),
190                })?;
191            evals.solves += 1;
192
193            // Update z_i and increment norm
194            increment_norm = T::zero();
195            for i in 0..self.stages {
196                for row_idx in 0..dim {
197                    let delta_val = rhs[i * dim + row_idx];
198                    let current_z = self.z[i].get_component(row_idx);
199                    self.z[i].set_component(row_idx, current_z + delta_val);
200                    // Calculate infinity norm of increment
201                    increment_norm = increment_norm.max(delta_val.abs());
202                }
203            }
204
205            // Next loop will re-check
206        }
207
208        // Newton failed to converge
209        if !newton_converged {
210            self.status = Status::Error(Error::Stiffness {
211                t: self.t,
212                y: self.y.clone(),
213            });
214            return Err(Error::Stiffness {
215                t: self.t,
216                y: self.y.clone(),
217            });
218        }
219
220        // Final stage derivatives
221        for i in 0..self.stages {
222            ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
223        }
224        evals.function += self.stages;
225
226        // y_{n+1} = y_n + h Σ b_i f_i
227        let mut y_new = self.y.clone();
228        for i in 0..self.stages {
229            y_new.add_scaled(self.b[i] * self.h, &self.k[i]);
230        }
231
232        // Fixed step: always accept
233        self.status = Status::Solving;
234
235        // Log previous
236        self.t_prev = self.t;
237        self.y_prev = self.y.clone();
238        self.dydt_prev = self.dydt.clone();
239        self.h_prev = self.h;
240
241        // Advance state
242        self.t += self.h;
243        self.y = y_new;
244
245        // Next-step derivative
246        ode.diff(self.t, &self.y, &mut self.dydt);
247        evals.function += 1;
248
249        Ok(evals)
250    }
251
252    fn t(&self) -> T {
253        self.t
254    }
255    fn y(&self) -> &Y {
256        &self.y
257    }
258    fn t_prev(&self) -> T {
259        self.t_prev
260    }
261    fn y_prev(&self) -> &Y {
262        &self.y_prev
263    }
264    fn h(&self) -> T {
265        self.h
266    }
267    fn set_h(&mut self, h: T) {
268        self.h = h;
269    }
270    fn status(&self) -> &Status<T, Y> {
271        &self.status
272    }
273    fn set_status(&mut self, status: Status<T, Y>) {
274        self.status = status;
275    }
276}
277
278impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
279    for ImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
280{
281    fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
282        // Check if t is within bounds
283        if t_interp < self.t_prev || t_interp > self.t {
284            return Err(Error::OutOfBounds {
285                t_interp,
286                t_prev: self.t_prev,
287                t_curr: self.t,
288            });
289        }
290
291        // Use cubic Hermite interpolation
292        let y_interp = cubic_hermite_interpolate(
293            self.t_prev,
294            self.t,
295            &self.y_prev,
296            &self.y,
297            &self.dydt_prev,
298            &self.dydt,
299            t_interp,
300        );
301
302        Ok(y_interp)
303    }
304}