1use crate::{
4 Error, Status,
5 alias::Evals,
6 interpolate::Interpolation,
7 ode::{ODENumericalMethod, ODE, methods::h_init},
8 traits::{CallBackData, Real, State},
9 utils::{constrain_step_size, validate_step_size_parameters},
10};
11
12const SQRT6: f64 = 2.449489743; const C0: f64 = (4.0 - SQRT6) / 10.0;
17const C1: f64 = (4.0 + SQRT6) / 10.0;
18const C2: f64 = 1.0;
19
20const A11: f64 = (88.0 - 7.0 * SQRT6) / 360.0;
22const A12: f64 = (296.0 - 169.0 * SQRT6) / 1800.0;
23const A13: f64 = (-2.0 + 3.0 * SQRT6) / 225.0;
24const A21: f64 = (296.0 + 169.0 * SQRT6) / 1800.0;
25const A22: f64 = (88.0 + 7.0 * SQRT6) / 360.0;
26const A23: f64 = (-2.0 - 3.0 * SQRT6) / 225.0;
27const A31: f64 = (16.0 - SQRT6) / 36.0;
28const A32: f64 = (16.0 + SQRT6) / 36.0;
29const A33: f64 = 1.0 / 9.0;
30
31const B0: f64 = (16.0 - SQRT6) / 36.0;
33const B1: f64 = (16.0 + SQRT6) / 36.0;
34const B2: f64 = 1.0 / 9.0;
35
36const BHAT0: f64 = (16.0 - SQRT6) / 36.0 - 0.01;
38const BHAT1: f64 = (16.0 + SQRT6) / 36.0 - 0.01;
39const BHAT2: f64 = 1.0 / 9.0 + 0.02;
40
41pub struct Radau5<T: Real, V: State<T>, D: CallBackData> {
56 pub h0: T,
58 h: T,
60 h_prev_step: T,
62
63 t: T,
65 y: V,
66 dydt: V, t_prev: T,
70 y_prev: V,
71 dydt_prev: V, k: [V; 3],
75 y_stage: [V; 3],
77 f_at_stages: [V; 3],
79
80 a: [[T; 3]; 3],
82 b_higher: [T; 3], b_lower: [T; 3], c: [T; 3], cont: [V; 4],
88
89 pub rtol: T,
91 pub atol: T,
92 pub h_max: T,
93 pub h_min: T,
94 pub max_steps: usize,
95 pub max_rejects: usize,
96 pub safety_factor: T,
97 pub min_scale: T,
98 pub max_scale: T,
99
100 pub max_iter: usize, pub tol: T, fd_epsilon_sqrt: T, reject: bool,
107 n_stiff: usize,
108 steps: usize,
109 status: Status<T, V, D>,
110
111 jacobian_matrix: nalgebra::DMatrix<T>, newton_matrix: nalgebra::DMatrix<T>, rhs_newton: nalgebra::DVector<T>, delta_k_vec: nalgebra::DVector<T>, }
117
118impl<T: Real, V: State<T>, D: CallBackData> Default for Radau5<T, V, D> {
119 fn default() -> Self {
120 let a_coeffs: [[f64; 3]; 3] = [
122 [A11, A12, A13],
123 [A21, A22, A23],
124 [A31, A32, A33]
125 ];
126 let b_coeffs: [f64; 3] = [B0, B1, B2];
127 let b_hat_coeffs: [f64; 3] = [BHAT0, BHAT1, BHAT2];
128 let c_coeffs: [f64; 3] = [C0, C1, C2];
129
130 let a_t: [[T; 3]; 3] = a_coeffs.map(|row| row.map(|x| T::from_f64(x).unwrap()));
131 let b_higher_t: [T; 3] = b_coeffs.map(|x| T::from_f64(x).unwrap());
132 let b_lower_t: [T; 3] = b_hat_coeffs.map(|x| T::from_f64(x).unwrap());
133 let c_t: [T; 3] = c_coeffs.map(|x| T::from_f64(x).unwrap());
134
135 Radau5 {
136 h0: T::zero(),
137 h: T::zero(),
138 h_prev_step: T::zero(),
139 t: T::zero(),
140 y: V::zeros(),
141 dydt: V::zeros(),
142 t_prev: T::zero(),
143 y_prev: V::zeros(),
144 dydt_prev: V::zeros(),
145 k: [V::zeros(); 3],
146 y_stage: [V::zeros(); 3],
147 f_at_stages: [V::zeros(); 3],
148 a: a_t,
149 b_higher: b_higher_t,
150 b_lower: b_lower_t,
151 c: c_t,
152 cont: [V::zeros(); 4],
153 rtol: T::from_f64(1.0e-6).unwrap(),
155 atol: T::from_f64(1.0e-6).unwrap(),
156 h_max: T::infinity(),
157 h_min: T::zero(),
158 max_steps: 10000,
159 max_rejects: 100,
160 safety_factor: T::from_f64(0.9).unwrap(),
161 min_scale: T::from_f64(0.2).unwrap(),
162 max_scale: T::from_f64(10.0).unwrap(),
163 max_iter: 50,
165 tol: T::from_f64(1e-8).unwrap(),
166 fd_epsilon_sqrt: T::zero(),
167 reject: false,
169 n_stiff: 0,
170 steps: 0,
171 status: Status::Uninitialized,
172 jacobian_matrix: nalgebra::DMatrix::zeros(0, 0),
174 newton_matrix: nalgebra::DMatrix::zeros(0, 0),
175 rhs_newton: nalgebra::DVector::zeros(0),
176 delta_k_vec: nalgebra::DVector::zeros(0),
177 }
178 }
179}
180
181impl<T: Real, V: State<T>, D: CallBackData> ODENumericalMethod<T, V, D> for Radau5<T, V, D> {
182 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
183 where
184 F: ODE<T, V, D>,
185 {
186 let mut evals = Evals::new();
187
188 let mut initial_dydt = V::zeros();
190 ode.diff(t0, y0, &mut initial_dydt);
191 evals.fcn += 1;
192
193 if self.h0 == T::zero() {
195 self.h0 = h_init(ode, t0, tf, y0, 5, self.rtol, self.atol, self.h_min, self.h_max); }
197
198 self.h = validate_step_size_parameters::<T, V, D>(self.h0, self.h_min, self.h_max, t0, tf)?;
200 self.h_prev_step = self.h;
201
202 self.reject = false;
204 self.n_stiff = 0;
205 self.steps = 0;
206
207 self.t = t0;
209 self.y = *y0;
210 self.dydt = initial_dydt; self.t_prev = t0;
214 self.y_prev = *y0;
215 self.dydt_prev = initial_dydt;
216
217 self.fd_epsilon_sqrt = T::default_epsilon().sqrt();
219
220 self.status = Status::Initialized;
222
223 let dim = y0.len();
225 self.jacobian_matrix = nalgebra::DMatrix::zeros(dim, dim);
226 let newton_system_size = 3 * dim; self.newton_matrix = nalgebra::DMatrix::zeros(newton_system_size, newton_system_size);
228 self.rhs_newton = nalgebra::DVector::zeros(newton_system_size);
229 self.delta_k_vec = nalgebra::DVector::zeros(newton_system_size);
230 self.f_at_stages = [V::zeros(); 3];
231
232 self.cont = [V::zeros(); 4];
234
235 Ok(evals)
236 }
237
238 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
239 where
240 F: ODE<T, V, D>,
241 {
242 let mut evals = Evals::new();
243
244 if self.h.abs() < self.h_min || self.h.abs() < T::default_epsilon() {
246 self.status = Status::Error(Error::StepSize { t: self.t, y: self.y });
247 return Err(Error::StepSize { t: self.t, y: self.y });
248 }
249
250 if self.steps >= self.max_steps {
252 self.status = Status::Error(Error::MaxSteps { t: self.t, y: self.y });
253 return Err(Error::MaxSteps { t: self.t, y: self.y });
254 }
255 self.steps += 1;
256
257 for i in 0..3 {
260 self.k[i] = self.dydt;
261 }
262
263 ode.jacobian(self.t, &self.y, &mut self.jacobian_matrix);
265 evals.jac += 1;
266
267 let newton_converged = self.newton_iteration(ode, &mut evals)?;
269
270 if !newton_converged {
271 self.h *= T::from_f64(0.25).unwrap();
272 self.h = constrain_step_size(self.h, self.h_min, self.h_max);
273 self.reject = true;
274 self.n_stiff += 1;
275 evals.fcn += 1;
276
277 if self.n_stiff >= self.max_rejects {
278 self.status = Status::Error(Error::Stiffness { t: self.t, y: self.y });
279 return Err(Error::Stiffness { t: self.t, y: self.y });
280 }
281 return Ok(evals); }
283
284 let mut delta_y_high = V::zeros();
286 for i in 0..3 {
287 delta_y_high += self.k[i] * (self.b_higher[i] * self.h);
288 }
289 let y_high = self.y + delta_y_high;
290
291 let mut delta_y_low = V::zeros();
292 for i in 0..3 {
293 delta_y_low += self.k[i] * (self.b_lower[i] * self.h);
294 }
295 let y_low = self.y + delta_y_low;
296
297 let err_vec = y_high - y_low;
298
299 let mut err_norm = T::zero();
301 for n in 0..self.y.len() {
302 let scale = self.atol + self.rtol * self.y.get(n).abs().max(y_high.get(n).abs());
303 if scale > T::zero() {
304 err_norm = err_norm.max((err_vec.get(n) / scale).abs());
305 }
306 }
307 err_norm = err_norm.max(T::default_epsilon() * T::from_f64(100.0).unwrap());
308
309 let order_p1 = T::from_usize(5 + 1).unwrap(); let mut scale = self.safety_factor * err_norm.powf(-T::one() / order_p1);
312 scale = scale.max(self.min_scale).min(self.max_scale);
313 let h_new = self.h * scale;
314
315 if err_norm <= T::one() {
316 self.status = Status::Solving;
318
319 self.h_prev_step = self.h;
321 self.t_prev = self.t;
322 self.y_prev = self.y;
323 self.dydt_prev = self.dydt;
324
325 self.t += self.h;
327 self.y = y_high;
328
329 ode.diff(self.t, &self.y, &mut self.dydt);
331 evals.fcn += 1;
332
333 self.compute_dense_output_coeffs();
335
336 if self.reject {
337 self.n_stiff = 0;
338 self.reject = false;
339 }
340 self.h = constrain_step_size(h_new, self.h_min, self.h_max);
341 } else {
342 self.status = Status::RejectedStep;
344 self.reject = true;
345 self.n_stiff += 1;
346
347 if self.n_stiff >= self.max_rejects {
348 self.status = Status::Error(Error::Stiffness { t: self.t, y: self.y });
349 return Err(Error::Stiffness { t: self.t, y: self.y });
350 }
351
352 self.h = constrain_step_size(h_new, self.h_min, self.h_max);
353 return Ok(evals); }
355
356 Ok(evals)
357 }
358
359 fn t(&self) -> T { self.t }
360 fn y(&self) -> &V { &self.y }
361 fn t_prev(&self) -> T { self.t_prev }
362 fn y_prev(&self) -> &V { &self.y_prev }
363 fn h(&self) -> T { self.h }
364 fn set_h(&mut self, h: T) { self.h = h; }
365 fn status(&self) -> &Status<T, V, D> { &self.status }
366 fn set_status(&mut self, status: Status<T, V, D>) { self.status = status; }
367}
368
369impl<T: Real, V: State<T>, D: CallBackData> Radau5<T, V, D> {
370 fn newton_iteration<F>(&mut self, ode: &F, evals: &mut Evals) -> Result<bool, Error<T, V>>
372 where
373 F: ODE<T, V, D>,
374 {
375 let dim = self.y.len();
376
377 for i in 0..3 { for l in 0..3 { let scale_factor = -self.h * self.a[i][l];
382 for r_idx in 0..dim { for c_idx in 0..dim { self.newton_matrix[(i * dim + r_idx, l * dim + c_idx)] =
385 self.jacobian_matrix[(r_idx, c_idx)] * scale_factor;
386 }
387 }
388 if i == l { for d_idx in 0..dim {
390 self.newton_matrix[(i * dim + d_idx, l * dim + d_idx)] += T::one();
391 }
392 }
393 }
394 }
395
396 let mut converged = false;
397
398 for _iter in 0..self.max_iter {
399 for i in 0..3 {
403 self.y_stage[i] = self.y;
404 for j in 0..3 {
405 self.y_stage[i] += self.k[j] * (self.a[i][j] * self.h);
406 }
407
408 ode.diff(self.t + self.c[i] * self.h, &self.y_stage[i], &mut self.f_at_stages[i]);
409 evals.fcn += 1;
410
411 for row_idx in 0..dim {
412 self.rhs_newton[i * dim + row_idx] = self.f_at_stages[i].get(row_idx) - self.k[i].get(row_idx);
413 }
414 }
415
416 let lu_decomp = nalgebra::LU::new(self.newton_matrix.clone());
418 if let Some(solution) = lu_decomp.solve(&self.rhs_newton) {
419 self.delta_k_vec.copy_from(&solution);
420 } else {
421 return Ok(false); }
423
424 let mut norm_delta_k_sq = T::zero();
426 for i in 0..3 {
427 for row_idx in 0..dim {
428 let delta_val = self.delta_k_vec[i * dim + row_idx];
429 let current_val = self.k[i].get(row_idx);
430 self.k[i].set(row_idx, current_val + delta_val);
431 norm_delta_k_sq += delta_val * delta_val;
432 }
433 }
434
435 let dyno = norm_delta_k_sq.sqrt();
437 if dyno < self.tol {
438 converged = true;
439 break;
440 }
441 }
442
443 Ok(converged)
444 }
445
446 fn compute_dense_output_coeffs(&mut self) {
449 self.cont[0] = self.y;
451
452 let c1_f = self.c[0];
457 let c2_f = self.c[1];
458
459 let c1m1 = c1_f - T::one(); let c2m1 = c2_f - T::one(); let c1mc2 = c1_f - c2_f; let z1_val = self.y_stage[0] - self.y_prev;
469 let z2_val = self.y_stage[1] - self.y_prev;
470 let z3_val = self.y_stage[2] - self.y_prev;
471
472 self.cont[1] = (z2_val - z3_val) / c2m1;
486
487 let ak = (z1_val - z2_val) / c1mc2;
488
489 let mut acont3_temp = z1_val / c1_f;
490
491 acont3_temp = (ak - acont3_temp) / c2_f;
492
493 self.cont[2] = (ak - self.cont[1]) / c1m1;
494
495 self.cont[3] = self.cont[2] - acont3_temp;
496 }
497}
498
499impl<T: Real, V: State<T>, D: CallBackData> Interpolation<T, V> for Radau5<T, V, D> {
500 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
501 if t_interp < self.t_prev || t_interp > self.t {
502 return Err(Error::OutOfBounds {
503 t_interp,
504 t_prev: self.t_prev,
505 t_curr: self.t,
506 });
507 }
508
509 let s = (t_interp - self.t) / self.h_prev_step;
512
513 let c1_f = self.c[0]; let c2_f = self.c[1]; let c1m1 = c1_f - T::one();
523 let c2m1 = c2_f - T::one();
524
525 let y_interp = self.cont[0] + (self.cont[1] + (self.cont[2] + self.cont[3] * (s - c1m1)) * (s - c2m1)) * s;
529
530 Ok(y_interp)
531 }
532}
533
534impl<
536 T: Real,
537 V: State<T>,
538 D: CallBackData,
539> Radau5<T, V, D> {
540 pub fn new() -> Self {
541 Self::default()
542 }
543 pub fn h0(mut self, h0: T) -> Self { self.h0 = h0; self }
544 pub fn rtol(mut self, rtol: T) -> Self { self.rtol = rtol; self }
545 pub fn atol(mut self, atol: T) -> Self { self.atol = atol; self }
546 pub fn h_min(mut self, h_min: T) -> Self { self.h_min = h_min; self }
547 pub fn h_max(mut self, h_max: T) -> Self { self.h_max = h_max; self }
548 pub fn max_steps(mut self, max_steps: usize) -> Self { self.max_steps = max_steps; self }
549 pub fn max_rejects(mut self, max_rejects: usize) -> Self { self.max_rejects = max_rejects; self }
550 pub fn safety_factor(mut self, safety_factor: T) -> Self { self.safety_factor = safety_factor; self }
551 pub fn min_scale(mut self, min_scale: T) -> Self { self.min_scale = min_scale; self }
552 pub fn max_scale(mut self, max_scale: T) -> Self { self.max_scale = max_scale; self }
553 pub fn max_iter(mut self, iter: usize) -> Self { self.max_iter = iter; self }
554 pub fn tol(mut self, tol: T) -> Self { self.tol = tol; self }
555}