1use faer::{ComplexField, Conjugate, SimpleEntity};
73use numra_core::Scalar;
74use numra_linalg::{DenseMatrix, LUFactorization, Matrix};
75
76use crate::error::SolverError;
77use crate::problem::OdeSystem;
78use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
79use crate::t_eval::{validate_grid, TEvalEmitter};
80
81#[derive(Clone, Debug, Default)]
83pub struct Radau5;
84
85impl Radau5 {
86 pub fn new() -> Self {
88 Self
89 }
90}
91
92mod coefficients {
94 pub const SQRT6: f64 = 2.449489742783178;
95
96 pub const C1: f64 = (4.0 - SQRT6) / 10.0; pub const C2: f64 = (4.0 + SQRT6) / 10.0; #[allow(dead_code)]
100 pub const C3: f64 = 1.0; pub const DD1: f64 = -(13.0 + 7.0 * SQRT6) / 3.0;
104 pub const DD2: f64 = (-13.0 + 7.0 * SQRT6) / 3.0;
105 pub const DD3: f64 = -1.0 / 3.0;
106
107 const CUBERT81: f64 = 4.3267487109222245;
110 const CUBERT9: f64 = 2.080083823051904;
111
112 const U1_RAW: f64 = (6.0 + CUBERT81 - CUBERT9) / 30.0;
114 pub const U1: f64 = 1.0 / U1_RAW; const ALPH_RAW: f64 = (12.0 - CUBERT81 + CUBERT9) / 60.0;
118 const BETA_RAW: f64 = (CUBERT81 + CUBERT9) * 1.7320508075688772 / 60.0; const CNO: f64 = ALPH_RAW * ALPH_RAW + BETA_RAW * BETA_RAW;
120 pub const ALPH: f64 = ALPH_RAW / CNO; pub const BETA: f64 = BETA_RAW / CNO; pub const T11: f64 = 9.1232394870892942792e-02;
126 pub const T12: f64 = -0.14125529502095420843;
127 pub const T13: f64 = -3.0029194105147424492e-02;
128 pub const T21: f64 = 0.24171793270710701896;
129 pub const T22: f64 = 0.20412935229379993199;
130 pub const T23: f64 = 0.38294211275726193779;
131 pub const T31: f64 = 0.96604818261509293619;
132 pub const T32: f64 = 1.0;
133 #[allow(dead_code)]
134 pub const T33: f64 = 0.0;
135
136 pub const TI11: f64 = 4.3255798900631553510;
138 pub const TI12: f64 = 0.33919925181580986954;
139 pub const TI13: f64 = 0.54177053993587487119;
140 pub const TI21: f64 = -4.1787185915519047273;
141 pub const TI22: f64 = -0.32768282076106238708;
142 pub const TI23: f64 = 0.47662355450055045196;
143 pub const TI31: f64 = -0.50287263494578687595;
144 pub const TI32: f64 = 2.5719269498556054292;
145 pub const TI33: f64 = -0.59603920482822492497;
146
147 pub const P11: f64 = 13.0 / 3.0 + 7.0 * SQRT6 / 3.0;
159 pub const P12: f64 = -23.0 / 3.0 - 22.0 * SQRT6 / 3.0;
160 pub const P13: f64 = 10.0 / 3.0 + 5.0 * SQRT6;
161 pub const P21: f64 = 13.0 / 3.0 - 7.0 * SQRT6 / 3.0;
162 pub const P22: f64 = -23.0 / 3.0 + 22.0 * SQRT6 / 3.0;
163 pub const P23: f64 = 10.0 / 3.0 - 5.0 * SQRT6;
164 pub const P31: f64 = 1.0 / 3.0;
165 pub const P32: f64 = -8.0 / 3.0;
166 pub const P33: f64 = 10.0 / 3.0;
167}
168
169const MAX_NEWTON_ITER: usize = 7;
171
172impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Radau5 {
173 fn solve<Sys: OdeSystem<S>>(
174 problem: &Sys,
175 t0: S,
176 tf: S,
177 y0: &[S],
178 options: &SolverOptions<S>,
179 ) -> Result<SolverResult<S>, SolverError> {
180 let dim = problem.dim();
181 if y0.len() != dim {
182 return Err(SolverError::DimensionMismatch {
183 expected: dim,
184 actual: y0.len(),
185 });
186 }
187
188 let mut t = t0;
189 let mut y = y0.to_vec();
190
191 let direction_init = if tf > t0 { S::ONE } else { -S::ONE };
192 if let Some(grid) = options.t_eval.as_deref() {
193 validate_grid(grid, t0, tf)?;
194 }
195 let mut grid_emitter = options
196 .t_eval
197 .as_deref()
198 .map(|g| TEvalEmitter::new(g, direction_init));
199 let (mut t_out, mut y_out) = if grid_emitter.is_some() {
200 (Vec::new(), Vec::new())
201 } else {
202 (vec![t0], y0.to_vec())
203 };
204 let mut dy_old_buf = vec![S::ZERO; dim];
206
207 let mut f0 = vec![S::ZERO; dim];
209 let mut z1 = vec![S::ZERO; dim];
210 let mut z2 = vec![S::ZERO; dim];
211 let mut z3 = vec![S::ZERO; dim];
212 let mut w1 = vec![S::ZERO; dim];
213 let mut w2 = vec![S::ZERO; dim];
214 let mut w3 = vec![S::ZERO; dim];
215 let mut cont = vec![S::ZERO; dim];
216 let mut scal = vec![S::ZERO; dim];
217 let mut y_new = vec![S::ZERO; dim];
218 let mut err = vec![S::ZERO; dim];
219 let mut jac_data = vec![S::ZERO; dim * dim];
220
221 let mut z1_prev = vec![S::ZERO; dim];
224 let mut z2_prev = vec![S::ZERO; dim];
225 let mut z3_prev = vec![S::ZERO; dim];
226 let mut h_prev: S = S::ONE; let mut have_prev = false;
228
229 let mut h_abs_old: Option<S> = None;
231 let mut err_norm_old: Option<S> = None;
232
233 let has_mass = problem.has_mass_matrix();
235 let mass_data = if has_mass {
236 let mut m = vec![S::ZERO; dim * dim];
237 problem.mass_matrix(&mut m);
238 Some(m)
239 } else {
240 None
241 };
242 let mass_ref = mass_data.as_deref();
243
244 let mut stats = SolverStats::default();
245
246 for i in 0..dim {
248 scal[i] = options.atol + options.rtol * y[i].abs();
249 }
250
251 problem.rhs(t, &y, &mut f0);
255 stats.n_eval += 1;
256
257 let mut h = Self::initial_step_size(&y, &f0, options, dim);
259 let h_min = options.h_min;
260 let h_max = (tf - t0).abs() * S::from_f64(0.5);
261
262 let mut lu_real: Option<LUFactorization<S>> = None;
264 let mut lu_complex: Option<LUFactorization<S>> = None;
265 let mut need_jac = true;
266
267 let mut first = true;
268 let mut reject = false;
269 let mut step_count = 0usize;
270 let direction = if tf > t0 { S::ONE } else { -S::ONE };
271
272 while (tf - t) * direction > S::ZERO {
273 if step_count >= options.max_steps {
274 return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
275 }
276
277 if (t + h - tf) * direction > S::ZERO {
279 h = tf - t;
280 }
281
282 if need_jac {
294 problem.jacobian(t, &y, &mut jac_data);
295 stats.n_jac += 1;
296 need_jac = false;
297 lu_real = None;
299 lu_complex = None;
300 }
301
302 if lu_real.is_none() {
306 let (e1, e2) = Self::form_transformed_matrices(&jac_data, h, dim, mass_ref);
307 lu_real = Some(LUFactorization::new(&e1)?);
308 lu_complex = Some(LUFactorization::new(&e2)?);
309 stats.n_lu += 2;
310 }
311
312 for i in 0..dim {
314 scal[i] = options.atol + options.rtol * y[i].abs();
315 }
316
317 let use_extrapolation = !first && !reject && have_prev;
322 if !use_extrapolation {
323 for i in 0..dim {
324 z1[i] = S::ZERO;
325 z2[i] = S::ZERO;
326 z3[i] = S::ZERO;
327 w1[i] = S::ZERO;
328 w2[i] = S::ZERO;
329 w3[i] = S::ZERO;
330 }
331 } else {
332 let p11 = S::from_f64(coefficients::P11);
337 let p12 = S::from_f64(coefficients::P12);
338 let p13 = S::from_f64(coefficients::P13);
339 let p21 = S::from_f64(coefficients::P21);
340 let p22 = S::from_f64(coefficients::P22);
341 let p23 = S::from_f64(coefficients::P23);
342 let p31 = S::from_f64(coefficients::P31);
343 let p32 = S::from_f64(coefficients::P32);
344 let p33 = S::from_f64(coefficients::P33);
345
346 let c1 = S::from_f64(coefficients::C1);
347 let c2 = S::from_f64(coefficients::C2);
348 let c3 = S::ONE;
349
350 let r1 = h * c1 / h_prev;
351 let r2 = h * c2 / h_prev;
352 let r3 = h * c3 / h_prev;
353
354 for i in 0..dim {
355 let q0 = z1_prev[i] * p11 + z2_prev[i] * p21 + z3_prev[i] * p31;
356 let q1 = z1_prev[i] * p12 + z2_prev[i] * p22 + z3_prev[i] * p32;
357 let q2 = z1_prev[i] * p13 + z2_prev[i] * p23 + z3_prev[i] * p33;
358
359 z1[i] = q0 * r1 + q1 * r1 * r1 + q2 * r1 * r1 * r1;
360 z2[i] = q0 * r2 + q1 * r2 * r2 + q2 * r2 * r2 * r2;
361 z3[i] = q0 * r3 + q1 * r3 * r3 + q2 * r3 * r3 * r3;
362 }
363
364 let ti11 = S::from_f64(coefficients::TI11);
367 let ti12 = S::from_f64(coefficients::TI12);
368 let ti13 = S::from_f64(coefficients::TI13);
369 let ti21 = S::from_f64(coefficients::TI21);
370 let ti22 = S::from_f64(coefficients::TI22);
371 let ti23 = S::from_f64(coefficients::TI23);
372 let ti31 = S::from_f64(coefficients::TI31);
373 let ti32 = S::from_f64(coefficients::TI32);
374 let ti33 = S::from_f64(coefficients::TI33);
375
376 for i in 0..dim {
377 w1[i] = ti11 * z1[i] + ti12 * z2[i] + ti13 * z3[i];
378 w2[i] = ti21 * z1[i] + ti22 * z2[i] + ti23 * z3[i];
379 w3[i] = ti31 * z1[i] + ti32 * z2[i] + ti33 * z3[i];
380 }
381 }
382
383 let newton_result = Self::newton_iteration(
385 problem,
386 t,
387 h,
388 &y,
389 &scal,
390 &mut z1,
391 &mut z2,
392 &mut z3,
393 &mut w1,
394 &mut w2,
395 &mut w3,
396 &mut cont,
397 lu_real.as_ref().unwrap(),
398 lu_complex.as_ref().unwrap(),
399 mass_ref,
400 &mut stats,
401 dim,
402 options,
403 );
404
405 let (newton_converged, newt_iter) = match newton_result {
406 Ok((converged, iter)) => (converged, iter),
407 Err(_) => (false, MAX_NEWTON_ITER),
408 };
409
410 if !newton_converged {
411 h = h * S::from_f64(0.5);
413 stats.n_reject += 1;
414 reject = true;
415 need_jac = true;
416
417 if h.abs() < h_min {
418 return Err(SolverError::StepSizeTooSmall {
419 t: t.to_f64(),
420 h: h.to_f64(),
421 h_min: h_min.to_f64(),
422 });
423 }
424 continue;
425 }
426
427 for i in 0..dim {
429 y_new[i] = y[i] + z3[i];
430 }
431
432 let err_norm = Self::error_estimate(
435 problem,
436 t,
437 &f0,
438 &z1,
439 &z2,
440 &z3,
441 &y,
442 &y_new,
443 h,
444 options,
445 lu_real.as_ref().unwrap(),
446 &mut err,
447 dim,
448 first,
449 reject,
450 &mut stats,
451 mass_ref,
452 );
453
454 let safety = Self::safety_factor::<S>(newt_iter, MAX_NEWTON_ITER);
458 let pred = Self::predict_factor(h.abs(), h_abs_old, err_norm, err_norm_old);
459 let factor = (safety * pred).max(S::from_f64(0.2)).min(S::from_f64(8.0));
460
461 if err_norm < S::ONE {
462 stats.n_accept += 1;
464
465 z1_prev.copy_from_slice(&z1);
467 z2_prev.copy_from_slice(&z2);
468 z3_prev.copy_from_slice(&z3);
469 h_prev = h;
470 have_prev = true;
471
472 h_abs_old = Some(h.abs());
474 err_norm_old = Some(err_norm);
475
476 let t_new = t + h;
477 dy_old_buf.copy_from_slice(&f0);
481 problem.rhs(t_new, &y_new, &mut f0);
482 stats.n_eval += 1;
483
484 if let Some(ref mut emitter) = grid_emitter {
485 emitter.emit_step(
486 t,
487 &y,
488 &dy_old_buf,
489 t_new,
490 &y_new,
491 &f0,
492 &mut t_out,
493 &mut y_out,
494 );
495 } else {
496 t_out.push(t_new);
497 y_out.extend_from_slice(&y_new);
498 }
499
500 t = t_new;
501 y.copy_from_slice(&y_new);
502
503 first = false;
504 reject = false;
505
506 if factor < S::from_f64(1.2) {
510 } else {
512 let h_proposed = h * factor;
513 let h_capped = if h_proposed.abs() > h_max {
514 if h_proposed > S::ZERO {
515 h_max
516 } else {
517 -h_max
518 }
519 } else {
520 h_proposed
521 };
522 h = h_capped;
523 lu_real = None;
524 lu_complex = None;
525 }
526 } else {
527 stats.n_reject += 1;
529 reject = true;
530
531 h = h * factor;
533 lu_real = None;
534 lu_complex = None;
535
536 if h.abs() < h_min {
537 return Err(SolverError::StepSizeTooSmall {
538 t: t.to_f64(),
539 h: h.to_f64(),
540 h_min: h_min.to_f64(),
541 });
542 }
543 }
544
545 step_count += 1;
546 }
547
548 Ok(SolverResult::new(t_out, y_out, dim, stats))
549 }
550}
551
552impl Radau5 {
553 fn initial_step_size<S: Scalar>(y: &[S], f: &[S], options: &SolverOptions<S>, dim: usize) -> S {
555 let mut d0 = S::ZERO;
556 let mut d1 = S::ZERO;
557
558 for i in 0..dim {
559 let sc = options.atol + options.rtol * y[i].abs();
560 d0 = d0 + (y[i] / sc) * (y[i] / sc);
561 d1 = d1 + (f[i] / sc) * (f[i] / sc);
562 }
563
564 let d0 = (d0 / S::from_usize(dim)).sqrt();
565 let d1 = (d1 / S::from_usize(dim)).sqrt();
566
567 let h0 = if d0 < S::from_f64(1e-5) || d1 < S::from_f64(1e-5) {
568 S::from_f64(1e-6)
569 } else {
570 S::from_f64(0.01) * d0 / d1
571 };
572
573 h0.min(options.h_max).max(options.h_min)
574 }
575
576 fn form_transformed_matrices<S>(
588 jac: &[S],
589 h: S,
590 dim: usize,
591 mass: Option<&[S]>,
592 ) -> (DenseMatrix<S>, DenseMatrix<S>)
593 where
594 S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
595 {
596 let fac1 = S::from_f64(coefficients::U1) / h;
597 let mut e1 = DenseMatrix::zeros(dim, dim);
598 for i in 0..dim {
599 for j in 0..dim {
600 let jij = jac[i * dim + j];
601 let mij = match mass {
602 Some(m) => m[i * dim + j],
603 None => {
604 if i == j {
605 S::ONE
606 } else {
607 S::ZERO
608 }
609 }
610 };
611 e1.set(i, j, fac1 * mij - jij);
612 }
613 }
614
615 let alphn = S::from_f64(coefficients::ALPH) / h;
620 let betan = S::from_f64(coefficients::BETA) / h;
621 let mut e2 = DenseMatrix::zeros(2 * dim, 2 * dim);
622
623 for i in 0..dim {
624 for j in 0..dim {
625 let jij = jac[i * dim + j];
626 let mij = match mass {
627 Some(m) => m[i * dim + j],
628 None => {
629 if i == j {
630 S::ONE
631 } else {
632 S::ZERO
633 }
634 }
635 };
636 e2.set(i, j, alphn * mij - jij);
637 e2.set(i, dim + j, -betan * mij);
638 e2.set(dim + i, j, betan * mij);
639 e2.set(dim + i, dim + j, alphn * mij - jij);
640 }
641 }
642
643 (e1, e2)
644 }
645
646 fn safety_factor<S: Scalar>(n_iter: usize, max_iter: usize) -> S {
653 let num = 0.9 * (2.0 * max_iter as f64 + 1.0);
654 let den = 2.0 * max_iter as f64 + n_iter as f64;
655 S::from_f64(num / den)
656 }
657
658 fn predict_factor<S: Scalar>(
671 h_abs: S,
672 h_abs_old: Option<S>,
673 err_norm: S,
674 err_norm_old: Option<S>,
675 ) -> S {
676 let multiplier = match (h_abs_old, err_norm_old) {
677 (Some(h_old), Some(err_old)) if err_norm > S::ZERO && h_old > S::ZERO => {
678 (h_abs / h_old) * (err_old / err_norm).powf(S::from_f64(0.25))
679 }
680 _ => S::ONE,
681 };
682 multiplier.min(S::ONE) * err_norm.powf(S::from_f64(-0.25))
683 }
684
685 #[allow(clippy::too_many_arguments)]
697 fn newton_iteration<S, Sys>(
698 problem: &Sys,
699 t: S,
700 h: S,
701 y: &[S],
702 scal: &[S],
703 z1: &mut [S],
704 z2: &mut [S],
705 z3: &mut [S],
706 w1: &mut [S],
707 w2: &mut [S],
708 w3: &mut [S],
709 cont: &mut [S],
710 lu_real: &LUFactorization<S>,
711 lu_complex: &LUFactorization<S>,
712 mass: Option<&[S]>,
713 stats: &mut SolverStats,
714 dim: usize,
715 options: &SolverOptions<S>,
716 ) -> Result<(bool, usize), SolverError>
717 where
718 S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
719 Sys: OdeSystem<S>,
720 {
721 let c1 = S::from_f64(coefficients::C1);
722 let c2 = S::from_f64(coefficients::C2);
723
724 let uround = S::from_f64(1e-16);
727 let fnewt = (S::from_f64(10.0) * uround / options.rtol)
728 .max(S::from_f64(0.03).min(options.rtol.sqrt()));
729
730 let fac1 = S::from_f64(coefficients::U1) / h;
732 let alphn = S::from_f64(coefficients::ALPH) / h;
733 let betan = S::from_f64(coefficients::BETA) / h;
734
735 let ti11 = S::from_f64(coefficients::TI11);
737 let ti12 = S::from_f64(coefficients::TI12);
738 let ti13 = S::from_f64(coefficients::TI13);
739 let ti21 = S::from_f64(coefficients::TI21);
740 let ti22 = S::from_f64(coefficients::TI22);
741 let ti23 = S::from_f64(coefficients::TI23);
742 let ti31 = S::from_f64(coefficients::TI31);
743 let ti32 = S::from_f64(coefficients::TI32);
744 let ti33 = S::from_f64(coefficients::TI33);
745
746 let t11 = S::from_f64(coefficients::T11);
747 let t12 = S::from_f64(coefficients::T12);
748 let t13 = S::from_f64(coefficients::T13);
749 let t21 = S::from_f64(coefficients::T21);
750 let t22 = S::from_f64(coefficients::T22);
751 let t23 = S::from_f64(coefficients::T23);
752 let t31 = S::from_f64(coefficients::T31);
753 let t32 = S::from_f64(coefficients::T32);
754 let mut dynold: S = uround;
758 let mut thqold: S = S::ONE;
759 let mut faccon: S = S::ONE;
760
761 let n3 = S::from_usize(3 * dim);
762
763 let mut f2_temp = vec![S::ZERO; dim];
765 let mut f3_temp = vec![S::ZERO; dim];
766 let mut z1_orig = vec![S::ZERO; dim];
767 let mut z2_orig = vec![S::ZERO; dim];
768 let mut z3_orig = vec![S::ZERO; dim];
769 let mut mz1_buf = vec![S::ZERO; dim];
770 let mut mz2_buf = vec![S::ZERO; dim];
771 let mut mz3_buf = vec![S::ZERO; dim];
772 let mut rhs1 = vec![S::ZERO; dim];
773 let mut rhs2 = vec![S::ZERO; dim];
774 let mut rhs3 = vec![S::ZERO; dim];
775 let mut rhs_complex = vec![S::ZERO; 2 * dim];
776
777 for newt in 0..MAX_NEWTON_ITER {
778 for i in 0..dim {
780 cont[i] = y[i] + z1[i];
781 }
782 problem.rhs(t + c1 * h, cont, z1); for i in 0..dim {
785 cont[i] = y[i] + z2[i];
786 }
787 problem.rhs(t + c2 * h, cont, &mut f2_temp);
788
789 for i in 0..dim {
790 cont[i] = y[i] + z3[i];
791 }
792 problem.rhs(t + h, cont, &mut f3_temp);
793 stats.n_eval += 3;
794
795 for i in 0..dim {
797 z1_orig[i] = t11 * w1[i] + t12 * w2[i] + t13 * w3[i];
798 z2_orig[i] = t21 * w1[i] + t22 * w2[i] + t23 * w3[i];
799 z3_orig[i] = t31 * w1[i] + t32 * w2[i]; }
801
802 if let Some(m) = mass {
804 for i in 0..dim {
805 mz1_buf[i] = S::ZERO;
806 mz2_buf[i] = S::ZERO;
807 mz3_buf[i] = S::ZERO;
808 }
809 for i in 0..dim {
810 for j in 0..dim {
811 let mij = m[i * dim + j];
812 mz1_buf[i] = mz1_buf[i] + mij * z1_orig[j];
813 mz2_buf[i] = mz2_buf[i] + mij * z2_orig[j];
814 mz3_buf[i] = mz3_buf[i] + mij * z3_orig[j];
815 }
816 }
817 } else {
818 mz1_buf.copy_from_slice(&z1_orig);
819 mz2_buf.copy_from_slice(&z2_orig);
820 mz3_buf.copy_from_slice(&z3_orig);
821 }
822
823 for i in 0..dim {
828 let a1 = z1[i]; let a2 = f2_temp[i];
830 let a3 = f3_temp[i];
831 let tf1 = ti11 * a1 + ti12 * a2 + ti13 * a3;
832 let tf2 = ti21 * a1 + ti22 * a2 + ti23 * a3;
833 let tf3 = ti31 * a1 + ti32 * a2 + ti33 * a3;
834
835 let tmz1 = ti11 * mz1_buf[i] + ti12 * mz2_buf[i] + ti13 * mz3_buf[i];
836 let tmz2 = ti21 * mz1_buf[i] + ti22 * mz2_buf[i] + ti23 * mz3_buf[i];
837 let tmz3 = ti31 * mz1_buf[i] + ti32 * mz2_buf[i] + ti33 * mz3_buf[i];
838
839 rhs1[i] = tf1 - fac1 * tmz1;
840 rhs2[i] = tf2 - alphn * tmz2 + betan * tmz3;
841 rhs3[i] = tf3 - alphn * tmz3 - betan * tmz2;
842 }
843
844 let dw1 = lu_real.solve(&rhs1)?;
846
847 for i in 0..dim {
848 rhs_complex[i] = rhs2[i];
849 rhs_complex[dim + i] = rhs3[i];
850 }
851 let dw_complex = lu_complex.solve(&rhs_complex)?;
852
853 let mut dyno = S::ZERO;
855 for i in 0..dim {
856 let denom = scal[i];
857 dyno = dyno
858 + (dw1[i] / denom) * (dw1[i] / denom)
859 + (dw_complex[i] / denom) * (dw_complex[i] / denom)
860 + (dw_complex[dim + i] / denom) * (dw_complex[dim + i] / denom);
861 }
862 dyno = (dyno / n3).sqrt();
863
864 if (1..MAX_NEWTON_ITER - 1).contains(&newt) {
868 let thq = dyno / dynold;
869 let theta = if newt == 1 {
870 thq
871 } else {
872 (thq * thqold).sqrt()
873 };
874 thqold = thq;
875
876 if theta < S::from_f64(0.99) {
877 faccon = theta / (S::ONE - theta);
878 let dyth =
879 faccon * dyno * theta.powf(S::from_usize(MAX_NEWTON_ITER - 1 - newt))
880 / fnewt;
881 if dyth >= S::ONE {
882 return Ok((false, newt + 1));
883 }
884 } else {
885 return Ok((false, newt + 1));
886 }
887 }
888 dynold = dyno.max(uround);
889
890 for i in 0..dim {
892 w1[i] = w1[i] + dw1[i];
893 w2[i] = w2[i] + dw_complex[i];
894 w3[i] = w3[i] + dw_complex[dim + i];
895 }
896
897 for i in 0..dim {
899 z1[i] = t11 * w1[i] + t12 * w2[i] + t13 * w3[i];
900 z2[i] = t21 * w1[i] + t22 * w2[i] + t23 * w3[i];
901 z3[i] = t31 * w1[i] + t32 * w2[i]; }
903
904 if faccon * dyno <= fnewt {
906 return Ok((true, newt + 1));
907 }
908 }
909
910 Ok((false, MAX_NEWTON_ITER))
911 }
912
913 #[allow(clippy::too_many_arguments)]
926 fn error_estimate<S, Sys>(
927 problem: &Sys,
928 t: S,
929 f0: &[S],
930 z1: &[S],
931 z2: &[S],
932 z3: &[S],
933 y: &[S],
934 y_new: &[S],
935 h: S,
936 options: &SolverOptions<S>,
937 lu_real: &LUFactorization<S>,
938 err: &mut [S],
939 dim: usize,
940 first: bool,
941 reject: bool,
942 stats: &mut SolverStats,
943 mass: Option<&[S]>,
944 ) -> S
945 where
946 S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
947 Sys: OdeSystem<S>,
948 {
949 let dd1 = S::from_f64(coefficients::DD1);
950 let dd2 = S::from_f64(coefficients::DD2);
951 let dd3 = S::from_f64(coefficients::DD3);
952
953 let mut f2 = vec![S::ZERO; dim];
955 for i in 0..dim {
956 f2[i] = dd1 * z1[i] + dd2 * z2[i] + dd3 * z3[i];
957 }
958
959 let mut cont = vec![S::ZERO; dim];
960 if let Some(m) = mass {
961 let mut mf2 = vec![S::ZERO; dim];
963 for i in 0..dim {
964 for j in 0..dim {
965 mf2[i] = mf2[i] + m[i * dim + j] * f2[j];
966 }
967 }
968 for i in 0..dim {
969 cont[i] = mf2[i] / h + f0[i]; }
971 for i in 0..dim {
973 f2[i] = mf2[i] / h;
974 }
975 } else {
976 for i in 0..dim {
978 f2[i] = f2[i] / h;
979 cont[i] = f2[i] + f0[i]; }
981 }
982
983 let solved = match lu_real.solve(&cont) {
985 Ok(s) => s,
986 Err(_) => return S::from_f64(1e6),
987 };
988
989 let mut err_norm = S::ZERO;
991 for i in 0..dim {
992 err[i] = solved[i];
993 let y_max = y[i].abs().max(y_new[i].abs());
994 let scale = options.atol + options.rtol * y_max;
995 let r = solved[i] / scale;
996 err_norm = err_norm + r * r;
997 }
998 let err_norm = (err_norm / S::from_usize(dim)).sqrt();
999 let err_norm = err_norm.max(S::from_f64(1e-10));
1000
1001 if err_norm >= S::ONE && (first || reject) {
1005 for i in 0..dim {
1006 cont[i] = y[i] + solved[i];
1007 }
1008 let mut f1 = vec![S::ZERO; dim];
1009 problem.rhs(t, &cont, &mut f1);
1010 stats.n_eval += 1;
1011
1012 for i in 0..dim {
1013 cont[i] = f1[i] + f2[i];
1014 }
1015 let solved2 = match lu_real.solve(&cont) {
1016 Ok(s) => s,
1017 Err(_) => return S::from_f64(1e6),
1018 };
1019
1020 let mut err_norm2 = S::ZERO;
1021 for i in 0..dim {
1022 err[i] = solved2[i];
1023 let y_max = y[i].abs().max(y_new[i].abs());
1024 let scale = options.atol + options.rtol * y_max;
1025 let r = solved2[i] / scale;
1026 err_norm2 = err_norm2 + r * r;
1027 }
1028 let err_norm2 = (err_norm2 / S::from_usize(dim)).sqrt();
1029 return err_norm2.max(S::from_f64(1e-10));
1030 }
1031
1032 err_norm
1033 }
1034}
1035
1036#[cfg(test)]
1037mod tests {
1038 use super::*;
1039 use crate::problem::{DaeProblem, OdeProblem};
1040
1041 #[test]
1042 fn test_radau5_stiff_decay() {
1043 let problem = OdeProblem::new(
1044 |_t, y: &[f64], dydt: &mut [f64]| {
1045 dydt[0] = -100.0 * y[0];
1046 },
1047 0.0,
1048 0.1,
1049 vec![1.0],
1050 );
1051 let options = SolverOptions::default().rtol(1e-2).atol(1e-4);
1052 let result = Radau5::solve(&problem, 0.0, 0.1, &[1.0], &options).unwrap();
1053 assert!(result.success);
1054 let y_final = result.y_final().unwrap();
1055 let exact = (-10.0_f64).exp();
1056 assert!(
1057 (y_final[0] - exact).abs() < 1e-4,
1058 "Error: {}",
1059 (y_final[0] - exact).abs()
1060 );
1061 }
1062
1063 #[test]
1064 fn test_radau5_exponential() {
1065 let problem = OdeProblem::new(
1066 |_t, y: &[f64], dydt: &mut [f64]| {
1067 dydt[0] = y[0];
1068 },
1069 0.0,
1070 1.0,
1071 vec![1.0],
1072 );
1073 let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
1074 let result = Radau5::solve(&problem, 0.0, 1.0, &[1.0], &options).unwrap();
1075 assert!(result.success);
1076 let y_final = result.y_final().unwrap();
1077 let exact = 1.0_f64.exp();
1078 assert!((y_final[0] - exact).abs() < 1e-5);
1079 }
1080
1081 #[test]
1082 fn test_radau5_linear_2d() {
1083 let problem = OdeProblem::new(
1084 |_t, y: &[f64], dydt: &mut [f64]| {
1085 dydt[0] = -y[0] + y[1];
1086 dydt[1] = -y[0] - y[1];
1087 },
1088 0.0,
1089 1.0,
1090 vec![1.0, 0.0],
1091 );
1092 let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1093 let result = Radau5::solve(&problem, 0.0, 1.0, &[1.0, 0.0], &options).unwrap();
1094 assert!(result.success);
1095 }
1096
1097 #[test]
1098 fn test_radau5_van_der_pol_mild() {
1099 let mu = 10.0;
1100 let problem = OdeProblem::new(
1101 move |_t, y: &[f64], dydt: &mut [f64]| {
1102 dydt[0] = y[1];
1103 dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
1104 },
1105 0.0,
1106 2.0,
1107 vec![2.0, 0.0],
1108 );
1109 let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1110 let result = Radau5::solve(&problem, 0.0, 2.0, &[2.0, 0.0], &options);
1111 assert!(result.is_ok());
1112 }
1113
1114 #[test]
1115 fn test_radau5_van_der_pol_stiff() {
1116 let mu = 100.0;
1117 let problem = OdeProblem::new(
1118 move |_t, y: &[f64], dydt: &mut [f64]| {
1119 dydt[0] = y[1];
1120 dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
1121 },
1122 0.0,
1123 20.0,
1124 vec![2.0, 0.0],
1125 );
1126 let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
1127 let result = Radau5::solve(&problem, 0.0, 20.0, &[2.0, 0.0], &options);
1128 assert!(
1129 result.is_ok(),
1130 "Van der Pol μ=100 failed: {:?}",
1131 result.err()
1132 );
1133 }
1134
1135 #[test]
1136 fn test_radau5_step_efficiency() {
1137 let mu = 100.0;
1142 let problem = OdeProblem::new(
1143 move |_t, y: &[f64], dydt: &mut [f64]| {
1144 dydt[0] = y[1];
1145 dydt[1] = mu * (1.0 - y[0] * y[0]) * y[1] - y[0];
1146 },
1147 0.0,
1148 20.0,
1149 vec![2.0, 0.0],
1150 );
1151 let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
1152 let result = Radau5::solve(&problem, 0.0, 20.0, &[2.0, 0.0], &options).unwrap();
1153
1154 assert!(
1156 result.stats.n_accept < 200,
1157 "Too many accepted steps: {} (expected < 200, ~15 typical)",
1158 result.stats.n_accept
1159 );
1160 assert!(result.success);
1161 }
1162
1163 #[test]
1164 fn test_radau5_simple_dae() {
1165 let dae = DaeProblem::new(
1169 |_t, y: &[f64], dydt: &mut [f64]| {
1170 dydt[0] = -y[0] + y[1];
1171 dydt[1] = y[0] - y[1];
1172 },
1173 |mass: &mut [f64]| {
1174 for i in 0..4 {
1175 mass[i] = 0.0;
1176 }
1177 mass[0] = 1.0;
1178 },
1179 0.0,
1180 1.0,
1181 vec![1.0, 1.0],
1182 vec![1],
1183 );
1184
1185 let options = SolverOptions::default()
1186 .rtol(1e-4)
1187 .atol(1e-6)
1188 .max_steps(500_000);
1189 let result = Radau5::solve(&dae, 0.0, 1.0, &[1.0, 1.0], &options);
1190
1191 assert!(result.is_ok(), "DAE solve failed: {:?}", result.err());
1192 let sol = result.unwrap();
1193
1194 let yf = sol.y_final().unwrap();
1195 assert!(
1196 (yf[0] - 1.0).abs() < 1e-4,
1197 "y1 deviated: {} (expected 1.0)",
1198 yf[0]
1199 );
1200 assert!(
1201 (yf[1] - 1.0).abs() < 1e-4,
1202 "y2 deviated: {} (expected 1.0)",
1203 yf[1]
1204 );
1205 let constraint = yf[0] - yf[1];
1206 assert!(
1207 constraint.abs() < 1e-4,
1208 "Constraint violated: {} (y1={}, y2={})",
1209 constraint,
1210 yf[0],
1211 yf[1]
1212 );
1213 }
1214
1215 #[test]
1216 fn test_radau5_dae_with_mass_identity() {
1217 let dae = DaeProblem::new(
1219 |_t, y: &[f64], dydt: &mut [f64]| {
1220 dydt[0] = -y[0];
1221 },
1222 |mass: &mut [f64]| {
1223 mass[0] = 1.0;
1224 },
1225 0.0,
1226 1.0,
1227 vec![1.0],
1228 vec![],
1229 );
1230
1231 let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
1232 let result = Radau5::solve(&dae, 0.0, 1.0, &[1.0], &options);
1233
1234 assert!(
1235 result.is_ok(),
1236 "DAE with identity mass failed: {:?}",
1237 result.err()
1238 );
1239 let sol = result.unwrap();
1240 let yf = sol.y_final().unwrap();
1241 let exact = (-1.0_f64).exp();
1242 assert!(
1243 (yf[0] - exact).abs() < 1e-5,
1244 "Error: {} (expected {}, got {})",
1245 (yf[0] - exact).abs(),
1246 exact,
1247 yf[0]
1248 );
1249 }
1250
1251 #[test]
1252 fn test_radau5_dae_scaled_mass() {
1253 let dae = DaeProblem::new(
1255 |_t, y: &[f64], dydt: &mut [f64]| {
1256 dydt[0] = -y[0];
1257 },
1258 |mass: &mut [f64]| {
1259 mass[0] = 2.0;
1260 },
1261 0.0,
1262 1.0,
1263 vec![1.0],
1264 vec![],
1265 );
1266
1267 let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1268 let result = Radau5::solve(&dae, 0.0, 1.0, &[1.0], &options);
1269
1270 assert!(
1271 result.is_ok(),
1272 "DAE with scaled mass failed: {:?}",
1273 result.err()
1274 );
1275 let sol = result.unwrap();
1276 let yf = sol.y_final().unwrap();
1277 let exact = (-0.5_f64).exp();
1278 assert!(
1279 (yf[0] - exact).abs() < 1e-3,
1280 "Error: {} (expected {}, got {})",
1281 (yf[0] - exact).abs(),
1282 exact,
1283 yf[0]
1284 );
1285 }
1286}