differential_equations/sde/
problem.rs

1//! SDE Problem Struct and Constructors
2
3use crate::{
4    error::Error,
5    interpolate::Interpolation,
6    sde::{SDE, StochasticNumericalMethod, solve_sde},
7    solout::*,
8    solution::Solution,
9    traits::{CallBackData, Real, State},
10};
11
12/// Initial Value Problem for Stochastic Differential Equations (SDEProblem)
13///
14/// The Initial Value Problem takes the form:
15/// dY = a(t, Y)dt + b(t, Y)dW, t0 <= t <= tf, Y(t0) = y0
16///
17/// where:
18/// - a(t, Y) is the drift term (deterministic part)
19/// - b(t, Y) is the diffusion term (stochastic part)
20/// - dW represents a Wiener process increment
21///
22/// # Overview
23///
24/// The SDEProblem struct provides a simple interface for solving stochastic differential equations:
25///
26/// # Example
27///
28/// ```
29/// use differential_equations::prelude::*;
30/// use nalgebra::SVector;
31/// use rand::SeedableRng;
32/// use rand_distr::{Distribution, Normal};
33///
34/// struct GBM {
35///     rng: rand::rngs::StdRng,
36/// }
37///
38/// impl GBM {
39///     fn new(seed: u64) -> Self {
40///         Self {
41///             rng: rand::rngs::StdRng::seed_from_u64(seed),
42///         }
43///     }
44/// }
45///
46/// impl SDE<f64, SVector<f64, 1>> for GBM {
47///     fn drift(&self, _t: f64, y: &SVector<f64, 1>, dydt: &mut SVector<f64, 1>) {
48///         dydt[0] = 0.1 * y[0]; // μS
49///     }
50///     
51///     fn diffusion(&self, _t: f64, y: &SVector<f64, 1>, dydw: &mut SVector<f64, 1>) {
52///         dydw[0] = 0.2 * y[0]; // σS
53///     }
54///     
55///     fn noise(&mut self, dt: f64, dw: &mut SVector<f64, 1>) {
56///         let normal = Normal::new(0.0, dt.sqrt()).unwrap();
57///         dw[0] = normal.sample(&mut self.rng);
58///     }
59/// }
60///
61/// let t0 = 0.0;
62/// let tf = 1.0;
63/// let y0 = SVector::<f64, 1>::new(100.0);
64/// let mut solver = ExplicitRungeKutta::three_eighths(0.01);
65/// let gbm = GBM::new(42);
66/// let mut gbm_problem = SDEProblem::new(gbm, t0, tf, y0);
67///
68/// // Solve the SDE
69/// let result = gbm_problem.solve(&mut solver);
70/// ```
71///
72/// # Fields
73///
74/// * `sde` - SDE implementing the stochastic differential equation
75/// * `t0` - Initial time
76/// * `tf` - Final time
77/// * `y0` - Initial state vector
78///
79/// # Basic Usage
80///
81/// * `new(sde, t0, tf, y0)` - Create a new SDE Problem
82/// * `solve(&mut solver)` - Solve using default output (solver step points)
83///
84/// # Output Control Methods
85///
86/// These methods configure how solution points are generated and returned:
87///
88/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
89/// * `dense(n)` - Include `n` interpolated points between each solver step
90/// * `t_eval(points)` - Evaluate solution at specific time points
91/// * `solout(custom_solout)` - Use a custom output handler
92/// * `seed(u64)` - Set a specific random seed for reproducible simulations
93///
94#[derive(Clone, Debug)]
95pub struct SDEProblem<T, Y, D, F>
96where
97    T: Real,
98    Y: State<T>,
99    D: CallBackData,
100    F: SDE<T, Y, D>,
101{
102    // SDE Problem Fields
103    pub sde: F, // SDE containing the Stochastic Differential Equation and Optional Terminate Function
104    pub t0: T,  // Initial Time
105    pub tf: T,  // Final Time
106    pub y0: Y,  // Initial State Vector
107
108    // Phantom Data for Users event output
109    _event_output_type: std::marker::PhantomData<D>,
110}
111
112impl<T, Y, D, F> SDEProblem<T, Y, D, F>
113where
114    T: Real,
115    Y: State<T>,
116    D: CallBackData,
117    F: SDE<T, Y, D>,
118{
119    /// Create a new Stochastic Differential Equation Problem
120    ///
121    /// # Arguments
122    /// * `sde` - SDE containing the Stochastic Differential Equation and Optional Terminate Function
123    /// * `t0` - Initial Time
124    /// * `tf` - Final Time
125    /// * `y0` - Initial State Vector
126    ///
127    /// # Returns
128    /// * SDE Problem ready to be solved
129    ///
130    pub fn new(sde: F, t0: T, tf: T, y0: Y) -> Self {
131        SDEProblem {
132            sde,
133            t0,
134            tf,
135            y0,
136            _event_output_type: std::marker::PhantomData,
137        }
138    }
139
140    /// Solve the SDE Problem using a default solout, e.g. outputting solutions at calculated steps
141    ///
142    /// # Returns
143    /// * `Result<Solution<T, Y, D>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
144    ///
145    pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
146    where
147        S: StochasticNumericalMethod<T, Y, D> + Interpolation<T, Y>,
148    {
149        let mut default_solout = DefaultSolout::new(); // Default solout implementation
150        solve_sde(
151            solver,
152            &mut self.sde,
153            self.t0,
154            self.tf,
155            &self.y0,
156            &mut default_solout,
157        )
158    }
159
160    /// Returns an SDE Problem with the provided solout function for outputting points
161    ///
162    /// # Returns
163    /// * SDE Problem with the provided solout function ready for .solve() method
164    ///
165    pub fn solout<'a, O: Solout<T, Y, D>>(
166        &'a mut self,
167        solout: &'a mut O,
168    ) -> SDEProblemMutRefSoloutPair<'a, T, Y, D, F, O> {
169        SDEProblemMutRefSoloutPair::new(self, solout)
170    }
171
172    /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time
173    /// Note that this does not include the solution of the calculated steps
174    ///
175    /// # Arguments
176    /// * `dt` - Interval between each output point
177    ///
178    /// # Returns
179    /// * SDE Problem with Even Solout function ready for .solve() method
180    ///
181    pub fn even(&mut self, dt: T) -> SDEProblemSoloutPair<'_, T, Y, D, F, EvenSolout<T>> {
182        let even_solout = EvenSolout::new(dt, self.t0, self.tf);
183        SDEProblemSoloutPair::new(self, even_solout)
184    }
185
186    /// Uses the Dense Output method to output n number of interpolation points between each step
187    /// Note this includes the solution of the calculated steps
188    ///
189    /// # Arguments
190    /// * `n` - Number of interpolation points between each step
191    ///
192    /// # Returns
193    /// * SDE Problem with Dense Output function ready for .solve() method
194    ///
195    pub fn dense(&mut self, n: usize) -> SDEProblemSoloutPair<'_, T, Y, D, F, DenseSolout> {
196        let dense_solout = DenseSolout::new(n);
197        SDEProblemSoloutPair::new(self, dense_solout)
198    }
199
200    /// Uses the provided time points for evaluation instead of the default method
201    /// Note this does not include the solution of the calculated steps
202    ///
203    /// # Arguments
204    /// * `points` - Custom output points
205    ///
206    /// # Returns
207    /// * SDE Problem with Custom Time Evaluation function ready for .solve() method
208    ///
209    pub fn t_eval(
210        &mut self,
211        points: impl AsRef<[T]>,
212    ) -> SDEProblemSoloutPair<'_, T, Y, D, F, TEvalSolout<T>> {
213        let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf);
214        SDEProblemSoloutPair::new(self, t_eval_solout)
215    }
216
217    /// Uses the CrossingSolout method to output points when a specific component crosses a threshold
218    /// Note this does not include the solution of the calculated steps
219    ///
220    /// # Arguments
221    /// * `component_idx` - Index of the component to monitor for crossing
222    /// * `threshold` - Value to cross
223    /// * `direction` - Direction of crossing (positive or negative)
224    ///
225    /// # Returns
226    /// * SDE Problem with CrossingSolout function ready for .solve() method
227    ///
228    pub fn crossing(
229        &mut self,
230        component_idx: usize,
231        threshold: T,
232        direction: CrossingDirection,
233    ) -> SDEProblemSoloutPair<'_, T, Y, D, F, CrossingSolout<T>> {
234        let crossing_solout =
235            CrossingSolout::new(component_idx, threshold).with_direction(direction);
236        SDEProblemSoloutPair::new(self, crossing_solout)
237    }
238
239    /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed
240    /// Note this does not include the solution of the calculated steps
241    ///
242    /// # Arguments
243    /// * `point` - Point on the hyperplane
244    /// * `normal` - Normal vector of the hyperplane
245    /// * `extractor` - Function to extract the component from the state vector
246    /// * `direction` - Direction of crossing (positive or negative)
247    ///
248    /// # Returns
249    /// * SDE Problem with HyperplaneCrossingSolout function ready for .solve() method
250    ///
251    pub fn hyperplane_crossing<Y1>(
252        &mut self,
253        point: Y1,
254        normal: Y1,
255        extractor: fn(&Y) -> Y1,
256        direction: CrossingDirection,
257    ) -> SDEProblemSoloutPair<'_, T, Y, D, F, HyperplaneCrossingSolout<T, Y1, Y>>
258    where
259        Y1: State<T>,
260    {
261        let solout =
262            HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
263
264        SDEProblemSoloutPair::new(self, solout)
265    }
266}
267
268/// SDEProblemMutRefSoloutPair serves as an intermediate between the SDEProblem struct and a custom solout provided by the user
269pub struct SDEProblemMutRefSoloutPair<'a, T, Y, D, F, O>
270where
271    T: Real,
272    Y: State<T>,
273    D: CallBackData,
274    F: SDE<T, Y, D>,
275    O: Solout<T, Y, D>,
276{
277    pub sde_problem: &'a mut SDEProblem<T, Y, D, F>,
278    pub solout: &'a mut O,
279}
280
281impl<'a, T, Y, D, F, O> SDEProblemMutRefSoloutPair<'a, T, Y, D, F, O>
282where
283    T: Real,
284    Y: State<T>,
285    D: CallBackData,
286    F: SDE<T, Y, D>,
287    O: Solout<T, Y, D>,
288{
289    /// Create a new SDEProblemMutRefSoloutPair
290    ///
291    /// # Arguments
292    /// * `sde_problem` - Reference to the SDE Problem struct
293    /// * `solout` - Reference to the solout implementation
294    ///
295    pub fn new(sde_problem: &'a mut SDEProblem<T, Y, D, F>, solout: &'a mut O) -> Self {
296        SDEProblemMutRefSoloutPair {
297            sde_problem,
298            solout,
299        }
300    }
301
302    /// Solve the SDE Problem using the provided solout
303    ///
304    /// # Arguments
305    /// * `solver` - StochasticNumericalMethod to use for solving the SDE Problem
306    ///
307    /// # Returns
308    /// * `Result<Solution<T, Y, D>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
309    ///
310    pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
311    where
312        S: StochasticNumericalMethod<T, Y, D> + Interpolation<T, Y>,
313    {
314        solve_sde(
315            solver,
316            &mut self.sde_problem.sde,
317            self.sde_problem.t0,
318            self.sde_problem.tf,
319            &self.sde_problem.y0,
320            self.solout,
321        )
322    }
323}
324
325/// SDEProblemSoloutPair serves as an intermediate between the SDEProblem struct and solve_sde when a predefined solout is used
326#[derive(Debug)]
327pub struct SDEProblemSoloutPair<'a, T, Y, D, F, O>
328where
329    T: Real,
330    Y: State<T>,
331    D: CallBackData,
332    F: SDE<T, Y, D>,
333    O: Solout<T, Y, D>,
334{
335    pub sde_problem: &'a mut SDEProblem<T, Y, D, F>,
336    pub solout: O,
337}
338
339impl<'a, T, Y, D, F, O> SDEProblemSoloutPair<'a, T, Y, D, F, O>
340where
341    T: Real,
342    Y: State<T>,
343    D: CallBackData,
344    F: SDE<T, Y, D>,
345    O: Solout<T, Y, D>,
346{
347    /// Create a new SDEProblemSoloutPair
348    ///
349    /// # Arguments
350    /// * `sde_problem` - Reference to the SDE Problem struct
351    /// * `solout` - Solout implementation
352    ///
353    pub fn new(sde_problem: &'a mut SDEProblem<T, Y, D, F>, solout: O) -> Self {
354        SDEProblemSoloutPair {
355            sde_problem,
356            solout,
357        }
358    }
359
360    /// Solve the SDE Problem using the provided solout
361    ///
362    /// # Arguments
363    /// * `solver` - StochasticNumericalMethod to use for solving the SDE Problem
364    ///
365    /// # Returns
366    /// * `Result<Solution<T, Y, D>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
367    ///
368    pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
369    where
370        S: StochasticNumericalMethod<T, Y, D> + Interpolation<T, Y>,
371    {
372        solve_sde(
373            solver,
374            &mut self.sde_problem.sde,
375            self.sde_problem.t0,
376            self.sde_problem.tf,
377            &self.sde_problem.y0,
378            &mut self.solout,
379        )
380    }
381}