differential_equations/ode/
problem.rs

1//! Initial Value Problem Struct and Constructors
2
3use crate::{
4    error::Error,
5    interpolate::Interpolation,
6    ode::{ODE, OrdinaryNumericalMethod, solve_ode},
7    solout::*,
8    solution::Solution,
9    traits::{CallBackData, Real, State},
10};
11
12/// Initial Value Problem for Ordinary Differential Equations (ODEs)
13///
14/// The Initial Value Problem takes the form:
15/// y' = f(t, y), a <= t <= b, y(a) = alpha
16///
17/// # Overview
18///
19/// The ODEProblem struct provides a simple interface for solving differential equations:
20///
21/// # Example
22///
23/// ```
24/// use differential_equations::prelude::*;
25///
26/// struct LinearEquation {
27///    pub a: f32,
28///    pub b: f32,
29/// }
30///
31/// impl ODE<f32, f32> for LinearEquation {
32///    fn diff(&self, _t: f32, y: &f32, dydt: &mut f32) {
33///        *dydt = self.a + self.b * y;
34///   }
35/// }
36///
37/// // Create the ode and initial conditions
38/// let ode = LinearEquation { a: 1.0, b: 2.0 };
39/// let t0 = 0.0;
40/// let tf = 1.0;
41/// let y0 = 1.0;
42/// let mut solver = ExplicitRungeKutta::dop853().rtol(1e-8).atol(1e-6);
43///
44/// // Basic usage:
45/// let problem = ODEProblem::new(ode, t0, tf, y0);
46/// let solution = problem.solve(&mut solver).unwrap();
47///
48/// // Advanced output control:
49/// let solution = problem.even(0.1).solve(&mut solver).unwrap();
50/// ```
51///
52/// # Fields
53///
54/// * `ode` - ODE implementing the differential equation
55/// * `t0` - Initial time
56/// * `tf` - Final time
57/// * `y0` - Initial state vector
58///
59/// # Basic Usage
60///
61/// * `new(ode, t0, tf, y0)` - Create a new ODEProblem
62/// * `solve(&mut solver)` - Solve using default output (solver step points)
63///
64/// # Output Control Methods
65///
66/// These methods configure how solution points are generated and returned:
67///
68/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
69/// * `dense(n)` - Include `n` interpolated points between each solver step
70/// * `t_eval(points)` - Evaluate solution at specific time points
71/// * `solout(custom_solout)` - Use a custom output handler
72///
73/// Each returns a solver configuration that can be executed with `.solve(&mut solver)`.
74///
75/// # Example 2
76///
77/// ```
78/// use differential_equations::prelude::*;
79/// use nalgebra::{SVector, vector};
80///
81/// struct HarmonicOscillator { k: f64 }
82///
83/// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
84///     fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
85///         dydt[0] = y[1];
86///         dydt[1] = -self.k * y[0];
87///     }
88/// }
89///
90/// let ode = HarmonicOscillator { k: 1.0 };
91/// let mut method = ExplicitRungeKutta::dop853().rtol(1e-12).atol(1e-12);
92///
93/// // Basic usage with default output points
94/// let problem = ODEProblem::new(ode, 0.0, 10.0, vector![1.0, 0.0]);
95/// let results = problem.solve(&mut method).unwrap();
96///
97/// // Advanced: evenly spaced output with 0.1 time intervals
98/// let results = problem.dense(4).solve(&mut method).unwrap();
99/// ```
100#[derive(Clone, Debug)]
101pub struct ODEProblem<T, Y, D, F>
102where
103    T: Real,
104    Y: State<T>,
105    D: CallBackData,
106    F: ODE<T, Y, D>,
107{
108    // Initial Value Problem Fields
109    pub ode: F, // ODE containing the Differential Equation and Optional Terminate Function.
110    pub t0: T,  // Initial Time.
111    pub tf: T,  // Final Time.
112    pub y0: Y,  // Initial State Vector.
113
114    // Phantom Data for Users event output
115    _event_output_type: std::marker::PhantomData<D>,
116}
117
118impl<T, Y, D, F> ODEProblem<T, Y, D, F>
119where
120    T: Real,
121    Y: State<T>,
122    D: CallBackData,
123    F: ODE<T, Y, D>,
124{
125    /// Create a new Initial Value Problem
126    ///
127    /// # Arguments
128    /// * `ode`  - ODE containing the Differential Equation and Optional Terminate Function.
129    /// * `t0`      - Initial Time.
130    /// * `tf`      - Final Time.
131    /// * `y0`      - Initial State Vector.
132    ///
133    /// # Returns
134    /// * ODEProblem Problem ready to be solved.
135    ///
136    pub fn new(ode: F, t0: T, tf: T, y0: Y) -> Self {
137        ODEProblem {
138            ode,
139            t0,
140            tf,
141            y0,
142            _event_output_type: std::marker::PhantomData,
143        }
144    }
145
146    /// Solve the ODEProblem using a default solout, e.g. outputting solutions at calculated steps.
147    ///
148    /// # Returns
149    /// * `Result<Solution<T, Y, D>, Status<T, Y, D>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Status)` if an errors or issues such as stiffness are encountered.
150    ///
151    pub fn solve<S>(&self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
152    where
153        S: OrdinaryNumericalMethod<T, Y, D> + Interpolation<T, Y>,
154    {
155        let mut default_solout = DefaultSolout::new();
156        solve_ode(
157            solver,
158            &self.ode,
159            self.t0,
160            self.tf,
161            &self.y0,
162            &mut default_solout,
163        )
164    }
165
166    /// Returns an ODEProblem OrdinaryNumericalMethod with the provided solout function for outputting points.
167    ///
168    /// # Returns
169    /// * ODEProblem OrdinaryNumericalMethod with the provided solout function ready for .solve() method.
170    ///
171    pub fn solout<'a, O: Solout<T, Y, D>>(
172        &'a self,
173        solout: &'a mut O,
174    ) -> ODEProblemMutRefSoloutPair<'a, T, Y, D, F, O> {
175        ODEProblemMutRefSoloutPair::new(self, solout)
176    }
177
178    /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time.
179    /// Note that this does not include the solution of the calculated steps.
180    ///
181    /// # Arguments
182    /// * `dt` - Interval between each output point.
183    ///
184    /// # Returns
185    /// * ODEProblem OrdinaryNumericalMethod with Even Solout function ready for .solve() method.
186    ///
187    pub fn even(&self, dt: T) -> ODEProblemSoloutPair<'_, T, Y, D, F, EvenSolout<T>> {
188        let even_solout = EvenSolout::new(dt, self.t0, self.tf);
189        ODEProblemSoloutPair::new(self, even_solout)
190    }
191
192    /// Uses the Dense Output method to output n number of interpolation points between each step.
193    /// Note this includes the solution of the calculated steps.
194    ///
195    /// # Arguments
196    /// * `n` - Number of interpolation points between each step.
197    ///
198    /// # Returns
199    /// * ODEProblem OrdinaryNumericalMethod with Dense Output function ready for .solve() method.
200    ///
201    pub fn dense(&self, n: usize) -> ODEProblemSoloutPair<'_, T, Y, D, F, DenseSolout> {
202        let dense_solout = DenseSolout::new(n);
203        ODEProblemSoloutPair::new(self, dense_solout)
204    }
205
206    /// Uses the provided time points for evaluation instead of the default method.
207    /// Note this does not include the solution of the calculated steps.
208    ///
209    /// # Arguments
210    /// * `points` - Custom output points.
211    ///
212    /// # Returns
213    /// * ODEProblem OrdinaryNumericalMethod with Custom Time Evaluation function ready for .solve() method.
214    ///
215    pub fn t_eval(
216        &self,
217        points: impl AsRef<[T]>,
218    ) -> ODEProblemSoloutPair<'_, T, Y, D, F, TEvalSolout<T>> {
219        let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf);
220        ODEProblemSoloutPair::new(self, t_eval_solout)
221    }
222
223    /// Uses the CrossingSolout method to output points when a specific component crosses a threshold.
224    /// Note this does not include the solution of the calculated steps.
225    ///
226    /// # Arguments
227    /// * `component_idx` - Index of the component to monitor for crossing.
228    /// * `threshhold` - Value to cross.
229    /// * `direction` - Direction of crossing (positive or negative).
230    ///
231    /// # Returns
232    /// * ODEProblem OrdinaryNumericalMethod with CrossingSolout function ready for .solve() method.
233    ///
234    pub fn crossing(
235        &self,
236        component_idx: usize,
237        threshhold: T,
238        direction: CrossingDirection,
239    ) -> ODEProblemSoloutPair<'_, T, Y, D, F, CrossingSolout<T>> {
240        let crossing_solout =
241            CrossingSolout::new(component_idx, threshhold).with_direction(direction);
242        ODEProblemSoloutPair::new(self, crossing_solout)
243    }
244
245    /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed.
246    /// Note this does not include the solution of the calculated steps.
247    ///
248    /// # Arguments
249    /// * `point` - Point on the hyperplane.
250    /// * `normal` - Normal vector of the hyperplane.
251    /// * `extractor` - Function to extract the component from the state vector.
252    /// * `direction` - Direction of crossing (positive or negative).
253    ///
254    /// # Returns
255    /// * ODEProblem OrdinaryNumericalMethod with HyperplaneCrossingSolout function ready for .solve() method.
256    ///
257    pub fn hyperplane_crossing<Y1>(
258        &self,
259        point: Y1,
260        normal: Y1,
261        extractor: fn(&Y) -> Y1,
262        direction: CrossingDirection,
263    ) -> ODEProblemSoloutPair<'_, T, Y, D, F, HyperplaneCrossingSolout<T, Y1, Y>>
264    where
265        Y1: State<T>,
266    {
267        let solout =
268            HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
269
270        ODEProblemSoloutPair::new(self, solout)
271    }
272}
273
274/// ODEProblemMutRefSoloutPair serves as a intermediate between the ODEProblem struct and a custom solout provided by the user.
275pub struct ODEProblemMutRefSoloutPair<'a, T, Y, D, F, O>
276where
277    T: Real,
278    Y: State<T>,
279    D: CallBackData,
280    F: ODE<T, Y, D>,
281    O: Solout<T, Y, D>,
282{
283    pub problem: &'a ODEProblem<T, Y, D, F>,
284    pub solout: &'a mut O,
285}
286
287impl<'a, T, Y, D, F, O> ODEProblemMutRefSoloutPair<'a, T, Y, D, F, O>
288where
289    T: Real,
290    Y: State<T>,
291    D: CallBackData,
292    F: ODE<T, Y, D>,
293    O: Solout<T, Y, D>,
294{
295    /// Create a new ODEProblemMutRefSoloutPair
296    ///
297    /// # Arguments
298    /// * `problem` - Reference to the ODEProblem struct
299    ///
300    pub fn new(problem: &'a ODEProblem<T, Y, D, F>, solout: &'a mut O) -> Self {
301        ODEProblemMutRefSoloutPair { problem, solout }
302    }
303
304    /// Solve the ODEProblem using the provided solout
305    ///
306    /// # Arguments
307    /// * `solver` - OrdinaryNumericalMethod to use for solving the ODEProblem
308    ///
309    /// # Returns
310    /// * `Result<Solution<T, Y, D>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
311    ///
312    pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
313    where
314        S: OrdinaryNumericalMethod<T, Y, D> + Interpolation<T, Y>,
315    {
316        solve_ode(
317            solver,
318            &self.problem.ode,
319            self.problem.t0,
320            self.problem.tf,
321            &self.problem.y0,
322            self.solout,
323        )
324    }
325}
326
327/// ODEProblemSoloutPair serves as a intermediate between the ODEProblem struct and solve_ode when a predefined solout is used.
328#[derive(Clone, Debug)]
329pub struct ODEProblemSoloutPair<'a, T, Y, D, F, O>
330where
331    T: Real,
332    Y: State<T>,
333    D: CallBackData,
334    F: ODE<T, Y, D>,
335    O: Solout<T, Y, D>,
336{
337    pub problem: &'a ODEProblem<T, Y, D, F>,
338    pub solout: O,
339}
340
341impl<'a, T, Y, D, F, O> ODEProblemSoloutPair<'a, T, Y, D, F, O>
342where
343    T: Real,
344    Y: State<T>,
345    D: CallBackData,
346    F: ODE<T, Y, D>,
347    O: Solout<T, Y, D>,
348{
349    /// Create a new ODEProblemSoloutPair
350    ///
351    /// # Arguments
352    /// * `problem` - Reference to the ODEProblem struct
353    /// * `solout` - Solout implementation
354    ///
355    pub fn new(problem: &'a ODEProblem<T, Y, D, F>, solout: O) -> Self {
356        ODEProblemSoloutPair { problem, solout }
357    }
358
359    /// Solve the ODEProblem using the provided solout
360    ///
361    /// # Arguments
362    /// * `solver` - OrdinaryNumericalMethod to use for solving the ODEProblem
363    ///
364    /// # Returns
365    /// * `Result<Solution<T, Y, D>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
366    ///
367    pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
368    where
369        S: OrdinaryNumericalMethod<T, Y, D> + Interpolation<T, Y>,
370    {
371        solve_ode(
372            solver,
373            &self.problem.ode,
374            self.problem.t0,
375            self.problem.tf,
376            &self.problem.y0,
377            &mut self.solout,
378        )
379    }
380}