1use crate::{
10 dae::{AlgebraicNumericalMethod, DAE, solve_dae},
11 dde::{DDE, DelayNumericalMethod, solve_dde},
12 error::Error,
13 interpolate::Interpolation,
14 methods::ToleranceConfig,
15 ode::{ODE, OrdinaryNumericalMethod, solve_ode},
16 sde::{SDE, StochasticNumericalMethod, solve_sde},
17 solout::{
18 CrossingDirection, CrossingSolout, DefaultSolout, DenseSolout, EvenSolout, Event,
19 EventWrappedSolout, HyperplaneCrossingSolout, Solout, TEvalSolout,
20 },
21 solution::Solution,
22 tolerance::Tolerance,
23 traits::{Real, State},
24};
25
26#[derive(Clone, Debug)]
30pub struct IVP<EqType, T: Real, Y: State<T>, Method, SoloutType> {
31 equation: EqType,
32 t0: T,
33 tf: T,
34 y0: Y,
35 method: Method,
36 solout: SoloutType,
37}
38
39#[derive(Debug)]
41pub struct OdeEq<'a, F> {
42 ode: &'a F,
43}
44
45impl<F> Clone for OdeEq<'_, F> {
46 fn clone(&self) -> Self {
47 *self
48 }
49}
50
51impl<F> Copy for OdeEq<'_, F> {}
52
53#[derive(Debug)]
55pub struct OdeEqOwned<F> {
56 ode: F,
57}
58
59impl<F: Clone> Clone for OdeEqOwned<F> {
60 fn clone(&self) -> Self {
61 Self {
62 ode: self.ode.clone(),
63 }
64 }
65}
66
67impl<F: Copy> Copy for OdeEqOwned<F> {}
68
69#[derive(Debug)]
71pub struct DaeEq<'a, F> {
72 dae: &'a F,
73}
74
75impl<F> Clone for DaeEq<'_, F> {
76 fn clone(&self) -> Self {
77 *self
78 }
79}
80
81impl<F> Copy for DaeEq<'_, F> {}
82
83#[derive(Debug)]
85pub struct DaeEqOwned<F> {
86 dae: F,
87}
88
89impl<F: Clone> Clone for DaeEqOwned<F> {
90 fn clone(&self) -> Self {
91 Self {
92 dae: self.dae.clone(),
93 }
94 }
95}
96
97impl<F: Copy> Copy for DaeEqOwned<F> {}
98
99#[derive(Debug)]
101pub struct SdeEq<'a, F> {
102 sde: &'a mut F,
103}
104
105#[derive(Debug)]
107pub struct SdeEqOwned<F> {
108 sde: F,
109}
110
111#[derive(Debug)]
113pub struct DdeEq<'a, const L: usize, F, H> {
114 dde: &'a F,
115 history: H,
116}
117
118impl<const L: usize, F, H: Clone> Clone for DdeEq<'_, L, F, H> {
119 fn clone(&self) -> Self {
120 Self {
121 dde: self.dde,
122 history: self.history.clone(),
123 }
124 }
125}
126
127#[derive(Debug)]
129pub struct DdeEqOwned<const L: usize, F, H> {
130 dde: F,
131 history: H,
132}
133
134impl<const L: usize, F: Clone, H: Clone> Clone for DdeEqOwned<L, F, H> {
135 fn clone(&self) -> Self {
136 Self {
137 dde: self.dde.clone(),
138 history: self.history.clone(),
139 }
140 }
141}
142
143#[derive(Debug)]
145pub struct OdeFnWrapper<F> {
146 f: F,
147}
148
149impl<T, Y, F> ODE<T, Y> for OdeFnWrapper<F>
150where
151 T: Real,
152 Y: State<T>,
153 F: Fn(T, &Y, &mut Y),
154{
155 fn diff(&self, t: T, y: &Y, dydt: &mut Y) {
156 (self.f)(t, y, dydt)
157 }
158}
159
160#[derive(Debug)]
162pub struct DaeFnWrapper<F, M> {
163 f: F,
164 m: M,
165}
166
167impl<T, Y, F, M> DAE<T, Y> for DaeFnWrapper<F, M>
168where
169 T: Real,
170 Y: State<T>,
171 F: Fn(T, &Y, &mut Y),
172 M: Fn(&mut crate::linalg::Matrix<T>),
173{
174 fn diff(&self, t: T, y: &Y, f: &mut Y) {
175 (self.f)(t, y, f)
176 }
177
178 fn mass(&self, m: &mut crate::linalg::Matrix<T>) {
179 (self.m)(m)
180 }
181}
182
183#[derive(Debug)]
185pub struct SdeFnWrapper<Drift, Diff, Noise> {
186 drift_fn: Drift,
187 diffusion_fn: Diff,
188 noise_fn: Noise,
189}
190
191impl<T, Y, Drift, Diff, Noise> SDE<T, Y> for SdeFnWrapper<Drift, Diff, Noise>
192where
193 T: Real,
194 Y: State<T>,
195 Drift: Fn(T, &Y, &mut Y),
196 Diff: Fn(T, &Y, &mut Y),
197 Noise: FnMut(T, &mut Y),
198{
199 fn drift(&self, t: T, y: &Y, dydt: &mut Y) {
200 (self.drift_fn)(t, y, dydt)
201 }
202
203 fn diffusion(&self, t: T, y: &Y, dydw: &mut Y) {
204 (self.diffusion_fn)(t, y, dydw)
205 }
206
207 fn noise(&mut self, dt: T, dw: &mut Y) {
208 (self.noise_fn)(dt, dw)
209 }
210}
211
212#[derive(Debug)]
214pub struct DdeFnWrapper<const L: usize, Diff, Lags> {
215 diff_fn: Diff,
216 lags_fn: Lags,
217}
218
219impl<const L: usize, T, Y, Diff, Lags> DDE<L, T, Y> for DdeFnWrapper<L, Diff, Lags>
220where
221 T: Real,
222 Y: State<T>,
223 Diff: Fn(T, &Y, &[Y; L], &mut Y),
224 Lags: Fn(T, &Y, &mut [T; L]),
225{
226 fn diff(&self, t: T, y: &Y, yd: &[Y; L], dydt: &mut Y) {
227 (self.diff_fn)(t, y, yd, dydt)
228 }
229
230 fn lags(&self, t: T, y: &Y, lags: &mut [T; L]) {
231 (self.lags_fn)(t, y, lags)
232 }
233}
234
235impl<'a, F, T: Real, Y: State<T>> IVP<OdeEq<'a, F>, T, Y, (), DefaultSolout> {
236 pub fn ode(system: &'a F, t0: T, tf: T, y0: Y) -> Self {
238 Self {
239 equation: OdeEq { ode: system },
240 t0,
241 tf,
242 y0,
243 method: (),
244 solout: DefaultSolout::new(),
245 }
246 }
247}
248
249impl<F, T: Real, Y: State<T>> IVP<OdeEqOwned<OdeFnWrapper<F>>, T, Y, (), DefaultSolout>
250where
251 F: Fn(T, &Y, &mut Y),
252{
253 pub fn ode_from_fn(f: F, t0: T, tf: T, y0: Y) -> Self {
264 Self {
265 equation: OdeEqOwned {
266 ode: OdeFnWrapper { f },
267 },
268 t0,
269 tf,
270 y0,
271 method: (),
272 solout: DefaultSolout::new(),
273 }
274 }
275}
276
277impl<'a, F, T: Real, Y: State<T>> IVP<DaeEq<'a, F>, T, Y, (), DefaultSolout> {
278 pub fn dae(system: &'a F, t0: T, tf: T, y0: Y) -> Self {
280 Self {
281 equation: DaeEq { dae: system },
282 t0,
283 tf,
284 y0,
285 method: (),
286 solout: DefaultSolout::new(),
287 }
288 }
289}
290
291impl<F, M, T: Real, Y: State<T>> IVP<DaeEqOwned<DaeFnWrapper<F, M>>, T, Y, (), DefaultSolout>
292where
293 F: Fn(T, &Y, &mut Y),
294 M: Fn(&mut crate::linalg::Matrix<T>),
295{
296 pub fn dae_from_fn(f: F, m: M, t0: T, tf: T, y0: Y) -> Self {
298 Self {
299 equation: DaeEqOwned {
300 dae: DaeFnWrapper { f, m },
301 },
302 t0,
303 tf,
304 y0,
305 method: (),
306 solout: DefaultSolout::new(),
307 }
308 }
309}
310
311impl<'a, F, T: Real, Y: State<T>> IVP<SdeEq<'a, F>, T, Y, (), DefaultSolout> {
312 pub fn sde(system: &'a mut F, t0: T, tf: T, y0: Y) -> Self {
314 Self {
315 equation: SdeEq { sde: system },
316 t0,
317 tf,
318 y0,
319 method: (),
320 solout: DefaultSolout::new(),
321 }
322 }
323}
324
325impl<Drift, Diff, Noise, T: Real, Y: State<T>>
326 IVP<SdeEqOwned<SdeFnWrapper<Drift, Diff, Noise>>, T, Y, (), DefaultSolout>
327where
328 Drift: Fn(T, &Y, &mut Y),
329 Diff: Fn(T, &Y, &mut Y),
330 Noise: FnMut(T, &mut Y),
331{
332 pub fn sde_from_fn(drift: Drift, diffusion: Diff, noise: Noise, t0: T, tf: T, y0: Y) -> Self {
334 Self {
335 equation: SdeEqOwned {
336 sde: SdeFnWrapper {
337 drift_fn: drift,
338 diffusion_fn: diffusion,
339 noise_fn: noise,
340 },
341 },
342 t0,
343 tf,
344 y0,
345 method: (),
346 solout: DefaultSolout::new(),
347 }
348 }
349}
350
351impl<'a, F, H, T: Real, Y: State<T>, const L: usize>
352 IVP<DdeEq<'a, L, F, H>, T, Y, (), DefaultSolout>
353{
354 pub fn dde(system: &'a F, t0: T, tf: T, y0: Y, history_function: H) -> Self {
356 Self {
357 equation: DdeEq {
358 dde: system,
359 history: history_function,
360 },
361 t0,
362 tf,
363 y0,
364 method: (),
365 solout: DefaultSolout::new(),
366 }
367 }
368}
369
370impl<const L: usize, Diff, Lags, H, T: Real, Y: State<T>>
371 IVP<DdeEqOwned<L, DdeFnWrapper<L, Diff, Lags>, H>, T, Y, (), DefaultSolout>
372where
373 Diff: Fn(T, &Y, &[Y; L], &mut Y),
374 Lags: Fn(T, &Y, &mut [T; L]),
375 H: Fn(T) -> Y + Clone,
376{
377 pub fn dde_from_fn(diff: Diff, lags: Lags, t0: T, tf: T, y0: Y, history_function: H) -> Self {
379 Self {
380 equation: DdeEqOwned {
381 dde: DdeFnWrapper {
382 diff_fn: diff,
383 lags_fn: lags,
384 },
385 history: history_function,
386 },
387 t0,
388 tf,
389 y0,
390 method: (),
391 solout: DefaultSolout::new(),
392 }
393 }
394}
395
396impl<EqType, T: Real, Y: State<T>, Method, SoloutType> IVP<EqType, T, Y, Method, SoloutType> {
397 fn with_method<NextMethod>(
398 self,
399 method: NextMethod,
400 ) -> IVP<EqType, T, Y, NextMethod, SoloutType> {
401 IVP {
402 equation: self.equation,
403 t0: self.t0,
404 tf: self.tf,
405 y0: self.y0,
406 method,
407 solout: self.solout,
408 }
409 }
410
411 fn map_method<NextMethod>(
412 self,
413 map: impl FnOnce(Method) -> NextMethod,
414 ) -> IVP<EqType, T, Y, NextMethod, SoloutType> {
415 IVP {
416 equation: self.equation,
417 t0: self.t0,
418 tf: self.tf,
419 y0: self.y0,
420 method: map(self.method),
421 solout: self.solout,
422 }
423 }
424
425 fn with_solout<NextSolout>(self, solout: NextSolout) -> IVP<EqType, T, Y, Method, NextSolout> {
426 IVP {
427 equation: self.equation,
428 t0: self.t0,
429 tf: self.tf,
430 y0: self.y0,
431 method: self.method,
432 solout,
433 }
434 }
435
436 pub fn method<SNew>(self, method: SNew) -> IVP<EqType, T, Y, SNew, SoloutType> {
442 self.with_method(method)
443 }
444
445 pub fn solout<ONew>(self, solout: ONew) -> IVP<EqType, T, Y, Method, ONew> {
447 self.with_solout(solout)
448 }
449
450 pub fn even(self, dt: T) -> IVP<EqType, T, Y, Method, EvenSolout<T>> {
453 let solout = EvenSolout::new(dt, self.t0, self.tf);
454 self.with_solout(solout)
455 }
456
457 pub fn dense(self, n: usize) -> IVP<EqType, T, Y, Method, DenseSolout> {
460 self.with_solout(DenseSolout::new(n))
461 }
462
463 pub fn t_eval(self, points: impl AsRef<[T]>) -> IVP<EqType, T, Y, Method, TEvalSolout<T>> {
466 let solout = TEvalSolout::new(points, self.t0, self.tf);
467 self.with_solout(solout)
468 }
469
470 pub fn event<'a, E>(
472 self,
473 event: &'a E,
474 ) -> IVP<EqType, T, Y, Method, EventWrappedSolout<'a, T, Y, SoloutType, E>>
475 where
476 E: Event<T, Y> + ?Sized,
477 SoloutType: Solout<T, Y>,
478 {
479 IVP {
480 equation: self.equation,
481 t0: self.t0,
482 tf: self.tf,
483 y0: self.y0,
484 method: self.method,
485 solout: EventWrappedSolout::new(self.solout, event, self.t0, self.tf),
486 }
487 }
488
489 pub fn crossing(
492 self,
493 component_idx: usize,
494 threshold: T,
495 direction: CrossingDirection,
496 ) -> IVP<EqType, T, Y, Method, CrossingSolout<T>> {
497 let crossing_solout =
498 CrossingSolout::new(component_idx, threshold).with_direction(direction);
499 self.with_solout(crossing_solout)
500 }
501
502 pub fn hyperplane_crossing<Y1: State<T>>(
505 self,
506 point: Y1,
507 normal: Y1,
508 extractor: fn(&Y) -> Y1,
509 direction: CrossingDirection,
510 ) -> IVP<EqType, T, Y, Method, HyperplaneCrossingSolout<T, Y1, Y>> {
511 let solout =
512 HyperplaneCrossingSolout::new(point, normal, extractor).with_direction(direction);
513 self.with_solout(solout)
514 }
515}
516
517impl<EqType, T: Real, Y: State<T>, Method, SoloutType> IVP<EqType, T, Y, Method, SoloutType>
518where
519 Method: ToleranceConfig<T>,
520{
521 pub fn rtol<V: Into<Tolerance<T>>>(self, rtol: V) -> Self {
523 self.map_method(|method| method.rtol(rtol))
524 }
525
526 pub fn atol<V: Into<Tolerance<T>>>(self, atol: V) -> Self {
528 self.map_method(|method| method.atol(atol))
529 }
530}
531
532impl<'a, F, T: Real, Y: State<T>, Method, SoloutType> IVP<OdeEq<'a, F>, T, Y, Method, SoloutType>
533where
534 F: ODE<T, Y>,
535 Method: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y>,
536 SoloutType: Solout<T, Y>,
537{
538 pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
540 solve_ode(
541 &mut self.method,
542 self.equation.ode,
543 self.t0,
544 self.tf,
545 &self.y0,
546 &mut self.solout,
547 )
548 }
549}
550
551impl<F, T: Real, Y: State<T>, Method, SoloutType> IVP<OdeEqOwned<F>, T, Y, Method, SoloutType>
552where
553 F: ODE<T, Y>,
554 Method: OrdinaryNumericalMethod<T, Y> + Interpolation<T, Y>,
555 SoloutType: Solout<T, Y>,
556{
557 pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
559 solve_ode(
560 &mut self.method,
561 &self.equation.ode,
562 self.t0,
563 self.tf,
564 &self.y0,
565 &mut self.solout,
566 )
567 }
568}
569
570impl<'a, F, T: Real, Y: State<T>, Method, SoloutType> IVP<DaeEq<'a, F>, T, Y, Method, SoloutType>
571where
572 F: DAE<T, Y>,
573 Method: AlgebraicNumericalMethod<T, Y> + Interpolation<T, Y>,
574 SoloutType: Solout<T, Y>,
575{
576 pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
578 solve_dae(
579 &mut self.method,
580 self.equation.dae,
581 self.t0,
582 self.tf,
583 &self.y0,
584 &mut self.solout,
585 )
586 }
587}
588
589impl<F, T: Real, Y: State<T>, Method, SoloutType> IVP<DaeEqOwned<F>, T, Y, Method, SoloutType>
590where
591 F: DAE<T, Y>,
592 Method: AlgebraicNumericalMethod<T, Y> + Interpolation<T, Y>,
593 SoloutType: Solout<T, Y>,
594{
595 pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
597 solve_dae(
598 &mut self.method,
599 &self.equation.dae,
600 self.t0,
601 self.tf,
602 &self.y0,
603 &mut self.solout,
604 )
605 }
606}
607
608impl<'a, F, T: Real, Y: State<T>, Method, SoloutType> IVP<SdeEq<'a, F>, T, Y, Method, SoloutType>
609where
610 F: SDE<T, Y>,
611 Method: StochasticNumericalMethod<T, Y> + Interpolation<T, Y>,
612 SoloutType: Solout<T, Y>,
613{
614 pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
616 solve_sde(
617 &mut self.method,
618 self.equation.sde,
619 self.t0,
620 self.tf,
621 &self.y0,
622 &mut self.solout,
623 )
624 }
625}
626
627impl<F, T: Real, Y: State<T>, Method, SoloutType> IVP<SdeEqOwned<F>, T, Y, Method, SoloutType>
628where
629 F: SDE<T, Y>,
630 Method: StochasticNumericalMethod<T, Y> + Interpolation<T, Y>,
631 SoloutType: Solout<T, Y>,
632{
633 pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
635 solve_sde(
636 &mut self.method,
637 &mut self.equation.sde,
638 self.t0,
639 self.tf,
640 &self.y0,
641 &mut self.solout,
642 )
643 }
644}
645
646impl<'a, const L: usize, F, H, T: Real, Y: State<T>, Method, SoloutType>
647 IVP<DdeEq<'a, L, F, H>, T, Y, Method, SoloutType>
648where
649 F: DDE<L, T, Y>,
650 H: Fn(T) -> Y + Clone,
651 Method: DelayNumericalMethod<L, T, Y, H> + Interpolation<T, Y>,
652 SoloutType: Solout<T, Y>,
653{
654 pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
656 solve_dde(
657 &mut self.method,
658 self.equation.dde,
659 self.t0,
660 self.tf,
661 &self.y0,
662 self.equation.history.clone(),
663 &mut self.solout,
664 )
665 }
666}
667
668impl<const L: usize, F, H, T: Real, Y: State<T>, Method, SoloutType>
669 IVP<DdeEqOwned<L, F, H>, T, Y, Method, SoloutType>
670where
671 F: DDE<L, T, Y>,
672 H: Fn(T) -> Y + Clone,
673 Method: DelayNumericalMethod<L, T, Y, H> + Interpolation<T, Y>,
674 SoloutType: Solout<T, Y>,
675{
676 pub fn solve(mut self) -> Result<Solution<T, Y>, Error<T, Y>> {
678 solve_dde(
679 &mut self.method,
680 &self.equation.dde,
681 self.t0,
682 self.tf,
683 &self.y0,
684 self.equation.history.clone(),
685 &mut self.solout,
686 )
687 }
688}