differential_equations/sde/
problem.rs

1//! SDE Problem Struct and Constructors
2
3use crate::{
4    error::Error,
5    interpolate::Interpolation,
6    sde::{SDE, StochasticNumericalMethod, solve_sde},
7    solout::*,
8    solution::Solution,
9    traits::{Real, State},
10};
11
12/// Initial Value Problem for Stochastic Differential Equations (SDEProblem)
13///
14/// The Initial Value Problem takes the form:
15/// dY = a(t, Y)dt + b(t, Y)dW, t0 <= t <= tf, Y(t0) = y0
16///
17/// where:
18/// - a(t, Y) is the drift term (deterministic part)
19/// - b(t, Y) is the diffusion term (stochastic part)
20/// - dW represents a Wiener process increment
21///
22/// # Overview
23///
24/// The SDEProblem struct provides a simple interface for solving stochastic differential equations:
25///
26/// # Example
27///
28/// ```
29/// use differential_equations::prelude::*;
30/// use nalgebra::SVector;
31/// use rand::SeedableRng;
32/// use rand_distr::{Distribution, Normal};
33///
34/// struct GBM {
35///     rng: rand::rngs::StdRng,
36/// }
37///
38/// impl GBM {
39///     fn new(seed: u64) -> Self {
40///         Self {
41///             rng: rand::rngs::StdRng::seed_from_u64(seed),
42///         }
43///     }
44/// }
45///
46/// impl SDE<f64, SVector<f64, 1>> for GBM {
47///     fn drift(&self, _t: f64, y: &SVector<f64, 1>, dydt: &mut SVector<f64, 1>) {
48///         dydt[0] = 0.1 * y[0]; // μS
49///     }
50///     
51///     fn diffusion(&self, _t: f64, y: &SVector<f64, 1>, dydw: &mut SVector<f64, 1>) {
52///         dydw[0] = 0.2 * y[0]; // σS
53///     }
54///     
55///     fn noise(&mut self, dt: f64, dw: &mut SVector<f64, 1>) {
56///         let normal = Normal::new(0.0, dt.sqrt()).unwrap();
57///         dw[0] = normal.sample(&mut self.rng);
58///     }
59/// }
60///
61/// let t0 = 0.0;
62/// let tf = 1.0;
63/// let y0 = SVector::<f64, 1>::new(100.0);
64/// let mut solver = ExplicitRungeKutta::three_eighths(0.01);
65/// let gbm = GBM::new(42);
66/// let mut gbm_problem = SDEProblem::new(gbm, t0, tf, y0);
67///
68/// // Solve the SDE
69/// let result = gbm_problem.solve(&mut solver);
70/// ```
71///
72/// # Fields
73///
74/// * `sde` - SDE implementing the stochastic differential equation
75/// * `t0` - Initial time
76/// * `tf` - Final time
77/// * `y0` - Initial state vector
78///
79/// # Basic Usage
80///
81/// * `new(sde, t0, tf, y0)` - Create a new SDE Problem
82/// * `solve(&mut solver)` - Solve using default output (solver step points)
83///
84/// # Output Control Methods
85///
86/// These methods configure how solution points are generated and returned:
87///
88/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
89/// * `dense(n)` - Include `n` interpolated points between each solver step
90/// * `t_eval(points)` - Evaluate solution at specific time points
91/// * `solout(custom_solout)` - Use a custom output handler
92/// * `seed(u64)` - Set a specific random seed for reproducible simulations
93///
94#[derive(Debug)]
95pub struct SDEProblem<'a, T, Y, F>
96where
97    T: Real,
98    Y: State<T>,
99    F: SDE<T, Y>,
100{
101    // SDE Problem Fields
102    pub sde: &'a mut F, // SDE containing the Stochastic Differential Equation and Optional Terminate Function
103    pub t0: T,          // Initial Time
104    pub tf: T,          // Final Time
105    pub y0: Y,          // Initial State Vector
106}
107
108impl<'a, T, Y, F> SDEProblem<'a, T, Y, F>
109where
110    T: Real,
111    Y: State<T>,
112    F: SDE<T, Y>,
113{
114    /// Create a new Stochastic Differential Equation Problem
115    ///
116    /// # Arguments
117    /// * `sde` - SDE containing the Stochastic Differential Equation and Optional Terminate Function
118    /// * `t0` - Initial Time
119    /// * `tf` - Final Time
120    /// * `y0` - Initial State Vector
121    ///
122    /// # Returns
123    /// * SDE Problem ready to be solved
124    ///
125    pub fn new(sde: &'a mut F, t0: T, tf: T, y0: Y) -> Self {
126        SDEProblem { sde, t0, tf, y0 }
127    }
128
129    /// Solve the SDE Problem using a default solout, e.g. outputting solutions at calculated steps
130    ///
131    /// # Returns
132    /// * `Result<Solution<T, Y>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
133    ///
134    pub fn solve<S>(&mut self, solver: &'a mut S) -> Result<Solution<T, Y>, Error<T, Y>>
135    where
136        S: StochasticNumericalMethod<T, Y> + Interpolation<T, Y>,
137    {
138        let mut default_solout = DefaultSolout::new(); // Default solout implementation
139        solve_sde(
140            solver,
141            self.sde,
142            self.t0,
143            self.tf,
144            &self.y0,
145            &mut default_solout,
146        )
147    }
148
149    /// Returns an SDE Problem with the provided solout function for outputting points
150    ///
151    /// # Returns
152    /// * SDE Problem with the provided solout function ready for .solve() method
153    ///
154    pub fn solout<O: Solout<T, Y>>(
155        &'a mut self,
156        solout: &'a mut O,
157    ) -> SDEProblemMutRefSoloutPair<'a, T, Y, F, O> {
158        SDEProblemMutRefSoloutPair::new(self, solout)
159    }
160
161    /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time
162    /// Note that this does not include the solution of the calculated steps
163    ///
164    /// # Arguments
165    /// * `dt` - Interval between each output point
166    ///
167    /// # Returns
168    /// * SDE Problem with Even Solout function ready for .solve() method
169    ///
170    pub fn even(&'a mut self, dt: T) -> SDEProblemSoloutPair<'a, T, Y, F, EvenSolout<T>> {
171        let even_solout = EvenSolout::new(dt, self.t0, self.tf);
172        SDEProblemSoloutPair::new(self, even_solout)
173    }
174
175    /// Uses the Dense Output method to output n number of interpolation points between each step
176    /// Note this includes the solution of the calculated steps
177    ///
178    /// # Arguments
179    /// * `n` - Number of interpolation points between each step
180    ///
181    /// # Returns
182    /// * SDE Problem with Dense Output function ready for .solve() method
183    ///
184    pub fn dense(&'a mut self, n: usize) -> SDEProblemSoloutPair<'a, T, Y, F, DenseSolout> {
185        let dense_solout = DenseSolout::new(n);
186        SDEProblemSoloutPair::new(self, dense_solout)
187    }
188
189    /// Uses the provided time points for evaluation instead of the default method
190    /// Note this does not include the solution of the calculated steps
191    ///
192    /// # Arguments
193    /// * `points` - Custom output points
194    ///
195    /// # Returns
196    /// * SDE Problem with Custom Time Evaluation function ready for .solve() method
197    ///
198    pub fn t_eval(
199        &'a mut self,
200        points: impl AsRef<[T]>,
201    ) -> SDEProblemSoloutPair<'a, T, Y, F, TEvalSolout<T>> {
202        let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf);
203        SDEProblemSoloutPair::new(self, t_eval_solout)
204    }
205
206    /// Uses the CrossingSolout method to output points when a specific component crosses a threshold
207    /// Note this does not include the solution of the calculated steps
208    ///
209    /// # Arguments
210    /// * `component_idx` - Index of the component to monitor for crossing
211    /// * `threshold` - Value to cross
212    /// * `direction` - Direction of crossing (positive or negative)
213    ///
214    /// # Returns
215    /// * SDE Problem with CrossingSolout function ready for .solve() method
216    ///
217    pub fn crossing(
218        &'a mut self,
219        component_idx: usize,
220        threshold: T,
221        direction: CrossingDirection,
222    ) -> SDEProblemSoloutPair<'a, T, Y, F, CrossingSolout<T>> {
223        let crossing_solout =
224            CrossingSolout::new(component_idx, threshold).with_direction(direction);
225        SDEProblemSoloutPair::new(self, crossing_solout)
226    }
227
228    /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed
229    /// Note this does not include the solution of the calculated steps
230    ///
231    /// # Arguments
232    /// * `point` - Point on the hyperplane
233    /// * `normal` - Normal vector of the hyperplane
234    /// * `extractor` - Function to extract the component from the state vector
235    /// * `direction` - Direction of crossing (positive or negative)
236    ///
237    /// # Returns
238    /// * SDE Problem with HyperplaneCrossingSolout function ready for .solve() method
239    ///
240    pub fn hyperplane_crossing<Y1>(
241        &'a mut self,
242        point: Y1,
243        normal: Y1,
244        extractor: fn(&Y) -> Y1,
245        direction: CrossingDirection,
246    ) -> SDEProblemSoloutPair<'a, T, Y, F, HyperplaneCrossingSolout<T, Y1, Y>>
247    where
248        Y1: State<T>,
249    {
250        let solout =
251            HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
252
253        SDEProblemSoloutPair::new(self, solout)
254    }
255
256    /// Uses an `EventSolout` to capture zero crossings of a user-defined event function (SciPy style).
257    /// The event implements `Event<T,Y>` returning g(t,y); roots are located with Brent-Dekker.
258    pub fn event<E>(
259        &'a mut self,
260        event: &'a E,
261    ) -> SDEProblemSoloutPair<'a, T, Y, F, EventSolout<'a, T, Y, E>>
262    where
263        E: Event<T, Y>,
264    {
265        let solout = EventSolout::new(event, self.t0, self.tf);
266        SDEProblemSoloutPair::new(self, solout)
267    }
268}
269
270/// SDEProblemMutRefSoloutPair serves as an intermediate between the SDEProblem struct and a custom solout provided by the user
271pub struct SDEProblemMutRefSoloutPair<'a, T, Y, F, O>
272where
273    T: Real,
274    Y: State<T>,
275    F: SDE<T, Y>,
276    O: Solout<T, Y>,
277{
278    pub sde_problem: &'a mut SDEProblem<'a, T, Y, F>,
279    pub solout: &'a mut O,
280}
281
282impl<'a, T, Y, F, O> SDEProblemMutRefSoloutPair<'a, T, Y, F, O>
283where
284    T: Real,
285    Y: State<T>,
286    F: SDE<T, Y>,
287    O: Solout<T, Y>,
288{
289    /// Create a new SDEProblemMutRefSoloutPair
290    ///
291    /// # Arguments
292    /// * `sde_problem` - Reference to the SDE Problem struct
293    /// * `solout` - Reference to the solout implementation
294    ///
295    pub fn new(sde_problem: &'a mut SDEProblem<'a, T, Y, F>, solout: &'a mut O) -> Self {
296        SDEProblemMutRefSoloutPair {
297            sde_problem,
298            solout,
299        }
300    }
301
302    /// Solve the SDE Problem using the provided solout
303    ///
304    /// # Arguments
305    /// * `solver` - StochasticNumericalMethod to use for solving the SDE Problem
306    ///
307    /// # Returns
308    /// * `Result<Solution<T, Y>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
309    ///
310    pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, Y>, Error<T, Y>>
311    where
312        S: StochasticNumericalMethod<T, Y> + Interpolation<T, Y>,
313    {
314        solve_sde(
315            solver,
316            self.sde_problem.sde,
317            self.sde_problem.t0,
318            self.sde_problem.tf,
319            &self.sde_problem.y0,
320            self.solout,
321        )
322    }
323}
324
325/// SDEProblemSoloutPair serves as an intermediate between the SDEProblem struct and solve_sde when a predefined solout is used
326#[derive(Debug)]
327pub struct SDEProblemSoloutPair<'a, T, Y, F, O>
328where
329    T: Real,
330    Y: State<T>,
331    F: SDE<T, Y>,
332    O: Solout<T, Y>,
333{
334    pub sde_problem: &'a mut SDEProblem<'a, T, Y, F>,
335    pub solout: O,
336}
337
338impl<'a, T, Y, F, O> SDEProblemSoloutPair<'a, T, Y, F, O>
339where
340    T: Real,
341    Y: State<T>,
342    F: SDE<T, Y>,
343    O: Solout<T, Y>,
344{
345    /// Create a new SDEProblemSoloutPair
346    ///
347    /// # Arguments
348    /// * `sde_problem` - Reference to the SDE Problem struct
349    /// * `solout` - Solout implementation
350    ///
351    pub fn new(sde_problem: &'a mut SDEProblem<'a, T, Y, F>, solout: O) -> Self {
352        SDEProblemSoloutPair {
353            sde_problem,
354            solout,
355        }
356    }
357
358    /// Solve the SDE Problem using the provided solout
359    ///
360    /// # Arguments
361    /// * `solver` - StochasticNumericalMethod to use for solving the SDE Problem
362    ///
363    /// # Returns
364    /// * `Result<Solution<T, Y>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
365    ///
366    pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, Y>, Error<T, Y>>
367    where
368        S: StochasticNumericalMethod<T, Y> + Interpolation<T, Y>,
369    {
370        solve_sde(
371            solver,
372            self.sde_problem.sde,
373            self.sde_problem.t0,
374            self.sde_problem.tf,
375            &self.sde_problem.y0,
376            &mut self.solout,
377        )
378    }
379
380    /// Wrap current solout with event detection while preserving original output strategy.
381    pub fn event<E>(
382        self,
383        event: &'a E,
384    ) -> SDEProblemSoloutPair<'a, T, Y, F, EventWrappedSolout<'a, T, Y, O, E>>
385    where
386        E: Event<T, Y>,
387    {
388        let wrapped = EventWrappedSolout::new(self.solout, event, self.sde_problem.t0, self.sde_problem.tf);
389        SDEProblemSoloutPair::new(self.sde_problem, wrapped)
390    }
391}