differential_equations/ode/
ode_problem.rs

1//! Initial Value Problem Struct and Constructors
2
3use crate::{
4    Error, Solution,
5    interpolate::Interpolation,
6    ode::{ODE, numerical_method::ODENumericalMethod, solve_ode},
7    solout::*,
8    traits::{CallBackData, Real, State},
9};
10
11/// Initial Value Problem for Ordinary Differential Equations (ODEs)
12///
13/// The Initial Value Problem takes the form:
14/// y' = f(t, y), a <= t <= b, y(a) = alpha
15///
16/// # Overview
17///
18/// The ODEProblem struct provides a simple interface for solving differential equations:
19///
20/// # Example
21///
22/// ```
23/// use differential_equations::prelude::*;
24///
25/// struct LinearEquation {
26///    pub a: f32,
27///    pub b: f32,
28/// }
29///
30/// impl ODE<f32, f32> for LinearEquation {
31///    fn diff(&self, _t: f32, y: &f32, dydt: &mut f32) {
32///        *dydt = self.a + self.b * y;
33///   }
34/// }
35///
36/// // Create the ode and initial conditions
37/// let ode = LinearEquation { a: 1.0, b: 2.0 };
38/// let t0 = 0.0;
39/// let tf = 1.0;
40/// let y0 = 1.0;
41/// let mut solver = DOP853::new().rtol(1e-8).atol(1e-6);
42///
43/// // Basic usage:
44/// let problem = ODEProblem::new(ode, t0, tf, y0);
45/// let solution = problem.solve(&mut solver).unwrap();
46///
47/// // Advanced output control:
48/// let solution = problem.even(0.1).solve(&mut solver).unwrap();
49/// ```
50///
51/// # Fields
52///
53/// * `ode` - ODE implementing the differential equation
54/// * `t0` - Initial time
55/// * `tf` - Final time
56/// * `y0` - Initial state vector
57///
58/// # Basic Usage
59///
60/// * `new(ode, t0, tf, y0)` - Create a new ODEProblem
61/// * `solve(&mut solver)` - Solve using default output (solver step points)
62///
63/// # Output Control Methods
64///
65/// These methods configure how solution points are generated and returned:
66///
67/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
68/// * `dense(n)` - Include `n` interpolated points between each solver step
69/// * `t_eval(points)` - Evaluate solution at specific time points
70/// * `solout(custom_solout)` - Use a custom output handler
71///
72/// Each returns a solver configuration that can be executed with `.solve(&mut solver)`.
73///
74/// # Example 2
75///
76/// ```
77/// use differential_equations::prelude::*;
78/// use nalgebra::{SVector, vector};
79///
80/// struct HarmonicOscillator { k: f64 }
81///
82/// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
83///     fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
84///         dydt[0] = y[1];
85///         dydt[1] = -self.k * y[0];
86///     }
87/// }
88///
89/// let ode = HarmonicOscillator { k: 1.0 };
90/// let mut method = DOP853::new().rtol(1e-12).atol(1e-12);
91///
92/// // Basic usage with default output points
93/// let problem = ODEProblem::new(ode, 0.0, 10.0, vector![1.0, 0.0]);
94/// let results = problem.solve(&mut method).unwrap();
95///
96/// // Advanced: evenly spaced output with 0.1 time intervals
97/// let results = problem.dense(4).solve(&mut method).unwrap();
98/// ```
99#[derive(Clone, Debug)]
100pub struct ODEProblem<T, V, D, F>
101where
102    T: Real,
103    V: State<T>,
104    D: CallBackData,
105    F: ODE<T, V, D>,
106{
107    // Initial Value Problem Fields
108    pub ode: F, // ODE containing the Differential Equation and Optional Terminate Function.
109    pub t0: T,  // Initial Time.
110    pub tf: T,  // Final Time.
111    pub y0: V,  // Initial State Vector.
112
113    // Phantom Data for Users event output
114    _event_output_type: std::marker::PhantomData<D>,
115}
116
117impl<T, V, D, F> ODEProblem<T, V, D, F>
118where
119    T: Real,
120    V: State<T>,
121    D: CallBackData,
122    F: ODE<T, V, D>,
123{
124    /// Create a new Initial Value Problem
125    ///
126    /// # Arguments
127    /// * `ode`  - ODE containing the Differential Equation and Optional Terminate Function.
128    /// * `t0`      - Initial Time.
129    /// * `tf`      - Final Time.
130    /// * `y0`      - Initial State Vector.
131    ///
132    /// # Returns
133    /// * ODEProblem Problem ready to be solved.
134    ///
135    pub fn new(ode: F, t0: T, tf: T, y0: V) -> Self {
136        ODEProblem {
137            ode,
138            t0,
139            tf,
140            y0,
141            _event_output_type: std::marker::PhantomData,
142        }
143    }
144
145    /// Solve the ODEProblem using a default solout, e.g. outputting solutions at calculated steps.
146    ///
147    /// # Returns
148    /// * `Result<Solution<T, V, D>, Status<T, V, D>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Status)` if an errors or issues such as stiffness are encountered.
149    ///
150    pub fn solve<S>(&self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
151    where
152        S: ODENumericalMethod<T, V, D> + Interpolation<T, V>,
153    {
154        let mut default_solout = DefaultSolout::new(); // Default solout implementation
155        solve_ode(
156            solver,
157            &self.ode,
158            self.t0,
159            self.tf,
160            &self.y0,
161            &mut default_solout,
162        )
163    }
164
165    /// Returns an ODEProblem ODENumericalMethod with the provided solout function for outputting points.
166    ///
167    /// # Returns
168    /// * ODEProblem ODENumericalMethod with the provided solout function ready for .solve() method.
169    ///
170    pub fn solout<'a, O: Solout<T, V, D>>(
171        &'a self,
172        solout: &'a mut O,
173    ) -> ODEProblemMutRefSoloutPair<'a, T, V, D, F, O> {
174        ODEProblemMutRefSoloutPair::new(self, solout)
175    }
176
177    /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time.
178    /// Note that this does not include the solution of the calculated steps.
179    ///
180    /// # Arguments
181    /// * `dt` - Interval between each output point.
182    ///
183    /// # Returns
184    /// * ODEProblem ODENumericalMethod with Even Solout function ready for .solve() method.
185    ///
186    pub fn even(&self, dt: T) -> ODEProblemSoloutPair<'_, T, V, D, F, EvenSolout<T>> {
187        let even_solout = EvenSolout::new(dt, self.t0, self.tf); // Even solout implementation
188        ODEProblemSoloutPair::new(self, even_solout)
189    }
190
191    /// Uses the Dense Output method to output n number of interpolation points between each step.
192    /// Note this includes the solution of the calculated steps.
193    ///
194    /// # Arguments
195    /// * `n` - Number of interpolation points between each step.
196    ///
197    /// # Returns
198    /// * ODEProblem ODENumericalMethod with Dense Output function ready for .solve() method.
199    ///
200    pub fn dense(&self, n: usize) -> ODEProblemSoloutPair<'_, T, V, D, F, DenseSolout> {
201        let dense_solout = DenseSolout::new(n); // Dense solout implementation
202        ODEProblemSoloutPair::new(self, dense_solout)
203    }
204
205    /// Uses the provided time points for evaluation instead of the default method.
206    /// Note this does not include the solution of the calculated steps.
207    ///
208    /// # Arguments
209    /// * `points` - Custom output points.
210    ///
211    /// # Returns
212    /// * ODEProblem ODENumericalMethod with Custom Time Evaluation function ready for .solve() method.
213    ///
214    pub fn t_eval(&self, points: Vec<T>) -> ODEProblemSoloutPair<'_, T, V, D, F, TEvalSolout<T>> {
215        let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf); // Custom time evaluation solout implementation
216        ODEProblemSoloutPair::new(self, t_eval_solout)
217    }
218
219    /// Uses the CrossingSolout method to output points when a specific component crosses a threshold.
220    /// Note this does not include the solution of the calculated steps.
221    ///
222    /// # Arguments
223    /// * `component_idx` - Index of the component to monitor for crossing.
224    /// * `threshhold` - Value to cross.
225    /// * `direction` - Direction of crossing (positive or negative).
226    ///
227    /// # Returns
228    /// * ODEProblem ODENumericalMethod with CrossingSolout function ready for .solve() method.
229    ///
230    pub fn crossing(
231        &self,
232        component_idx: usize,
233        threshhold: T,
234        direction: CrossingDirection,
235    ) -> ODEProblemSoloutPair<'_, T, V, D, F, CrossingSolout<T>> {
236        let crossing_solout =
237            CrossingSolout::new(component_idx, threshhold).with_direction(direction); // Crossing solout implementation
238        ODEProblemSoloutPair::new(self, crossing_solout)
239    }
240
241    /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed.
242    /// Note this does not include the solution of the calculated steps.
243    ///
244    /// # Arguments
245    /// * `point` - Point on the hyperplane.
246    /// * `normal` - Normal vector of the hyperplane.
247    /// * `extractor` - Function to extract the component from the state vector.
248    /// * `direction` - Direction of crossing (positive or negative).
249    ///
250    /// # Returns
251    /// * ODEProblem ODENumericalMethod with HyperplaneCrossingSolout function ready for .solve() method.
252    ///
253    pub fn hyperplane_crossing<V1>(
254        &self,
255        point: V1,
256        normal: V1,
257        extractor: fn(&V) -> V1,
258        direction: CrossingDirection,
259    ) -> ODEProblemSoloutPair<'_, T, V, D, F, HyperplaneCrossingSolout<T, V1, V>>
260    where
261        V1: State<T>,
262    {
263        let solout =
264            HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
265
266        ODEProblemSoloutPair::new(self, solout)
267    }
268}
269
270/// ODEProblemMutRefSoloutPair serves as a intermediate between the ODEProblem struct and a custom solout provided by the user.
271pub struct ODEProblemMutRefSoloutPair<'a, T, V, D, F, O>
272where
273    T: Real,
274    V: State<T>,
275    D: CallBackData,
276    F: ODE<T, V, D>,
277    O: Solout<T, V, D>,
278{
279    pub problem: &'a ODEProblem<T, V, D, F>, // Reference to the ODEProblem struct
280    pub solout: &'a mut O,                   // Reference to the solout implementation
281}
282
283impl<'a, T, V, D, F, O> ODEProblemMutRefSoloutPair<'a, T, V, D, F, O>
284where
285    T: Real,
286    V: State<T>,
287    D: CallBackData,
288    F: ODE<T, V, D>,
289    O: Solout<T, V, D>,
290{
291    /// Create a new ODEProblemMutRefSoloutPair
292    ///
293    /// # Arguments
294    /// * `problem` - Reference to the ODEProblem struct
295    ///
296    pub fn new(problem: &'a ODEProblem<T, V, D, F>, solout: &'a mut O) -> Self {
297        ODEProblemMutRefSoloutPair { problem, solout }
298    }
299
300    /// Solve the ODEProblem using the provided solout
301    ///
302    /// # Arguments
303    /// * `solver` - ODENumericalMethod to use for solving the ODEProblem
304    ///
305    /// # Returns
306    /// * `Result<Solution<T, V, D>, Error<T, V>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
307    ///
308    pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
309    where
310        S: ODENumericalMethod<T, V, D> + Interpolation<T, V>,
311    {
312        solve_ode(
313            solver,
314            &self.problem.ode,
315            self.problem.t0,
316            self.problem.tf,
317            &self.problem.y0,
318            self.solout,
319        )
320    }
321}
322
323/// ODEProblemSoloutPair serves as a intermediate between the ODEProblem struct and solve_ode when a predefined solout is used.
324#[derive(Clone, Debug)]
325pub struct ODEProblemSoloutPair<'a, T, V, D, F, O>
326where
327    T: Real,
328    V: State<T>,
329    D: CallBackData,
330    F: ODE<T, V, D>,
331    O: Solout<T, V, D>,
332{
333    pub problem: &'a ODEProblem<T, V, D, F>, // Reference to the ODEProblem struct
334    pub solout: O,                           // Solout implementation
335}
336
337impl<'a, T, V, D, F, O> ODEProblemSoloutPair<'a, T, V, D, F, O>
338where
339    T: Real,
340    V: State<T>,
341    D: CallBackData,
342    F: ODE<T, V, D>,
343    O: Solout<T, V, D>,
344{
345    /// Create a new ODEProblemSoloutPair
346    ///
347    /// # Arguments
348    /// * `problem` - Reference to the ODEProblem struct
349    /// * `solout` - Solout implementation
350    ///
351    pub fn new(problem: &'a ODEProblem<T, V, D, F>, solout: O) -> Self {
352        ODEProblemSoloutPair { problem, solout }
353    }
354
355    /// Solve the ODEProblem using the provided solout
356    ///
357    /// # Arguments
358    /// * `solver` - ODENumericalMethod to use for solving the ODEProblem
359    ///
360    /// # Returns
361    /// * `Result<Solution<T, V, D>, Error<T, V>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
362    ///
363    pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
364    where
365        S: ODENumericalMethod<T, V, D> + Interpolation<T, V>,
366    {
367        solve_ode(
368            solver,
369            &self.problem.ode,
370            self.problem.t0,
371            self.problem.tf,
372            &self.problem.y0,
373            &mut self.solout,
374        )
375    }
376}