differential_equations/ode/problem.rs
1//! Initial Value Problem Struct and Constructors
2
3use crate::{
4 error::Error,
5 interpolate::Interpolation,
6 ode::{ODE, OrdinaryNumericalMethod, solve_ode},
7 solout::*,
8 solution::Solution,
9 traits::{CallBackData, Real, State},
10};
11
12/// Initial Value Problem for Ordinary Differential Equations (ODEs)
13///
14/// The Initial Value Problem takes the form:
15/// y' = f(t, y), a <= t <= b, y(a) = alpha
16///
17/// # Overview
18///
19/// The ODEProblem struct provides a simple interface for solving differential equations:
20///
21/// # Example
22///
23/// ```
24/// use differential_equations::prelude::*;
25///
26/// struct LinearEquation {
27/// pub a: f32,
28/// pub b: f32,
29/// }
30///
31/// impl ODE<f32, f32> for LinearEquation {
32/// fn diff(&self, _t: f32, y: &f32, dydt: &mut f32) {
33/// *dydt = self.a + self.b * y;
34/// }
35/// }
36///
37/// // Create the ode and initial conditions
38/// let ode = LinearEquation { a: 1.0, b: 2.0 };
39/// let t0 = 0.0;
40/// let tf = 1.0;
41/// let y0 = 1.0;
42/// let mut solver = ExplicitRungeKutta::dop853().rtol(1e-8).atol(1e-6);
43///
44/// // Basic usage:
45/// let problem = ODEProblem::new(ode, t0, tf, y0);
46/// let solution = problem.solve(&mut solver).unwrap();
47///
48/// // Advanced output control:
49/// let solution = problem.even(0.1).solve(&mut solver).unwrap();
50/// ```
51///
52/// # Fields
53///
54/// * `ode` - ODE implementing the differential equation
55/// * `t0` - Initial time
56/// * `tf` - Final time
57/// * `y0` - Initial state vector
58///
59/// # Basic Usage
60///
61/// * `new(ode, t0, tf, y0)` - Create a new ODEProblem
62/// * `solve(&mut solver)` - Solve using default output (solver step points)
63///
64/// # Output Control Methods
65///
66/// These methods configure how solution points are generated and returned:
67///
68/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
69/// * `dense(n)` - Include `n` interpolated points between each solver step
70/// * `t_eval(points)` - Evaluate solution at specific time points
71/// * `solout(custom_solout)` - Use a custom output handler
72///
73/// Each returns a solver configuration that can be executed with `.solve(&mut solver)`.
74///
75/// # Example 2
76///
77/// ```
78/// use differential_equations::prelude::*;
79/// use nalgebra::{SVector, vector};
80///
81/// struct HarmonicOscillator { k: f64 }
82///
83/// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
84/// fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
85/// dydt[0] = y[1];
86/// dydt[1] = -self.k * y[0];
87/// }
88/// }
89///
90/// let ode = HarmonicOscillator { k: 1.0 };
91/// let mut method = ExplicitRungeKutta::dop853().rtol(1e-12).atol(1e-12);
92///
93/// // Basic usage with default output points
94/// let problem = ODEProblem::new(ode, 0.0, 10.0, vector![1.0, 0.0]);
95/// let results = problem.solve(&mut method).unwrap();
96///
97/// // Advanced: evenly spaced output with 0.1 time intervals
98/// let results = problem.dense(4).solve(&mut method).unwrap();
99/// ```
100#[derive(Clone, Debug)]
101pub struct ODEProblem<T, Y, D, F>
102where
103 T: Real,
104 Y: State<T>,
105 D: CallBackData,
106 F: ODE<T, Y, D>,
107{
108 // Initial Value Problem Fields
109 pub ode: F, // ODE containing the Differential Equation and Optional Terminate Function.
110 pub t0: T, // Initial Time.
111 pub tf: T, // Final Time.
112 pub y0: Y, // Initial State Vector.
113
114 // Phantom Data for Users event output
115 _event_output_type: std::marker::PhantomData<D>,
116}
117
118impl<T, Y, D, F> ODEProblem<T, Y, D, F>
119where
120 T: Real,
121 Y: State<T>,
122 D: CallBackData,
123 F: ODE<T, Y, D>,
124{
125 /// Create a new Initial Value Problem
126 ///
127 /// # Arguments
128 /// * `ode` - ODE containing the Differential Equation and Optional Terminate Function.
129 /// * `t0` - Initial Time.
130 /// * `tf` - Final Time.
131 /// * `y0` - Initial State Vector.
132 ///
133 /// # Returns
134 /// * ODEProblem Problem ready to be solved.
135 ///
136 pub fn new(ode: F, t0: T, tf: T, y0: Y) -> Self {
137 ODEProblem {
138 ode,
139 t0,
140 tf,
141 y0,
142 _event_output_type: std::marker::PhantomData,
143 }
144 }
145
146 /// Solve the ODEProblem using a default solout, e.g. outputting solutions at calculated steps.
147 ///
148 /// # Returns
149 /// * `Result<Solution<T, Y, D>, Status<T, Y, D>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Status)` if an errors or issues such as stiffness are encountered.
150 ///
151 pub fn solve<S>(&self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
152 where
153 S: OrdinaryNumericalMethod<T, Y, D> + Interpolation<T, Y>,
154 {
155 let mut default_solout = DefaultSolout::new();
156 solve_ode(
157 solver,
158 &self.ode,
159 self.t0,
160 self.tf,
161 &self.y0,
162 &mut default_solout,
163 )
164 }
165
166 /// Returns an ODEProblem OrdinaryNumericalMethod with the provided solout function for outputting points.
167 ///
168 /// # Returns
169 /// * ODEProblem OrdinaryNumericalMethod with the provided solout function ready for .solve() method.
170 ///
171 pub fn solout<'a, O: Solout<T, Y, D>>(
172 &'a self,
173 solout: &'a mut O,
174 ) -> ODEProblemMutRefSoloutPair<'a, T, Y, D, F, O> {
175 ODEProblemMutRefSoloutPair::new(self, solout)
176 }
177
178 /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time.
179 /// Note that this does not include the solution of the calculated steps.
180 ///
181 /// # Arguments
182 /// * `dt` - Interval between each output point.
183 ///
184 /// # Returns
185 /// * ODEProblem OrdinaryNumericalMethod with Even Solout function ready for .solve() method.
186 ///
187 pub fn even(&self, dt: T) -> ODEProblemSoloutPair<'_, T, Y, D, F, EvenSolout<T>> {
188 let even_solout = EvenSolout::new(dt, self.t0, self.tf);
189 ODEProblemSoloutPair::new(self, even_solout)
190 }
191
192 /// Uses the Dense Output method to output n number of interpolation points between each step.
193 /// Note this includes the solution of the calculated steps.
194 ///
195 /// # Arguments
196 /// * `n` - Number of interpolation points between each step.
197 ///
198 /// # Returns
199 /// * ODEProblem OrdinaryNumericalMethod with Dense Output function ready for .solve() method.
200 ///
201 pub fn dense(&self, n: usize) -> ODEProblemSoloutPair<'_, T, Y, D, F, DenseSolout> {
202 let dense_solout = DenseSolout::new(n);
203 ODEProblemSoloutPair::new(self, dense_solout)
204 }
205
206 /// Uses the provided time points for evaluation instead of the default method.
207 /// Note this does not include the solution of the calculated steps.
208 ///
209 /// # Arguments
210 /// * `points` - Custom output points.
211 ///
212 /// # Returns
213 /// * ODEProblem OrdinaryNumericalMethod with Custom Time Evaluation function ready for .solve() method.
214 ///
215 pub fn t_eval(
216 &self,
217 points: impl AsRef<[T]>,
218 ) -> ODEProblemSoloutPair<'_, T, Y, D, F, TEvalSolout<T>> {
219 let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf);
220 ODEProblemSoloutPair::new(self, t_eval_solout)
221 }
222
223 /// Uses the CrossingSolout method to output points when a specific component crosses a threshold.
224 /// Note this does not include the solution of the calculated steps.
225 ///
226 /// # Arguments
227 /// * `component_idx` - Index of the component to monitor for crossing.
228 /// * `threshhold` - Value to cross.
229 /// * `direction` - Direction of crossing (positive or negative).
230 ///
231 /// # Returns
232 /// * ODEProblem OrdinaryNumericalMethod with CrossingSolout function ready for .solve() method.
233 ///
234 pub fn crossing(
235 &self,
236 component_idx: usize,
237 threshhold: T,
238 direction: CrossingDirection,
239 ) -> ODEProblemSoloutPair<'_, T, Y, D, F, CrossingSolout<T>> {
240 let crossing_solout =
241 CrossingSolout::new(component_idx, threshhold).with_direction(direction);
242 ODEProblemSoloutPair::new(self, crossing_solout)
243 }
244
245 /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed.
246 /// Note this does not include the solution of the calculated steps.
247 ///
248 /// # Arguments
249 /// * `point` - Point on the hyperplane.
250 /// * `normal` - Normal vector of the hyperplane.
251 /// * `extractor` - Function to extract the component from the state vector.
252 /// * `direction` - Direction of crossing (positive or negative).
253 ///
254 /// # Returns
255 /// * ODEProblem OrdinaryNumericalMethod with HyperplaneCrossingSolout function ready for .solve() method.
256 ///
257 pub fn hyperplane_crossing<Y1>(
258 &self,
259 point: Y1,
260 normal: Y1,
261 extractor: fn(&Y) -> Y1,
262 direction: CrossingDirection,
263 ) -> ODEProblemSoloutPair<'_, T, Y, D, F, HyperplaneCrossingSolout<T, Y1, Y>>
264 where
265 Y1: State<T>,
266 {
267 let solout =
268 HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
269
270 ODEProblemSoloutPair::new(self, solout)
271 }
272}
273
274/// ODEProblemMutRefSoloutPair serves as a intermediate between the ODEProblem struct and a custom solout provided by the user.
275pub struct ODEProblemMutRefSoloutPair<'a, T, Y, D, F, O>
276where
277 T: Real,
278 Y: State<T>,
279 D: CallBackData,
280 F: ODE<T, Y, D>,
281 O: Solout<T, Y, D>,
282{
283 pub problem: &'a ODEProblem<T, Y, D, F>,
284 pub solout: &'a mut O,
285}
286
287impl<'a, T, Y, D, F, O> ODEProblemMutRefSoloutPair<'a, T, Y, D, F, O>
288where
289 T: Real,
290 Y: State<T>,
291 D: CallBackData,
292 F: ODE<T, Y, D>,
293 O: Solout<T, Y, D>,
294{
295 /// Create a new ODEProblemMutRefSoloutPair
296 ///
297 /// # Arguments
298 /// * `problem` - Reference to the ODEProblem struct
299 ///
300 pub fn new(problem: &'a ODEProblem<T, Y, D, F>, solout: &'a mut O) -> Self {
301 ODEProblemMutRefSoloutPair { problem, solout }
302 }
303
304 /// Solve the ODEProblem using the provided solout
305 ///
306 /// # Arguments
307 /// * `solver` - OrdinaryNumericalMethod to use for solving the ODEProblem
308 ///
309 /// # Returns
310 /// * `Result<Solution<T, Y, D>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
311 ///
312 pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
313 where
314 S: OrdinaryNumericalMethod<T, Y, D> + Interpolation<T, Y>,
315 {
316 solve_ode(
317 solver,
318 &self.problem.ode,
319 self.problem.t0,
320 self.problem.tf,
321 &self.problem.y0,
322 self.solout,
323 )
324 }
325}
326
327/// ODEProblemSoloutPair serves as a intermediate between the ODEProblem struct and solve_ode when a predefined solout is used.
328#[derive(Clone, Debug)]
329pub struct ODEProblemSoloutPair<'a, T, Y, D, F, O>
330where
331 T: Real,
332 Y: State<T>,
333 D: CallBackData,
334 F: ODE<T, Y, D>,
335 O: Solout<T, Y, D>,
336{
337 pub problem: &'a ODEProblem<T, Y, D, F>,
338 pub solout: O,
339}
340
341impl<'a, T, Y, D, F, O> ODEProblemSoloutPair<'a, T, Y, D, F, O>
342where
343 T: Real,
344 Y: State<T>,
345 D: CallBackData,
346 F: ODE<T, Y, D>,
347 O: Solout<T, Y, D>,
348{
349 /// Create a new ODEProblemSoloutPair
350 ///
351 /// # Arguments
352 /// * `problem` - Reference to the ODEProblem struct
353 /// * `solout` - Solout implementation
354 ///
355 pub fn new(problem: &'a ODEProblem<T, Y, D, F>, solout: O) -> Self {
356 ODEProblemSoloutPair { problem, solout }
357 }
358
359 /// Solve the ODEProblem using the provided solout
360 ///
361 /// # Arguments
362 /// * `solver` - OrdinaryNumericalMethod to use for solving the ODEProblem
363 ///
364 /// # Returns
365 /// * `Result<Solution<T, Y, D>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
366 ///
367 pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
368 where
369 S: OrdinaryNumericalMethod<T, Y, D> + Interpolation<T, Y>,
370 {
371 solve_ode(
372 solver,
373 &self.problem.ode,
374 self.problem.t0,
375 self.problem.tf,
376 &self.problem.y0,
377 &mut self.solout,
378 )
379 }
380}