1use faer::{ComplexField, Conjugate, SimpleEntity};
31use numra_core::Scalar;
32
33use crate::bdf::Bdf;
34use crate::error::SolverError;
35use crate::esdirk::Esdirk54;
36use crate::problem::OdeSystem;
37use crate::radau5::Radau5;
38use crate::solver::{Solver, SolverOptions, SolverResult};
39use crate::tsit5::Tsit5;
40use crate::verner::{Vern6, Vern8};
41
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44pub enum Stiffness {
45 NonStiff,
47 ModeratelyStiff,
49 VeryStiff,
51 Unknown,
53}
54
55#[derive(Clone, Copy, Debug, PartialEq, Eq)]
57pub enum Accuracy {
58 Low,
60 Standard,
62 High,
64 VeryHigh,
66}
67
68#[derive(Clone, Debug, Default)]
70pub struct SolverHints {
71 pub stiffness: Option<Stiffness>,
73 pub accuracy: Option<Accuracy>,
75 pub prefer_implicit: bool,
77 pub detect_stiffness: bool,
79}
80
81impl SolverHints {
82 pub fn new() -> Self {
84 Self {
85 stiffness: None,
86 accuracy: None,
87 prefer_implicit: false,
88 detect_stiffness: true,
89 }
90 }
91
92 pub fn stiffness(mut self, stiffness: Stiffness) -> Self {
94 self.stiffness = Some(stiffness);
95 self
96 }
97
98 pub fn accuracy(mut self, accuracy: Accuracy) -> Self {
100 self.accuracy = Some(accuracy);
101 self
102 }
103
104 pub fn implicit(mut self) -> Self {
106 self.prefer_implicit = true;
107 self
108 }
109
110 pub fn detect_stiffness(mut self, detect: bool) -> Self {
112 self.detect_stiffness = detect;
113 self
114 }
115}
116
117#[derive(Clone, Debug, Default)]
119pub struct Auto {
120 #[allow(dead_code)]
121 hints: SolverHints,
122}
123
124impl Auto {
125 pub fn new() -> Self {
127 Self {
128 hints: SolverHints::new(),
129 }
130 }
131
132 pub fn with_hints(hints: SolverHints) -> Self {
134 Self { hints }
135 }
136
137 fn classify_accuracy<S: Scalar>(options: &SolverOptions<S>) -> Accuracy {
139 let rtol = options.rtol.to_f64();
140 if rtol >= 1e-3 {
141 Accuracy::Low
142 } else if rtol >= 1e-7 {
143 Accuracy::Standard
144 } else if rtol >= 1e-11 {
145 Accuracy::High
146 } else {
147 Accuracy::VeryHigh
148 }
149 }
150
151 fn detect_stiffness<S, Sys>(
153 problem: &Sys,
154 t: S,
155 y: &[S],
156 _options: &SolverOptions<S>,
157 ) -> Stiffness
158 where
159 S: Scalar,
160 Sys: OdeSystem<S>,
161 {
162 let dim = problem.dim();
163 if dim == 0 {
164 return Stiffness::Unknown;
165 }
166
167 let eps = S::from_f64(1e-8);
169 let mut f0 = vec![S::ZERO; dim];
170 let mut f1 = vec![S::ZERO; dim];
171 let _jv = vec![S::ZERO; dim];
172
173 problem.rhs(t, y, &mut f0);
174
175 let mut max_jac = S::ZERO;
177 let mut min_jac = S::INFINITY;
178 let mut y_pert = y.to_vec();
179
180 for j in 0..dim.min(10) {
181 let yj = y[j];
183 let h = eps * (S::ONE + yj.abs());
184 y_pert[j] = yj + h;
185 problem.rhs(t, &y_pert, &mut f1);
186 y_pert[j] = yj;
187
188 for i in 0..dim {
189 let jij = ((f1[i] - f0[i]) / h).abs();
190 if jij > S::from_f64(1e-15) {
191 max_jac = max_jac.max(jij);
192 min_jac = min_jac.min(jij);
193 }
194 }
195 }
196
197 if max_jac < S::from_f64(1e-10) {
199 return Stiffness::NonStiff;
200 }
201
202 let ratio = max_jac / min_jac.max(S::from_f64(1e-15));
203 let ratio_f64 = ratio.to_f64();
204
205 if ratio_f64 > 1e4 {
206 Stiffness::VeryStiff
207 } else if ratio_f64 > 100.0 {
208 Stiffness::ModeratelyStiff
209 } else {
210 Stiffness::NonStiff
211 }
212 }
213
214 pub fn solve_with_hints<S, Sys>(
216 problem: &Sys,
217 t0: S,
218 tf: S,
219 y0: &[S],
220 options: &SolverOptions<S>,
221 hints: &SolverHints,
222 ) -> Result<SolverResult<S>, SolverError>
223 where
224 S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
225 Sys: OdeSystem<S>,
226 {
227 let accuracy = hints
229 .accuracy
230 .unwrap_or_else(|| Self::classify_accuracy(options));
231
232 let stiffness = hints.stiffness.unwrap_or_else(|| {
234 if hints.detect_stiffness {
235 Self::detect_stiffness(problem, t0, y0, options)
236 } else {
237 Stiffness::Unknown
238 }
239 });
240
241 match (stiffness, accuracy, hints.prefer_implicit) {
243 (Stiffness::NonStiff, Accuracy::Low, false)
245 | (Stiffness::NonStiff, Accuracy::Standard, false) => {
246 Tsit5::solve(problem, t0, tf, y0, options)
247 }
248 (Stiffness::NonStiff, Accuracy::High, false) => {
249 Vern6::solve(problem, t0, tf, y0, options)
250 }
251 (Stiffness::NonStiff, Accuracy::VeryHigh, false) => {
252 Vern8::solve(problem, t0, tf, y0, options)
253 }
254
255 (Stiffness::ModeratelyStiff, _, _) => Esdirk54::solve(problem, t0, tf, y0, options),
257
258 (Stiffness::VeryStiff, Accuracy::Low, _)
260 | (Stiffness::VeryStiff, Accuracy::Standard, _) => {
261 Bdf::solve(problem, t0, tf, y0, options)
262 }
263 (Stiffness::VeryStiff, Accuracy::High, _)
264 | (Stiffness::VeryStiff, Accuracy::VeryHigh, _) => {
265 Radau5::solve(problem, t0, tf, y0, options)
266 }
267
268 (_, _, true) => Esdirk54::solve(problem, t0, tf, y0, options),
270
271 (Stiffness::Unknown, _, _) => {
273 match Tsit5::solve(problem, t0, tf, y0, options) {
275 Ok(result) => {
276 if result.stats.n_reject < result.stats.n_accept {
278 return Ok(result);
279 }
280 }
281 Err(_) => {}
282 }
283
284 Esdirk54::solve(problem, t0, tf, y0, options)
286 }
287 }
288 }
289}
290
291impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> Solver<S> for Auto {
292 fn solve<Sys: OdeSystem<S>>(
293 problem: &Sys,
294 t0: S,
295 tf: S,
296 y0: &[S],
297 options: &SolverOptions<S>,
298 ) -> Result<SolverResult<S>, SolverError> {
299 let hints = SolverHints::new();
300 Self::solve_with_hints(problem, t0, tf, y0, options, &hints)
301 }
302}
303
304pub fn auto_solve<S, Sys>(
306 problem: &Sys,
307 t0: S,
308 tf: S,
309 y0: &[S],
310 options: &SolverOptions<S>,
311) -> Result<SolverResult<S>, SolverError>
312where
313 S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
314 Sys: OdeSystem<S>,
315{
316 Auto::solve(problem, t0, tf, y0, options)
317}
318
319pub fn auto_solve_with_hints<S, Sys>(
321 problem: &Sys,
322 t0: S,
323 tf: S,
324 y0: &[S],
325 options: &SolverOptions<S>,
326 hints: &SolverHints,
327) -> Result<SolverResult<S>, SolverError>
328where
329 S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
330 Sys: OdeSystem<S>,
331{
332 Auto::solve_with_hints(problem, t0, tf, y0, options, hints)
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::problem::OdeProblem;
339
340 #[test]
341 fn test_auto_nonstiff() {
342 let problem = OdeProblem::new(
343 |_t, y: &[f64], dydt: &mut [f64]| {
344 dydt[0] = -y[0];
345 },
346 0.0,
347 5.0,
348 vec![1.0],
349 );
350 let options = SolverOptions::default().rtol(1e-6);
351 let result = Auto::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
352
353 assert!(result.success);
354 let y_final = result.y_final().unwrap();
355 let expected = (-5.0_f64).exp();
356 assert!((y_final[0] - expected).abs() < 1e-4);
357 }
358
359 #[test]
360 fn test_auto_stiff() {
361 let problem = OdeProblem::new(
364 |_t, y: &[f64], dydt: &mut [f64]| {
365 dydt[0] = -100.0 * y[0];
366 },
367 0.0,
368 0.1,
369 vec![1.0],
370 );
371 let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
372 let hints = SolverHints::new().stiffness(Stiffness::ModeratelyStiff);
374
375 let result = Auto::solve_with_hints(&problem, 0.0, 0.1, &[1.0], &options, &hints).unwrap();
376
377 assert!(result.success);
378 let y_final = result.y_final().unwrap();
379 let expected = (-10.0_f64).exp();
380 assert!(
381 (y_final[0] - expected).abs() < 0.05,
382 "Auto stiff: got {}, expected {}",
383 y_final[0],
384 expected
385 );
386 }
387
388 #[test]
389 fn test_auto_high_accuracy() {
390 let problem = OdeProblem::new(
391 |_t, y: &[f64], dydt: &mut [f64]| {
392 dydt[0] = y[1];
393 dydt[1] = -y[0];
394 },
395 0.0,
396 10.0,
397 vec![1.0, 0.0],
398 );
399 let options = SolverOptions::default().rtol(1e-5).atol(1e-7);
401 let hints = SolverHints::new().stiffness(Stiffness::NonStiff);
402
403 let result =
404 Auto::solve_with_hints(&problem, 0.0, 10.0, &[1.0, 0.0], &options, &hints).unwrap();
405
406 assert!(result.success);
407 let y_final = result.y_final().unwrap();
408 assert!(
410 (y_final[0] - 10.0_f64.cos()).abs() < 1e-3,
411 "Auto high accuracy: got {}, expected {}",
412 y_final[0],
413 10.0_f64.cos()
414 );
415 }
416
417 #[test]
418 fn test_auto_detect_stiffness() {
419 let problem1 = OdeProblem::new(
421 |_t, y: &[f64], dydt: &mut [f64]| {
422 dydt[0] = -y[0];
423 },
424 0.0,
425 1.0,
426 vec![1.0],
427 );
428 let options = SolverOptions::default();
429 let stiffness1 = Auto::detect_stiffness(&problem1, 0.0, &[1.0], &options);
430 assert_eq!(stiffness1, Stiffness::NonStiff);
431
432 let problem2 = OdeProblem::new(
434 |_t, y: &[f64], dydt: &mut [f64]| {
435 dydt[0] = -1000.0 * y[0] + 0.01 * y[1];
436 dydt[1] = 0.01 * y[0] - y[1];
437 },
438 0.0,
439 1.0,
440 vec![1.0, 1.0],
441 );
442 let stiffness2 = Auto::detect_stiffness(&problem2, 0.0, &[1.0, 1.0], &options);
443 assert!(stiffness2 == Stiffness::VeryStiff || stiffness2 == Stiffness::ModeratelyStiff);
444 }
445
446 #[test]
447 fn test_accuracy_classification() {
448 let opts_low: SolverOptions<f64> = SolverOptions::default().rtol(1e-2);
449 let opts_std: SolverOptions<f64> = SolverOptions::default().rtol(1e-6);
450 let opts_high: SolverOptions<f64> = SolverOptions::default().rtol(1e-10);
451 let opts_vhigh: SolverOptions<f64> = SolverOptions::default().rtol(1e-13);
452
453 assert_eq!(Auto::classify_accuracy(&opts_low), Accuracy::Low);
454 assert_eq!(Auto::classify_accuracy(&opts_std), Accuracy::Standard);
455 assert_eq!(Auto::classify_accuracy(&opts_high), Accuracy::High);
456 assert_eq!(Auto::classify_accuracy(&opts_vhigh), Accuracy::VeryHigh);
457 }
458
459 #[test]
460 fn test_auto_convenience() {
461 let problem = OdeProblem::new(
462 |_t, y: &[f64], dydt: &mut [f64]| {
463 dydt[0] = -y[0];
464 },
465 0.0,
466 2.0,
467 vec![1.0],
468 );
469 let options = SolverOptions::default();
470
471 let result = auto_solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
472 assert!(result.success);
473 }
474}