differential_equations/sde/
sde_problem.rs

1//! SDE Problem Struct and Constructors
2
3use crate::{
4    Error, Solution,
5    interpolate::Interpolation,
6    sde::{SDENumericalMethod, SDE, solve_sde},
7    solout::*,
8    traits::{CallBackData, Real, State},
9};
10
11/// Initial Value Problem for Stochastic Differential Equations (SDEProblem)
12///
13/// The Initial Value Problem takes the form:
14/// dY = a(t, Y)dt + b(t, Y)dW, t0 <= t <= tf, Y(t0) = y0
15///
16/// where:
17/// - a(t, Y) is the drift term (deterministic part)
18/// - b(t, Y) is the diffusion term (stochastic part)
19/// - dW represents a Wiener process increment
20///
21/// # Overview
22///
23/// The SDEProblem struct provides a simple interface for solving stochastic differential equations:
24///
25/// # Example
26///
27/// ```
28/// use differential_equations::prelude::*;
29/// use nalgebra::SVector;
30/// use rand::SeedableRng;
31/// use rand_distr::{Distribution, Normal};
32///
33/// struct GBM {
34///     rng: rand::rngs::StdRng,
35/// }
36///
37/// impl GBM {
38///     fn new(seed: u64) -> Self {
39///         Self {
40///             rng: rand::rngs::StdRng::seed_from_u64(seed),
41///         }
42///     }
43/// }
44///
45/// impl SDE<f64, SVector<f64, 1>> for GBM {
46///     fn drift(&self, _t: f64, y: &SVector<f64, 1>, dydt: &mut SVector<f64, 1>) {
47///         dydt[0] = 0.1 * y[0]; // μS
48///     }
49///     
50///     fn diffusion(&self, _t: f64, y: &SVector<f64, 1>, dydw: &mut SVector<f64, 1>) {
51///         dydw[0] = 0.2 * y[0]; // σS
52///     }
53///     
54///     fn noise(&self, dt: f64, dw: &mut SVector<f64, 1>) {
55///         let normal = Normal::new(0.0, dt.sqrt()).unwrap();
56///         dw[0] = normal.sample(&mut self.rng.clone());
57///     }
58/// }
59///
60/// let t0 = 0.0;
61/// let tf = 1.0;
62/// let y0 = SVector::<f64, 1>::new(100.0);
63/// let mut solver = EM::new(0.01);
64/// let gbm = GBM::new(42);
65/// let gbm_problem = SDEProblem::new(gbm, t0, tf, y0);
66///
67/// // Solve the SDE
68/// let result = gbm_problem.solve(&mut solver);
69/// ```
70///
71/// # Fields
72///
73/// * `sde` - SDE implementing the stochastic differential equation
74/// * `t0` - Initial time
75/// * `tf` - Final time
76/// * `y0` - Initial state vector
77///
78/// # Basic Usage
79///
80/// * `new(sde, t0, tf, y0)` - Create a new SDE Problem
81/// * `solve(&mut solver)` - Solve using default output (solver step points)
82///
83/// # Output Control Methods
84///
85/// These methods configure how solution points are generated and returned:
86///
87/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
88/// * `dense(n)` - Include `n` interpolated points between each solver step
89/// * `t_eval(points)` - Evaluate solution at specific time points
90/// * `solout(custom_solout)` - Use a custom output handler
91/// * `seed(u64)` - Set a specific random seed for reproducible simulations
92///
93#[derive(Clone, Debug)]
94pub struct SDEProblem<T, V, D, F>
95where
96    T: Real,
97    V: State<T>,
98    D: CallBackData,
99    F: SDE<T, V, D>,
100{
101    // SDE Problem Fields
102    pub sde: F, // SDE containing the Stochastic Differential Equation and Optional Terminate Function
103    pub t0: T,  // Initial Time
104    pub tf: T,  // Final Time
105    pub y0: V,  // Initial State Vector
106
107    // Phantom Data for Users event output
108    _event_output_type: std::marker::PhantomData<D>,
109}
110
111impl<T, V, D, F> SDEProblem<T, V, D, F>
112where
113    T: Real,
114    V: State<T>,
115    D: CallBackData,
116    F: SDE<T, V, D>,
117{
118    /// Create a new Stochastic Differential Equation Problem
119    ///
120    /// # Arguments
121    /// * `sde` - SDE containing the Stochastic Differential Equation and Optional Terminate Function
122    /// * `t0` - Initial Time
123    /// * `tf` - Final Time
124    /// * `y0` - Initial State Vector
125    ///
126    /// # Returns
127    /// * SDE Problem ready to be solved
128    ///
129    pub fn new(sde: F, t0: T, tf: T, y0: V) -> Self {
130        SDEProblem {
131            sde,
132            t0,
133            tf,
134            y0,
135            _event_output_type: std::marker::PhantomData,
136        }
137    }
138
139    /// Solve the SDE Problem using a default solout, e.g. outputting solutions at calculated steps
140    ///
141    /// # Returns
142    /// * `Result<Solution<T, V, D>, Error<T, V>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
143    ///
144    pub fn solve<S>(&self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
145    where
146        S: SDENumericalMethod<T, V, D> + Interpolation<T, V>,
147    {
148        let mut default_solout = DefaultSolout::new(); // Default solout implementation
149        solve_sde(
150            solver,
151            &self.sde,
152            self.t0,
153            self.tf,
154            &self.y0,
155            &mut default_solout,
156        )
157    }
158
159    /// Returns an SDE Problem with the provided solout function for outputting points
160    ///
161    /// # Returns
162    /// * SDE Problem with the provided solout function ready for .solve() method
163    ///
164    pub fn solout<'a, O: Solout<T, V, D>>(
165        &'a self,
166        solout: &'a mut O,
167    ) -> SDEProblemMutRefSoloutPair<'a, T, V, D, F, O> {
168        SDEProblemMutRefSoloutPair::new(self, solout)
169    }
170
171    /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time
172    /// Note that this does not include the solution of the calculated steps
173    ///
174    /// # Arguments
175    /// * `dt` - Interval between each output point
176    ///
177    /// # Returns
178    /// * SDE Problem with Even Solout function ready for .solve() method
179    ///
180    pub fn even(&self, dt: T) -> SDEProblemSoloutPair<'_, T, V, D, F, EvenSolout<T>> {
181        let even_solout = EvenSolout::new(dt, self.t0, self.tf); // Even solout implementation
182        SDEProblemSoloutPair::new(self, even_solout)
183    }
184
185    /// Uses the Dense Output method to output n number of interpolation points between each step
186    /// Note this includes the solution of the calculated steps
187    ///
188    /// # Arguments
189    /// * `n` - Number of interpolation points between each step
190    ///
191    /// # Returns
192    /// * SDE Problem with Dense Output function ready for .solve() method
193    ///
194    pub fn dense(&self, n: usize) -> SDEProblemSoloutPair<'_, T, V, D, F, DenseSolout> {
195        let dense_solout = DenseSolout::new(n); // Dense solout implementation
196        SDEProblemSoloutPair::new(self, dense_solout)
197    }
198
199    /// Uses the provided time points for evaluation instead of the default method
200    /// Note this does not include the solution of the calculated steps
201    ///
202    /// # Arguments
203    /// * `points` - Custom output points
204    ///
205    /// # Returns
206    /// * SDE Problem with Custom Time Evaluation function ready for .solve() method
207    ///
208    pub fn t_eval(&self, points: Vec<T>) -> SDEProblemSoloutPair<'_, T, V, D, F, TEvalSolout<T>> {
209        let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf); // Custom time evaluation solout implementation
210        SDEProblemSoloutPair::new(self, t_eval_solout)
211    }
212
213    /// Uses the CrossingSolout method to output points when a specific component crosses a threshold
214    /// Note this does not include the solution of the calculated steps
215    ///
216    /// # Arguments
217    /// * `component_idx` - Index of the component to monitor for crossing
218    /// * `threshold` - Value to cross
219    /// * `direction` - Direction of crossing (positive or negative)
220    ///
221    /// # Returns
222    /// * SDE Problem with CrossingSolout function ready for .solve() method
223    ///
224    pub fn crossing(
225        &self,
226        component_idx: usize,
227        threshold: T,
228        direction: CrossingDirection,
229    ) -> SDEProblemSoloutPair<'_, T, V, D, F, CrossingSolout<T>> {
230        let crossing_solout =
231            CrossingSolout::new(component_idx, threshold).with_direction(direction); // Crossing solout implementation
232        SDEProblemSoloutPair::new(self, crossing_solout)
233    }
234
235    /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed
236    /// Note this does not include the solution of the calculated steps
237    ///
238    /// # Arguments
239    /// * `point` - Point on the hyperplane
240    /// * `normal` - Normal vector of the hyperplane
241    /// * `extractor` - Function to extract the component from the state vector
242    /// * `direction` - Direction of crossing (positive or negative)
243    ///
244    /// # Returns
245    /// * SDE Problem with HyperplaneCrossingSolout function ready for .solve() method
246    ///
247    pub fn hyperplane_crossing<V1>(
248        &self,
249        point: V1,
250        normal: V1,
251        extractor: fn(&V) -> V1,
252        direction: CrossingDirection,
253    ) -> SDEProblemSoloutPair<'_, T, V, D, F, HyperplaneCrossingSolout<T, V1, V>>
254    where
255        V1: State<T>,
256    {
257        let solout =
258            HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
259
260        SDEProblemSoloutPair::new(self, solout)
261    }
262}
263
264/// SDEProblemMutRefSoloutPair serves as an intermediate between the SDEProblem struct and a custom solout provided by the user
265pub struct SDEProblemMutRefSoloutPair<'a, T, V, D, F, O>
266where
267    T: Real,
268    V: State<T>,
269    D: CallBackData,
270    F: SDE<T, V, D>,
271    O: Solout<T, V, D>,
272{
273    pub sde_problem: &'a SDEProblem<T, V, D, F>, // Reference to the SDE Problem struct
274    pub solout: &'a mut O,                       // Reference to the solout implementation
275}
276
277impl<'a, T, V, D, F, O> SDEProblemMutRefSoloutPair<'a, T, V, D, F, O>
278where
279    T: Real,
280    V: State<T>,
281    D: CallBackData,
282    F: SDE<T, V, D>,
283    O: Solout<T, V, D>,
284{
285    /// Create a new SDEProblemMutRefSoloutPair
286    ///
287    /// # Arguments
288    /// * `sde_problem` - Reference to the SDE Problem struct
289    /// * `solout` - Reference to the solout implementation
290    ///
291    pub fn new(sde_problem: &'a SDEProblem<T, V, D, F>, solout: &'a mut O) -> Self {
292        SDEProblemMutRefSoloutPair {
293            sde_problem,
294            solout,
295        }
296    }
297
298    /// Solve the SDE Problem using the provided solout
299    ///
300    /// # Arguments
301    /// * `solver` - SDENumericalMethod to use for solving the SDE Problem
302    ///
303    /// # Returns
304    /// * `Result<Solution<T, V, D>, Error<T, V>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
305    ///
306    pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
307    where
308        S: SDENumericalMethod<T, V, D> + Interpolation<T, V>,
309    {
310        solve_sde(
311            solver,
312            &self.sde_problem.sde,
313            self.sde_problem.t0,
314            self.sde_problem.tf,
315            &self.sde_problem.y0,
316            self.solout,
317        )
318    }
319}
320
321/// SDEProblemSoloutPair serves as an intermediate between the SDEProblem struct and solve_sde when a predefined solout is used
322#[derive(Clone, Debug)]
323pub struct SDEProblemSoloutPair<'a, T, V, D, F, O>
324where
325    T: Real,
326    V: State<T>,
327    D: CallBackData,
328    F: SDE<T, V, D>,
329    O: Solout<T, V, D>,
330{
331    pub sde_problem: &'a SDEProblem<T, V, D, F>, // Reference to the SDE Problem struct
332    pub solout: O,                               // Solout implementation
333}
334
335impl<'a, T, V, D, F, O> SDEProblemSoloutPair<'a, T, V, D, F, O>
336where
337    T: Real,
338    V: State<T>,
339    D: CallBackData,
340    F: SDE<T, V, D>,
341    O: Solout<T, V, D>,
342{
343    /// Create a new SDEProblemSoloutPair
344    ///
345    /// # Arguments
346    /// * `sde_problem` - Reference to the SDE Problem struct
347    /// * `solout` - Solout implementation
348    ///
349    pub fn new(sde_problem: &'a SDEProblem<T, V, D, F>, solout: O) -> Self {
350        SDEProblemSoloutPair {
351            sde_problem,
352            solout,
353        }
354    }
355
356    /// Solve the SDE Problem using the provided solout
357    ///
358    /// # Arguments
359    /// * `solver` - SDENumericalMethod to use for solving the SDE Problem
360    ///
361    /// # Returns
362    /// * `Result<Solution<T, V, D>, Error<T, V>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
363    ///
364    pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
365    where
366        S: SDENumericalMethod<T, V, D> + Interpolation<T, V>,
367    {
368        solve_sde(
369            solver,
370            &self.sde_problem.sde,
371            self.sde_problem.t0,
372            self.sde_problem.tf,
373            &self.sde_problem.y0,
374            &mut self.solout,
375        )
376    }
377}