differential_equations/ode/
problem.rs

1//! Initial Value Problem Struct and Constructors
2
3use crate::{
4    error::Error,
5    interpolate::Interpolation,
6    ode::{ODE, OrdinaryNumericalMethod, solve_ode},
7    solout::*,
8    solution::Solution,
9    traits::{Real, State},
10};
11
12/// Initial Value Problem for Ordinary Differential Equations (ODEs)
13///
14/// The Initial Value Problem takes the form:
15/// y' = f(t, y), a <= t <= b, y(a) = alpha
16///
17/// # Overview
18///
19/// The ODEProblem struct provides a simple interface for solving differential equations:
20///
21/// # Example
22///
23/// ```
24/// use differential_equations::prelude::*;
25///
26/// struct LinearEquation {
27///    pub a: f32,
28///    pub b: f32,
29/// }
30///
31/// impl ODE<f32, f32> for LinearEquation {
32///    fn diff(&self, _t: f32, y: &f32, dydt: &mut f32) {
33///        *dydt = self.a + self.b * y;
34///   }
35/// }
36///
37/// // Create the ode and initial conditions
38/// let ode = LinearEquation { a: 1.0, b: 2.0 };
39/// let t0 = 0.0;
40/// let tf = 1.0;
41/// let y0 = 1.0;
42/// let mut solver = ExplicitRungeKutta::dop853().rtol(1e-8).atol(1e-6);
43///
44/// // Basic usage:
45/// let problem = ODEProblem::new(ode, t0, tf, y0);
46/// let solution = problem.solve(&mut solver).unwrap();
47///
48/// // Advanced output control:
49/// let solution = problem.even(0.1).solve(&mut solver).unwrap();
50/// ```
51///
52/// # Fields
53///
54/// * `ode` - ODE implementing the differential equation
55/// * `t0` - Initial time
56/// * `tf` - Final time
57/// * `y0` - Initial state vector
58///
59/// # Basic Usage
60///
61/// * `new(ode, t0, tf, y0)` - Create a new ODEProblem
62/// * `solve(&mut solver)` - Solve using default output (solver step points)
63///
64/// # Output Control Methods
65///
66/// These methods configure how solution points are generated and returned:
67///
68/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
69/// * `dense(n)` - Include `n` interpolated points between each solver step
70/// * `t_eval(points)` - Evaluate solution at specific time points
71/// * `solout(custom_solout)` - Use a custom output handler
72///
73/// Each returns a solver configuration that can be executed with `.solve(&mut solver)`.
74///
75/// # Example 2
76///
77/// ```
78/// use differential_equations::prelude::*;
79/// use nalgebra::{SVector, vector};
80///
81/// struct HarmonicOscillator { k: f64 }
82///
83/// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
84///     fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
85///         dydt[0] = y[1];
86///         dydt[1] = -self.k * y[0];
87///     }
88/// }
89///
90/// let ode = HarmonicOscillator { k: 1.0 };
91/// let mut method = ExplicitRungeKutta::dop853().rtol(1e-12).atol(1e-12);
92///
93/// // Basic usage with default output points
94/// let problem = ODEProblem::new(ode, 0.0, 10.0, vector![1.0, 0.0]);
95/// let results = problem.solve(&mut method).unwrap();
96///
97/// // Advanced: evenly spaced output with 0.1 time intervals
98/// let results = problem.dense(4).solve(&mut method).unwrap();
99/// ```
100#[derive(Clone, Debug)]
101pub struct ODEProblem<'a, T, Y, F>
102where
103    T: Real,
104    Y: State<T>,
105    F: ODE<T, Y>,
106{
107    /// ODE object implementing [`ODE`](crate::ode::ODE) trait
108    pub ode: &'a F,
109    /// Initial Time
110    pub t0: T,
111    /// Final Time
112    pub tf: T,
113    /// Initial State Vector
114    pub y0: Y,
115}
116
117impl<'a, T, Y, F> ODEProblem<'a, T, Y, F>
118where
119    T: Real,
120    Y: State<T>,
121    F: ODE<T, Y>,
122{
123    /// Create a new Initial Value Problem
124    ///
125    /// # Arguments
126    /// * `ode`  - ODE containing the Differential Equation and Optional Terminate Function.
127    /// * `t0`      - Initial Time.
128    /// * `tf`      - Final Time.
129    /// * `y0`      - Initial State Vector.
130    ///
131    /// # Returns
132    /// * ODEProblem Problem ready to be solved.
133    ///
134    pub fn new(ode: &'a F, t0: T, tf: T, y0: Y) -> Self {
135        ODEProblem { ode, t0, tf, y0 }
136    }
137
138    /// Solve the ODEProblem using a default solout, e.g. outputting solutions at calculated steps.
139    ///
140    /// # Returns
141    /// * `Result<Solution<T, Y>, Status<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Status)` if an errors or issues such as stiffness are encountered.
142    ///
143    pub fn solve<S>(&self, solver: &mut S) -> Result<Solution<T, Y>, Error<T, Y>>
144    where
145        S: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y>,
146    {
147        let mut default_solout = DefaultSolout::new();
148        solve_ode(
149            solver,
150            self.ode,
151            self.t0,
152            self.tf,
153            &self.y0,
154            &mut default_solout,
155        )
156    }
157
158    /// Returns an ODEProblem OrdinaryNumericalMethod with the provided solout function for outputting points.
159    ///
160    /// # Returns
161    /// * ODEProblem OrdinaryNumericalMethod with the provided solout function ready for .solve() method.
162    ///
163    pub fn solout<O: Solout<T, Y>>(
164        &'a self,
165        solout: &'a mut O,
166    ) -> ODEProblemMutRefSoloutPair<'a, T, Y, F, O> {
167        ODEProblemMutRefSoloutPair::new(self, solout)
168    }
169
170    /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time.
171    /// Note that this does not include the solution of the calculated steps.
172    ///
173    /// # Arguments
174    /// * `dt` - Interval between each output point.
175    ///
176    /// # Returns
177    /// * ODEProblem OrdinaryNumericalMethod with Even Solout function ready for .solve() method.
178    ///
179    pub fn even(&self, dt: T) -> ODEProblemSoloutPair<'_, T, Y, F, EvenSolout<T>> {
180        let even_solout = EvenSolout::new(dt, self.t0, self.tf);
181        ODEProblemSoloutPair::new(self, even_solout)
182    }
183
184    /// Uses the Dense Output method to output n number of interpolation points between each step.
185    /// Note this includes the solution of the calculated steps.
186    ///
187    /// # Arguments
188    /// * `n` - Number of interpolation points between each step.
189    ///
190    /// # Returns
191    /// * ODEProblem OrdinaryNumericalMethod with Dense Output function ready for .solve() method.
192    ///
193    pub fn dense(&self, n: usize) -> ODEProblemSoloutPair<'_, T, Y, F, DenseSolout> {
194        let dense_solout = DenseSolout::new(n);
195        ODEProblemSoloutPair::new(self, dense_solout)
196    }
197
198    /// Uses the provided time points for evaluation instead of the default method.
199    /// Note this does not include the solution of the calculated steps.
200    ///
201    /// # Arguments
202    /// * `points` - Custom output points.
203    ///
204    /// # Returns
205    /// * ODEProblem OrdinaryNumericalMethod with Custom Time Evaluation function ready for .solve() method.
206    ///
207    pub fn t_eval(
208        &self,
209        points: impl AsRef<[T]>,
210    ) -> ODEProblemSoloutPair<'_, T, Y, F, TEvalSolout<T>> {
211        let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf);
212        ODEProblemSoloutPair::new(self, t_eval_solout)
213    }
214
215    /// Uses the CrossingSolout method to output points when a specific component crosses a threshold.
216    /// Note this does not include the solution of the calculated steps.
217    ///
218    /// # Arguments
219    /// * `component_idx` - Index of the component to monitor for crossing.
220    /// * `threshhold` - Value to cross.
221    /// * `direction` - Direction of crossing (positive or negative).
222    ///
223    /// # Returns
224    /// * ODEProblem OrdinaryNumericalMethod with CrossingSolout function ready for .solve() method.
225    ///
226    pub fn crossing(
227        &self,
228        component_idx: usize,
229        threshhold: T,
230        direction: CrossingDirection,
231    ) -> ODEProblemSoloutPair<'_, T, Y, F, CrossingSolout<T>> {
232        let crossing_solout =
233            CrossingSolout::new(component_idx, threshhold).with_direction(direction);
234        ODEProblemSoloutPair::new(self, crossing_solout)
235    }
236
237    /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed.
238    /// Note this does not include the solution of the calculated steps.
239    ///
240    /// # Arguments
241    /// * `point` - Point on the hyperplane.
242    /// * `normal` - Normal vector of the hyperplane.
243    /// * `extractor` - Function to extract the component from the state vector.
244    /// * `direction` - Direction of crossing (positive or negative).
245    ///
246    /// # Returns
247    /// * ODEProblem OrdinaryNumericalMethod with HyperplaneCrossingSolout function ready for .solve() method.
248    ///
249    pub fn hyperplane_crossing<Y1>(
250        &self,
251        point: Y1,
252        normal: Y1,
253        extractor: fn(&Y) -> Y1,
254        direction: CrossingDirection,
255    ) -> ODEProblemSoloutPair<'_, T, Y, F, HyperplaneCrossingSolout<T, Y1, Y>>
256    where
257        Y1: State<T>,
258    {
259        let solout =
260            HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
261
262        ODEProblemSoloutPair::new(self, solout)
263    }
264
265    /// Uses an `EventSolout` to detect zero crossings of a user-defined event function (SciPy style).
266    /// The provided event implements the `Event` trait returning a scalar function g(t,y) whose
267    /// roots are sought. Each detected event point (t*, y*) is appended to the solution. Optional
268    /// termination after N events can be configured in the Event implementation via `config()`.
269    ///
270    /// # Arguments
271    /// * `event` - Object implementing `Event<T, Y>` whose zero crossings are desired.
272    ///
273    /// # Returns
274    /// * `ODEProblemSoloutPair` with `EventSolout` ready for `.solve(&mut solver)`.
275    ///
276    /// # Example
277    /// ```
278    /// use differential_equations::prelude::*;
279    /// use nalgebra::{Vector2, vector};
280    ///
281    /// struct SHO; // Simple harmonic oscillator
282    /// impl ODE<f64, Vector2<f64>> for SHO {
283    ///     fn diff(&self, _t: f64, y: &Vector2<f64>, dydt: &mut Vector2<f64>) {
284    ///         dydt[0]=y[1];
285    ///         dydt[1]=-y[0];
286    ///     }
287    /// }
288    ///
289    /// // Event: detect when position crosses zero going positive (like SciPy event)
290    /// struct ZeroUp;
291    /// impl Event<f64, Vector2<f64>> for ZeroUp {
292    ///     fn config(&self) -> EventConfig {
293    ///         // Force only positive crossings
294    ///         EventConfig::default().direction(CrossingDirection::Positive)
295    ///     }
296    /// 
297    ///     fn event(&self, _t: f64, y: &Vector2<f64>) -> f64 {
298    ///         y[0]
299    ///     }
300    /// }
301    ///
302    /// let osc = SHO; 
303    /// let t0 = 0.0; 
304    /// let tf = 10.0; 
305    /// let y0 = vector![1.0, 0.0];
306    /// let problem = ODEProblem::new(&osc, t0, tf, y0);
307    /// let mut solver = ExplicitRungeKutta::dop853();
308    /// let solution = problem.event(&ZeroUp).solve(&mut solver).unwrap();
309    /// // solution.t now contains zero-up crossing times
310    /// ```
311    pub fn event<E>(&'a self, event: &'a E) -> ODEProblemSoloutPair<'a, T, Y, F, EventSolout<'a, T, Y, E>>
312    where
313        E: Event<T, Y>,
314    {
315        let solout = EventSolout::new(event, self.t0, self.tf);
316        ODEProblemSoloutPair::new(self, solout)
317    }
318}
319
320/// ODEProblemMutRefSoloutPair serves as a intermediate between the ODEProblem struct and a custom solout provided by the user.
321pub struct ODEProblemMutRefSoloutPair<'a, T, Y, F, O>
322where
323    T: Real,
324    Y: State<T>,
325    F: ODE<T, Y>,
326    O: Solout<T, Y>,
327{
328    pub problem: &'a ODEProblem<'a, T, Y, F>,
329    pub solout: &'a mut O,
330}
331
332impl<'a, T, Y, F, O> ODEProblemMutRefSoloutPair<'a, T, Y, F, O>
333where
334    T: Real,
335    Y: State<T>,
336    F: ODE<T, Y>,
337    O: Solout<T, Y>,
338{
339    /// Create a new ODEProblemMutRefSoloutPair
340    ///
341    /// # Arguments
342    /// * `problem` - Reference to the ODEProblem struct
343    ///
344    pub fn new(problem: &'a ODEProblem<T, Y, F>, solout: &'a mut O) -> Self {
345        ODEProblemMutRefSoloutPair { problem, solout }
346    }
347
348    /// Solve the ODEProblem using the provided solout
349    ///
350    /// # Arguments
351    /// * `solver` - OrdinaryNumericalMethod to use for solving the ODEProblem
352    ///
353    /// # Returns
354    /// * `Result<Solution<T, Y>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
355    ///
356    pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, Y>, Error<T, Y>>
357    where
358        S: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y>,
359    {
360        solve_ode(
361            solver,
362            self.problem.ode,
363            self.problem.t0,
364            self.problem.tf,
365            &self.problem.y0,
366            self.solout,
367        )
368    }
369}
370
371/// ODEProblemSoloutPair serves as a intermediate between the ODEProblem struct and solve_ode when a predefined solout is used.
372#[derive(Clone, Debug)]
373pub struct ODEProblemSoloutPair<'a, T, Y, F, O>
374where
375    T: Real,
376    Y: State<T>,
377    F: ODE<T, Y>,
378    O: Solout<T, Y>,
379{
380    pub problem: &'a ODEProblem<'a, T, Y, F>,
381    pub solout: O,
382}
383
384impl<'a, T, Y, F, O> ODEProblemSoloutPair<'a, T, Y, F, O>
385where
386    T: Real,
387    Y: State<T>,
388    F: ODE<T, Y>,
389    O: Solout<T, Y>,
390{
391    /// Create a new ODEProblemSoloutPair
392    ///
393    /// # Arguments
394    /// * `problem` - Reference to the ODEProblem struct
395    /// * `solout` - Solout implementation
396    ///
397    pub fn new(problem: &'a ODEProblem<T, Y, F>, solout: O) -> Self {
398        ODEProblemSoloutPair { problem, solout }
399    }
400
401    /// Solve the ODEProblem using the provided solout
402    ///
403    /// # Arguments
404    /// * `solver` - OrdinaryNumericalMethod to use for solving the ODEProblem
405    ///
406    /// # Returns
407    /// * `Result<Solution<T, Y>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
408    ///
409    pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, Y>, Error<T, Y>>
410    where
411        S: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y>,
412    {
413        solve_ode(
414            solver,
415            self.problem.ode,
416            self.problem.t0,
417            self.problem.tf,
418            &self.problem.y0,
419            &mut self.solout,
420        )
421    }
422
423    /// Wrap current solout with event detection while preserving original output strategy.
424    pub fn event<E>(
425        self,
426        event: &'a E,
427    ) -> ODEProblemSoloutPair<'a, T, Y, F, EventWrappedSolout<'a, T, Y, O, E>>
428    where
429        E: Event<T, Y>,
430    {
431        let wrapped = EventWrappedSolout::new(self.solout, event, self.problem.t0, self.problem.tf);
432        ODEProblemSoloutPair {
433            problem: self.problem,
434            solout: wrapped,
435        }
436    }
437}