1use crate::dense::DenseOutput;
10use crate::error::SolverError;
11use crate::events::{Event, EventFunction};
12use crate::problem::OdeSystem;
13use core::fmt;
14use numra_core::Scalar;
15use std::sync::Arc;
16
17pub struct SolverOptions<S: Scalar> {
21 pub rtol: S,
23 pub atol: S,
25 pub h0: Option<S>,
27 pub h_max: S,
29 pub h_min: S,
31 pub max_steps: usize,
33 pub t_eval: Option<Vec<S>>,
38 pub dense_output: bool,
40 pub events: Vec<Arc<dyn EventFunction<S>>>,
42}
43
44impl<S: Scalar> Clone for SolverOptions<S> {
45 fn clone(&self) -> Self {
46 Self {
47 rtol: self.rtol,
48 atol: self.atol,
49 h0: self.h0,
50 h_max: self.h_max,
51 h_min: self.h_min,
52 max_steps: self.max_steps,
53 t_eval: self.t_eval.clone(),
54 dense_output: self.dense_output,
55 events: self.events.clone(),
56 }
57 }
58}
59
60impl<S: Scalar> fmt::Debug for SolverOptions<S> {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 f.debug_struct("SolverOptions")
63 .field("rtol", &self.rtol)
64 .field("atol", &self.atol)
65 .field("h0", &self.h0)
66 .field("h_max", &self.h_max)
67 .field("h_min", &self.h_min)
68 .field("max_steps", &self.max_steps)
69 .field("t_eval", &self.t_eval)
70 .field("dense_output", &self.dense_output)
71 .field("events", &format!("[{} event(s)]", self.events.len()))
72 .finish()
73 }
74}
75
76impl<S: Scalar> Default for SolverOptions<S> {
77 fn default() -> Self {
78 Self {
79 rtol: S::from_f64(1e-6),
80 atol: S::from_f64(1e-9),
81 h0: None,
82 h_max: S::INFINITY,
83 h_min: S::EPSILON * S::from_f64(100.0),
88 max_steps: 100_000,
89 t_eval: None,
90 dense_output: false,
91 events: Vec::new(),
92 }
93 }
94}
95
96impl<S: Scalar> SolverOptions<S> {
97 pub fn rtol(mut self, rtol: S) -> Self {
99 self.rtol = rtol;
100 self
101 }
102
103 pub fn atol(mut self, atol: S) -> Self {
105 self.atol = atol;
106 self
107 }
108
109 pub fn h0(mut self, h0: S) -> Self {
111 self.h0 = Some(h0);
112 self
113 }
114
115 pub fn h_max(mut self, h_max: S) -> Self {
117 self.h_max = h_max;
118 self
119 }
120
121 pub fn t_eval(mut self, t_eval: Vec<S>) -> Self {
123 self.t_eval = Some(t_eval);
124 self
125 }
126
127 pub fn dense(mut self) -> Self {
129 self.dense_output = true;
130 self
131 }
132
133 pub fn max_steps(mut self, max_steps: usize) -> Self {
135 self.max_steps = max_steps;
136 self
137 }
138
139 pub fn h_min(mut self, h_min: S) -> Self {
141 self.h_min = h_min;
142 self
143 }
144
145 pub fn event(mut self, event: Box<dyn EventFunction<S>>) -> Self {
149 self.events.push(Arc::from(event));
150 self
151 }
152}
153
154#[derive(Clone, Debug, Default)]
156pub struct SolverStats {
157 pub n_eval: usize,
159 pub n_jac: usize,
161 pub n_accept: usize,
163 pub n_reject: usize,
165 pub n_lu: usize,
167}
168
169impl SolverStats {
170 pub fn new() -> Self {
171 Self::default()
172 }
173}
174
175#[derive(Clone, Debug)]
177pub struct SolverResult<S: Scalar> {
178 pub t: Vec<S>,
180 pub y: Vec<S>,
182 pub dim: usize,
184 pub stats: SolverStats,
186 pub success: bool,
188 pub message: String,
190 pub events: Vec<Event<S>>,
192 pub terminated_by_event: bool,
194 pub dense_output: Option<DenseOutput<S>>,
196}
197
198impl<S: Scalar> SolverResult<S> {
199 pub fn new(t: Vec<S>, y: Vec<S>, dim: usize, stats: SolverStats) -> Self {
201 Self {
202 t,
203 y,
204 dim,
205 stats,
206 success: true,
207 message: String::new(),
208 events: Vec::new(),
209 terminated_by_event: false,
210 dense_output: None,
211 }
212 }
213
214 pub fn failed(message: String, stats: SolverStats) -> Self {
216 Self {
217 t: Vec::new(),
218 y: Vec::new(),
219 dim: 0,
220 stats,
221 success: false,
222 message,
223 events: Vec::new(),
224 terminated_by_event: false,
225 dense_output: None,
226 }
227 }
228
229 pub fn len(&self) -> usize {
231 self.t.len()
232 }
233
234 pub fn is_empty(&self) -> bool {
236 self.t.is_empty()
237 }
238
239 pub fn t_final(&self) -> Option<S> {
241 self.t.last().copied()
242 }
243
244 pub fn y_final(&self) -> Option<Vec<S>> {
246 if self.t.is_empty() {
247 None
248 } else {
249 let start = (self.t.len() - 1) * self.dim;
250 Some(self.y[start..start + self.dim].to_vec())
251 }
252 }
253
254 pub fn y_at(&self, i: usize) -> &[S] {
256 let start = i * self.dim;
257 &self.y[start..start + self.dim]
258 }
259
260 pub fn n_steps(&self) -> usize {
262 self.y.len().checked_div(self.dim).unwrap_or(0)
263 }
264
265 pub fn component(&self, j: usize) -> Option<Vec<S>> {
271 if j >= self.dim {
272 return None;
273 }
274 Some(
275 (0..self.n_steps())
276 .map(|i| self.y[i * self.dim + j])
277 .collect(),
278 )
279 }
280
281 pub fn iter(&self) -> impl Iterator<Item = (S, &[S])> {
283 self.t
284 .iter()
285 .enumerate()
286 .map(move |(i, &t)| (t, self.y_at(i)))
287 }
288}
289
290pub trait Solver<S: Scalar> {
292 fn solve<Sys: OdeSystem<S>>(
294 problem: &Sys,
295 t0: S,
296 tf: S,
297 y0: &[S],
298 options: &SolverOptions<S>,
299 ) -> Result<SolverResult<S>, SolverError>;
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_solver_options_default() {
308 let opts: SolverOptions<f64> = SolverOptions::default();
309 assert!((opts.rtol - 1e-6).abs() < 1e-10);
310 assert!((opts.atol - 1e-9).abs() < 1e-15);
311 }
312
313 #[test]
314 fn test_solver_options_builder() {
315 let opts: SolverOptions<f64> = SolverOptions::default().rtol(1e-8).atol(1e-10).h0(0.01);
316 assert!((opts.rtol - 1e-8).abs() < 1e-15);
317 assert!((opts.atol - 1e-10).abs() < 1e-15);
318 assert!((opts.h0.unwrap() - 0.01).abs() < 1e-15);
319 }
320
321 #[test]
322 fn test_solver_result() {
323 let t = vec![0.0, 0.5, 1.0];
324 let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0]; let result = SolverResult::new(t, y, 2, SolverStats::new());
326
327 assert_eq!(result.len(), 3);
328 assert!((result.t_final().unwrap() - 1.0).abs() < 1e-10);
329
330 let y_final = result.y_final().unwrap();
331 assert!((y_final[0] - 0.2).abs() < 1e-10);
332 assert!((y_final[1] - 1.0).abs() < 1e-10);
333
334 assert_eq!(result.y_at(0), &[1.0, 2.0]);
335 assert_eq!(result.y_at(1), &[0.5, 1.5]);
336 }
337
338 #[test]
339 fn test_n_steps() {
340 let t = vec![0.0, 0.5, 1.0];
341 let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0];
342 let result = SolverResult::new(t, y, 2, SolverStats::new());
343 assert_eq!(result.n_steps(), 3);
344
345 let empty = SolverResult::<f64>::failed("err".to_string(), SolverStats::new());
346 assert_eq!(empty.n_steps(), 0);
347 }
348
349 #[test]
350 fn test_component() {
351 let t = vec![0.0, 0.5, 1.0];
352 let y = vec![1.0, 2.0, 0.5, 1.5, 0.2, 1.0];
354 let result = SolverResult::new(t, y, 2, SolverStats::new());
355
356 let comp0 = result.component(0).unwrap();
357 assert_eq!(comp0, vec![1.0, 0.5, 0.2]);
358
359 let comp1 = result.component(1).unwrap();
360 assert_eq!(comp1, vec![2.0, 1.5, 1.0]);
361 }
362
363 #[test]
364 fn test_component_out_of_bounds() {
365 let t = vec![0.0];
366 let y = vec![1.0, 2.0];
367 let result = SolverResult::new(t, y, 2, SolverStats::new());
368 assert!(result.component(2).is_none());
369 }
370}