1use crate::dense::DenseOutput;
10use crate::error::SolverError;
11use crate::events::{Event, EventFunction};
12use crate::problem::OdeSystem;
13use core::fmt;
14use numra_core::Scalar;
15use std::sync::Arc;
16
17pub struct SolverOptions<S: Scalar> {
21 pub rtol: S,
23 pub atol: S,
25 pub h0: Option<S>,
27 pub h_max: S,
29 pub h_min: S,
31 pub max_steps: usize,
33 pub t_eval: Option<Vec<S>>,
38 pub dense_output: bool,
40 pub max_order: Option<usize>,
45 pub min_order: Option<usize>,
54 pub events: Vec<Arc<dyn EventFunction<S>>>,
56}
57
58impl<S: Scalar> Clone for SolverOptions<S> {
59 fn clone(&self) -> Self {
60 Self {
61 rtol: self.rtol,
62 atol: self.atol,
63 h0: self.h0,
64 h_max: self.h_max,
65 h_min: self.h_min,
66 max_steps: self.max_steps,
67 t_eval: self.t_eval.clone(),
68 dense_output: self.dense_output,
69 max_order: self.max_order,
70 min_order: self.min_order,
71 events: self.events.clone(),
72 }
73 }
74}
75
76impl<S: Scalar> fmt::Debug for SolverOptions<S> {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 f.debug_struct("SolverOptions")
79 .field("rtol", &self.rtol)
80 .field("atol", &self.atol)
81 .field("h0", &self.h0)
82 .field("h_max", &self.h_max)
83 .field("h_min", &self.h_min)
84 .field("max_steps", &self.max_steps)
85 .field("t_eval", &self.t_eval)
86 .field("dense_output", &self.dense_output)
87 .field("max_order", &self.max_order)
88 .field("min_order", &self.min_order)
89 .field("events", &format!("[{} event(s)]", self.events.len()))
90 .finish()
91 }
92}
93
94impl<S: Scalar> Default for SolverOptions<S> {
95 fn default() -> Self {
96 Self {
97 rtol: S::from_f64(1e-6),
98 atol: S::from_f64(1e-9),
99 h0: None,
100 h_max: S::INFINITY,
101 h_min: S::EPSILON * S::from_f64(100.0),
106 max_steps: 100_000,
107 t_eval: None,
108 dense_output: false,
109 max_order: None,
110 min_order: None,
111 events: Vec::new(),
112 }
113 }
114}
115
116impl<S: Scalar> SolverOptions<S> {
117 pub fn rtol(mut self, rtol: S) -> Self {
119 self.rtol = rtol;
120 self
121 }
122
123 pub fn atol(mut self, atol: S) -> Self {
125 self.atol = atol;
126 self
127 }
128
129 pub fn h0(mut self, h0: S) -> Self {
131 self.h0 = Some(h0);
132 self
133 }
134
135 pub fn h_max(mut self, h_max: S) -> Self {
137 self.h_max = h_max;
138 self
139 }
140
141 pub fn t_eval(mut self, t_eval: Vec<S>) -> Self {
143 self.t_eval = Some(t_eval);
144 self
145 }
146
147 pub fn dense(mut self) -> Self {
149 self.dense_output = true;
150 self
151 }
152
153 pub fn max_steps(mut self, max_steps: usize) -> Self {
155 self.max_steps = max_steps;
156 self
157 }
158
159 pub fn h_min(mut self, h_min: S) -> Self {
161 self.h_min = h_min;
162 self
163 }
164
165 pub fn max_order(mut self, n: usize) -> Self {
171 self.max_order = Some(n);
172 self
173 }
174
175 pub fn min_order(mut self, n: usize) -> Self {
182 self.min_order = Some(n);
183 self
184 }
185
186 pub fn event(mut self, event: Box<dyn EventFunction<S>>) -> Self {
190 self.events.push(Arc::from(event));
191 self
192 }
193}
194
195#[derive(Clone, Debug, Default)]
197pub struct SolverStats {
198 pub n_eval: usize,
200 pub n_jac: usize,
202 pub n_accept: usize,
204 pub n_reject: usize,
206 pub n_lu: usize,
208}
209
210impl SolverStats {
211 pub fn new() -> Self {
212 Self::default()
213 }
214}
215
216#[derive(Clone, Debug)]
218pub struct SolverResult<S: Scalar> {
219 pub t: Vec<S>,
221 pub y: Vec<S>,
223 pub dim: usize,
225 pub stats: SolverStats,
227 pub success: bool,
229 pub message: String,
231 pub events: Vec<Event<S>>,
233 pub terminated_by_event: bool,
235 pub dense_output: Option<DenseOutput<S>>,
237}
238
239impl<S: Scalar> SolverResult<S> {
240 pub fn new(t: Vec<S>, y: Vec<S>, dim: usize, stats: SolverStats) -> Self {
242 Self {
243 t,
244 y,
245 dim,
246 stats,
247 success: true,
248 message: String::new(),
249 events: Vec::new(),
250 terminated_by_event: false,
251 dense_output: None,
252 }
253 }
254
255 pub fn failed(message: String, stats: SolverStats) -> Self {
257 Self {
258 t: Vec::new(),
259 y: Vec::new(),
260 dim: 0,
261 stats,
262 success: false,
263 message,
264 events: Vec::new(),
265 terminated_by_event: false,
266 dense_output: None,
267 }
268 }
269
270 pub fn len(&self) -> usize {
272 self.t.len()
273 }
274
275 pub fn is_empty(&self) -> bool {
277 self.t.is_empty()
278 }
279
280 pub fn t_final(&self) -> Option<S> {
282 self.t.last().copied()
283 }
284
285 pub fn y_final(&self) -> Option<Vec<S>> {
287 if self.t.is_empty() {
288 None
289 } else {
290 let start = (self.t.len() - 1) * self.dim;
291 Some(self.y[start..start + self.dim].to_vec())
292 }
293 }
294
295 pub fn y_at(&self, i: usize) -> &[S] {
297 let start = i * self.dim;
298 &self.y[start..start + self.dim]
299 }
300
301 pub fn n_steps(&self) -> usize {
303 self.y.len().checked_div(self.dim).unwrap_or(0)
304 }
305
306 pub fn component(&self, j: usize) -> Option<Vec<S>> {
312 if j >= self.dim {
313 return None;
314 }
315 Some(
316 (0..self.n_steps())
317 .map(|i| self.y[i * self.dim + j])
318 .collect(),
319 )
320 }
321
322 pub fn iter(&self) -> impl Iterator<Item = (S, &[S])> {
324 self.t
325 .iter()
326 .enumerate()
327 .map(move |(i, &t)| (t, self.y_at(i)))
328 }
329}
330
331pub trait Solver<S: Scalar> {
333 fn solve<Sys: OdeSystem<S>>(
335 problem: &Sys,
336 t0: S,
337 tf: S,
338 y0: &[S],
339 options: &SolverOptions<S>,
340 ) -> Result<SolverResult<S>, SolverError>;
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_solver_options_default() {
349 let opts: SolverOptions<f64> = SolverOptions::default();
350 assert!((opts.rtol - 1e-6).abs() < 1e-10);
351 assert!((opts.atol - 1e-9).abs() < 1e-15);
352 }
353
354 #[test]
355 fn test_solver_options_builder() {
356 let opts: SolverOptions<f64> = SolverOptions::default().rtol(1e-8).atol(1e-10).h0(0.01);
357 assert!((opts.rtol - 1e-8).abs() < 1e-15);
358 assert!((opts.atol - 1e-10).abs() < 1e-15);
359 assert!((opts.h0.unwrap() - 0.01).abs() < 1e-15);
360 }
361
362 #[test]
363 fn test_solver_result() {
364 let t = vec![0.0, 0.5, 1.0];
365 let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0]; let result = SolverResult::new(t, y, 2, SolverStats::new());
367
368 assert_eq!(result.len(), 3);
369 assert!((result.t_final().unwrap() - 1.0).abs() < 1e-10);
370
371 let y_final = result.y_final().unwrap();
372 assert!((y_final[0] - 0.2).abs() < 1e-10);
373 assert!((y_final[1] - 1.0).abs() < 1e-10);
374
375 assert_eq!(result.y_at(0), &[1.0, 2.0]);
376 assert_eq!(result.y_at(1), &[0.5, 1.5]);
377 }
378
379 #[test]
380 fn test_n_steps() {
381 let t = vec![0.0, 0.5, 1.0];
382 let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0];
383 let result = SolverResult::new(t, y, 2, SolverStats::new());
384 assert_eq!(result.n_steps(), 3);
385
386 let empty = SolverResult::<f64>::failed("err".to_string(), SolverStats::new());
387 assert_eq!(empty.n_steps(), 0);
388 }
389
390 #[test]
391 fn test_component() {
392 let t = vec![0.0, 0.5, 1.0];
393 let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0];
395 let result = SolverResult::new(t, y, 2, SolverStats::new());
396
397 let comp0 = result.component(0).unwrap();
398 assert_eq!(comp0, vec![1.0, 0.5, 0.2]);
399
400 let comp1 = result.component(1).unwrap();
401 assert_eq!(comp1, vec![2.0, 1.5, 1.0]);
402 }
403
404 #[test]
405 fn test_component_out_of_bounds() {
406 let t = vec![0.0];
407 let y = vec![1.0, 2.0];
408 let result = SolverResult::new(t, y, 2, SolverStats::new());
409 assert!(result.component(2).is_none());
410 }
411}