differential_equations/sde/problem.rs
1//! SDE Problem Struct and Constructors
2
3use crate::{
4 error::Error,
5 interpolate::Interpolation,
6 sde::{SDE, StochasticNumericalMethod, solve_sde},
7 solout::*,
8 solution::Solution,
9 traits::{CallBackData, Real, State},
10};
11
12/// Initial Value Problem for Stochastic Differential Equations (SDEProblem)
13///
14/// The Initial Value Problem takes the form:
15/// dY = a(t, Y)dt + b(t, Y)dW, t0 <= t <= tf, Y(t0) = y0
16///
17/// where:
18/// - a(t, Y) is the drift term (deterministic part)
19/// - b(t, Y) is the diffusion term (stochastic part)
20/// - dW represents a Wiener process increment
21///
22/// # Overview
23///
24/// The SDEProblem struct provides a simple interface for solving stochastic differential equations:
25///
26/// # Example
27///
28/// ```
29/// use differential_equations::prelude::*;
30/// use nalgebra::SVector;
31/// use rand::SeedableRng;
32/// use rand_distr::{Distribution, Normal};
33///
34/// struct GBM {
35/// rng: rand::rngs::StdRng,
36/// }
37///
38/// impl GBM {
39/// fn new(seed: u64) -> Self {
40/// Self {
41/// rng: rand::rngs::StdRng::seed_from_u64(seed),
42/// }
43/// }
44/// }
45///
46/// impl SDE<f64, SVector<f64, 1>> for GBM {
47/// fn drift(&self, _t: f64, y: &SVector<f64, 1>, dydt: &mut SVector<f64, 1>) {
48/// dydt[0] = 0.1 * y[0]; // μS
49/// }
50///
51/// fn diffusion(&self, _t: f64, y: &SVector<f64, 1>, dydw: &mut SVector<f64, 1>) {
52/// dydw[0] = 0.2 * y[0]; // σS
53/// }
54///
55/// fn noise(&mut self, dt: f64, dw: &mut SVector<f64, 1>) {
56/// let normal = Normal::new(0.0, dt.sqrt()).unwrap();
57/// dw[0] = normal.sample(&mut self.rng);
58/// }
59/// }
60///
61/// let t0 = 0.0;
62/// let tf = 1.0;
63/// let y0 = SVector::<f64, 1>::new(100.0);
64/// let mut solver = ExplicitRungeKutta::three_eighths(0.01);
65/// let gbm = GBM::new(42);
66/// let mut gbm_problem = SDEProblem::new(gbm, t0, tf, y0);
67///
68/// // Solve the SDE
69/// let result = gbm_problem.solve(&mut solver);
70/// ```
71///
72/// # Fields
73///
74/// * `sde` - SDE implementing the stochastic differential equation
75/// * `t0` - Initial time
76/// * `tf` - Final time
77/// * `y0` - Initial state vector
78///
79/// # Basic Usage
80///
81/// * `new(sde, t0, tf, y0)` - Create a new SDE Problem
82/// * `solve(&mut solver)` - Solve using default output (solver step points)
83///
84/// # Output Control Methods
85///
86/// These methods configure how solution points are generated and returned:
87///
88/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
89/// * `dense(n)` - Include `n` interpolated points between each solver step
90/// * `t_eval(points)` - Evaluate solution at specific time points
91/// * `solout(custom_solout)` - Use a custom output handler
92/// * `seed(u64)` - Set a specific random seed for reproducible simulations
93///
94#[derive(Clone, Debug)]
95pub struct SDEProblem<T, Y, D, F>
96where
97 T: Real,
98 Y: State<T>,
99 D: CallBackData,
100 F: SDE<T, Y, D>,
101{
102 // SDE Problem Fields
103 pub sde: F, // SDE containing the Stochastic Differential Equation and Optional Terminate Function
104 pub t0: T, // Initial Time
105 pub tf: T, // Final Time
106 pub y0: Y, // Initial State Vector
107
108 // Phantom Data for Users event output
109 _event_output_type: std::marker::PhantomData<D>,
110}
111
112impl<T, Y, D, F> SDEProblem<T, Y, D, F>
113where
114 T: Real,
115 Y: State<T>,
116 D: CallBackData,
117 F: SDE<T, Y, D>,
118{
119 /// Create a new Stochastic Differential Equation Problem
120 ///
121 /// # Arguments
122 /// * `sde` - SDE containing the Stochastic Differential Equation and Optional Terminate Function
123 /// * `t0` - Initial Time
124 /// * `tf` - Final Time
125 /// * `y0` - Initial State Vector
126 ///
127 /// # Returns
128 /// * SDE Problem ready to be solved
129 ///
130 pub fn new(sde: F, t0: T, tf: T, y0: Y) -> Self {
131 SDEProblem {
132 sde,
133 t0,
134 tf,
135 y0,
136 _event_output_type: std::marker::PhantomData,
137 }
138 }
139
140 /// Solve the SDE Problem using a default solout, e.g. outputting solutions at calculated steps
141 ///
142 /// # Returns
143 /// * `Result<Solution<T, Y, D>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
144 ///
145 pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
146 where
147 S: StochasticNumericalMethod<T, Y, D> + Interpolation<T, Y>,
148 {
149 let mut default_solout = DefaultSolout::new(); // Default solout implementation
150 solve_sde(
151 solver,
152 &mut self.sde,
153 self.t0,
154 self.tf,
155 &self.y0,
156 &mut default_solout,
157 )
158 }
159
160 /// Returns an SDE Problem with the provided solout function for outputting points
161 ///
162 /// # Returns
163 /// * SDE Problem with the provided solout function ready for .solve() method
164 ///
165 pub fn solout<'a, O: Solout<T, Y, D>>(
166 &'a mut self,
167 solout: &'a mut O,
168 ) -> SDEProblemMutRefSoloutPair<'a, T, Y, D, F, O> {
169 SDEProblemMutRefSoloutPair::new(self, solout)
170 }
171
172 /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time
173 /// Note that this does not include the solution of the calculated steps
174 ///
175 /// # Arguments
176 /// * `dt` - Interval between each output point
177 ///
178 /// # Returns
179 /// * SDE Problem with Even Solout function ready for .solve() method
180 ///
181 pub fn even(&mut self, dt: T) -> SDEProblemSoloutPair<'_, T, Y, D, F, EvenSolout<T>> {
182 let even_solout = EvenSolout::new(dt, self.t0, self.tf);
183 SDEProblemSoloutPair::new(self, even_solout)
184 }
185
186 /// Uses the Dense Output method to output n number of interpolation points between each step
187 /// Note this includes the solution of the calculated steps
188 ///
189 /// # Arguments
190 /// * `n` - Number of interpolation points between each step
191 ///
192 /// # Returns
193 /// * SDE Problem with Dense Output function ready for .solve() method
194 ///
195 pub fn dense(&mut self, n: usize) -> SDEProblemSoloutPair<'_, T, Y, D, F, DenseSolout> {
196 let dense_solout = DenseSolout::new(n);
197 SDEProblemSoloutPair::new(self, dense_solout)
198 }
199
200 /// Uses the provided time points for evaluation instead of the default method
201 /// Note this does not include the solution of the calculated steps
202 ///
203 /// # Arguments
204 /// * `points` - Custom output points
205 ///
206 /// # Returns
207 /// * SDE Problem with Custom Time Evaluation function ready for .solve() method
208 ///
209 pub fn t_eval(
210 &mut self,
211 points: impl AsRef<[T]>,
212 ) -> SDEProblemSoloutPair<'_, T, Y, D, F, TEvalSolout<T>> {
213 let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf);
214 SDEProblemSoloutPair::new(self, t_eval_solout)
215 }
216
217 /// Uses the CrossingSolout method to output points when a specific component crosses a threshold
218 /// Note this does not include the solution of the calculated steps
219 ///
220 /// # Arguments
221 /// * `component_idx` - Index of the component to monitor for crossing
222 /// * `threshold` - Value to cross
223 /// * `direction` - Direction of crossing (positive or negative)
224 ///
225 /// # Returns
226 /// * SDE Problem with CrossingSolout function ready for .solve() method
227 ///
228 pub fn crossing(
229 &mut self,
230 component_idx: usize,
231 threshold: T,
232 direction: CrossingDirection,
233 ) -> SDEProblemSoloutPair<'_, T, Y, D, F, CrossingSolout<T>> {
234 let crossing_solout =
235 CrossingSolout::new(component_idx, threshold).with_direction(direction);
236 SDEProblemSoloutPair::new(self, crossing_solout)
237 }
238
239 /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed
240 /// Note this does not include the solution of the calculated steps
241 ///
242 /// # Arguments
243 /// * `point` - Point on the hyperplane
244 /// * `normal` - Normal vector of the hyperplane
245 /// * `extractor` - Function to extract the component from the state vector
246 /// * `direction` - Direction of crossing (positive or negative)
247 ///
248 /// # Returns
249 /// * SDE Problem with HyperplaneCrossingSolout function ready for .solve() method
250 ///
251 pub fn hyperplane_crossing<Y1>(
252 &mut self,
253 point: Y1,
254 normal: Y1,
255 extractor: fn(&Y) -> Y1,
256 direction: CrossingDirection,
257 ) -> SDEProblemSoloutPair<'_, T, Y, D, F, HyperplaneCrossingSolout<T, Y1, Y>>
258 where
259 Y1: State<T>,
260 {
261 let solout =
262 HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
263
264 SDEProblemSoloutPair::new(self, solout)
265 }
266}
267
268/// SDEProblemMutRefSoloutPair serves as an intermediate between the SDEProblem struct and a custom solout provided by the user
269pub struct SDEProblemMutRefSoloutPair<'a, T, Y, D, F, O>
270where
271 T: Real,
272 Y: State<T>,
273 D: CallBackData,
274 F: SDE<T, Y, D>,
275 O: Solout<T, Y, D>,
276{
277 pub sde_problem: &'a mut SDEProblem<T, Y, D, F>,
278 pub solout: &'a mut O,
279}
280
281impl<'a, T, Y, D, F, O> SDEProblemMutRefSoloutPair<'a, T, Y, D, F, O>
282where
283 T: Real,
284 Y: State<T>,
285 D: CallBackData,
286 F: SDE<T, Y, D>,
287 O: Solout<T, Y, D>,
288{
289 /// Create a new SDEProblemMutRefSoloutPair
290 ///
291 /// # Arguments
292 /// * `sde_problem` - Reference to the SDE Problem struct
293 /// * `solout` - Reference to the solout implementation
294 ///
295 pub fn new(sde_problem: &'a mut SDEProblem<T, Y, D, F>, solout: &'a mut O) -> Self {
296 SDEProblemMutRefSoloutPair {
297 sde_problem,
298 solout,
299 }
300 }
301
302 /// Solve the SDE Problem using the provided solout
303 ///
304 /// # Arguments
305 /// * `solver` - StochasticNumericalMethod to use for solving the SDE Problem
306 ///
307 /// # Returns
308 /// * `Result<Solution<T, Y, D>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
309 ///
310 pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
311 where
312 S: StochasticNumericalMethod<T, Y, D> + Interpolation<T, Y>,
313 {
314 solve_sde(
315 solver,
316 &mut self.sde_problem.sde,
317 self.sde_problem.t0,
318 self.sde_problem.tf,
319 &self.sde_problem.y0,
320 self.solout,
321 )
322 }
323}
324
325/// SDEProblemSoloutPair serves as an intermediate between the SDEProblem struct and solve_sde when a predefined solout is used
326#[derive(Debug)]
327pub struct SDEProblemSoloutPair<'a, T, Y, D, F, O>
328where
329 T: Real,
330 Y: State<T>,
331 D: CallBackData,
332 F: SDE<T, Y, D>,
333 O: Solout<T, Y, D>,
334{
335 pub sde_problem: &'a mut SDEProblem<T, Y, D, F>,
336 pub solout: O,
337}
338
339impl<'a, T, Y, D, F, O> SDEProblemSoloutPair<'a, T, Y, D, F, O>
340where
341 T: Real,
342 Y: State<T>,
343 D: CallBackData,
344 F: SDE<T, Y, D>,
345 O: Solout<T, Y, D>,
346{
347 /// Create a new SDEProblemSoloutPair
348 ///
349 /// # Arguments
350 /// * `sde_problem` - Reference to the SDE Problem struct
351 /// * `solout` - Solout implementation
352 ///
353 pub fn new(sde_problem: &'a mut SDEProblem<T, Y, D, F>, solout: O) -> Self {
354 SDEProblemSoloutPair {
355 sde_problem,
356 solout,
357 }
358 }
359
360 /// Solve the SDE Problem using the provided solout
361 ///
362 /// # Arguments
363 /// * `solver` - StochasticNumericalMethod to use for solving the SDE Problem
364 ///
365 /// # Returns
366 /// * `Result<Solution<T, Y, D>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
367 ///
368 pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, Y, D>, Error<T, Y>>
369 where
370 S: StochasticNumericalMethod<T, Y, D> + Interpolation<T, Y>,
371 {
372 solve_sde(
373 solver,
374 &mut self.sde_problem.sde,
375 self.sde_problem.t0,
376 self.sde_problem.tf,
377 &self.sde_problem.y0,
378 &mut self.solout,
379 )
380 }
381}