differential_equations/sde/sde_problem.rs
1//! SDE Problem Struct and Constructors
2
3use crate::{
4 Error, Solution,
5 interpolate::Interpolation,
6 sde::{SDENumericalMethod, SDE, solve_sde},
7 solout::*,
8 traits::{CallBackData, Real, State},
9};
10
11/// Initial Value Problem for Stochastic Differential Equations (SDEProblem)
12///
13/// The Initial Value Problem takes the form:
14/// dY = a(t, Y)dt + b(t, Y)dW, t0 <= t <= tf, Y(t0) = y0
15///
16/// where:
17/// - a(t, Y) is the drift term (deterministic part)
18/// - b(t, Y) is the diffusion term (stochastic part)
19/// - dW represents a Wiener process increment
20///
21/// # Overview
22///
23/// The SDEProblem struct provides a simple interface for solving stochastic differential equations:
24///
25/// # Example
26///
27/// ```
28/// use differential_equations::prelude::*;
29/// use nalgebra::SVector;
30/// use rand::SeedableRng;
31/// use rand_distr::{Distribution, Normal};
32///
33/// struct GBM {
34/// rng: rand::rngs::StdRng,
35/// }
36///
37/// impl GBM {
38/// fn new(seed: u64) -> Self {
39/// Self {
40/// rng: rand::rngs::StdRng::seed_from_u64(seed),
41/// }
42/// }
43/// }
44///
45/// impl SDE<f64, SVector<f64, 1>> for GBM {
46/// fn drift(&self, _t: f64, y: &SVector<f64, 1>, dydt: &mut SVector<f64, 1>) {
47/// dydt[0] = 0.1 * y[0]; // μS
48/// }
49///
50/// fn diffusion(&self, _t: f64, y: &SVector<f64, 1>, dydw: &mut SVector<f64, 1>) {
51/// dydw[0] = 0.2 * y[0]; // σS
52/// }
53///
54/// fn noise(&self, dt: f64, dw: &mut SVector<f64, 1>) {
55/// let normal = Normal::new(0.0, dt.sqrt()).unwrap();
56/// dw[0] = normal.sample(&mut self.rng.clone());
57/// }
58/// }
59///
60/// let t0 = 0.0;
61/// let tf = 1.0;
62/// let y0 = SVector::<f64, 1>::new(100.0);
63/// let mut solver = EM::new(0.01);
64/// let gbm = GBM::new(42);
65/// let gbm_problem = SDEProblem::new(gbm, t0, tf, y0);
66///
67/// // Solve the SDE
68/// let result = gbm_problem.solve(&mut solver);
69/// ```
70///
71/// # Fields
72///
73/// * `sde` - SDE implementing the stochastic differential equation
74/// * `t0` - Initial time
75/// * `tf` - Final time
76/// * `y0` - Initial state vector
77///
78/// # Basic Usage
79///
80/// * `new(sde, t0, tf, y0)` - Create a new SDE Problem
81/// * `solve(&mut solver)` - Solve using default output (solver step points)
82///
83/// # Output Control Methods
84///
85/// These methods configure how solution points are generated and returned:
86///
87/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
88/// * `dense(n)` - Include `n` interpolated points between each solver step
89/// * `t_eval(points)` - Evaluate solution at specific time points
90/// * `solout(custom_solout)` - Use a custom output handler
91/// * `seed(u64)` - Set a specific random seed for reproducible simulations
92///
93#[derive(Clone, Debug)]
94pub struct SDEProblem<T, V, D, F>
95where
96 T: Real,
97 V: State<T>,
98 D: CallBackData,
99 F: SDE<T, V, D>,
100{
101 // SDE Problem Fields
102 pub sde: F, // SDE containing the Stochastic Differential Equation and Optional Terminate Function
103 pub t0: T, // Initial Time
104 pub tf: T, // Final Time
105 pub y0: V, // Initial State Vector
106
107 // Phantom Data for Users event output
108 _event_output_type: std::marker::PhantomData<D>,
109}
110
111impl<T, V, D, F> SDEProblem<T, V, D, F>
112where
113 T: Real,
114 V: State<T>,
115 D: CallBackData,
116 F: SDE<T, V, D>,
117{
118 /// Create a new Stochastic Differential Equation Problem
119 ///
120 /// # Arguments
121 /// * `sde` - SDE containing the Stochastic Differential Equation and Optional Terminate Function
122 /// * `t0` - Initial Time
123 /// * `tf` - Final Time
124 /// * `y0` - Initial State Vector
125 ///
126 /// # Returns
127 /// * SDE Problem ready to be solved
128 ///
129 pub fn new(sde: F, t0: T, tf: T, y0: V) -> Self {
130 SDEProblem {
131 sde,
132 t0,
133 tf,
134 y0,
135 _event_output_type: std::marker::PhantomData,
136 }
137 }
138
139 /// Solve the SDE Problem using a default solout, e.g. outputting solutions at calculated steps
140 ///
141 /// # Returns
142 /// * `Result<Solution<T, V, D>, Error<T, V>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
143 ///
144 pub fn solve<S>(&self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
145 where
146 S: SDENumericalMethod<T, V, D> + Interpolation<T, V>,
147 {
148 let mut default_solout = DefaultSolout::new(); // Default solout implementation
149 solve_sde(
150 solver,
151 &self.sde,
152 self.t0,
153 self.tf,
154 &self.y0,
155 &mut default_solout,
156 )
157 }
158
159 /// Returns an SDE Problem with the provided solout function for outputting points
160 ///
161 /// # Returns
162 /// * SDE Problem with the provided solout function ready for .solve() method
163 ///
164 pub fn solout<'a, O: Solout<T, V, D>>(
165 &'a self,
166 solout: &'a mut O,
167 ) -> SDEProblemMutRefSoloutPair<'a, T, V, D, F, O> {
168 SDEProblemMutRefSoloutPair::new(self, solout)
169 }
170
171 /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time
172 /// Note that this does not include the solution of the calculated steps
173 ///
174 /// # Arguments
175 /// * `dt` - Interval between each output point
176 ///
177 /// # Returns
178 /// * SDE Problem with Even Solout function ready for .solve() method
179 ///
180 pub fn even(&self, dt: T) -> SDEProblemSoloutPair<'_, T, V, D, F, EvenSolout<T>> {
181 let even_solout = EvenSolout::new(dt, self.t0, self.tf); // Even solout implementation
182 SDEProblemSoloutPair::new(self, even_solout)
183 }
184
185 /// Uses the Dense Output method to output n number of interpolation points between each step
186 /// Note this includes the solution of the calculated steps
187 ///
188 /// # Arguments
189 /// * `n` - Number of interpolation points between each step
190 ///
191 /// # Returns
192 /// * SDE Problem with Dense Output function ready for .solve() method
193 ///
194 pub fn dense(&self, n: usize) -> SDEProblemSoloutPair<'_, T, V, D, F, DenseSolout> {
195 let dense_solout = DenseSolout::new(n); // Dense solout implementation
196 SDEProblemSoloutPair::new(self, dense_solout)
197 }
198
199 /// Uses the provided time points for evaluation instead of the default method
200 /// Note this does not include the solution of the calculated steps
201 ///
202 /// # Arguments
203 /// * `points` - Custom output points
204 ///
205 /// # Returns
206 /// * SDE Problem with Custom Time Evaluation function ready for .solve() method
207 ///
208 pub fn t_eval(&self, points: Vec<T>) -> SDEProblemSoloutPair<'_, T, V, D, F, TEvalSolout<T>> {
209 let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf); // Custom time evaluation solout implementation
210 SDEProblemSoloutPair::new(self, t_eval_solout)
211 }
212
213 /// Uses the CrossingSolout method to output points when a specific component crosses a threshold
214 /// Note this does not include the solution of the calculated steps
215 ///
216 /// # Arguments
217 /// * `component_idx` - Index of the component to monitor for crossing
218 /// * `threshold` - Value to cross
219 /// * `direction` - Direction of crossing (positive or negative)
220 ///
221 /// # Returns
222 /// * SDE Problem with CrossingSolout function ready for .solve() method
223 ///
224 pub fn crossing(
225 &self,
226 component_idx: usize,
227 threshold: T,
228 direction: CrossingDirection,
229 ) -> SDEProblemSoloutPair<'_, T, V, D, F, CrossingSolout<T>> {
230 let crossing_solout =
231 CrossingSolout::new(component_idx, threshold).with_direction(direction); // Crossing solout implementation
232 SDEProblemSoloutPair::new(self, crossing_solout)
233 }
234
235 /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed
236 /// Note this does not include the solution of the calculated steps
237 ///
238 /// # Arguments
239 /// * `point` - Point on the hyperplane
240 /// * `normal` - Normal vector of the hyperplane
241 /// * `extractor` - Function to extract the component from the state vector
242 /// * `direction` - Direction of crossing (positive or negative)
243 ///
244 /// # Returns
245 /// * SDE Problem with HyperplaneCrossingSolout function ready for .solve() method
246 ///
247 pub fn hyperplane_crossing<V1>(
248 &self,
249 point: V1,
250 normal: V1,
251 extractor: fn(&V) -> V1,
252 direction: CrossingDirection,
253 ) -> SDEProblemSoloutPair<'_, T, V, D, F, HyperplaneCrossingSolout<T, V1, V>>
254 where
255 V1: State<T>,
256 {
257 let solout =
258 HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
259
260 SDEProblemSoloutPair::new(self, solout)
261 }
262}
263
264/// SDEProblemMutRefSoloutPair serves as an intermediate between the SDEProblem struct and a custom solout provided by the user
265pub struct SDEProblemMutRefSoloutPair<'a, T, V, D, F, O>
266where
267 T: Real,
268 V: State<T>,
269 D: CallBackData,
270 F: SDE<T, V, D>,
271 O: Solout<T, V, D>,
272{
273 pub sde_problem: &'a SDEProblem<T, V, D, F>, // Reference to the SDE Problem struct
274 pub solout: &'a mut O, // Reference to the solout implementation
275}
276
277impl<'a, T, V, D, F, O> SDEProblemMutRefSoloutPair<'a, T, V, D, F, O>
278where
279 T: Real,
280 V: State<T>,
281 D: CallBackData,
282 F: SDE<T, V, D>,
283 O: Solout<T, V, D>,
284{
285 /// Create a new SDEProblemMutRefSoloutPair
286 ///
287 /// # Arguments
288 /// * `sde_problem` - Reference to the SDE Problem struct
289 /// * `solout` - Reference to the solout implementation
290 ///
291 pub fn new(sde_problem: &'a SDEProblem<T, V, D, F>, solout: &'a mut O) -> Self {
292 SDEProblemMutRefSoloutPair {
293 sde_problem,
294 solout,
295 }
296 }
297
298 /// Solve the SDE Problem using the provided solout
299 ///
300 /// # Arguments
301 /// * `solver` - SDENumericalMethod to use for solving the SDE Problem
302 ///
303 /// # Returns
304 /// * `Result<Solution<T, V, D>, Error<T, V>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
305 ///
306 pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
307 where
308 S: SDENumericalMethod<T, V, D> + Interpolation<T, V>,
309 {
310 solve_sde(
311 solver,
312 &self.sde_problem.sde,
313 self.sde_problem.t0,
314 self.sde_problem.tf,
315 &self.sde_problem.y0,
316 self.solout,
317 )
318 }
319}
320
321/// SDEProblemSoloutPair serves as an intermediate between the SDEProblem struct and solve_sde when a predefined solout is used
322#[derive(Clone, Debug)]
323pub struct SDEProblemSoloutPair<'a, T, V, D, F, O>
324where
325 T: Real,
326 V: State<T>,
327 D: CallBackData,
328 F: SDE<T, V, D>,
329 O: Solout<T, V, D>,
330{
331 pub sde_problem: &'a SDEProblem<T, V, D, F>, // Reference to the SDE Problem struct
332 pub solout: O, // Solout implementation
333}
334
335impl<'a, T, V, D, F, O> SDEProblemSoloutPair<'a, T, V, D, F, O>
336where
337 T: Real,
338 V: State<T>,
339 D: CallBackData,
340 F: SDE<T, V, D>,
341 O: Solout<T, V, D>,
342{
343 /// Create a new SDEProblemSoloutPair
344 ///
345 /// # Arguments
346 /// * `sde_problem` - Reference to the SDE Problem struct
347 /// * `solout` - Solout implementation
348 ///
349 pub fn new(sde_problem: &'a SDEProblem<T, V, D, F>, solout: O) -> Self {
350 SDEProblemSoloutPair {
351 sde_problem,
352 solout,
353 }
354 }
355
356 /// Solve the SDE Problem using the provided solout
357 ///
358 /// # Arguments
359 /// * `solver` - SDENumericalMethod to use for solving the SDE Problem
360 ///
361 /// # Returns
362 /// * `Result<Solution<T, V, D>, Error<T, V>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
363 ///
364 pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, V, D>, Error<T, V>>
365 where
366 S: SDENumericalMethod<T, V, D> + Interpolation<T, V>,
367 {
368 solve_sde(
369 solver,
370 &self.sde_problem.sde,
371 self.sde_problem.t0,
372 self.sde_problem.tf,
373 &self.sde_problem.y0,
374 &mut self.solout,
375 )
376 }
377}