1use super::{IVPSolver, IVPStatus};
8use crate::roots::secant;
9use nalgebra::{
10 allocator::Allocator, dimension::DimMin, ComplexField, DefaultAllocator, DimName, RealField,
11 VectorN, U1, U3, U7,
12};
13use num_traits::{FromPrimitive, Zero};
14use std::collections::VecDeque;
15
16pub trait BDFSolver<N: ComplexField, S: DimName, O: DimName>: Sized
26where
27 DefaultAllocator: Allocator<N, S>,
28 DefaultAllocator: Allocator<N::RealField, O>,
29{
30 fn higher_coefficients() -> VectorN<N::RealField, O>;
38 fn lower_coefficients() -> VectorN<N::RealField, O>;
47
48 fn solve_ivp<T: Clone, F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>>(
50 self,
51 f: F,
52 params: &mut T,
53 ) -> super::Path<N, N::RealField, S>;
54
55 fn with_tolerance(self, tol: N::RealField) -> Result<Self, String>;
57 fn with_dt_max(self, max: N::RealField) -> Result<Self, String>;
59 fn with_dt_min(self, min: N::RealField) -> Result<Self, String>;
61 fn with_start(self, t_initial: N::RealField) -> Result<Self, String>;
63 fn with_end(self, t_final: N::RealField) -> Result<Self, String>;
65 fn with_initial_conditions(self, start: &[N]) -> Result<Self, String>;
67 fn build(self) -> Self;
69}
70
71#[derive(Debug, Clone)]
75#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
76pub struct BDFInfo<N, S, O>
77where
78 N: ComplexField,
79 S: DimName + DimMin<S, Output = S>,
80 O: DimName,
81 DefaultAllocator: Allocator<N, S>
82 + Allocator<N, O>
83 + Allocator<N, S, S>
84 + Allocator<N, U1, S>
85 + Allocator<(usize, usize), S>,
86{
87 dt: Option<N::RealField>,
88 time: Option<N::RealField>,
89 end: Option<N::RealField>,
90 state: Option<VectorN<N, S>>,
91 dt_max: Option<N::RealField>,
92 dt_min: Option<N::RealField>,
93 tolerance: Option<N::RealField>,
94 higher_coffecients: VectorN<N, O>,
95 lower_coefficients: VectorN<N, O>,
96 memory: VecDeque<(N::RealField, VectorN<N, S>)>,
97 nflag: bool,
98 last: bool,
99}
100
101impl<N, S, O> BDFInfo<N, S, O>
102where
103 N: ComplexField,
104 S: DimName + DimMin<S, Output = S>,
105 O: DimName,
106 DefaultAllocator: Allocator<N, S>
107 + Allocator<N, O>
108 + Allocator<N, S, S>
109 + Allocator<N, U1, S>
110 + Allocator<(usize, usize), S>,
111{
112 pub fn new() -> Self {
113 BDFInfo {
114 dt: None,
115 time: None,
116 end: None,
117 state: None,
118 dt_max: None,
119 dt_min: None,
120 tolerance: None,
121 higher_coffecients: VectorN::<N, O>::zero(),
122 lower_coefficients: VectorN::<N, O>::zero(),
123 memory: VecDeque::new(),
124 nflag: false,
125 last: false,
126 }
127 }
128}
129
130#[allow(clippy::too_many_arguments)]
131fn rk4<
132 N: ComplexField,
133 S: DimName,
134 T: Clone,
135 F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>,
136>(
137 time: N::RealField,
138 dt: N::RealField,
139 initial: &[N],
140 states: &mut VecDeque<(N::RealField, VectorN<N, S>)>,
141 mut f: F,
142 params: &mut T,
143 num: usize,
144) -> Result<(), String>
145where
146 DefaultAllocator: Allocator<N, S>,
147{
148 let mut state = VectorN::from_column_slice(initial);
149 let mut time = time;
150 for i in 0..num {
151 let k1 = f(time, state.as_slice(), &mut params.clone())? * N::from_real(dt);
152 let intermediate = &state + &k1 * N::from_f64(0.5).unwrap();
153 let k2 = f(
154 time + N::RealField::from_f64(0.5).unwrap() * dt,
155 intermediate.as_slice(),
156 &mut params.clone(),
157 )? * N::from_real(dt);
158 let intermediate = &state + &k2 * N::from_f64(0.5).unwrap();
159 let k3 = f(
160 time + N::RealField::from_f64(0.5).unwrap() * dt,
161 intermediate.as_slice(),
162 &mut params.clone(),
163 )? * N::from_real(dt);
164 let intermediate = &state + &k3;
165 let k4 = f(time + dt, intermediate.as_slice(), &mut params.clone())? * N::from_real(dt);
166 if i != 0 {
167 states.push_back((time, state.clone()));
168 }
169 state += (k1 + k2 * N::from_f64(2.0).unwrap() + k3 * N::from_f64(2.0).unwrap() + k4)
170 * N::from_f64(1.0 / 6.0).unwrap();
171 time += dt;
172 }
173 states.push_back((time, state));
174
175 Ok(())
176}
177
178impl<N, S, O> Default for BDFInfo<N, S, O>
179where
180 N: ComplexField,
181 S: DimName + DimMin<S, Output = S>,
182 O: DimName,
183 DefaultAllocator: Allocator<N, S>
184 + Allocator<N, O>
185 + Allocator<N, S, S>
186 + Allocator<N, U1, S>
187 + Allocator<(usize, usize), S>,
188{
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194impl<N, S, O> IVPSolver<N, S> for BDFInfo<N, S, O>
195where
196 N: ComplexField,
197 S: DimName + DimMin<S, Output = S>,
198 O: DimName,
199 DefaultAllocator: Allocator<N, S>
200 + Allocator<N, O>
201 + Allocator<N, S, S>
202 + Allocator<(usize, usize), S>
203 + Allocator<N, U1, S>,
204{
205 fn step<T: Clone, F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>>(
206 &mut self,
207 mut f: F,
208 params: &mut T,
209 ) -> Result<IVPStatus<N, S>, String> {
210 if self.time.unwrap() >= self.end.unwrap() {
211 return Ok(IVPStatus::Done);
212 }
213
214 let mut output = vec![];
215
216 if self.time.unwrap() + self.dt.unwrap() >= self.end.unwrap() {
217 self.dt = Some(self.end.unwrap() - self.time.unwrap());
218 rk4(
219 self.time.unwrap(),
220 self.dt.unwrap(),
221 self.state.as_ref().unwrap().as_slice(),
222 &mut self.memory,
223 &mut f,
224 params,
225 1,
226 )?;
227 *self.time.get_or_insert(N::RealField::zero()) += self.dt.unwrap();
228 return Ok(IVPStatus::Ok(vec![(
229 self.time.unwrap(),
230 self.memory.back().unwrap().1.clone(),
231 )]));
232 }
233
234 if self.memory.is_empty() {
235 rk4(
236 self.time.unwrap(),
237 self.dt.unwrap(),
238 self.state.as_ref().unwrap().as_slice(),
239 &mut self.memory,
240 &mut f,
241 params,
242 self.higher_coffecients.len(),
243 )?;
244 self.time = Some(
245 self.time.unwrap()
246 + N::RealField::from_usize(self.higher_coffecients.len()).unwrap()
247 * self.dt.unwrap(),
248 );
249 self.state = Some(self.memory.back().unwrap().1.clone());
250 }
251
252 let tenth_real = N::RealField::from_f64(0.1).unwrap();
253 let half_real = N::RealField::from_f64(0.5).unwrap();
254 let two_real = N::RealField::from_i32(2).unwrap();
255
256 let higher_func = |y: &[N]| -> VectorN<N, S> {
257 let y = VectorN::<N, S>::from_column_slice(y);
258 let mut state = -f(
259 self.time.unwrap() + self.dt.unwrap(),
260 y.as_slice(),
261 &mut params.clone(),
262 )
263 .unwrap()
264 * N::from_real(self.dt.unwrap())
265 * self.higher_coffecients[0];
266 for (ind, coeff) in self.higher_coffecients.iter().enumerate().skip(1) {
267 state += &self.memory[self.memory.len() - ind].1 * *coeff;
268 }
269 state + y
270 };
271
272 let higher = secant(
273 self.memory[self.memory.len() - 1].1.as_slice(),
274 higher_func,
275 self.dt.unwrap(),
276 self.tolerance.unwrap(),
277 1000,
278 )?;
279
280 let lower_func = |y: &[N]| -> VectorN<N, S> {
281 let y = VectorN::<N, S>::from_column_slice(y);
282 let mut state = -f(
283 self.time.unwrap() + self.dt.unwrap(),
284 y.as_slice(),
285 &mut params.clone(),
286 )
287 .unwrap()
288 * N::from_real(self.dt.unwrap())
289 * self.lower_coefficients[0];
290 for (ind, coeff) in self.lower_coefficients.iter().enumerate().skip(1) {
291 state += &self.memory[self.memory.len() - ind].1 * *coeff;
292 }
293 state + y
294 };
295 let lower = secant(
296 self.memory[self.memory.len() - 1].1.as_slice(),
297 lower_func,
298 self.dt.unwrap(),
299 self.tolerance.unwrap(),
300 1000,
301 )?;
302
303 let diff = &higher - &lower;
304 let error = diff.dot(&diff).sqrt().abs();
305
306 if error <= self.tolerance.unwrap() {
307 self.state = Some(higher.clone());
308 self.time = Some(self.time.unwrap() + self.dt.unwrap());
309 if self.nflag {
310 for state in self.memory.iter() {
311 output.push((state.0, state.1.clone()));
312 }
313 self.nflag = false;
314 }
315 output.push((self.time.unwrap(), self.state.as_ref().unwrap().clone()));
316
317 self.memory.push_back((self.time.unwrap(), higher));
318 self.memory.pop_front();
319
320 if self.last {
321 return Ok(IVPStatus::Ok(output));
322 }
323
324 if error < tenth_real * self.tolerance.unwrap()
325 || self.time.unwrap() > self.end.unwrap()
326 {
327 self.dt = Some(self.dt.unwrap() * two_real);
328
329 if self.dt.unwrap() > self.dt_max.unwrap() {
330 self.dt = Some(self.dt_max.unwrap());
331 }
332
333 if self.time.unwrap()
334 + N::RealField::from_usize(self.higher_coffecients.len()).unwrap()
335 * self.dt.unwrap()
336 > self.end.unwrap()
337 {
338 self.dt = Some(
339 (self.end.unwrap() - self.time.unwrap())
340 / N::RealField::from_usize(self.higher_coffecients.len()).unwrap(),
341 );
342 self.last = true;
343 }
344
345 self.memory.clear();
346 }
347
348 return Ok(IVPStatus::Ok(output));
349 }
350
351 self.dt = Some(self.dt.unwrap() * half_real);
352 if self.dt.unwrap() < self.dt_min.unwrap() {
353 return Err("BDFInfo step: minimum dt exceeded".to_owned());
354 }
355
356 self.memory.clear();
357 Ok(IVPStatus::Redo)
358 }
359
360 fn with_tolerance(mut self, tol: N::RealField) -> Result<Self, String> {
361 if !tol.is_sign_positive() {
362 return Err("BDFInfo with_tolerance: tolerance must be postive".to_owned());
363 }
364 self.tolerance = Some(tol);
365 Ok(self)
366 }
367
368 fn with_dt_max(mut self, max: N::RealField) -> Result<Self, String> {
369 if !max.is_sign_positive() {
370 return Err("BDFInfo with_dt_max: dt_max must be positive".to_owned());
371 }
372 if let Some(min) = self.dt_min {
373 if max <= min {
374 return Err("BDFInfo with_dt_max: dt_max must be greater than dt_min".to_owned());
375 }
376 }
377 self.dt_max = Some(max);
378 self.dt = Some(max);
379 Ok(self)
380 }
381
382 fn with_dt_min(mut self, min: N::RealField) -> Result<Self, String> {
383 if !min.is_sign_positive() {
384 return Err("BDFInfo with_dt_min: dt_min must be positive".to_owned());
385 }
386 if let Some(max) = self.dt_max {
387 if min >= max {
388 return Err("BDFInfo with_dt_min: dt_min must be less than dt_max".to_owned());
389 }
390 }
391 self.dt_min = Some(min);
392 Ok(self)
393 }
394
395 fn with_start(mut self, t_initial: N::RealField) -> Result<Self, String> {
396 if let Some(end) = self.end {
397 if end <= t_initial {
398 return Err("BDFInfo with_start: Start must be before end".to_owned());
399 }
400 }
401 self.time = Some(t_initial);
402 Ok(self)
403 }
404
405 fn with_end(mut self, t_final: N::RealField) -> Result<Self, String> {
406 if let Some(start) = self.time {
407 if t_final <= start {
408 return Err("BDFInfo with_end: Start must be before end".to_owned());
409 }
410 }
411 self.end = Some(t_final);
412 Ok(self)
413 }
414
415 fn with_initial_conditions(mut self, start: &[N]) -> Result<Self, String> {
416 self.state = Some(VectorN::<N, S>::from_column_slice(start));
417 Ok(self)
418 }
419
420 fn build(self) -> Self {
421 self
422 }
423
424 fn get_initial_conditions(&self) -> Option<VectorN<N, S>> {
425 if let Some(state) = &self.state {
426 Some(state.clone())
427 } else {
428 None
429 }
430 }
431
432 fn get_time(&self) -> Option<N::RealField> {
433 if let Some(time) = &self.time {
434 Some(*time)
435 } else {
436 None
437 }
438 }
439
440 fn check_start(&self) -> Result<(), String> {
441 if self.time == None {
442 Err("BDFInfo check_start: No initial time".to_owned())
443 } else if self.end == None {
444 Err("BDFInfo check_start: No end time".to_owned())
445 } else if self.tolerance == None {
446 Err("BDFInfo check_start: No tolerance".to_owned())
447 } else if self.state == None {
448 Err("BDFInfo check_start: No initial conditions".to_owned())
449 } else if self.dt_max == None {
450 Err("BDFInfo check_start: No dt_max".to_owned())
451 } else if self.dt_min == None {
452 Err("BDFInfo check_start: No dt_min".to_owned())
453 } else {
454 Ok(())
455 }
456 }
457}
458
459#[derive(Debug, Clone)]
489#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
490pub struct BDF6<N, S>
491where
492 N: ComplexField,
493 S: DimName + DimMin<S, Output = S>,
494 DefaultAllocator: Allocator<N, S>
495 + Allocator<N, U7>
496 + Allocator<N, S, S>
497 + Allocator<N, U1, S>
498 + Allocator<(usize, usize), S>,
499{
500 info: BDFInfo<N, S, U7>,
501}
502
503impl<N, S> BDF6<N, S>
504where
505 N: ComplexField,
506 S: DimName + DimMin<S, Output = S>,
507 DefaultAllocator: Allocator<N, S>
508 + Allocator<N, U7>
509 + Allocator<N, S, S>
510 + Allocator<N, U1, S>
511 + Allocator<(usize, usize), S>,
512{
513 pub fn new() -> Self {
514 let mut info = BDFInfo::new();
515 info.higher_coffecients = VectorN::<N, U7>::from_iterator(
516 Self::higher_coefficients().iter().map(|&x| N::from_real(x)),
517 );
518 info.lower_coefficients = VectorN::<N, U7>::from_iterator(
519 Self::lower_coefficients().iter().map(|&x| N::from_real(x)),
520 );
521
522 BDF6 { info }
523 }
524}
525
526impl<N, S> Default for BDF6<N, S>
527where
528 N: ComplexField,
529 S: DimName + DimMin<S, Output = S>,
530 DefaultAllocator: Allocator<N, S>
531 + Allocator<N, U7>
532 + Allocator<N, S, S>
533 + Allocator<N, U1, S>
534 + Allocator<(usize, usize), S>,
535{
536 fn default() -> Self {
537 Self::new()
538 }
539}
540
541impl<N, S> BDFSolver<N, S, U7> for BDF6<N, S>
542where
543 N: ComplexField,
544 S: DimName + DimMin<S, Output = S>,
545 DefaultAllocator: Allocator<N, S>
546 + Allocator<N, U7>
547 + Allocator<N, S, S>
548 + Allocator<N, U1, S>
549 + Allocator<(usize, usize), S>,
550{
551 fn higher_coefficients() -> VectorN<N::RealField, U7> {
552 VectorN::<N::RealField, U7>::from_column_slice(&[
553 N::RealField::from_f64(60.0 / 147.0).unwrap(),
554 N::RealField::from_f64(-360.0 / 147.0).unwrap(),
555 N::RealField::from_f64(450.0 / 147.0).unwrap(),
556 N::RealField::from_f64(-400.0 / 147.0).unwrap(),
557 N::RealField::from_f64(225.0 / 147.0).unwrap(),
558 N::RealField::from_f64(-72.0 / 147.0).unwrap(),
559 N::RealField::from_f64(10.0 / 147.0).unwrap(),
560 ])
561 }
562
563 fn lower_coefficients() -> VectorN<N::RealField, U7> {
564 VectorN::<N::RealField, U7>::from_column_slice(&[
565 N::RealField::from_f64(60.0 / 137.0).unwrap(),
566 N::RealField::from_f64(-300.0 / 137.0).unwrap(),
567 N::RealField::from_f64(300.0 / 137.0).unwrap(),
568 N::RealField::from_f64(-200.0 / 137.0).unwrap(),
569 N::RealField::from_f64(75.0 / 137.0).unwrap(),
570 N::RealField::from_f64(-12.0 / 137.0).unwrap(),
571 N::RealField::zero(),
572 ])
573 }
574
575 fn solve_ivp<
576 T: Clone,
577 F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>,
578 >(
579 self,
580 f: F,
581 params: &mut T,
582 ) -> super::Path<N, N::RealField, S> {
583 self.info.solve_ivp(f, params)
584 }
585
586 fn with_tolerance(mut self, tol: N::RealField) -> Result<Self, String> {
587 self.info = self.info.with_tolerance(tol)?;
588 Ok(self)
589 }
590
591 fn with_dt_max(mut self, max: N::RealField) -> Result<Self, String> {
592 self.info = self.info.with_dt_max(max)?;
593 Ok(self)
594 }
595
596 fn with_dt_min(mut self, min: N::RealField) -> Result<Self, String> {
597 self.info = self.info.with_dt_min(min)?;
598 Ok(self)
599 }
600
601 fn with_start(mut self, t_initial: N::RealField) -> Result<Self, String> {
602 self.info = self.info.with_start(t_initial)?;
603 Ok(self)
604 }
605
606 fn with_end(mut self, t_final: N::RealField) -> Result<Self, String> {
607 self.info = self.info.with_end(t_final)?;
608 Ok(self)
609 }
610
611 fn with_initial_conditions(mut self, start: &[N]) -> Result<Self, String> {
612 self.info = self.info.with_initial_conditions(start)?;
613 Ok(self)
614 }
615
616 fn build(mut self) -> Self {
617 self.info = self.info.build();
618 self
619 }
620}
621
622impl<N, S> From<BDF6<N, S>> for BDFInfo<N, S, U7>
623where
624 N: ComplexField,
625 S: DimName + DimMin<S, Output = S>,
626 DefaultAllocator: Allocator<N, S>
627 + Allocator<N, U7>
628 + Allocator<N, S, S>
629 + Allocator<N, U1, S>
630 + Allocator<(usize, usize), S>,
631{
632 fn from(bdf: BDF6<N, S>) -> BDFInfo<N, S, U7> {
633 bdf.info
634 }
635}
636
637#[derive(Debug, Clone)]
667#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
668pub struct BDF2<N, S>
669where
670 N: ComplexField,
671 S: DimName + DimMin<S, Output = S>,
672 DefaultAllocator: Allocator<N, S>
673 + Allocator<N, U3>
674 + Allocator<N, S, S>
675 + Allocator<N, U1, S>
676 + Allocator<(usize, usize), S>,
677{
678 info: BDFInfo<N, S, U3>,
679}
680
681impl<N, S> BDF2<N, S>
682where
683 N: ComplexField,
684 S: DimName + DimMin<S, Output = S>,
685 DefaultAllocator: Allocator<N, S>
686 + Allocator<N, U3>
687 + Allocator<N, S, S>
688 + Allocator<N, U1, S>
689 + Allocator<(usize, usize), S>,
690{
691 pub fn new() -> Self {
692 let mut info = BDFInfo::new();
693 info.higher_coffecients = VectorN::<N, U3>::from_iterator(
694 Self::higher_coefficients().iter().map(|&x| N::from_real(x)),
695 );
696 info.lower_coefficients = VectorN::<N, U3>::from_iterator(
697 Self::lower_coefficients().iter().map(|&x| N::from_real(x)),
698 );
699
700 BDF2 { info }
701 }
702}
703
704impl<N, S> Default for BDF2<N, S>
705where
706 N: ComplexField,
707 S: DimName + DimMin<S, Output = S>,
708 DefaultAllocator: Allocator<N, S>
709 + Allocator<N, U3>
710 + Allocator<N, S, S>
711 + Allocator<N, U1, S>
712 + Allocator<(usize, usize), S>,
713{
714 fn default() -> Self {
715 Self::new()
716 }
717}
718
719impl<N, S> BDFSolver<N, S, U3> for BDF2<N, S>
720where
721 N: ComplexField,
722 S: DimName + DimMin<S, Output = S>,
723 DefaultAllocator: Allocator<N, S>
724 + Allocator<N, U3>
725 + Allocator<N, S, S>
726 + Allocator<N, U1, S>
727 + Allocator<(usize, usize), S>,
728{
729 fn higher_coefficients() -> VectorN<N::RealField, U3> {
730 VectorN::<N::RealField, U3>::from_column_slice(&[
731 N::RealField::from_f64(2.0 / 3.0).unwrap(),
732 N::RealField::from_f64(-4.0 / 3.0).unwrap(),
733 N::RealField::from_f64(1.0 / 3.0).unwrap(),
734 ])
735 }
736
737 fn lower_coefficients() -> VectorN<N::RealField, U3> {
738 VectorN::<N::RealField, U3>::from_column_slice(&[
739 N::RealField::from_f64(1.0).unwrap(),
740 N::RealField::from_f64(-1.0).unwrap(),
741 N::RealField::zero(),
742 ])
743 }
744
745 fn solve_ivp<
746 T: Clone,
747 F: FnMut(N::RealField, &[N], &mut T) -> Result<VectorN<N, S>, String>,
748 >(
749 self,
750 f: F,
751 params: &mut T,
752 ) -> super::Path<N, N::RealField, S> {
753 self.info.solve_ivp(f, params)
754 }
755
756 fn with_tolerance(mut self, tol: N::RealField) -> Result<Self, String> {
757 self.info = self.info.with_tolerance(tol)?;
758 Ok(self)
759 }
760
761 fn with_dt_max(mut self, max: N::RealField) -> Result<Self, String> {
762 self.info = self.info.with_dt_max(max)?;
763 Ok(self)
764 }
765
766 fn with_dt_min(mut self, min: N::RealField) -> Result<Self, String> {
767 self.info = self.info.with_dt_min(min)?;
768 Ok(self)
769 }
770
771 fn with_start(mut self, t_initial: N::RealField) -> Result<Self, String> {
772 self.info = self.info.with_start(t_initial)?;
773 Ok(self)
774 }
775
776 fn with_end(mut self, t_final: N::RealField) -> Result<Self, String> {
777 self.info = self.info.with_end(t_final)?;
778 Ok(self)
779 }
780
781 fn with_initial_conditions(mut self, start: &[N]) -> Result<Self, String> {
782 self.info = self.info.with_initial_conditions(start)?;
783 Ok(self)
784 }
785
786 fn build(mut self) -> Self {
787 self.info = self.info.build();
788 self
789 }
790}
791
792impl<N, S> From<BDF2<N, S>> for BDFInfo<N, S, U3>
793where
794 N: ComplexField,
795 S: DimName + DimMin<S, Output = S>,
796 DefaultAllocator: Allocator<N, S>
797 + Allocator<N, U3>
798 + Allocator<N, S, S>
799 + Allocator<N, U1, S>
800 + Allocator<(usize, usize), S>,
801{
802 fn from(bdf: BDF2<N, S>) -> BDFInfo<N, S, U3> {
803 bdf.info
804 }
805}