differential_equations/methods/irk/fixed/
ordinary.rs1use crate::{
4 error::Error,
5 interpolate::{Interpolation, cubic_hermite_interpolate},
6 linalg::Matrix,
7 methods::{Fixed, ImplicitRungeKutta, Ordinary},
8 ode::{ODE, OrdinaryNumericalMethod},
9 stats::Evals,
10 status::Status,
11 traits::{Real, State},
12 utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
16 OrdinaryNumericalMethod<T, Y> for ImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
17{
18 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19 where
20 F: ODE<T, Y>,
21 {
22 let mut evals = Evals::new();
23
24 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
26 Ok(h0) => self.h = h0,
28 Err(status) => return Err(status),
29 }
30
31 self.stiffness_counter = 0;
33 self.newton_iterations = 0;
34 self.jacobian_evaluations = 0;
35 self.lu_decompositions = 0;
36
37 self.t = t0;
39 self.y = *y0;
40 ode.diff(self.t, &self.y, &mut self.dydt);
41 evals.function += 1;
42
43 self.t_prev = self.t;
45 self.y_prev = self.y;
46 self.dydt_prev = self.dydt;
47
48 let dim = y0.len();
50 let newton_system_size = self.stages * dim;
51 self.stage_jacobians = core::array::from_fn(|_| Matrix::zeros(dim, dim));
52 self.newton_matrix = Matrix::zeros(newton_system_size, newton_system_size);
53 self.rhs_newton = vec![T::zero(); newton_system_size];
54 self.delta_k_vec = vec![T::zero(); newton_system_size];
55 self.jacobian_age = 0;
56
57 self.status = Status::Initialized;
59
60 Ok(evals)
61 }
62
63 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
64 where
65 F: ODE<T, Y>,
66 {
67 let mut evals = Evals::new();
68
69 if self.steps >= self.max_steps {
71 self.status = Status::Error(Error::MaxSteps {
72 t: self.t,
73 y: self.y,
74 });
75 return Err(Error::MaxSteps {
76 t: self.t,
77 y: self.y,
78 });
79 }
80 self.steps += 1;
81
82 let dim = self.y.len();
84 for i in 0..self.stages {
85 self.z[i] = self.y;
86 }
87
88 let mut newton_converged = false;
90 let mut newton_iter = 0;
91
92 let mut increment_norm = T::infinity();
94
95 while !newton_converged && newton_iter < self.max_newton_iter {
96 newton_iter += 1;
97 self.newton_iterations += 1;
98 evals.newton += 1;
99
100 for i in 0..self.stages {
102 ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
103 }
104 evals.function += self.stages;
105
106 let mut residual_norm = T::zero();
108 for i in 0..self.stages {
109 let mut residual = self.z[i] - self.y;
111
112 for j in 0..self.stages {
114 residual = residual - self.k[j] * (self.a[i][j] * self.h);
115 }
116
117 for row_idx in 0..dim {
119 let res_val = residual.get(row_idx);
120 residual_norm = residual_norm.max(res_val.abs());
121 self.rhs_newton[i * dim + row_idx] = -res_val;
123 }
124 }
125
126 if residual_norm < self.newton_tol {
128 newton_converged = true;
129 break;
130 }
131
132 if newton_iter > 1 && increment_norm < self.newton_tol {
134 newton_converged = true;
135 break;
136 }
137
138 if newton_iter == 1 || self.jacobian_age > 3 {
140 for i in 0..self.stages {
142 ode.jacobian(
143 self.t + self.c[i] * self.h,
144 &self.z[i],
145 &mut self.stage_jacobians[i],
146 );
147 evals.jacobian += 1;
148 }
149
150 let nsys = self.stages * dim;
152 let mut nm = Matrix::zeros(nsys, nsys);
153 for i in 0..self.stages {
154 for j in 0..self.stages {
155 let scale_factor = -self.h * self.a[i][j];
156 for r in 0..dim {
158 for c_col in 0..dim {
159 nm[(i * dim + r, j * dim + c_col)] =
160 self.stage_jacobians[j][(r, c_col)] * scale_factor;
161 }
162 }
163 }
164
165 for d_idx in 0..dim {
167 let idx = i * dim + d_idx;
168 nm[(idx, idx)] += T::one();
169 }
170 }
171 self.newton_matrix = nm;
172
173 self.jacobian_age = 0;
174 }
175 self.jacobian_age += 1;
176
177 let mut rhs = self.rhs_newton.clone();
179 self.newton_matrix.lin_solve_mut(&mut rhs[..]);
180 for i in 0..self.delta_k_vec.len() {
181 self.delta_k_vec[i] = rhs[i];
182 }
183 self.lu_decompositions += 1;
184
185 increment_norm = T::zero();
187 for i in 0..self.stages {
188 for row_idx in 0..dim {
189 let delta_val = self.delta_k_vec[i * dim + row_idx];
190 let current_val = self.z[i].get(row_idx);
191 self.z[i].set(row_idx, current_val + delta_val);
192 increment_norm = increment_norm.max(delta_val.abs());
194 }
195 }
196
197 }
199
200 if !newton_converged {
202 self.status = Status::Error(Error::Stiffness {
203 t: self.t,
204 y: self.y,
205 });
206 return Err(Error::Stiffness {
207 t: self.t,
208 y: self.y,
209 });
210 }
211
212 for i in 0..self.stages {
214 ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
215 }
216 evals.function += self.stages;
217
218 let mut y_new = self.y;
220 for i in 0..self.stages {
221 y_new += self.k[i] * (self.b[i] * self.h);
222 }
223
224 self.status = Status::Solving;
226
227 self.t_prev = self.t;
229 self.y_prev = self.y;
230 self.dydt_prev = self.dydt;
231 self.h_prev = self.h;
232
233 self.t += self.h;
235 self.y = y_new;
236
237 ode.diff(self.t, &self.y, &mut self.dydt);
239 evals.function += 1;
240
241 Ok(evals)
242 }
243
244 fn t(&self) -> T {
245 self.t
246 }
247 fn y(&self) -> &Y {
248 &self.y
249 }
250 fn t_prev(&self) -> T {
251 self.t_prev
252 }
253 fn y_prev(&self) -> &Y {
254 &self.y_prev
255 }
256 fn h(&self) -> T {
257 self.h
258 }
259 fn set_h(&mut self, h: T) {
260 self.h = h;
261 }
262 fn status(&self) -> &Status<T, Y> {
263 &self.status
264 }
265 fn set_status(&mut self, status: Status<T, Y>) {
266 self.status = status;
267 }
268}
269
270impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
271 for ImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
272{
273 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
274 if t_interp < self.t_prev || t_interp > self.t {
276 return Err(Error::OutOfBounds {
277 t_interp,
278 t_prev: self.t_prev,
279 t_curr: self.t,
280 });
281 }
282
283 let y_interp = cubic_hermite_interpolate(
285 self.t_prev,
286 self.t,
287 &self.y_prev,
288 &self.y,
289 &self.dydt_prev,
290 &self.dydt,
291 t_interp,
292 );
293
294 Ok(y_interp)
295 }
296}