differential_equations/sde/problem.rs
1//! SDE Problem Struct and Constructors
2
3use crate::{
4 error::Error,
5 interpolate::Interpolation,
6 sde::{SDE, StochasticNumericalMethod, solve_sde},
7 solout::*,
8 solution::Solution,
9 traits::{Real, State},
10};
11
12/// Initial Value Problem for Stochastic Differential Equations (SDEProblem)
13///
14/// The Initial Value Problem takes the form:
15/// dY = a(t, Y)dt + b(t, Y)dW, t0 <= t <= tf, Y(t0) = y0
16///
17/// where:
18/// - a(t, Y) is the drift term (deterministic part)
19/// - b(t, Y) is the diffusion term (stochastic part)
20/// - dW represents a Wiener process increment
21///
22/// # Overview
23///
24/// The SDEProblem struct provides a simple interface for solving stochastic differential equations:
25///
26/// # Example
27///
28/// ```
29/// use differential_equations::prelude::*;
30/// use nalgebra::SVector;
31/// use rand::SeedableRng;
32/// use rand_distr::{Distribution, Normal};
33///
34/// struct GBM {
35/// rng: rand::rngs::StdRng,
36/// }
37///
38/// impl GBM {
39/// fn new(seed: u64) -> Self {
40/// Self {
41/// rng: rand::rngs::StdRng::seed_from_u64(seed),
42/// }
43/// }
44/// }
45///
46/// impl SDE<f64, SVector<f64, 1>> for GBM {
47/// fn drift(&self, _t: f64, y: &SVector<f64, 1>, dydt: &mut SVector<f64, 1>) {
48/// dydt[0] = 0.1 * y[0]; // μS
49/// }
50///
51/// fn diffusion(&self, _t: f64, y: &SVector<f64, 1>, dydw: &mut SVector<f64, 1>) {
52/// dydw[0] = 0.2 * y[0]; // σS
53/// }
54///
55/// fn noise(&mut self, dt: f64, dw: &mut SVector<f64, 1>) {
56/// let normal = Normal::new(0.0, dt.sqrt()).unwrap();
57/// dw[0] = normal.sample(&mut self.rng);
58/// }
59/// }
60///
61/// let t0 = 0.0;
62/// let tf = 1.0;
63/// let y0 = SVector::<f64, 1>::new(100.0);
64/// let mut solver = ExplicitRungeKutta::three_eighths(0.01);
65/// let gbm = GBM::new(42);
66/// let mut gbm_problem = SDEProblem::new(gbm, t0, tf, y0);
67///
68/// // Solve the SDE
69/// let result = gbm_problem.solve(&mut solver);
70/// ```
71///
72/// # Fields
73///
74/// * `sde` - SDE implementing the stochastic differential equation
75/// * `t0` - Initial time
76/// * `tf` - Final time
77/// * `y0` - Initial state vector
78///
79/// # Basic Usage
80///
81/// * `new(sde, t0, tf, y0)` - Create a new SDE Problem
82/// * `solve(&mut solver)` - Solve using default output (solver step points)
83///
84/// # Output Control Methods
85///
86/// These methods configure how solution points are generated and returned:
87///
88/// * `even(dt)` - Generate evenly spaced output points with interval `dt`
89/// * `dense(n)` - Include `n` interpolated points between each solver step
90/// * `t_eval(points)` - Evaluate solution at specific time points
91/// * `solout(custom_solout)` - Use a custom output handler
92/// * `seed(u64)` - Set a specific random seed for reproducible simulations
93///
94#[derive(Debug)]
95pub struct SDEProblem<'a, T, Y, F>
96where
97 T: Real,
98 Y: State<T>,
99 F: SDE<T, Y>,
100{
101 // SDE Problem Fields
102 pub sde: &'a mut F, // SDE containing the Stochastic Differential Equation and Optional Terminate Function
103 pub t0: T, // Initial Time
104 pub tf: T, // Final Time
105 pub y0: Y, // Initial State Vector
106}
107
108impl<'a, T, Y, F> SDEProblem<'a, T, Y, F>
109where
110 T: Real,
111 Y: State<T>,
112 F: SDE<T, Y>,
113{
114 /// Create a new Stochastic Differential Equation Problem
115 ///
116 /// # Arguments
117 /// * `sde` - SDE containing the Stochastic Differential Equation and Optional Terminate Function
118 /// * `t0` - Initial Time
119 /// * `tf` - Final Time
120 /// * `y0` - Initial State Vector
121 ///
122 /// # Returns
123 /// * SDE Problem ready to be solved
124 ///
125 pub fn new(sde: &'a mut F, t0: T, tf: T, y0: Y) -> Self {
126 SDEProblem { sde, t0, tf, y0 }
127 }
128
129 /// Solve the SDE Problem using a default solout, e.g. outputting solutions at calculated steps
130 ///
131 /// # Returns
132 /// * `Result<Solution<T, Y>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
133 ///
134 pub fn solve<S>(&mut self, solver: &'a mut S) -> Result<Solution<T, Y>, Error<T, Y>>
135 where
136 S: StochasticNumericalMethod<T, Y> + Interpolation<T, Y>,
137 {
138 let mut default_solout = DefaultSolout::new(); // Default solout implementation
139 solve_sde(
140 solver,
141 self.sde,
142 self.t0,
143 self.tf,
144 &self.y0,
145 &mut default_solout,
146 )
147 }
148
149 /// Returns an SDE Problem with the provided solout function for outputting points
150 ///
151 /// # Returns
152 /// * SDE Problem with the provided solout function ready for .solve() method
153 ///
154 pub fn solout<O: Solout<T, Y>>(
155 &'a mut self,
156 solout: &'a mut O,
157 ) -> SDEProblemMutRefSoloutPair<'a, T, Y, F, O> {
158 SDEProblemMutRefSoloutPair::new(self, solout)
159 }
160
161 /// Uses the an Even Solout implementation to output evenly spaced points between the initial and final time
162 /// Note that this does not include the solution of the calculated steps
163 ///
164 /// # Arguments
165 /// * `dt` - Interval between each output point
166 ///
167 /// # Returns
168 /// * SDE Problem with Even Solout function ready for .solve() method
169 ///
170 pub fn even(&'a mut self, dt: T) -> SDEProblemSoloutPair<'a, T, Y, F, EvenSolout<T>> {
171 let even_solout = EvenSolout::new(dt, self.t0, self.tf);
172 SDEProblemSoloutPair::new(self, even_solout)
173 }
174
175 /// Uses the Dense Output method to output n number of interpolation points between each step
176 /// Note this includes the solution of the calculated steps
177 ///
178 /// # Arguments
179 /// * `n` - Number of interpolation points between each step
180 ///
181 /// # Returns
182 /// * SDE Problem with Dense Output function ready for .solve() method
183 ///
184 pub fn dense(&'a mut self, n: usize) -> SDEProblemSoloutPair<'a, T, Y, F, DenseSolout> {
185 let dense_solout = DenseSolout::new(n);
186 SDEProblemSoloutPair::new(self, dense_solout)
187 }
188
189 /// Uses the provided time points for evaluation instead of the default method
190 /// Note this does not include the solution of the calculated steps
191 ///
192 /// # Arguments
193 /// * `points` - Custom output points
194 ///
195 /// # Returns
196 /// * SDE Problem with Custom Time Evaluation function ready for .solve() method
197 ///
198 pub fn t_eval(
199 &'a mut self,
200 points: impl AsRef<[T]>,
201 ) -> SDEProblemSoloutPair<'a, T, Y, F, TEvalSolout<T>> {
202 let t_eval_solout = TEvalSolout::new(points, self.t0, self.tf);
203 SDEProblemSoloutPair::new(self, t_eval_solout)
204 }
205
206 /// Uses the CrossingSolout method to output points when a specific component crosses a threshold
207 /// Note this does not include the solution of the calculated steps
208 ///
209 /// # Arguments
210 /// * `component_idx` - Index of the component to monitor for crossing
211 /// * `threshold` - Value to cross
212 /// * `direction` - Direction of crossing (positive or negative)
213 ///
214 /// # Returns
215 /// * SDE Problem with CrossingSolout function ready for .solve() method
216 ///
217 pub fn crossing(
218 &'a mut self,
219 component_idx: usize,
220 threshold: T,
221 direction: CrossingDirection,
222 ) -> SDEProblemSoloutPair<'a, T, Y, F, CrossingSolout<T>> {
223 let crossing_solout =
224 CrossingSolout::new(component_idx, threshold).with_direction(direction);
225 SDEProblemSoloutPair::new(self, crossing_solout)
226 }
227
228 /// Uses the HyperplaneCrossingSolout method to output points when a specific hyperplane is crossed
229 /// Note this does not include the solution of the calculated steps
230 ///
231 /// # Arguments
232 /// * `point` - Point on the hyperplane
233 /// * `normal` - Normal vector of the hyperplane
234 /// * `extractor` - Function to extract the component from the state vector
235 /// * `direction` - Direction of crossing (positive or negative)
236 ///
237 /// # Returns
238 /// * SDE Problem with HyperplaneCrossingSolout function ready for .solve() method
239 ///
240 pub fn hyperplane_crossing<Y1>(
241 &'a mut self,
242 point: Y1,
243 normal: Y1,
244 extractor: fn(&Y) -> Y1,
245 direction: CrossingDirection,
246 ) -> SDEProblemSoloutPair<'a, T, Y, F, HyperplaneCrossingSolout<T, Y1, Y>>
247 where
248 Y1: State<T>,
249 {
250 let solout =
251 HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
252
253 SDEProblemSoloutPair::new(self, solout)
254 }
255
256 /// Uses an `EventSolout` to capture zero crossings of a user-defined event function (SciPy style).
257 /// The event implements `Event<T,Y>` returning g(t,y); roots are located with Brent-Dekker.
258 pub fn event<E>(
259 &'a mut self,
260 event: &'a E,
261 ) -> SDEProblemSoloutPair<'a, T, Y, F, EventSolout<'a, T, Y, E>>
262 where
263 E: Event<T, Y>,
264 {
265 let solout = EventSolout::new(event, self.t0, self.tf);
266 SDEProblemSoloutPair::new(self, solout)
267 }
268}
269
270/// SDEProblemMutRefSoloutPair serves as an intermediate between the SDEProblem struct and a custom solout provided by the user
271pub struct SDEProblemMutRefSoloutPair<'a, T, Y, F, O>
272where
273 T: Real,
274 Y: State<T>,
275 F: SDE<T, Y>,
276 O: Solout<T, Y>,
277{
278 pub sde_problem: &'a mut SDEProblem<'a, T, Y, F>,
279 pub solout: &'a mut O,
280}
281
282impl<'a, T, Y, F, O> SDEProblemMutRefSoloutPair<'a, T, Y, F, O>
283where
284 T: Real,
285 Y: State<T>,
286 F: SDE<T, Y>,
287 O: Solout<T, Y>,
288{
289 /// Create a new SDEProblemMutRefSoloutPair
290 ///
291 /// # Arguments
292 /// * `sde_problem` - Reference to the SDE Problem struct
293 /// * `solout` - Reference to the solout implementation
294 ///
295 pub fn new(sde_problem: &'a mut SDEProblem<'a, T, Y, F>, solout: &'a mut O) -> Self {
296 SDEProblemMutRefSoloutPair {
297 sde_problem,
298 solout,
299 }
300 }
301
302 /// Solve the SDE Problem using the provided solout
303 ///
304 /// # Arguments
305 /// * `solver` - StochasticNumericalMethod to use for solving the SDE Problem
306 ///
307 /// # Returns
308 /// * `Result<Solution<T, Y>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
309 ///
310 pub fn solve<S>(&mut self, solver: &mut S) -> Result<Solution<T, Y>, Error<T, Y>>
311 where
312 S: StochasticNumericalMethod<T, Y> + Interpolation<T, Y>,
313 {
314 solve_sde(
315 solver,
316 self.sde_problem.sde,
317 self.sde_problem.t0,
318 self.sde_problem.tf,
319 &self.sde_problem.y0,
320 self.solout,
321 )
322 }
323}
324
325/// SDEProblemSoloutPair serves as an intermediate between the SDEProblem struct and solve_sde when a predefined solout is used
326#[derive(Debug)]
327pub struct SDEProblemSoloutPair<'a, T, Y, F, O>
328where
329 T: Real,
330 Y: State<T>,
331 F: SDE<T, Y>,
332 O: Solout<T, Y>,
333{
334 pub sde_problem: &'a mut SDEProblem<'a, T, Y, F>,
335 pub solout: O,
336}
337
338impl<'a, T, Y, F, O> SDEProblemSoloutPair<'a, T, Y, F, O>
339where
340 T: Real,
341 Y: State<T>,
342 F: SDE<T, Y>,
343 O: Solout<T, Y>,
344{
345 /// Create a new SDEProblemSoloutPair
346 ///
347 /// # Arguments
348 /// * `sde_problem` - Reference to the SDE Problem struct
349 /// * `solout` - Solout implementation
350 ///
351 pub fn new(sde_problem: &'a mut SDEProblem<'a, T, Y, F>, solout: O) -> Self {
352 SDEProblemSoloutPair {
353 sde_problem,
354 solout,
355 }
356 }
357
358 /// Solve the SDE Problem using the provided solout
359 ///
360 /// # Arguments
361 /// * `solver` - StochasticNumericalMethod to use for solving the SDE Problem
362 ///
363 /// # Returns
364 /// * `Result<Solution<T, Y>, Error<T, Y>>` - `Ok(Solution)` if successful or interrupted by events, `Err(Error)` if errors or issues are encountered
365 ///
366 pub fn solve<S>(mut self, solver: &mut S) -> Result<Solution<T, Y>, Error<T, Y>>
367 where
368 S: StochasticNumericalMethod<T, Y> + Interpolation<T, Y>,
369 {
370 solve_sde(
371 solver,
372 self.sde_problem.sde,
373 self.sde_problem.t0,
374 self.sde_problem.tf,
375 &self.sde_problem.y0,
376 &mut self.solout,
377 )
378 }
379
380 /// Wrap current solout with event detection while preserving original output strategy.
381 pub fn event<E>(
382 self,
383 event: &'a E,
384 ) -> SDEProblemSoloutPair<'a, T, Y, F, EventWrappedSolout<'a, T, Y, O, E>>
385 where
386 E: Event<T, Y>,
387 {
388 let wrapped = EventWrappedSolout::new(self.solout, event, self.sde_problem.t0, self.sde_problem.tf);
389 SDEProblemSoloutPair::new(self.sde_problem, wrapped)
390 }
391}