1use faer::{ComplexField, Conjugate, SimpleEntity};
31use numra_core::Scalar;
32
33use crate::bdf::Bdf;
34use crate::error::SolverError;
35use crate::esdirk::Esdirk54;
36use crate::problem::OdeSystem;
37use crate::radau5::Radau5;
38use crate::solver::{Solver, SolverOptions, SolverResult};
39use crate::tsit5::Tsit5;
40use crate::verner::{Vern6, Vern8};
41
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44pub enum Stiffness {
45 NonStiff,
47 ModeratelyStiff,
49 VeryStiff,
51 Unknown,
53}
54
55#[derive(Clone, Copy, Debug, PartialEq, Eq)]
57pub enum Accuracy {
58 Low,
60 Standard,
62 High,
64 VeryHigh,
66}
67
68#[derive(Clone, Debug, Default)]
70pub struct SolverHints {
71 pub stiffness: Option<Stiffness>,
73 pub accuracy: Option<Accuracy>,
75 pub prefer_implicit: bool,
77 pub detect_stiffness: bool,
79}
80
81impl SolverHints {
82 pub fn new() -> Self {
84 Self {
85 stiffness: None,
86 accuracy: None,
87 prefer_implicit: false,
88 detect_stiffness: true,
89 }
90 }
91
92 pub fn stiffness(mut self, stiffness: Stiffness) -> Self {
94 self.stiffness = Some(stiffness);
95 self
96 }
97
98 pub fn accuracy(mut self, accuracy: Accuracy) -> Self {
100 self.accuracy = Some(accuracy);
101 self
102 }
103
104 pub fn implicit(mut self) -> Self {
106 self.prefer_implicit = true;
107 self
108 }
109
110 pub fn detect_stiffness(mut self, detect: bool) -> Self {
112 self.detect_stiffness = detect;
113 self
114 }
115}
116
117fn classify_accuracy<S: Scalar>(options: &SolverOptions<S>) -> Accuracy {
119 let rtol = options.rtol.to_f64();
120 if rtol >= 1e-3 {
121 Accuracy::Low
122 } else if rtol >= 1e-7 {
123 Accuracy::Standard
124 } else if rtol >= 1e-11 {
125 Accuracy::High
126 } else {
127 Accuracy::VeryHigh
128 }
129}
130
131fn detect_stiffness<S, Sys>(problem: &Sys, t: S, y: &[S], _options: &SolverOptions<S>) -> Stiffness
133where
134 S: Scalar,
135 Sys: OdeSystem<S>,
136{
137 let dim = problem.dim();
138 if dim == 0 {
139 return Stiffness::Unknown;
140 }
141
142 let h_factor = S::EPSILON.sqrt();
144 let mut f0 = vec![S::ZERO; dim];
145 let mut f1 = vec![S::ZERO; dim];
146 let _jv = vec![S::ZERO; dim];
147
148 problem.rhs(t, y, &mut f0);
149
150 let mut max_jac = S::ZERO;
152 let mut min_jac = S::INFINITY;
153 let mut y_pert = y.to_vec();
154
155 for j in 0..dim.min(10) {
156 let yj = y[j];
158 let h = h_factor * (S::ONE + yj.abs());
159 y_pert[j] = yj + h;
160 problem.rhs(t, &y_pert, &mut f1);
161 y_pert[j] = yj;
162
163 for i in 0..dim {
164 let jij = ((f1[i] - f0[i]) / h).abs();
165 if jij > S::from_f64(1e-15) {
166 max_jac = max_jac.max(jij);
167 min_jac = min_jac.min(jij);
168 }
169 }
170 }
171
172 if max_jac < S::from_f64(1e-10) {
174 return Stiffness::NonStiff;
175 }
176
177 let ratio = max_jac / min_jac.max(S::from_f64(1e-15));
178 let ratio_f64 = ratio.to_f64();
179
180 if ratio_f64 > 1e4 {
181 Stiffness::VeryStiff
182 } else if ratio_f64 > 100.0 {
183 Stiffness::ModeratelyStiff
184 } else {
185 Stiffness::NonStiff
186 }
187}
188
189pub fn auto_solve<S, Sys>(
191 problem: &Sys,
192 t0: S,
193 tf: S,
194 y0: &[S],
195 options: &SolverOptions<S>,
196) -> Result<SolverResult<S>, SolverError>
197where
198 S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
199 Sys: OdeSystem<S>,
200{
201 auto_solve_with_hints(problem, t0, tf, y0, options, &SolverHints::new())
202}
203
204pub fn auto_solve_with_hints<S, Sys>(
206 problem: &Sys,
207 t0: S,
208 tf: S,
209 y0: &[S],
210 options: &SolverOptions<S>,
211 hints: &SolverHints,
212) -> Result<SolverResult<S>, SolverError>
213where
214 S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField,
215 Sys: OdeSystem<S>,
216{
217 let accuracy = hints.accuracy.unwrap_or_else(|| classify_accuracy(options));
219
220 let stiffness = hints.stiffness.unwrap_or_else(|| {
222 if hints.detect_stiffness {
223 detect_stiffness(problem, t0, y0, options)
224 } else {
225 Stiffness::Unknown
226 }
227 });
228
229 match (stiffness, accuracy, hints.prefer_implicit) {
231 (Stiffness::NonStiff, Accuracy::Low, false)
233 | (Stiffness::NonStiff, Accuracy::Standard, false) => {
234 Tsit5::solve(problem, t0, tf, y0, options)
235 }
236 (Stiffness::NonStiff, Accuracy::High, false) => Vern6::solve(problem, t0, tf, y0, options),
237 (Stiffness::NonStiff, Accuracy::VeryHigh, false) => {
238 Vern8::solve(problem, t0, tf, y0, options)
239 }
240
241 (Stiffness::ModeratelyStiff, _, _) => Esdirk54::solve(problem, t0, tf, y0, options),
243
244 (Stiffness::VeryStiff, Accuracy::Low, _)
246 | (Stiffness::VeryStiff, Accuracy::Standard, _) => Bdf::solve(problem, t0, tf, y0, options),
247 (Stiffness::VeryStiff, Accuracy::High, _)
248 | (Stiffness::VeryStiff, Accuracy::VeryHigh, _) => {
249 Radau5::solve(problem, t0, tf, y0, options)
250 }
251
252 (_, _, true) => Esdirk54::solve(problem, t0, tf, y0, options),
254
255 (Stiffness::Unknown, _, _) => {
257 if let Ok(result) = Tsit5::solve(problem, t0, tf, y0, options) {
259 if result.stats.n_reject < result.stats.n_accept {
261 return Ok(result);
262 }
263 }
264
265 Esdirk54::solve(problem, t0, tf, y0, options)
267 }
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::problem::OdeProblem;
275
276 #[test]
277 fn test_auto_nonstiff() {
278 let problem = OdeProblem::new(
279 |_t, y: &[f64], dydt: &mut [f64]| {
280 dydt[0] = -y[0];
281 },
282 0.0,
283 5.0,
284 vec![1.0],
285 );
286 let options = SolverOptions::default().rtol(1e-6);
287 let result = auto_solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
288
289 assert!(result.success);
290 let y_final = result.y_final().unwrap();
291 let expected = (-5.0_f64).exp();
292 assert!((y_final[0] - expected).abs() < 1e-4);
293 }
294
295 #[test]
296 fn test_auto_stiff() {
297 let problem = OdeProblem::new(
300 |_t, y: &[f64], dydt: &mut [f64]| {
301 dydt[0] = -100.0 * y[0];
302 },
303 0.0,
304 0.1,
305 vec![1.0],
306 );
307 let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
308 let hints = SolverHints::new().stiffness(Stiffness::ModeratelyStiff);
310
311 let result = auto_solve_with_hints(&problem, 0.0, 0.1, &[1.0], &options, &hints).unwrap();
312
313 assert!(result.success);
314 let y_final = result.y_final().unwrap();
315 let expected = (-10.0_f64).exp();
316 assert!(
317 (y_final[0] - expected).abs() < 0.05,
318 "stiff: got {}, expected {}",
319 y_final[0],
320 expected
321 );
322 }
323
324 #[test]
325 fn test_auto_high_accuracy() {
326 let problem = OdeProblem::new(
327 |_t, y: &[f64], dydt: &mut [f64]| {
328 dydt[0] = y[1];
329 dydt[1] = -y[0];
330 },
331 0.0,
332 10.0,
333 vec![1.0, 0.0],
334 );
335 let options = SolverOptions::default().rtol(1e-5).atol(1e-7);
337 let hints = SolverHints::new().stiffness(Stiffness::NonStiff);
338
339 let result =
340 auto_solve_with_hints(&problem, 0.0, 10.0, &[1.0, 0.0], &options, &hints).unwrap();
341
342 assert!(result.success);
343 let y_final = result.y_final().unwrap();
344 assert!(
346 (y_final[0] - 10.0_f64.cos()).abs() < 1e-3,
347 "high accuracy: got {}, expected {}",
348 y_final[0],
349 10.0_f64.cos()
350 );
351 }
352
353 #[test]
354 fn test_auto_detect_stiffness() {
355 let problem1 = OdeProblem::new(
357 |_t, y: &[f64], dydt: &mut [f64]| {
358 dydt[0] = -y[0];
359 },
360 0.0,
361 1.0,
362 vec![1.0],
363 );
364 let options = SolverOptions::default();
365 let stiffness1 = detect_stiffness(&problem1, 0.0, &[1.0], &options);
366 assert_eq!(stiffness1, Stiffness::NonStiff);
367
368 let problem2 = OdeProblem::new(
370 |_t, y: &[f64], dydt: &mut [f64]| {
371 dydt[0] = -1000.0 * y[0] + 0.01 * y[1];
372 dydt[1] = 0.01 * y[0] - y[1];
373 },
374 0.0,
375 1.0,
376 vec![1.0, 1.0],
377 );
378 let stiffness2 = detect_stiffness(&problem2, 0.0, &[1.0, 1.0], &options);
379 assert!(stiffness2 == Stiffness::VeryStiff || stiffness2 == Stiffness::ModeratelyStiff);
380 }
381
382 #[test]
383 fn test_accuracy_classification() {
384 let opts_low: SolverOptions<f64> = SolverOptions::default().rtol(1e-2);
385 let opts_std: SolverOptions<f64> = SolverOptions::default().rtol(1e-6);
386 let opts_high: SolverOptions<f64> = SolverOptions::default().rtol(1e-10);
387 let opts_vhigh: SolverOptions<f64> = SolverOptions::default().rtol(1e-13);
388
389 assert_eq!(classify_accuracy(&opts_low), Accuracy::Low);
390 assert_eq!(classify_accuracy(&opts_std), Accuracy::Standard);
391 assert_eq!(classify_accuracy(&opts_high), Accuracy::High);
392 assert_eq!(classify_accuracy(&opts_vhigh), Accuracy::VeryHigh);
393 }
394
395 #[test]
396 fn test_auto_convenience() {
397 let problem = OdeProblem::new(
398 |_t, y: &[f64], dydt: &mut [f64]| {
399 dydt[0] = -y[0];
400 },
401 0.0,
402 2.0,
403 vec![1.0],
404 );
405 let options = SolverOptions::default();
406
407 let result = auto_solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
408 assert!(result.success);
409 }
410}