differential_equations/methods/irk/fixed/
ordinary.rs1use crate::{
4 error::Error,
5 interpolate::{Interpolation, cubic_hermite_interpolate},
6 linalg::Matrix,
7 methods::{Fixed, ImplicitRungeKutta, Ordinary},
8 ode::{ODE, OrdinaryNumericalMethod},
9 stats::Evals,
10 status::Status,
11 traits::{Real, State},
12 utils::validate_step_size_parameters,
13};
14
15impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
16 OrdinaryNumericalMethod<T, Y> for ImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
17{
18 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
19 where
20 F: ODE<T, Y> + ?Sized,
21 {
22 let mut evals = Evals::new();
23
24 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
26 Ok(h0) => self.h = h0,
28 Err(status) => return Err(status),
29 }
30
31 self.stiffness_counter = 0;
33 self.newton_iterations = 0;
34 self.jacobian_evaluations = 0;
35 self.lu_decompositions = 0;
36
37 self.t = t0;
39 self.y = y0.clone();
40 self.dydt = y0.zeros_like();
41 self.y_prev = y0.clone();
42 self.dydt_prev = y0.zeros_like();
43 self.k = core::array::from_fn(|_| y0.zeros_like());
44 self.z = core::array::from_fn(|_| y0.zeros_like());
45 ode.diff(self.t, &self.y, &mut self.dydt);
46 evals.function += 1;
47
48 self.t_prev = self.t;
50 self.y_prev = self.y.clone();
51 self.dydt_prev = self.dydt.clone();
52
53 let dim = y0.len();
55 let newton_system_size = self.stages * dim;
56 self.stage_jacobians = core::array::from_fn(|_| Matrix::zeros(dim, dim));
57 self.newton_matrix = Matrix::zeros(newton_system_size, newton_system_size);
58 self.rhs_newton = vec![T::zero(); newton_system_size];
59 self.delta_k_vec = vec![T::zero(); newton_system_size];
60 self.jacobian_age = 0;
61
62 self.status = Status::Initialized;
64
65 Ok(evals)
66 }
67
68 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
69 where
70 F: ODE<T, Y> + ?Sized,
71 {
72 let mut evals = Evals::new();
73
74 if self.steps >= self.max_steps {
76 self.status = Status::Error(Error::MaxSteps {
77 t: self.t,
78 y: self.y.clone(),
79 });
80 return Err(Error::MaxSteps {
81 t: self.t,
82 y: self.y.clone(),
83 });
84 }
85 self.steps += 1;
86
87 let dim = self.y.len();
89 for i in 0..self.stages {
90 self.z[i] = self.y.clone();
91 }
92
93 let mut newton_converged = false;
95 let mut newton_iter = 0;
96
97 let mut increment_norm = T::infinity();
99
100 while !newton_converged && newton_iter < self.max_newton_iter {
101 newton_iter += 1;
102 self.newton_iterations += 1;
103 evals.newton += 1;
104
105 for i in 0..self.stages {
107 ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
108 }
109 evals.function += self.stages;
110
111 let mut residual_norm = T::zero();
113 for i in 0..self.stages {
114 let mut residual = self.z[i].minus(&self.y);
116
117 for j in 0..self.stages {
119 residual.add_scaled(-(self.a[i][j] * self.h), &self.k[j]);
120 }
121
122 for row_idx in 0..dim {
124 let res_val = residual.get_component(row_idx);
125 residual_norm = residual_norm.max(res_val.abs());
126 self.rhs_newton[i * dim + row_idx] = -res_val;
128 }
129 }
130
131 if residual_norm < self.newton_tol {
133 newton_converged = true;
134 break;
135 }
136
137 if newton_iter > 1 && increment_norm < self.newton_tol {
139 newton_converged = true;
140 break;
141 }
142
143 if newton_iter == 1 || self.jacobian_age > 3 {
145 for i in 0..self.stages {
147 ode.jacobian(
148 self.t + self.c[i] * self.h,
149 &self.z[i],
150 &mut self.stage_jacobians[i],
151 );
152 evals.jacobian += 1;
153 }
154
155 let nsys = self.stages * dim;
157 let mut nm = Matrix::zeros(nsys, nsys);
158 for i in 0..self.stages {
159 for j in 0..self.stages {
160 let scale_factor = -self.h * self.a[i][j];
161 for r in 0..dim {
163 for c_col in 0..dim {
164 nm[(i * dim + r, j * dim + c_col)] =
165 self.stage_jacobians[j][(r, c_col)] * scale_factor;
166 }
167 }
168 }
169
170 for d_idx in 0..dim {
172 let idx = i * dim + d_idx;
173 nm[(idx, idx)] += T::one();
174 }
175 }
176 self.newton_matrix = nm;
177
178 self.jacobian_age = 0;
179 }
180 self.jacobian_age += 1;
181
182 let mut rhs = self.rhs_newton.clone();
184 self.newton_matrix
185 .lin_solve_mut(&mut rhs[..])
186 .map_err(|e| crate::error::Error::LinearAlgebra {
187 t: self.t,
188 y: self.y.clone(),
189 msg: e.to_string(),
190 })?;
191 evals.solves += 1;
192
193 increment_norm = T::zero();
195 for i in 0..self.stages {
196 for row_idx in 0..dim {
197 let delta_val = rhs[i * dim + row_idx];
198 let current_z = self.z[i].get_component(row_idx);
199 self.z[i].set_component(row_idx, current_z + delta_val);
200 increment_norm = increment_norm.max(delta_val.abs());
202 }
203 }
204
205 }
207
208 if !newton_converged {
210 self.status = Status::Error(Error::Stiffness {
211 t: self.t,
212 y: self.y.clone(),
213 });
214 return Err(Error::Stiffness {
215 t: self.t,
216 y: self.y.clone(),
217 });
218 }
219
220 for i in 0..self.stages {
222 ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
223 }
224 evals.function += self.stages;
225
226 let mut y_new = self.y.clone();
228 for i in 0..self.stages {
229 y_new.add_scaled(self.b[i] * self.h, &self.k[i]);
230 }
231
232 self.status = Status::Solving;
234
235 self.t_prev = self.t;
237 self.y_prev = self.y.clone();
238 self.dydt_prev = self.dydt.clone();
239 self.h_prev = self.h;
240
241 self.t += self.h;
243 self.y = y_new;
244
245 ode.diff(self.t, &self.y, &mut self.dydt);
247 evals.function += 1;
248
249 Ok(evals)
250 }
251
252 fn t(&self) -> T {
253 self.t
254 }
255 fn y(&self) -> &Y {
256 &self.y
257 }
258 fn t_prev(&self) -> T {
259 self.t_prev
260 }
261 fn y_prev(&self) -> &Y {
262 &self.y_prev
263 }
264 fn h(&self) -> T {
265 self.h
266 }
267 fn set_h(&mut self, h: T) {
268 self.h = h;
269 }
270 fn status(&self) -> &Status<T, Y> {
271 &self.status
272 }
273 fn set_status(&mut self, status: Status<T, Y>) {
274 self.status = status;
275 }
276}
277
278impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
279 for ImplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
280{
281 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
282 if t_interp < self.t_prev || t_interp > self.t {
284 return Err(Error::OutOfBounds {
285 t_interp,
286 t_prev: self.t_prev,
287 t_curr: self.t,
288 });
289 }
290
291 let y_interp = cubic_hermite_interpolate(
293 self.t_prev,
294 self.t,
295 &self.y_prev,
296 &self.y,
297 &self.dydt_prev,
298 &self.dydt,
299 t_interp,
300 );
301
302 Ok(y_interp)
303 }
304}