1use crate::{BVector, Dimension, DimensionError};
8use nalgebra::{allocator::Allocator, ComplexField, DefaultAllocator, Dim, RealField, U1};
9use num_traits::{FromPrimitive, Zero};
10use std::{error::Error, marker::PhantomData};
11use thiserror::Error;
12
13pub mod adams;
14pub mod bdf;
15pub mod rk;
16
17#[derive(Error, Clone, Debug)]
21pub enum IVPStatus<T: Error> {
22 #[error("the solver needs the step to be re-done")]
23 Redo,
24 #[error("the solver is complete")]
25 Done,
26 #[error("unspecified solver error: {0}")]
27 Failure(#[from] T),
28}
29
30pub type UserError = Box<dyn Error>;
32
33pub trait Derivative<N: ComplexField + Copy, D: Dim, T: Clone>:
35 FnMut(N::RealField, &[N], &mut T) -> Result<BVector<N, D>, UserError>
36where
37 DefaultAllocator: Allocator<N, D>,
38{
39}
40
41impl<N, D: Dim, T, F> Derivative<N, D, T> for F
42where
43 N: ComplexField + Copy,
44 T: Clone,
45 F: FnMut(N::RealField, &[N], &mut T) -> Result<BVector<N, D>, UserError>,
46 DefaultAllocator: Allocator<N, D>,
47{
48}
49
50#[derive(Error, Debug)]
51pub enum IVPError {
52 #[error("the solver does not have all required parameters set")]
53 MissingParameters,
54 #[error("the solver hit an error from the user-provided derivative function: {0}")]
55 UserError(#[from] UserError),
56 #[error("the provided tolerance was out-of-bounds")]
57 ToleranceOOB,
58 #[error("the provided time delta was out-of-bounds")]
59 TimeDeltaOOB,
60 #[error("the provided ending time was before the provided starting time")]
61 TimeEndOOB,
62 #[error("the provided starting time was after the provided ending time")]
63 TimeStartOOB,
64 #[error("a conversion from a necessary primitive failed")]
65 FromPrimitiveFailure,
66 #[error("the time step fell below the paramater minimum allowed value")]
67 MinimumTimeDeltaExceeded,
68 #[error("the number of iterations exceeded the maximum allowable")]
69 MaximumIterationsExceeded,
70 #[error("a matrix was unable to be inverted")]
71 SingularMatrix,
72 #[error("attempted to build a dynamic solver with static dimension")]
73 DynamicOnStatic,
74 #[error("attempted to build a static solver with dynamic dimension")]
75 StaticOnDynamic,
76}
77
78impl From<UserError> for IVPStatus<IVPError> {
79 fn from(value: UserError) -> Self {
80 Self::Failure(IVPError::UserError(value))
81 }
82}
83
84impl From<DimensionError> for IVPError {
85 fn from(value: DimensionError) -> Self {
86 match value {
87 DimensionError::DynamicOnStatic => Self::DynamicOnStatic,
88 DimensionError::StaticOnDynamic => Self::StaticOnDynamic,
89 }
90 }
91}
92
93pub type Step<R, C, D, E> = Result<(R, BVector<C, D>), IVPStatus<E>>;
97
98pub trait IVPStepper<D: Dimension>: Sized
103where
104 DefaultAllocator: Allocator<Self::Field, D>,
105{
106 type Error: Error + From<IVPError>;
108 type Field: ComplexField + Copy;
110 type RealField: RealField;
112 type UserData: Clone;
117
118 fn step(&mut self) -> Step<Self::RealField, Self::Field, D, Self::Error>;
122
123 fn time(&self) -> Self::RealField;
125}
126
127pub trait IVPSolver<'a, D: Dimension>: Sized
135where
136 DefaultAllocator: Allocator<Self::Field, D>,
137{
138 type Error: Error + From<IVPError>;
140 type Field: ComplexField + Copy;
142 type RealField: RealField;
144 type UserData: Clone;
146 type Derivative: Derivative<Self::Field, D, Self::UserData> + 'a;
148 type Solver: IVPStepper<
150 D,
151 Error = Self::Error,
152 Field = Self::Field,
153 RealField = Self::RealField,
154 UserData = Self::UserData,
155 >;
156
157 fn new() -> Result<Self, Self::Error>;
160
161 fn new_dyn(size: usize) -> Result<Self, Self::Error>;
164
165 fn dim(&self) -> D;
167
168 fn with_tolerance(self, tol: Self::RealField) -> Result<Self, Self::Error>;
170
171 fn with_maximum_dt(self, max: Self::RealField) -> Result<Self, Self::Error>;
172 fn with_minimum_dt(self, min: Self::RealField) -> Result<Self, Self::Error>;
173 fn with_initial_time(self, initial: Self::RealField) -> Result<Self, Self::Error>;
174 fn with_ending_time(self, ending: Self::RealField) -> Result<Self, Self::Error>;
175
176 fn with_initial_conditions_slice(self, start: &[Self::Field]) -> Result<Self, Self::Error> {
178 let svector = BVector::from_column_slice_generic(self.dim(), U1::from_usize(1), start);
179 self.with_initial_conditions(svector)
180 }
181
182 fn with_initial_conditions(self, start: BVector<Self::Field, D>) -> Result<Self, Self::Error>;
184
185 fn with_derivative(self, derivative: Self::Derivative) -> Self;
187
188 fn solve(self, data: Self::UserData) -> Result<IVPIterator<D, Self::Solver>, Self::Error>;
190}
191
192pub struct IVPIterator<D: Dimension, T: IVPStepper<D>>
193where
194 DefaultAllocator: Allocator<T::Field, D>,
195{
196 solver: T,
197 finished: bool,
198 _dim: PhantomData<D>,
199}
200
201pub type Path<R, C, D, E> = Result<Vec<(R, BVector<C, D>)>, E>;
204
205impl<D: Dimension, T: IVPStepper<D>> IVPIterator<D, T>
206where
207 DefaultAllocator: Allocator<T::Field, D>,
208{
209 pub fn collect_vec(self) -> Path<T::RealField, T::Field, D, T::Error> {
210 self.collect::<Result<Vec<_>, _>>()
211 }
212}
213
214impl<D: Dimension, T: IVPStepper<D>> Iterator for IVPIterator<D, T>
215where
216 DefaultAllocator: Allocator<T::Field, D>,
217{
218 type Item = Result<(T::RealField, BVector<T::Field, D>), T::Error>;
219
220 fn next(&mut self) -> Option<Self::Item> {
221 use IVPStatus as IE;
222
223 if self.finished {
224 return None;
225 }
226
227 loop {
228 match self.solver.step() {
229 Ok(vec) => break Some(Ok(vec)),
230 Err(IE::Done) => break None,
231 Err(IE::Redo) => continue,
232 Err(IE::Failure(e)) => {
233 self.finished = true;
234 break Some(Err(e));
235 }
236 }
237 }
238 }
239}
240
241pub struct Euler<'a, N, D, T, F>
270where
271 N: ComplexField + Copy,
272 D: Dimension,
273 T: Clone,
274 F: Derivative<N, D, T> + 'a,
275 DefaultAllocator: Allocator<N, D>,
276{
277 init_dt: Option<N::RealField>,
278 init_time: Option<N::RealField>,
279 init_end: Option<N::RealField>,
280 init_state: Option<BVector<N, D>>,
281 init_derivative: Option<F>,
282 dim: D,
283 _data: PhantomData<&'a T>,
284}
285
286pub struct EulerSolver<'a, N, D, T, F>
290where
291 N: ComplexField + Copy,
292 D: Dimension,
293 T: Clone,
294 F: Derivative<N, D, T> + 'a,
295 DefaultAllocator: Allocator<N, D>,
296{
297 dt: N,
298 time: N,
299 end: N,
300 state: BVector<N, D>,
301 derivative: F,
302 data: T,
303 _lifetime: PhantomData<&'a ()>,
304}
305
306impl<'a, N, D, T, F> IVPStepper<D> for EulerSolver<'a, N, D, T, F>
307where
308 N: ComplexField + Copy,
309 D: Dimension,
310 T: Clone,
311 F: Derivative<N, D, T> + 'a,
312 DefaultAllocator: Allocator<N, D>,
313{
314 type Error = IVPError;
315 type Field = N;
316 type RealField = N::RealField;
317 type UserData = T;
318
319 fn step(
320 &mut self,
321 ) -> Result<(Self::RealField, BVector<Self::Field, D>), IVPStatus<Self::Error>> {
322 if self.time.real() >= self.end.real() {
323 return Err(IVPStatus::Done);
324 }
325 if (self.time + self.dt).real() >= self.end.real() {
326 self.dt = self.end - self.time;
327 }
328
329 let derivative = (self.derivative)(self.time.real(), self.state.as_slice(), &mut self.data)
330 .map_err(IVPError::UserError)?;
331
332 let old_time = self.time.real();
333 let old_state = self.state.clone();
334
335 self.state += derivative * self.dt;
336 self.time += self.dt;
337
338 Ok((old_time, old_state))
339 }
340
341 fn time(&self) -> Self::RealField {
342 self.time.real()
343 }
344}
345
346impl<'a, N, D, T, F> IVPSolver<'a, D> for Euler<'a, N, D, T, F>
347where
348 N: ComplexField + Copy,
349 D: Dimension,
350 T: Clone,
351 F: Derivative<N, D, T> + 'a,
352 DefaultAllocator: Allocator<N, D>,
353{
354 type Error = IVPError;
355 type Field = N;
356 type RealField = N::RealField;
357 type Derivative = F;
358 type UserData = T;
359 type Solver = EulerSolver<'a, N, D, T, F>;
360
361 fn new() -> Result<Self, Self::Error> {
362 Ok(Self {
363 init_dt: None,
364 init_time: None,
365 init_end: None,
366 init_state: None,
367 init_derivative: None,
368 dim: D::dim()?,
369 _data: PhantomData,
370 })
371 }
372
373 fn new_dyn(size: usize) -> Result<Self, Self::Error> {
374 Ok(Self {
375 init_dt: None,
376 init_time: None,
377 init_end: None,
378 init_state: None,
379 init_derivative: None,
380 dim: D::dim_dyn(size)?,
381 _data: PhantomData,
382 })
383 }
384
385 fn dim(&self) -> D {
386 self.dim
387 }
388
389 fn with_tolerance(self, _tol: Self::RealField) -> Result<Self, Self::Error> {
391 Ok(self)
392 }
393
394 fn with_maximum_dt(mut self, max: Self::RealField) -> Result<Self, Self::Error> {
397 if max <= <Self::RealField as Zero>::zero() {
398 return Err(IVPError::TimeDeltaOOB);
399 }
400
401 self.init_dt = if let Some(dt) = self.init_dt {
402 Some((dt + max) / Self::RealField::from_u8(2).ok_or(IVPError::FromPrimitiveFailure)?)
403 } else {
404 Some(max)
405 };
406 Ok(self)
407 }
408
409 fn with_minimum_dt(mut self, min: Self::RealField) -> Result<Self, Self::Error> {
412 if min <= <Self::RealField as Zero>::zero() {
413 return Err(IVPError::TimeDeltaOOB);
414 }
415
416 self.init_dt = if let Some(dt) = self.init_dt {
417 Some((dt + min) / Self::RealField::from_u8(2).ok_or(IVPError::FromPrimitiveFailure)?)
418 } else {
419 Some(min)
420 };
421 Ok(self)
422 }
423
424 fn with_initial_time(mut self, initial: Self::RealField) -> Result<Self, Self::Error> {
425 self.init_time = Some(initial.clone());
426
427 if let Some(end) = self.init_end.as_ref() {
428 if *end <= initial {
429 return Err(IVPError::TimeStartOOB);
430 }
431 }
432
433 Ok(self)
434 }
435
436 fn with_ending_time(mut self, ending: Self::RealField) -> Result<Self, Self::Error> {
437 self.init_end = Some(ending.clone());
438
439 if let Some(initial) = self.init_time.as_ref() {
440 if *initial >= ending {
441 return Err(IVPError::TimeEndOOB);
442 }
443 }
444
445 Ok(self)
446 }
447
448 fn with_initial_conditions(
449 mut self,
450 start: BVector<Self::Field, D>,
451 ) -> Result<Self, Self::Error> {
452 self.init_state = Some(start);
453 Ok(self)
454 }
455
456 fn with_derivative(mut self, derivative: Self::Derivative) -> Self {
457 self.init_derivative = Some(derivative);
458 self
459 }
460
461 fn solve(mut self, data: Self::UserData) -> Result<IVPIterator<D, Self::Solver>, Self::Error> {
462 let dt = self.init_dt.ok_or(IVPError::MissingParameters)?;
463 let time = self.init_time.ok_or(IVPError::MissingParameters)?;
464 let end = self.init_end.ok_or(IVPError::MissingParameters)?;
465 let state = self.init_state.take().ok_or(IVPError::MissingParameters)?;
466 let derivative = self
467 .init_derivative
468 .take()
469 .ok_or(IVPError::MissingParameters)?;
470
471 Ok(IVPIterator {
472 solver: EulerSolver {
473 dt: N::from_real(dt),
474 time: N::from_real(time),
475 end: N::from_real(end),
476 state,
477 derivative,
478 data,
479 _lifetime: PhantomData,
480 },
481 finished: false,
482 _dim: PhantomData,
483 })
484 }
485}
486
487#[cfg(test)]
488mod test {
489 use super::*;
490 use crate::BSVector;
491 use nalgebra::{DimName, Dyn};
492
493 type Path<D> = Vec<(f64, BVector<f64, D>)>;
494
495 fn solve_ivp<'a, D, F>(
496 (initial, end): (f64, f64),
497 dt: f64,
498 initial_conds: &[f64],
499 derivative: F,
500 ) -> Result<Path<D>, IVPError>
501 where
502 D: Dimension,
503 F: Derivative<f64, D, ()> + 'a,
504 DefaultAllocator: Allocator<f64, D>,
505 {
506 let ivp = Euler::new()?
507 .with_initial_time(initial)?
508 .with_ending_time(end)?
509 .with_maximum_dt(dt)?
510 .with_initial_conditions_slice(initial_conds)?
511 .with_derivative(derivative);
512 ivp.solve(())?.collect()
513 }
514
515 fn exp_deriv(_: f64, y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
516 Ok(BSVector::from_column_slice(y))
517 }
518
519 fn quadratic_deriv(t: f64, _y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
520 Ok(BSVector::from_column_slice(&[-2.0 * t]))
521 }
522
523 fn sine_deriv(t: f64, y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
524 Ok(BSVector::from_iterator(y.iter().map(|_| t.cos())))
525 }
526
527 fn cos_deriv(_t: f64, y: &[f64], _: &mut ()) -> Result<BSVector<f64, 2>, UserError> {
528 Ok(BSVector::from_column_slice(&[y[1], -y[0]]))
529 }
530
531 fn dynamic_cos_deriv(_t: f64, y: &[f64], _: &mut ()) -> Result<BVector<f64, Dyn>, UserError> {
532 Ok(BVector::from_column_slice_generic(
533 Dyn::from_usize(y.len()),
534 U1::name(),
535 &[y[1], -y[0]],
536 ))
537 }
538
539 #[test]
540 #[should_panic]
541 fn euler_dynamic_cos_panics() {
542 let t_initial = 0.0;
543 let t_final = 1.0;
544
545 let path = solve_ivp((t_initial, t_final), 0.01, &[1.0, 0.0], dynamic_cos_deriv).unwrap();
546
547 for step in path {
548 assert!(approx_eq!(
549 f64,
550 step.1.column(0)[0],
551 step.0.cos(),
552 epsilon = 0.01
553 ));
554 }
555 }
556
557 #[test]
558 fn euler_dynamic_cos() {
559 let t_initial = 0.0;
560 let t_final = 1.0;
561
562 let ivp = Euler::new_dyn(2)
563 .unwrap()
564 .with_initial_time(t_initial)
565 .unwrap()
566 .with_ending_time(t_final)
567 .unwrap()
568 .with_maximum_dt(0.01)
569 .unwrap()
570 .with_initial_conditions_slice(&[1.0, 0.0])
571 .unwrap()
572 .with_derivative(dynamic_cos_deriv)
573 .solve(())
574 .unwrap();
575 let path = ivp.collect_vec().unwrap();
576
577 for step in path {
578 assert!(approx_eq!(
579 f64,
580 step.1.column(0)[0],
581 step.0.cos(),
582 epsilon = 0.01
583 ));
584 }
585 }
586
587 #[test]
588 fn euler_cos() {
589 let t_initial = 0.0;
590 let t_final = 1.0;
591
592 let path = solve_ivp((t_initial, t_final), 0.01, &[1.0, 0.0], cos_deriv).unwrap();
593
594 for step in path {
595 assert!(approx_eq!(
596 f64,
597 step.1.column(0)[0],
598 step.0.cos(),
599 epsilon = 0.01
600 ));
601 }
602 }
603
604 #[test]
605 fn euler_exp() {
606 let t_initial = 0.0;
607 let t_final = 1.0;
608
609 let path = solve_ivp((t_initial, t_final), 0.005, &[1.0], exp_deriv).unwrap();
610
611 for step in path {
612 assert!(approx_eq!(
613 f64,
614 step.1.column(0)[0],
615 step.0.exp(),
616 epsilon = 0.01
617 ));
618 }
619 }
620
621 #[test]
622 fn euler_quadratic() {
623 let t_initial = 0.0;
624 let t_final = 1.0;
625
626 let path = solve_ivp((t_initial, t_final), 0.01, &[1.0], quadratic_deriv).unwrap();
627
628 for step in path {
629 assert!(approx_eq!(
630 f64,
631 step.1.column(0)[0],
632 1.0 - step.0.powi(2),
633 epsilon = 0.01
634 ));
635 }
636 }
637
638 #[test]
639 fn euler_sin() {
640 let t_initial = 0.0;
641 let t_final = 1.0;
642
643 let path = solve_ivp((t_initial, t_final), 0.01, &[0.0], sine_deriv).unwrap();
644
645 for step in path {
646 assert!(approx_eq!(
647 f64,
648 step.1.column(0)[0],
649 step.0.sin(),
650 epsilon = 0.01
651 ));
652 }
653 }
654}