differential_equations/ode/ode_problem.rs
1//! Initial Value Problem Struct and Constructors
2
3use crate::{
4 Error, Solution,
5 interpolate::Interpolation,
6 ode::{ODE, numerical_method::ODENumericalMethod, solve_ode},
7 solout::*,
8 traits::{CallBackData, Real, State},
9};
10
11/// Initial Value Problem for Ordinary Differential Equations (ODEs)
12///
13/// The Initial Value Problem takes the form:
14/// y' = f(t, y), a <= t <= b, y(a) = alpha
15///
16/// # Overview
17///
18/// The ODEProblem struct provides a simple interface for solving differential equations:
19///
20/// # Example
21///
22/// ```
23/// use differential_equations::prelude::*;
24///
25/// struct LinearEquation {
26/// pub a: f32,
27/// pub b: f32,
28/// }
29///
30/// impl ODE<f32, f32> for LinearEquation {
31/// fn diff(&self, _t: f32, y: &f32, dydt: &mut f32) {
32/// *dydt = self.a + self.b * y;
33/// }
34/// }
35///
36/// // Create the ode and initial conditions
37/// let ode = LinearEquation { a: 1.0, b: 2.0 };
38/// let t0 = 0.0;
39/// let tf = 1.0;
40/// let y0 = 1.0;
41/// let mut solver = DOP853::new().rtol(1e-8).atol(1e-6);
42///
43/// // Basic usage:
44/// let problem = ODEProblem::new(ode, t0, tf, y0);
45/// let solution = problem.solve(&mut solver).unwrap();
46///
47/// // Advanced output control:
48/// let solution = problem.even(0.1).solve(&mut solver).unwrap();
49/// ```
50///
51/// # Fields
52///
53/// * `ode` - ODE implementing the differential equation
54/// * `t0` - Initial time
55/// * `tf` - Final time
56/// * `y0` - Initial state vector
57///
58/// # Basic Usage
59///
60/// * `new(ode, t0, tf, y0)` - Create a new ODEProblem
61/// * `solve(&mut solver)` - Solve using default output (solver step points)
62///
63/// # Output Control Methods
64///
65/// These methods configure how solution points are generated and returned:
66///
67/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
68/// * `dense(n)` - Include `n` interpolated points between each solver step
69/// * `t_eval(points)` - Evaluate solution at specific time points
70/// * `solout(custom_solout)` - Use a custom output handler
71///
72/// Each returns a solver configuration that can be executed with `.solve(&mut solver)`.
73///
74/// # Example 2
75///
76/// ```
77/// use differential_equations::prelude::*;
78/// use nalgebra::{SVector, vector};
79///
80/// struct HarmonicOscillator { k: f64 }
81///
82/// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
83/// fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
84/// dydt[0] = y[1];
85/// dydt[1] = -self.k * y[0];
86/// }
87/// }
88///
89/// let ode = HarmonicOscillator { k: 1.0 };
90/// let mut method = DOP853::new().rtol(1e-12).atol(1e-12);
91///
92/// // Basic usage with default output points
93/// let problem = ODEProblem::new(ode, 0.0, 10.0, vector![1.0, 0.0]);
94/// let results = problem.solve(&mut method).unwrap();
95///
96/// // Advanced: evenly spaced output with 0.1 time intervals
97/// let results = problem.dense(4).solve(&mut method).unwrap();
98/// ```
99#[derive(Clone, Debug)]
100pub struct ODEProblem<T, V, D, F>
101where
102 T: Real,
103 V: State<T>,
104 D: CallBackData,
105 F: ODE<T, V, D>,
106{
107 // Initial Value Problem Fields
108 pub ode: F, // ODE containing the Differential Equation and Optional Terminate Function.
109 pub t0: T, // Initial Time.
110 pub tf: T, // Final Time.
111 pub y0: V, // Initial State Vector.
112
113 // Phantom Data for Users event output
114 _event_output_type: std::marker::PhantomData<D>,
115}
116
117impl<T, V, D, F> ODEProblem<T, V, D, F>
118where
119 T: Real,
120 V: State<T>,
121 D: CallBackData,
122 F: ODE<T, V, D>,
123{
124 /// Create a new Initial Value Problem
125 ///
126 /// # Arguments
127 /// * `ode` - ODE containing the Differential Equation and Optional Terminate Function.
128 /// * `t0` - Initial Time.
129 /// * `tf` - Final Time.
130 /// * `y0` - Initial State Vector.
131 ///
132 /// # Returns
133 /// * ODEProblem Problem ready to be solved.
134 ///
135 pub fn new(ode: F, t0: T, tf: T, y0: V) -> Self {
136 ODEProblem {
137 ode,
138 t0,
139 tf,
140 y0,
141 _event_output_type: std::marker::PhantomData,
142 }
143 }
144
145 /// Solve the ODEProblem using a default solout, e.g. outputting solutions at calculated steps.
146 ///
147 /// # Returns
148 /// * `Result<Solution<T, V, D>, Status<T, V, D>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Status)` if an errors or issues such as stiffness are encountered.
149 ///
150 pub fn solve<S>(&self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
151 where
152 S: ODENumericalMethod<T, V, D> + Interpolation<T, V>,
153 {
154 let mut default_solout = DefaultSolout::new(); // Default solout implementation
155 solve_ode(
156 solver,
157 &self.ode,
158 self.t0,
159 self.tf,
160 &self.y0,
161 &mut default_solout,
162 )
163 }
164
165 /// Returns an ODEProblem ODENumericalMethod with the provided solout function for outputting points.
166 ///
167 /// # Returns
168 /// * ODEProblem ODENumericalMethod with the provided solout function ready for .solve() method.
169 ///
170 pub fn solout<'a, O: Solout<T, V, D>>(
171 &'a self,
172 solout: &'a mut O,
173 ) -> ODEProblemMutRefSoloutPair<'a, T, V, D, F, O> {
174 ODEProblemMutRefSoloutPair::new(self, solout)
175 }
176
177 /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time.
178 /// Note that this does not include the solution of the calculated steps.
179 ///
180 /// # Arguments
181 /// * `dt` - Interval between each output point.
182 ///
183 /// # Returns
184 /// * ODEProblem ODENumericalMethod with Even Solout function ready for .solve() method.
185 ///
186 pub fn even(&self, dt: T) -> ODEProblemSoloutPair<'_, T, V, D, F, EvenSolout<T>> {
187 let even_solout = EvenSolout::new(dt, self.t0, self.tf); // Even solout implementation
188 ODEProblemSoloutPair::new(self, even_solout)
189 }
190
191 /// Uses the Dense Output method to output n number of interpolation points between each step.
192 /// Note this includes the solution of the calculated steps.
193 ///
194 /// # Arguments
195 /// * `n` - Number of interpolation points between each step.
196 ///
197 /// # Returns
198 /// * ODEProblem ODENumericalMethod with Dense Output function ready for .solve() method.
199 ///
200 pub fn dense(&self, n: usize) -> ODEProblemSoloutPair<'_, T, V, D, F, DenseSolout> {
201 let dense_solout = DenseSolout::new(n); // Dense solout implementation
202 ODEProblemSoloutPair::new(self, dense_solout)
203 }
204
205 /// Uses the provided time points for evaluation instead of the default method.
206 /// Note this does not include the solution of the calculated steps.
207 ///
208 /// # Arguments
209 /// * `points` - Custom output points.
210 ///
211 /// # Returns
212 /// * ODEProblem ODENumericalMethod with Custom Time Evaluation function ready for .solve() method.
213 ///
214 pub fn t_eval(&self, points: Vec<T>) -> ODEProblemSoloutPair<'_, T, V, D, F, TEvalSolout<T>> {
215 let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf); // Custom time evaluation solout implementation
216 ODEProblemSoloutPair::new(self, t_eval_solout)
217 }
218
219 /// Uses the CrossingSolout method to output points when a specific component crosses a threshold.
220 /// Note this does not include the solution of the calculated steps.
221 ///
222 /// # Arguments
223 /// * `component_idx` - Index of the component to monitor for crossing.
224 /// * `threshhold` - Value to cross.
225 /// * `direction` - Direction of crossing (positive or negative).
226 ///
227 /// # Returns
228 /// * ODEProblem ODENumericalMethod with CrossingSolout function ready for .solve() method.
229 ///
230 pub fn crossing(
231 &self,
232 component_idx: usize,
233 threshhold: T,
234 direction: CrossingDirection,
235 ) -> ODEProblemSoloutPair<'_, T, V, D, F, CrossingSolout<T>> {
236 let crossing_solout =
237 CrossingSolout::new(component_idx, threshhold).with_direction(direction); // Crossing solout implementation
238 ODEProblemSoloutPair::new(self, crossing_solout)
239 }
240
241 /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed.
242 /// Note this does not include the solution of the calculated steps.
243 ///
244 /// # Arguments
245 /// * `point` - Point on the hyperplane.
246 /// * `normal` - Normal vector of the hyperplane.
247 /// * `extractor` - Function to extract the component from the state vector.
248 /// * `direction` - Direction of crossing (positive or negative).
249 ///
250 /// # Returns
251 /// * ODEProblem ODENumericalMethod with HyperplaneCrossingSolout function ready for .solve() method.
252 ///
253 pub fn hyperplane_crossing<V1>(
254 &self,
255 point: V1,
256 normal: V1,
257 extractor: fn(&V) -> V1,
258 direction: CrossingDirection,
259 ) -> ODEProblemSoloutPair<'_, T, V, D, F, HyperplaneCrossingSolout<T, V1, V>>
260 where
261 V1: State<T>,
262 {
263 let solout =
264 HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
265
266 ODEProblemSoloutPair::new(self, solout)
267 }
268}
269
270/// ODEProblemMutRefSoloutPair serves as a intermediate between the ODEProblem struct and a custom solout provided by the user.
271pub struct ODEProblemMutRefSoloutPair<'a, T, V, D, F, O>
272where
273 T: Real,
274 V: State<T>,
275 D: CallBackData,
276 F: ODE<T, V, D>,
277 O: Solout<T, V, D>,
278{
279 pub problem: &'a ODEProblem<T, V, D, F>, // Reference to the ODEProblem struct
280 pub solout: &'a mut O, // Reference to the solout implementation
281}
282
283impl<'a, T, V, D, F, O> ODEProblemMutRefSoloutPair<'a, T, V, D, F, O>
284where
285 T: Real,
286 V: State<T>,
287 D: CallBackData,
288 F: ODE<T, V, D>,
289 O: Solout<T, V, D>,
290{
291 /// Create a new ODEProblemMutRefSoloutPair
292 ///
293 /// # Arguments
294 /// * `problem` - Reference to the ODEProblem struct
295 ///
296 pub fn new(problem: &'a ODEProblem<T, V, D, F>, solout: &'a mut O) -> Self {
297 ODEProblemMutRefSoloutPair { problem, solout }
298 }
299
300 /// Solve the ODEProblem using the provided solout
301 ///
302 /// # Arguments
303 /// * `solver` - ODENumericalMethod to use for solving the ODEProblem
304 ///
305 /// # Returns
306 /// * `Result<Solution<T, V, D>, Error<T, V>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
307 ///
308 pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
309 where
310 S: ODENumericalMethod<T, V, D> + Interpolation<T, V>,
311 {
312 solve_ode(
313 solver,
314 &self.problem.ode,
315 self.problem.t0,
316 self.problem.tf,
317 &self.problem.y0,
318 self.solout,
319 )
320 }
321}
322
323/// ODEProblemSoloutPair serves as a intermediate between the ODEProblem struct and solve_ode when a predefined solout is used.
324#[derive(Clone, Debug)]
325pub struct ODEProblemSoloutPair<'a, T, V, D, F, O>
326where
327 T: Real,
328 V: State<T>,
329 D: CallBackData,
330 F: ODE<T, V, D>,
331 O: Solout<T, V, D>,
332{
333 pub problem: &'a ODEProblem<T, V, D, F>, // Reference to the ODEProblem struct
334 pub solout: O, // Solout implementation
335}
336
337impl<'a, T, V, D, F, O> ODEProblemSoloutPair<'a, T, V, D, F, O>
338where
339 T: Real,
340 V: State<T>,
341 D: CallBackData,
342 F: ODE<T, V, D>,
343 O: Solout<T, V, D>,
344{
345 /// Create a new ODEProblemSoloutPair
346 ///
347 /// # Arguments
348 /// * `problem` - Reference to the ODEProblem struct
349 /// * `solout` - Solout implementation
350 ///
351 pub fn new(problem: &'a ODEProblem<T, V, D, F>, solout: O) -> Self {
352 ODEProblemSoloutPair { problem, solout }
353 }
354
355 /// Solve the ODEProblem using the provided solout
356 ///
357 /// # Arguments
358 /// * `solver` - ODENumericalMethod to use for solving the ODEProblem
359 ///
360 /// # Returns
361 /// * `Result<Solution<T, V, D>, Error<T, V>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
362 ///
363 pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
364 where
365 S: ODENumericalMethod<T, V, D> + Interpolation<T, V>,
366 {
367 solve_ode(
368 solver,
369 &self.problem.ode,
370 self.problem.t0,
371 self.problem.tf,
372 &self.problem.y0,
373 &mut self.solout,
374 )
375 }
376}