differential_equations/ode/problem.rs
1//! Initial Value Problem Struct and Constructors
2
3use crate::{
4 error::Error,
5 interpolate::Interpolation,
6 ode::{ODE, OrdinaryNumericalMethod, solve_ode},
7 solout::*,
8 solution::Solution,
9 traits::{Real, State},
10};
11
12/// Initial Value Problem for Ordinary Differential Equations (ODEs)
13///
14/// The Initial Value Problem takes the form:
15/// y' = f(t, y), a <= t <= b, y(a) = alpha
16///
17/// # Overview
18///
19/// The ODEProblem struct provides a simple interface for solving differential equations:
20///
21/// # Example
22///
23/// ```
24/// use differential_equations::prelude::*;
25///
26/// struct LinearEquation {
27/// pub a: f32,
28/// pub b: f32,
29/// }
30///
31/// impl ODE<f32, f32> for LinearEquation {
32/// fn diff(&self, _t: f32, y: &f32, dydt: &mut f32) {
33/// *dydt = self.a + self.b * y;
34/// }
35/// }
36///
37/// // Create the ode and initial conditions
38/// let ode = LinearEquation { a: 1.0, b: 2.0 };
39/// let t0 = 0.0;
40/// let tf = 1.0;
41/// let y0 = 1.0;
42/// let mut solver = ExplicitRungeKutta::dop853().rtol(1e-8).atol(1e-6);
43///
44/// // Basic usage:
45/// let problem = ODEProblem::new(ode, t0, tf, y0);
46/// let solution = problem.solve(&mut solver).unwrap();
47///
48/// // Advanced output control:
49/// let solution = problem.even(0.1).solve(&mut solver).unwrap();
50/// ```
51///
52/// # Fields
53///
54/// * `ode` - ODE implementing the differential equation
55/// * `t0` - Initial time
56/// * `tf` - Final time
57/// * `y0` - Initial state vector
58///
59/// # Basic Usage
60///
61/// * `new(ode, t0, tf, y0)` - Create a new ODEProblem
62/// * `solve(&mut solver)` - Solve using default output (solver step points)
63///
64/// # Output Control Methods
65///
66/// These methods configure how solution points are generated and returned:
67///
68/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
69/// * `dense(n)` - Include `n` interpolated points between each solver step
70/// * `t_eval(points)` - Evaluate solution at specific time points
71/// * `solout(custom_solout)` - Use a custom output handler
72///
73/// Each returns a solver configuration that can be executed with `.solve(&mut solver)`.
74///
75/// # Example 2
76///
77/// ```
78/// use differential_equations::prelude::*;
79/// use nalgebra::{SVector, vector};
80///
81/// struct HarmonicOscillator { k: f64 }
82///
83/// impl ODE<f64, SVector<f64, 2>> for HarmonicOscillator {
84/// fn diff(&self, _t: f64, y: &SVector<f64, 2>, dydt: &mut SVector<f64, 2>) {
85/// dydt[0] = y[1];
86/// dydt[1] = -self.k * y[0];
87/// }
88/// }
89///
90/// let ode = HarmonicOscillator { k: 1.0 };
91/// let mut method = ExplicitRungeKutta::dop853().rtol(1e-12).atol(1e-12);
92///
93/// // Basic usage with default output points
94/// let problem = ODEProblem::new(ode, 0.0, 10.0, vector![1.0, 0.0]);
95/// let results = problem.solve(&mut method).unwrap();
96///
97/// // Advanced: evenly spaced output with 0.1 time intervals
98/// let results = problem.dense(4).solve(&mut method).unwrap();
99/// ```
100#[derive(Clone, Debug)]
101pub struct ODEProblem<'a, T, Y, F>
102where
103 T: Real,
104 Y: State<T>,
105 F: ODE<T, Y>,
106{
107 /// ODE object implementing [`ODE`](crate::ode::ODE) trait
108 pub ode: &'a F,
109 /// Initial Time
110 pub t0: T,
111 /// Final Time
112 pub tf: T,
113 /// Initial State Vector
114 pub y0: Y,
115}
116
117impl<'a, T, Y, F> ODEProblem<'a, T, Y, F>
118where
119 T: Real,
120 Y: State<T>,
121 F: ODE<T, Y>,
122{
123 /// Create a new Initial Value Problem
124 ///
125 /// # Arguments
126 /// * `ode` - ODE containing the Differential Equation and Optional Terminate Function.
127 /// * `t0` - Initial Time.
128 /// * `tf` - Final Time.
129 /// * `y0` - Initial State Vector.
130 ///
131 /// # Returns
132 /// * ODEProblem Problem ready to be solved.
133 ///
134 pub fn new(ode: &'a F, t0: T, tf: T, y0: Y) -> Self {
135 ODEProblem { ode, t0, tf, y0 }
136 }
137
138 /// Solve the ODEProblem using a default solout, e.g. outputting solutions at calculated steps.
139 ///
140 /// # Returns
141 /// * `Result<Solution<T, Y>, Status<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Status)` if an errors or issues such as stiffness are encountered.
142 ///
143 pub fn solve<S>(&self, solver: &mut S) -> Result<Solution<T, Y>, Error<T, Y>>
144 where
145 S: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y>,
146 {
147 let mut default_solout = DefaultSolout::new();
148 solve_ode(
149 solver,
150 self.ode,
151 self.t0,
152 self.tf,
153 &self.y0,
154 &mut default_solout,
155 )
156 }
157
158 /// Returns an ODEProblem OrdinaryNumericalMethod with the provided solout function for outputting points.
159 ///
160 /// # Returns
161 /// * ODEProblem OrdinaryNumericalMethod with the provided solout function ready for .solve() method.
162 ///
163 pub fn solout<O: Solout<T, Y>>(
164 &'a self,
165 solout: &'a mut O,
166 ) -> ODEProblemMutRefSoloutPair<'a, T, Y, F, O> {
167 ODEProblemMutRefSoloutPair::new(self, solout)
168 }
169
170 /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time.
171 /// Note that this does not include the solution of the calculated steps.
172 ///
173 /// # Arguments
174 /// * `dt` - Interval between each output point.
175 ///
176 /// # Returns
177 /// * ODEProblem OrdinaryNumericalMethod with Even Solout function ready for .solve() method.
178 ///
179 pub fn even(&self, dt: T) -> ODEProblemSoloutPair<'_, T, Y, F, EvenSolout<T>> {
180 let even_solout = EvenSolout::new(dt, self.t0, self.tf);
181 ODEProblemSoloutPair::new(self, even_solout)
182 }
183
184 /// Uses the Dense Output method to output n number of interpolation points between each step.
185 /// Note this includes the solution of the calculated steps.
186 ///
187 /// # Arguments
188 /// * `n` - Number of interpolation points between each step.
189 ///
190 /// # Returns
191 /// * ODEProblem OrdinaryNumericalMethod with Dense Output function ready for .solve() method.
192 ///
193 pub fn dense(&self, n: usize) -> ODEProblemSoloutPair<'_, T, Y, F, DenseSolout> {
194 let dense_solout = DenseSolout::new(n);
195 ODEProblemSoloutPair::new(self, dense_solout)
196 }
197
198 /// Uses the provided time points for evaluation instead of the default method.
199 /// Note this does not include the solution of the calculated steps.
200 ///
201 /// # Arguments
202 /// * `points` - Custom output points.
203 ///
204 /// # Returns
205 /// * ODEProblem OrdinaryNumericalMethod with Custom Time Evaluation function ready for .solve() method.
206 ///
207 pub fn t_eval(
208 &self,
209 points: impl AsRef<[T]>,
210 ) -> ODEProblemSoloutPair<'_, T, Y, F, TEvalSolout<T>> {
211 let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf);
212 ODEProblemSoloutPair::new(self, t_eval_solout)
213 }
214
215 /// Uses the CrossingSolout method to output points when a specific component crosses a threshold.
216 /// Note this does not include the solution of the calculated steps.
217 ///
218 /// # Arguments
219 /// * `component_idx` - Index of the component to monitor for crossing.
220 /// * `threshhold` - Value to cross.
221 /// * `direction` - Direction of crossing (positive or negative).
222 ///
223 /// # Returns
224 /// * ODEProblem OrdinaryNumericalMethod with CrossingSolout function ready for .solve() method.
225 ///
226 pub fn crossing(
227 &self,
228 component_idx: usize,
229 threshhold: T,
230 direction: CrossingDirection,
231 ) -> ODEProblemSoloutPair<'_, T, Y, F, CrossingSolout<T>> {
232 let crossing_solout =
233 CrossingSolout::new(component_idx, threshhold).with_direction(direction);
234 ODEProblemSoloutPair::new(self, crossing_solout)
235 }
236
237 /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed.
238 /// Note this does not include the solution of the calculated steps.
239 ///
240 /// # Arguments
241 /// * `point` - Point on the hyperplane.
242 /// * `normal` - Normal vector of the hyperplane.
243 /// * `extractor` - Function to extract the component from the state vector.
244 /// * `direction` - Direction of crossing (positive or negative).
245 ///
246 /// # Returns
247 /// * ODEProblem OrdinaryNumericalMethod with HyperplaneCrossingSolout function ready for .solve() method.
248 ///
249 pub fn hyperplane_crossing<Y1>(
250 &self,
251 point: Y1,
252 normal: Y1,
253 extractor: fn(&Y) -> Y1,
254 direction: CrossingDirection,
255 ) -> ODEProblemSoloutPair<'_, T, Y, F, HyperplaneCrossingSolout<T, Y1, Y>>
256 where
257 Y1: State<T>,
258 {
259 let solout =
260 HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
261
262 ODEProblemSoloutPair::new(self, solout)
263 }
264
265 /// Uses an `EventSolout` to detect zero crossings of a user-defined event function (SciPy style).
266 /// The provided event implements the `Event` trait returning a scalar function g(t,y) whose
267 /// roots are sought. Each detected event point (t*, y*) is appended to the solution. Optional
268 /// termination after N events can be configured in the Event implementation via `config()`.
269 ///
270 /// # Arguments
271 /// * `event` - Object implementing `Event<T, Y>` whose zero crossings are desired.
272 ///
273 /// # Returns
274 /// * `ODEProblemSoloutPair` with `EventSolout` ready for `.solve(&mut solver)`.
275 ///
276 /// # Example
277 /// ```
278 /// use differential_equations::prelude::*;
279 /// use nalgebra::{Vector2, vector};
280 ///
281 /// struct SHO; // Simple harmonic oscillator
282 /// impl ODE<f64, Vector2<f64>> for SHO {
283 /// fn diff(&self, _t: f64, y: &Vector2<f64>, dydt: &mut Vector2<f64>) {
284 /// dydt[0]=y[1];
285 /// dydt[1]=-y[0];
286 /// }
287 /// }
288 ///
289 /// // Event: detect when position crosses zero going positive (like SciPy event)
290 /// struct ZeroUp;
291 /// impl Event<f64, Vector2<f64>> for ZeroUp {
292 /// fn config(&self) -> EventConfig {
293 /// // Force only positive crossings
294 /// EventConfig::default().direction(CrossingDirection::Positive)
295 /// }
296 ///
297 /// fn event(&self, _t: f64, y: &Vector2<f64>) -> f64 {
298 /// y[0]
299 /// }
300 /// }
301 ///
302 /// let osc = SHO;
303 /// let t0 = 0.0;
304 /// let tf = 10.0;
305 /// let y0 = vector![1.0, 0.0];
306 /// let problem = ODEProblem::new(&osc, t0, tf, y0);
307 /// let mut solver = ExplicitRungeKutta::dop853();
308 /// let solution = problem.event(&ZeroUp).solve(&mut solver).unwrap();
309 /// // solution.t now contains zero-up crossing times
310 /// ```
311 pub fn event<E>(&'a self, event: &'a E) -> ODEProblemSoloutPair<'a, T, Y, F, EventSolout<'a, T, Y, E>>
312 where
313 E: Event<T, Y>,
314 {
315 let solout = EventSolout::new(event, self.t0, self.tf);
316 ODEProblemSoloutPair::new(self, solout)
317 }
318}
319
320/// ODEProblemMutRefSoloutPair serves as a intermediate between the ODEProblem struct and a custom solout provided by the user.
321pub struct ODEProblemMutRefSoloutPair<'a, T, Y, F, O>
322where
323 T: Real,
324 Y: State<T>,
325 F: ODE<T, Y>,
326 O: Solout<T, Y>,
327{
328 pub problem: &'a ODEProblem<'a, T, Y, F>,
329 pub solout: &'a mut O,
330}
331
332impl<'a, T, Y, F, O> ODEProblemMutRefSoloutPair<'a, T, Y, F, O>
333where
334 T: Real,
335 Y: State<T>,
336 F: ODE<T, Y>,
337 O: Solout<T, Y>,
338{
339 /// Create a new ODEProblemMutRefSoloutPair
340 ///
341 /// # Arguments
342 /// * `problem` - Reference to the ODEProblem struct
343 ///
344 pub fn new(problem: &'a ODEProblem<T, Y, F>, solout: &'a mut O) -> Self {
345 ODEProblemMutRefSoloutPair { problem, solout }
346 }
347
348 /// Solve the ODEProblem using the provided solout
349 ///
350 /// # Arguments
351 /// * `solver` - OrdinaryNumericalMethod to use for solving the ODEProblem
352 ///
353 /// # Returns
354 /// * `Result<Solution<T, Y>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
355 ///
356 pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, Y>, Error<T, Y>>
357 where
358 S: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y>,
359 {
360 solve_ode(
361 solver,
362 self.problem.ode,
363 self.problem.t0,
364 self.problem.tf,
365 &self.problem.y0,
366 self.solout,
367 )
368 }
369}
370
371/// ODEProblemSoloutPair serves as a intermediate between the ODEProblem struct and solve_ode when a predefined solout is used.
372#[derive(Clone, Debug)]
373pub struct ODEProblemSoloutPair<'a, T, Y, F, O>
374where
375 T: Real,
376 Y: State<T>,
377 F: ODE<T, Y>,
378 O: Solout<T, Y>,
379{
380 pub problem: &'a ODEProblem<'a, T, Y, F>,
381 pub solout: O,
382}
383
384impl<'a, T, Y, F, O> ODEProblemSoloutPair<'a, T, Y, F, O>
385where
386 T: Real,
387 Y: State<T>,
388 F: ODE<T, Y>,
389 O: Solout<T, Y>,
390{
391 /// Create a new ODEProblemSoloutPair
392 ///
393 /// # Arguments
394 /// * `problem` - Reference to the ODEProblem struct
395 /// * `solout` - Solout implementation
396 ///
397 pub fn new(problem: &'a ODEProblem<T, Y, F>, solout: O) -> Self {
398 ODEProblemSoloutPair { problem, solout }
399 }
400
401 /// Solve the ODEProblem using the provided solout
402 ///
403 /// # Arguments
404 /// * `solver` - OrdinaryNumericalMethod to use for solving the ODEProblem
405 ///
406 /// # Returns
407 /// * `Result<Solution<T, Y>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if an errors or issues such as stiffness are encountered.
408 ///
409 pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, Y>, Error<T, Y>>
410 where
411 S: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y>,
412 {
413 solve_ode(
414 solver,
415 self.problem.ode,
416 self.problem.t0,
417 self.problem.tf,
418 &self.problem.y0,
419 &mut self.solout,
420 )
421 }
422
423 /// Wrap current solout with event detection while preserving original output strategy.
424 pub fn event<E>(
425 self,
426 event: &'a E,
427 ) -> ODEProblemSoloutPair<'a, T, Y, F, EventWrappedSolout<'a, T, Y, O, E>>
428 where
429 E: Event<T, Y>,
430 {
431 let wrapped = EventWrappedSolout::new(self.solout, event, self.problem.t0, self.problem.tf);
432 ODEProblemSoloutPair {
433 problem: self.problem,
434 solout: wrapped,
435 }
436 }
437}