Skip to main content

differential_equations/
ivp.rs

1//! Unified builder for initial value problems.
2//!
3//! The high-level API owns the numerical method and output handler. Solvers mutate
4//! their internal interpolation/history state during integration, so ownership makes
5//! each `IVP::solve` call single-use and avoids accidental solver reuse after a run.
6//! Call the low-level `solve_*` functions directly when reference-based control is
7//! required.
8
9use crate::{
10    dae::{AlgebraicNumericalMethod, DAE, solve_dae},
11    dde::{DDE, DelayNumericalMethod, solve_dde},
12    error::Error,
13    interpolate::Interpolation,
14    methods::ToleranceConfig,
15    ode::{ODE, OrdinaryNumericalMethod, solve_ode},
16    sde::{SDE, StochasticNumericalMethod, solve_sde},
17    solout::{
18        CrossingDirection, CrossingSolout, DefaultSolout, DenseSolout, EvenSolout, Event,
19        EventWrappedSolout, HyperplaneCrossingSolout, Solout, TEvalSolout,
20    },
21    solution::Solution,
22    tolerance::Tolerance,
23    traits::{Real, State},
24};
25
26/// Unified builder for initial value problems (IVPs).
27///
28/// Consolidates solver configurations, output configurations, and events.
29#[derive(Clone, Debug)]
30pub struct IVP<EqType, T: Real, Y: State<T>, Method, SoloutType> {
31    equation: EqType,
32    t0: T,
33    tf: T,
34    y0: Y,
35    method: Method,
36    solout: SoloutType,
37}
38
39/// Marker for ordinary differential equations.
40#[derive(Debug)]
41pub struct OdeEq<'a, F> {
42    ode: &'a F,
43}
44
45impl<F> Clone for OdeEq<'_, F> {
46    fn clone(&self) -> Self {
47        *self
48    }
49}
50
51impl<F> Copy for OdeEq<'_, F> {}
52
53/// Marker for owned ordinary differential equations (from closure).
54#[derive(Debug)]
55pub struct OdeEqOwned<F> {
56    ode: F,
57}
58
59impl<F: Clone> Clone for OdeEqOwned<F> {
60    fn clone(&self) -> Self {
61        Self {
62            ode: self.ode.clone(),
63        }
64    }
65}
66
67impl<F: Copy> Copy for OdeEqOwned<F> {}
68
69/// Marker for differential algebraic equations.
70#[derive(Debug)]
71pub struct DaeEq<'a, F> {
72    dae: &'a F,
73}
74
75impl<F> Clone for DaeEq<'_, F> {
76    fn clone(&self) -> Self {
77        *self
78    }
79}
80
81impl<F> Copy for DaeEq<'_, F> {}
82
83/// Marker for owned differential algebraic equations (from closure).
84#[derive(Debug)]
85pub struct DaeEqOwned<F> {
86    dae: F,
87}
88
89impl<F: Clone> Clone for DaeEqOwned<F> {
90    fn clone(&self) -> Self {
91        Self {
92            dae: self.dae.clone(),
93        }
94    }
95}
96
97impl<F: Copy> Copy for DaeEqOwned<F> {}
98
99/// Marker for stochastic differential equations.
100#[derive(Debug)]
101pub struct SdeEq<'a, F> {
102    sde: &'a mut F,
103}
104
105/// Marker for owned stochastic differential equations (from closure).
106#[derive(Debug)]
107pub struct SdeEqOwned<F> {
108    sde: F,
109}
110
111/// Marker for delay differential equations.
112#[derive(Debug)]
113pub struct DdeEq<'a, const L: usize, F, H> {
114    dde: &'a F,
115    history: H,
116}
117
118impl<const L: usize, F, H: Clone> Clone for DdeEq<'_, L, F, H> {
119    fn clone(&self) -> Self {
120        Self {
121            dde: self.dde,
122            history: self.history.clone(),
123        }
124    }
125}
126
127/// Marker for owned delay differential equations (from closure).
128#[derive(Debug)]
129pub struct DdeEqOwned<const L: usize, F, H> {
130    dde: F,
131    history: H,
132}
133
134impl<const L: usize, F: Clone, H: Clone> Clone for DdeEqOwned<L, F, H> {
135    fn clone(&self) -> Self {
136        Self {
137            dde: self.dde.clone(),
138            history: self.history.clone(),
139        }
140    }
141}
142
143/// Internal wrapper for `ode_from_fn`
144#[derive(Debug)]
145pub struct OdeFnWrapper<F> {
146    f: F,
147}
148
149impl<T, Y, F> ODE<T, Y> for OdeFnWrapper<F>
150where
151    T: Real,
152    Y: State<T>,
153    F: Fn(T, &Y, &mut Y),
154{
155    fn diff(&self, t: T, y: &Y, dydt: &mut Y) {
156        (self.f)(t, y, dydt)
157    }
158}
159
160/// Internal wrapper for `dae_from_fn`
161#[derive(Debug)]
162pub struct DaeFnWrapper<F, M> {
163    f: F,
164    m: M,
165}
166
167impl<T, Y, F, M> DAE<T, Y> for DaeFnWrapper<F, M>
168where
169    T: Real,
170    Y: State<T>,
171    F: Fn(T, &Y, &mut Y),
172    M: Fn(&mut crate::linalg::Matrix<T>),
173{
174    fn diff(&self, t: T, y: &Y, f: &mut Y) {
175        (self.f)(t, y, f)
176    }
177
178    fn mass(&self, m: &mut crate::linalg::Matrix<T>) {
179        (self.m)(m)
180    }
181}
182
183/// Internal wrapper for `sde_from_fn`
184#[derive(Debug)]
185pub struct SdeFnWrapper<Drift, Diff, Noise> {
186    drift_fn: Drift,
187    diffusion_fn: Diff,
188    noise_fn: Noise,
189}
190
191impl<T, Y, Drift, Diff, Noise> SDE<T, Y> for SdeFnWrapper<Drift, Diff, Noise>
192where
193    T: Real,
194    Y: State<T>,
195    Drift: Fn(T, &Y, &mut Y),
196    Diff: Fn(T, &Y, &mut Y),
197    Noise: FnMut(T, &mut Y),
198{
199    fn drift(&self, t: T, y: &Y, dydt: &mut Y) {
200        (self.drift_fn)(t, y, dydt)
201    }
202
203    fn diffusion(&self, t: T, y: &Y, dydw: &mut Y) {
204        (self.diffusion_fn)(t, y, dydw)
205    }
206
207    fn noise(&mut self, dt: T, dw: &mut Y) {
208        (self.noise_fn)(dt, dw)
209    }
210}
211
212/// Internal wrapper for `dde_from_fn`
213#[derive(Debug)]
214pub struct DdeFnWrapper<const L: usize, Diff, Lags> {
215    diff_fn: Diff,
216    lags_fn: Lags,
217}
218
219impl<const L: usize, T, Y, Diff, Lags> DDE<L, T, Y> for DdeFnWrapper<L, Diff, Lags>
220where
221    T: Real,
222    Y: State<T>,
223    Diff: Fn(T, &Y, &[Y; L], &mut Y),
224    Lags: Fn(T, &Y, &mut [T; L]),
225{
226    fn diff(&self, t: T, y: &Y, yd: &[Y; L], dydt: &mut Y) {
227        (self.diff_fn)(t, y, yd, dydt)
228    }
229
230    fn lags(&self, t: T, y: &Y, lags: &mut [T; L]) {
231        (self.lags_fn)(t, y, lags)
232    }
233}
234
235impl<'a, F, T: Real, Y: State<T>> IVP<OdeEq<'a, F>, T, Y, (), DefaultSolout> {
236    /// Create a new initial value problem for an ordinary differential equation.
237    pub fn ode(system: &'a F, t0: T, tf: T, y0: Y) -> Self {
238        Self {
239            equation: OdeEq { ode: system },
240            t0,
241            tf,
242            y0,
243            method: (),
244            solout: DefaultSolout::new(),
245        }
246    }
247}
248
249impl<F, T: Real, Y: State<T>> IVP<OdeEqOwned<OdeFnWrapper<F>>, T, Y, (), DefaultSolout>
250where
251    F: Fn(T, &Y, &mut Y),
252{
253    /// Create a new initial value problem for an ordinary differential equation from a closure.
254    ///
255    /// # Example
256    /// ```rust
257    /// use differential_equations::prelude::*;
258    /// let t0 = 0.0;
259    /// let tf = 1.0;
260    /// let y0 = 1.0;
261    /// let ivp = IVP::ode_from_fn(|t, y, dydt| { *dydt = t * y; }, t0, tf, y0);
262    /// ```
263    pub fn ode_from_fn(f: F, t0: T, tf: T, y0: Y) -> Self {
264        Self {
265            equation: OdeEqOwned {
266                ode: OdeFnWrapper { f },
267            },
268            t0,
269            tf,
270            y0,
271            method: (),
272            solout: DefaultSolout::new(),
273        }
274    }
275}
276
277impl<'a, F, T: Real, Y: State<T>> IVP<DaeEq<'a, F>, T, Y, (), DefaultSolout> {
278    /// Create a new initial value problem for a differential algebraic equation.
279    pub fn dae(system: &'a F, t0: T, tf: T, y0: Y) -> Self {
280        Self {
281            equation: DaeEq { dae: system },
282            t0,
283            tf,
284            y0,
285            method: (),
286            solout: DefaultSolout::new(),
287        }
288    }
289}
290
291impl<F, M, T: Real, Y: State<T>> IVP<DaeEqOwned<DaeFnWrapper<F, M>>, T, Y, (), DefaultSolout>
292where
293    F: Fn(T, &Y, &mut Y),
294    M: Fn(&mut crate::linalg::Matrix<T>),
295{
296    /// Create a new initial value problem for a differential algebraic equation from closures.
297    pub fn dae_from_fn(f: F, m: M, t0: T, tf: T, y0: Y) -> Self {
298        Self {
299            equation: DaeEqOwned {
300                dae: DaeFnWrapper { f, m },
301            },
302            t0,
303            tf,
304            y0,
305            method: (),
306            solout: DefaultSolout::new(),
307        }
308    }
309}
310
311impl<'a, F, T: Real, Y: State<T>> IVP<SdeEq<'a, F>, T, Y, (), DefaultSolout> {
312    /// Create a new initial value problem for a stochastic differential equation.
313    pub fn sde(system: &'a mut F, t0: T, tf: T, y0: Y) -> Self {
314        Self {
315            equation: SdeEq { sde: system },
316            t0,
317            tf,
318            y0,
319            method: (),
320            solout: DefaultSolout::new(),
321        }
322    }
323}
324
325impl<Drift, Diff, Noise, T: Real, Y: State<T>>
326    IVP<SdeEqOwned<SdeFnWrapper<Drift, Diff, Noise>>, T, Y, (), DefaultSolout>
327where
328    Drift: Fn(T, &Y, &mut Y),
329    Diff: Fn(T, &Y, &mut Y),
330    Noise: FnMut(T, &mut Y),
331{
332    /// Create a new initial value problem for a stochastic differential equation from closures.
333    pub fn sde_from_fn(drift: Drift, diffusion: Diff, noise: Noise, t0: T, tf: T, y0: Y) -> Self {
334        Self {
335            equation: SdeEqOwned {
336                sde: SdeFnWrapper {
337                    drift_fn: drift,
338                    diffusion_fn: diffusion,
339                    noise_fn: noise,
340                },
341            },
342            t0,
343            tf,
344            y0,
345            method: (),
346            solout: DefaultSolout::new(),
347        }
348    }
349}
350
351impl<'a, F, H, T: Real, Y: State<T>, const L: usize>
352    IVP<DdeEq<'a, L, F, H>, T, Y, (), DefaultSolout>
353{
354    /// Create a new initial value problem for a delay differential equation.
355    pub fn dde(system: &'a F, t0: T, tf: T, y0: Y, history_function: H) -> Self {
356        Self {
357            equation: DdeEq {
358                dde: system,
359                history: history_function,
360            },
361            t0,
362            tf,
363            y0,
364            method: (),
365            solout: DefaultSolout::new(),
366        }
367    }
368}
369
370impl<const L: usize, Diff, Lags, H, T: Real, Y: State<T>>
371    IVP<DdeEqOwned<L, DdeFnWrapper<L, Diff, Lags>, H>, T, Y, (), DefaultSolout>
372where
373    Diff: Fn(T, &Y, &[Y; L], &mut Y),
374    Lags: Fn(T, &Y, &mut [T; L]),
375    H: Fn(T) -> Y + Clone,
376{
377    /// Create a new initial value problem for a delay differential equation from closures.
378    pub fn dde_from_fn(diff: Diff, lags: Lags, t0: T, tf: T, y0: Y, history_function: H) -> Self {
379        Self {
380            equation: DdeEqOwned {
381                dde: DdeFnWrapper {
382                    diff_fn: diff,
383                    lags_fn: lags,
384                },
385                history: history_function,
386            },
387            t0,
388            tf,
389            y0,
390            method: (),
391            solout: DefaultSolout::new(),
392        }
393    }
394}
395
396impl<EqType, T: Real, Y: State<T>, Method, SoloutType> IVP<EqType, T, Y, Method, SoloutType> {
397    fn with_method<NextMethod>(
398        self,
399        method: NextMethod,
400    ) -> IVP<EqType, T, Y, NextMethod, SoloutType> {
401        IVP {
402            equation: self.equation,
403            t0: self.t0,
404            tf: self.tf,
405            y0: self.y0,
406            method,
407            solout: self.solout,
408        }
409    }
410
411    fn map_method<NextMethod>(
412        self,
413        map: impl FnOnce(Method) -> NextMethod,
414    ) -> IVP<EqType, T, Y, NextMethod, SoloutType> {
415        IVP {
416            equation: self.equation,
417            t0: self.t0,
418            tf: self.tf,
419            y0: self.y0,
420            method: map(self.method),
421            solout: self.solout,
422        }
423    }
424
425    fn with_solout<NextSolout>(self, solout: NextSolout) -> IVP<EqType, T, Y, Method, NextSolout> {
426        IVP {
427            equation: self.equation,
428            t0: self.t0,
429            tf: self.tf,
430            y0: self.y0,
431            method: self.method,
432            solout,
433        }
434    }
435
436    /// Set the numerical method to be used.
437    ///
438    /// The builder owns the method because solving mutates method state. Construct
439    /// a fresh method for each solve, or use the low-level `solve_*` functions when
440    /// you need to manage a mutable solver reference directly.
441    pub fn method<SNew>(self, method: SNew) -> IVP<EqType, T, Y, SNew, SoloutType> {
442        self.with_method(method)
443    }
444
445    /// Set a custom solout function.
446    pub fn solout<ONew>(self, solout: ONew) -> IVP<EqType, T, Y, Method, ONew> {
447        self.with_solout(solout)
448    }
449
450    /// Output evenly spaced points between the initial and final time.
451    /// Note that this does not include the solution of the calculated steps.
452    pub fn even(self, dt: T) -> IVP<EqType, T, Y, Method, EvenSolout<T>> {
453        let solout = EvenSolout::new(dt, self.t0, self.tf);
454        self.with_solout(solout)
455    }
456
457    /// Use the Dense Output method to output n number of interpolation points between each step.
458    /// Note this includes the solution of the calculated steps.
459    pub fn dense(self, n: usize) -> IVP<EqType, T, Y, Method, DenseSolout> {
460        self.with_solout(DenseSolout::new(n))
461    }
462
463    /// Use the provided time points for evaluation instead of the default method.
464    /// Note this does not include the solution of the calculated steps.
465    pub fn t_eval(self, points: impl AsRef<[T]>) -> IVP<EqType, T, Y, Method, TEvalSolout<T>> {
466        let solout = TEvalSolout::new(points, self.t0, self.tf);
467        self.with_solout(solout)
468    }
469
470    /// Wrap current solout with event detection while preserving original output strategy.
471    pub fn event<'a, E>(
472        self,
473        event: &'a E,
474    ) -> IVP<EqType, T, Y, Method, EventWrappedSolout<'a, T, Y, SoloutType, E>>
475    where
476        E: Event<T, Y> + ?Sized,
477        SoloutType: Solout<T, Y>,
478    {
479        IVP {
480            equation: self.equation,
481            t0: self.t0,
482            tf: self.tf,
483            y0: self.y0,
484            method: self.method,
485            solout: EventWrappedSolout::new(self.solout, event, self.t0, self.tf),
486        }
487    }
488
489    /// Uses the CrossingSolout method to output points when a specific component crosses a threshold.
490    /// Note this does not include the solution of the calculated steps.
491    pub fn crossing(
492        self,
493        component_idx: usize,
494        threshold: T,
495        direction: CrossingDirection,
496    ) -> IVP<EqType, T, Y, Method, CrossingSolout<T>> {
497        let crossing_solout =
498            CrossingSolout::new(component_idx, threshold).with_direction(direction);
499        self.with_solout(crossing_solout)
500    }
501
502    /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed.
503    /// Note this does not include the solution of the calculated steps.
504    pub fn hyperplane_crossing<Y1: State<T>>(
505        self,
506        point: Y1,
507        normal: Y1,
508        extractor: fn(&Y) -> Y1,
509        direction: CrossingDirection,
510    ) -> IVP<EqType, T, Y, Method, HyperplaneCrossingSolout<T, Y1, Y>> {
511        let solout =
512            HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
513        self.with_solout(solout)
514    }
515}
516
517impl<EqType, T: Real, Y: State<T>, Method, SoloutType> IVP<EqType, T, Y, Method, SoloutType>
518where
519    Method: ToleranceConfig<T>,
520{
521    /// Set relative tolerance on the underlying solver.
522    pub fn rtol<V: Into<Tolerance<T>>>(self, rtol: V) -> Self {
523        self.map_method(|method| method.rtol(rtol))
524    }
525
526    /// Set absolute tolerance on the underlying solver.
527    pub fn atol<V: Into<Tolerance<T>>>(self, atol: V) -> Self {
528        self.map_method(|method| method.atol(atol))
529    }
530}
531
532impl<'a, F, T: Real, Y: State<T>, Method, SoloutType> IVP<OdeEq<'a, F>, T, Y, Method, SoloutType>
533where
534    F: ODE<T, Y>,
535    Method: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y>,
536    SoloutType: Solout<T, Y>,
537{
538    /// Solve the ODE initial value problem.
539    pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
540        solve_ode(
541            &mut self.method,
542            self.equation.ode,
543            self.t0,
544            self.tf,
545            &self.y0,
546            &mut self.solout,
547        )
548    }
549}
550
551impl<F, T: Real, Y: State<T>, Method, SoloutType> IVP<OdeEqOwned<F>, T, Y, Method, SoloutType>
552where
553    F: ODE<T, Y>,
554    Method: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y>,
555    SoloutType: Solout<T, Y>,
556{
557    /// Solve the ODE initial value problem.
558    pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
559        solve_ode(
560            &mut self.method,
561            &self.equation.ode,
562            self.t0,
563            self.tf,
564            &self.y0,
565            &mut self.solout,
566        )
567    }
568}
569
570impl<'a, F, T: Real, Y: State<T>, Method, SoloutType> IVP<DaeEq<'a, F>, T, Y, Method, SoloutType>
571where
572    F: DAE<T, Y>,
573    Method: AlgebraicNumericalMethod<T, Y> + Interpolation<T, Y>,
574    SoloutType: Solout<T, Y>,
575{
576    /// Solve the DAE initial value problem.
577    pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
578        solve_dae(
579            &mut self.method,
580            self.equation.dae,
581            self.t0,
582            self.tf,
583            &self.y0,
584            &mut self.solout,
585        )
586    }
587}
588
589impl<F, T: Real, Y: State<T>, Method, SoloutType> IVP<DaeEqOwned<F>, T, Y, Method, SoloutType>
590where
591    F: DAE<T, Y>,
592    Method: AlgebraicNumericalMethod<T, Y> + Interpolation<T, Y>,
593    SoloutType: Solout<T, Y>,
594{
595    /// Solve the DAE initial value problem.
596    pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
597        solve_dae(
598            &mut self.method,
599            &self.equation.dae,
600            self.t0,
601            self.tf,
602            &self.y0,
603            &mut self.solout,
604        )
605    }
606}
607
608impl<'a, F, T: Real, Y: State<T>, Method, SoloutType> IVP<SdeEq<'a, F>, T, Y, Method, SoloutType>
609where
610    F: SDE<T, Y>,
611    Method: StochasticNumericalMethod<T, Y> + Interpolation<T, Y>,
612    SoloutType: Solout<T, Y>,
613{
614    /// Solve the SDE initial value problem.
615    pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
616        solve_sde(
617            &mut self.method,
618            self.equation.sde,
619            self.t0,
620            self.tf,
621            &self.y0,
622            &mut self.solout,
623        )
624    }
625}
626
627impl<F, T: Real, Y: State<T>, Method, SoloutType> IVP<SdeEqOwned<F>, T, Y, Method, SoloutType>
628where
629    F: SDE<T, Y>,
630    Method: StochasticNumericalMethod<T, Y> + Interpolation<T, Y>,
631    SoloutType: Solout<T, Y>,
632{
633    /// Solve the SDE initial value problem.
634    pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
635        solve_sde(
636            &mut self.method,
637            &mut self.equation.sde,
638            self.t0,
639            self.tf,
640            &self.y0,
641            &mut self.solout,
642        )
643    }
644}
645
646impl<'a, const L: usize, F, H, T: Real, Y: State<T>, Method, SoloutType>
647    IVP<DdeEq<'a, L, F, H>, T, Y, Method, SoloutType>
648where
649    F: DDE<L, T, Y>,
650    H: Fn(T) -> Y + Clone,
651    Method: DelayNumericalMethod<L, T, Y, H> + Interpolation<T, Y>,
652    SoloutType: Solout<T, Y>,
653{
654    /// Solve the DDE initial value problem.
655    pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
656        solve_dde(
657            &mut self.method,
658            self.equation.dde,
659            self.t0,
660            self.tf,
661            &self.y0,
662            self.equation.history.clone(),
663            &mut self.solout,
664        )
665    }
666}
667
668impl<const L: usize, F, H, T: Real, Y: State<T>, Method, SoloutType>
669    IVP<DdeEqOwned<L, F, H>, T, Y, Method, SoloutType>
670where
671    F: DDE<L, T, Y>,
672    H: Fn(T) -> Y + Clone,
673    Method: DelayNumericalMethod<L, T, Y, H> + Interpolation<T, Y>,
674    SoloutType: Solout<T, Y>,
675{
676    /// Solve the DDE initial value problem.
677    pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
678        solve_dde(
679            &mut self.method,
680            &self.equation.dde,
681            self.t0,
682            self.tf,
683            &self.y0,
684            self.equation.history.clone(),
685            &mut self.solout,
686        )
687    }
688}