differential_equations/methods/irk/adaptive/
ordinary.rs1use crate::{
4 error::Error,
5 interpolate::{Interpolation, cubic_hermite_interpolate},
6 linalg::Matrix,
7 methods::h_init::InitialStepSize,
8 methods::{Adaptive, ImplicitRungeKutta, Ordinary},
9 ode::{ODE, OrdinaryNumericalMethod},
10 stats::Evals,
11 status::Status,
12 traits::{Real, State},
13 utils::{constrain_step_size, validate_step_size_parameters},
14};
15
16impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
17 OrdinaryNumericalMethod<T, Y> for ImplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
18{
19 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
20 where
21 F: ODE<T, Y>,
22 {
23 let mut evals = Evals::new();
24
25 if self.h0 == T::zero() {
27 self.h0 = InitialStepSize::<Ordinary>::compute(
29 ode, t0, tf, y0, self.order, &self.rtol, &self.atol, self.h_min, self.h_max,
30 &mut evals,
31 );
32 }
33
34 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
36 Ok(h0) => self.h = h0,
37 Err(status) => return Err(status),
38 }
39
40 self.stiffness_counter = 0;
42 self.newton_iterations = 0;
43 self.jacobian_evaluations = 0;
44 self.lu_decompositions = 0;
45
46 self.t = t0;
48 self.y = *y0;
49 ode.diff(self.t, &self.y, &mut self.dydt);
50 evals.function += 1;
51
52 self.t_prev = self.t;
54 self.y_prev = self.y;
55 self.dydt_prev = self.dydt;
56
57 let dim = y0.len();
59 let newton_system_size = self.stages * dim;
60 self.stage_jacobians = core::array::from_fn(|_| Matrix::zeros(dim, dim));
61 self.newton_matrix = Matrix::zeros(newton_system_size, newton_system_size);
62 self.rhs_newton = vec![T::zero(); newton_system_size];
64 self.delta_k_vec = vec![T::zero(); newton_system_size];
65 self.jacobian_age = 0;
66
67 self.status = Status::Initialized;
69
70 Ok(evals)
71 }
72
73 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
74 where
75 F: ODE<T, Y>,
76 {
77 let mut evals = Evals::new();
78
79 if self.h.abs() < self.h_prev.abs() * T::from_f64(1e-14).unwrap() {
81 self.status = Status::Error(Error::StepSize {
82 t: self.t,
83 y: self.y,
84 });
85 return Err(Error::StepSize {
86 t: self.t,
87 y: self.y,
88 });
89 }
90
91 if self.steps >= self.max_steps {
93 self.status = Status::Error(Error::MaxSteps {
94 t: self.t,
95 y: self.y,
96 });
97 return Err(Error::MaxSteps {
98 t: self.t,
99 y: self.y,
100 });
101 }
102 self.steps += 1;
103
104 let dim = self.y.len();
106 for i in 0..self.stages {
107 self.z[i] = self.y;
108 }
109
110 let mut newton_converged = false;
112 let mut newton_iter = 0;
113
114 let mut increment_norm = T::infinity();
116
117 while !newton_converged && newton_iter < self.max_newton_iter {
118 newton_iter += 1;
119 self.newton_iterations += 1;
120 evals.newton += 1;
121
122 for i in 0..self.stages {
124 ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
125 }
126 evals.function += self.stages;
127
128 let mut residual_norm = T::zero();
130 for i in 0..self.stages {
131 let mut residual = self.z[i] - self.y;
133
134 for j in 0..self.stages {
136 residual = residual - self.k[j] * (self.a[i][j] * self.h);
137 }
138
139 for row_idx in 0..dim {
141 let res_val = residual.get(row_idx);
142 residual_norm = residual_norm.max(res_val.abs());
143 self.rhs_newton[i * dim + row_idx] = -res_val;
145 }
146 }
147
148 if residual_norm < self.newton_tol {
150 newton_converged = true;
151 break;
152 }
153
154 if newton_iter > 1 && increment_norm < self.newton_tol {
156 newton_converged = true;
157 break;
158 }
159
160 if newton_iter == 1 || self.jacobian_age > 3 {
162 for i in 0..self.stages {
164 ode.jacobian(
165 self.t + self.c[i] * self.h,
166 &self.z[i],
167 &mut self.stage_jacobians[i],
168 );
169 evals.jacobian += 1;
170 }
171
172 let nsys = self.stages * dim;
175 let mut nm = Matrix::zeros(nsys, nsys);
176 for i in 0..self.stages {
178 for j in 0..self.stages {
179 let scale = -self.h * self.a[i][j];
180 for r in 0..dim {
181 for c_col in 0..dim {
182 nm[(i * dim + r, j * dim + c_col)] =
183 self.stage_jacobians[j][(r, c_col)] * scale;
184 }
185 }
186 }
187 for d_idx in 0..dim {
189 let idx = i * dim + d_idx;
190 nm[(idx, idx)] += T::one();
191 }
192 }
193 self.newton_matrix = nm;
194
195 self.jacobian_age = 0;
196 }
197 self.jacobian_age += 1;
198
199 let mut rhs = self.rhs_newton.clone();
201 self.newton_matrix.lin_solve_mut(&mut rhs[..]);
202 for i in 0..self.delta_k_vec.len() {
203 self.delta_k_vec[i] = rhs[i];
204 }
205 self.lu_decompositions += 1;
206
207 increment_norm = T::zero();
209 for i in 0..self.stages {
210 for row_idx in 0..dim {
211 let delta_val = self.delta_k_vec[i * dim + row_idx];
212 let current_val = self.z[i].get(row_idx);
213 self.z[i].set(row_idx, current_val + delta_val);
214 increment_norm = increment_norm.max(delta_val.abs());
216 }
217 }
218
219 }
221
222 if !newton_converged {
224 self.h *= T::from_f64(0.25).unwrap();
226 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
227 self.status = Status::RejectedStep;
228 self.stiffness_counter += 1;
229
230 if self.stiffness_counter >= self.max_rejects {
231 self.status = Status::Error(Error::Stiffness {
232 t: self.t,
233 y: self.y,
234 });
235 return Err(Error::Stiffness {
236 t: self.t,
237 y: self.y,
238 });
239 }
240 return Ok(evals);
241 }
242
243 for i in 0..self.stages {
245 ode.diff(self.t + self.c[i] * self.h, &self.z[i], &mut self.k[i]);
246 }
247 evals.function += self.stages;
248
249 let mut y_new = self.y;
251 for i in 0..self.stages {
252 y_new += self.k[i] * (self.b[i] * self.h);
253 }
254
255 let mut err_norm = T::zero();
257 let bh = &self.bh.unwrap();
258
259 let mut y_low = self.y;
261 for i in 0..self.stages {
262 y_low += self.k[i] * (bh[i] * self.h);
263 }
264
265 let err = y_new - y_low;
267
268 for n in 0..self.y.len() {
270 let scale = self.atol[n] + self.rtol[n] * self.y.get(n).abs().max(y_new.get(n).abs());
271 if scale > T::zero() {
272 err_norm = err_norm.max((err.get(n) / scale).abs());
273 }
274 }
275
276 err_norm = err_norm.max(T::default_epsilon() * T::from_f64(100.0).unwrap());
278
279 let order = T::from_usize(self.order).unwrap();
281 let error_exponent = T::one() / order;
282 let mut scale = self.safety_factor * err_norm.powf(-error_exponent);
283
284 scale = scale.max(self.min_scale).min(self.max_scale);
286
287 if err_norm <= T::one() {
289 self.status = Status::Solving;
291
292 self.t_prev = self.t;
294 self.y_prev = self.y;
295 self.dydt_prev = self.dydt;
296 self.h_prev = self.h;
297
298 self.t += self.h;
300 self.y = y_new;
301
302 ode.diff(self.t, &self.y, &mut self.dydt);
304 evals.function += 1;
305
306 if let Status::RejectedStep = self.status {
308 self.stiffness_counter = 0;
309
310 scale = scale.min(T::one());
312 }
313 } else {
314 self.status = Status::RejectedStep;
316 self.stiffness_counter += 1;
317
318 if self.stiffness_counter >= self.max_rejects {
320 self.status = Status::Error(Error::Stiffness {
321 t: self.t,
322 y: self.y,
323 });
324 return Err(Error::Stiffness {
325 t: self.t,
326 y: self.y,
327 });
328 }
329 }
330
331 self.h *= scale;
333
334 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
336
337 Ok(evals)
338 }
339
340 fn t(&self) -> T {
341 self.t
342 }
343 fn y(&self) -> &Y {
344 &self.y
345 }
346 fn t_prev(&self) -> T {
347 self.t_prev
348 }
349 fn y_prev(&self) -> &Y {
350 &self.y_prev
351 }
352 fn h(&self) -> T {
353 self.h
354 }
355 fn set_h(&mut self, h: T) {
356 self.h = h;
357 }
358 fn status(&self) -> &Status<T, Y> {
359 &self.status
360 }
361 fn set_status(&mut self, status: Status<T, Y>) {
362 self.status = status;
363 }
364}
365
366impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
367 for ImplicitRungeKutta<Ordinary, Adaptive, T, Y, O, S, I>
368{
369 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
370 if t_interp < self.t_prev || t_interp > self.t {
372 return Err(Error::OutOfBounds {
373 t_interp,
374 t_prev: self.t_prev,
375 t_curr: self.t,
376 });
377 }
378
379 let y_interp = cubic_hermite_interpolate(
381 self.t_prev,
382 self.t,
383 &self.y_prev,
384 &self.y,
385 &self.dydt_prev,
386 &self.dydt,
387 t_interp,
388 );
389
390 Ok(y_interp)
391 }
392}