1use super::{Derivative, IVPError, IVPIterator, IVPSolver, IVPStatus, IVPStepper, Step};
8use crate::{BMatrix, BSMatrix, BSVector, BVector, Dimension};
9use nalgebra::{
10 allocator::Allocator, ComplexField, Const, DefaultAllocator, Dim, DimName, RealField, U1,
11};
12use num_traits::{FromPrimitive, One, Zero};
13use std::marker::PhantomData;
14
15pub trait RungeKuttaCoefficients<const O: usize> {
21 type RealField: RealField;
23
24 fn t_coefficients() -> Option<BSVector<Self::RealField, O>>;
27
28 fn k_coefficients() -> Option<BSMatrix<Self::RealField, O, O>>;
33
34 fn avg_coefficients() -> Option<BSVector<Self::RealField, O>>;
39
40 fn error_coefficients() -> Option<BSVector<Self::RealField, O>>;
45}
46
47pub struct RungeKutta<'a, N, D, const O: usize, T, F, R>
51where
52 D: Dimension,
53 N: ComplexField + Copy,
54 T: Clone,
55 F: Derivative<N, D, T> + 'a,
56 R: RungeKuttaCoefficients<O, RealField = N::RealField>,
57 DefaultAllocator: Allocator<N, D>,
58 DefaultAllocator: Allocator<N, Const<O>>,
59{
60 init_dt_max: Option<N::RealField>,
61 init_dt_min: Option<N::RealField>,
62 init_time: Option<N::RealField>,
63 init_end: Option<N::RealField>,
64 init_tolerance: Option<N::RealField>,
65 init_state: Option<BVector<N, D>>,
66 init_derivative: Option<F>,
67 dim: D,
68 _data: PhantomData<&'a (T, R)>,
69}
70
71pub struct RungeKuttaSolver<'a, N, D, const O: usize, T, F>
76where
77 D: Dimension,
78 N: ComplexField + Copy,
79 T: Clone,
80 F: Derivative<N, D, T> + 'a,
81 DefaultAllocator: Allocator<N, D>,
82 DefaultAllocator: Allocator<N, Const<O>>,
83 DefaultAllocator: Allocator<N, D, Const<O>>,
84{
85 dt_max: N,
87 dt_min: N,
88 time: N,
89 end: N,
90 tolerance: N,
91 derivative: F,
92 data: T,
93
94 dt: N,
96 state: BVector<N, D>,
97
98 t_coefficients: BSVector<N, O>,
100 k_coefficients: BSMatrix<N, O, O>,
101 avg_coefficients: BSVector<N, O>,
102 error_coefficients: BSVector<N, O>,
103
104 half_steps: BMatrix<N, D, Const<O>>,
106 step: BVector<N, D>,
107 scratch_pad: BVector<N, D>,
108
109 one_tenth: N,
111 one_fourth: N,
112 point_eighty_four: N,
113 four: N,
114
115 _lifetime: PhantomData<&'a ()>,
116}
117
118impl<'a, N, D, const O: usize, T, F, R> IVPSolver<'a, D> for RungeKutta<'a, N, D, O, T, F, R>
119where
120 D: Dimension,
121 N: ComplexField + Copy,
122 T: Clone,
123 F: Derivative<N, D, T> + 'a,
124 R: RungeKuttaCoefficients<O, RealField = N::RealField>,
125 DefaultAllocator: Allocator<N, D>,
126 DefaultAllocator: Allocator<N, Const<O>>,
127 DefaultAllocator: Allocator<N, D, Const<O>>,
128{
129 type Error = IVPError;
130 type Field = N;
131 type RealField = N::RealField;
132 type Derivative = F;
133 type UserData = T;
134 type Solver = RungeKuttaSolver<'a, N, D, O, T, F>;
135
136 fn new() -> Result<Self, IVPError> {
137 Ok(Self {
138 init_dt_max: None,
139 init_dt_min: None,
140 init_time: None,
141 init_end: None,
142 init_tolerance: None,
143 init_state: None,
144 init_derivative: None,
145 dim: D::dim()?,
146 _data: PhantomData,
147 })
148 }
149
150 fn new_dyn(size: usize) -> Result<Self, IVPError> {
151 Ok(Self {
152 init_dt_max: None,
153 init_dt_min: None,
154 init_time: None,
155 init_end: None,
156 init_tolerance: None,
157 init_state: None,
158 init_derivative: None,
159 dim: D::dim_dyn(size)?,
160 _data: PhantomData,
161 })
162 }
163
164 fn dim(&self) -> D {
165 self.dim
166 }
167
168 fn with_tolerance(mut self, tol: Self::RealField) -> Result<Self, Self::Error> {
169 if tol <= <Self::RealField as Zero>::zero() {
170 return Err(IVPError::ToleranceOOB);
171 }
172 self.init_tolerance = Some(tol);
173 Ok(self)
174 }
175
176 fn with_maximum_dt(mut self, max: Self::RealField) -> Result<Self, Self::Error> {
180 if max <= <Self::RealField as Zero>::zero() {
181 return Err(IVPError::TimeDeltaOOB);
182 }
183
184 self.init_dt_max = Some(max.clone());
185 if let Some(dt_min) = self.init_dt_min.as_mut() {
186 if *dt_min > max {
187 *dt_min = max;
188 }
189 }
190
191 Ok(self)
192 }
193
194 fn with_minimum_dt(mut self, min: Self::RealField) -> Result<Self, Self::Error> {
198 if min <= <Self::RealField as Zero>::zero() {
199 return Err(IVPError::TimeDeltaOOB);
200 }
201
202 self.init_dt_min = Some(min.clone());
203 if let Some(dt_max) = self.init_dt_max.as_mut() {
204 if *dt_max < min {
205 *dt_max = min;
206 }
207 }
208
209 Ok(self)
210 }
211
212 fn with_initial_time(mut self, initial: Self::RealField) -> Result<Self, Self::Error> {
213 self.init_time = Some(initial.clone());
214
215 if let Some(end) = self.init_end.as_ref() {
216 if *end <= initial {
217 return Err(IVPError::TimeStartOOB);
218 }
219 }
220
221 Ok(self)
222 }
223
224 fn with_ending_time(mut self, ending: Self::RealField) -> Result<Self, Self::Error> {
225 self.init_end = Some(ending.clone());
226
227 if let Some(initial) = self.init_time.as_ref() {
228 if *initial >= ending {
229 return Err(IVPError::TimeEndOOB);
230 }
231 }
232
233 Ok(self)
234 }
235
236 fn with_initial_conditions(
237 mut self,
238 start: BVector<Self::Field, D>,
239 ) -> Result<Self, Self::Error> {
240 self.init_state = Some(start);
241 Ok(self)
242 }
243
244 fn with_derivative(mut self, derivative: Self::Derivative) -> Self {
245 self.init_derivative = Some(derivative);
246 self
247 }
248
249 fn solve(self, data: Self::UserData) -> Result<IVPIterator<D, Self::Solver>, Self::Error> {
250 let dt_max = self.init_dt_max.ok_or(IVPError::MissingParameters)?;
251 let dt_min = self.init_dt_min.ok_or(IVPError::MissingParameters)?;
252 let tolerance = self.init_tolerance.ok_or(IVPError::MissingParameters)?;
253 let time = self.init_time.ok_or(IVPError::MissingParameters)?;
254 let end = self.init_end.ok_or(IVPError::MissingParameters)?;
255 let state = self.init_state.ok_or(IVPError::MissingParameters)?;
256 let derivative = self.init_derivative.ok_or(IVPError::MissingParameters)?;
257
258 let two = Self::Field::from_u8(2).ok_or(IVPError::FromPrimitiveFailure)?;
259 let half = Self::Field::one() / two;
260
261 let one_tenth =
262 Self::Field::one() / Self::Field::from_u8(10).ok_or(IVPError::FromPrimitiveFailure)?;
263 let four = Self::Field::from_u8(4).ok_or(IVPError::FromPrimitiveFailure)?;
264 let one_fourth = Self::Field::one() / four;
265
266 let one_hundred = Self::Field::from_u8(100).ok_or(IVPError::FromPrimitiveFailure)?;
267 let eighty_four = Self::Field::from_u8(100).ok_or(IVPError::FromPrimitiveFailure)?;
268 let point_eighty_four = eighty_four / one_hundred;
269
270 let t_coefficients = BSVector::from_iterator(
271 R::t_coefficients()
272 .ok_or(IVPError::FromPrimitiveFailure)?
273 .as_slice()
274 .iter()
275 .cloned()
276 .map(Self::Field::from_real),
277 );
278
279 let k_coefficients = BSMatrix::<N, O, O>::from_iterator_generic(
280 <Const<O> as Dim>::from_usize(O),
281 <Const<O> as Dim>::from_usize(O),
282 R::k_coefficients()
283 .ok_or(IVPError::FromPrimitiveFailure)?
284 .as_slice()
285 .iter()
286 .cloned()
287 .map(Self::Field::from_real),
288 );
289
290 let avg_coefficients = BSVector::from_iterator(
291 R::avg_coefficients()
292 .ok_or(IVPError::FromPrimitiveFailure)?
293 .as_slice()
294 .iter()
295 .cloned()
296 .map(Self::Field::from_real),
297 );
298
299 let error_coefficients = BSVector::from_iterator(
300 R::error_coefficients()
301 .ok_or(IVPError::FromPrimitiveFailure)?
302 .as_slice()
303 .iter()
304 .cloned()
305 .map(Self::Field::from_real),
306 );
307
308 Ok(IVPIterator {
309 solver: RungeKuttaSolver {
310 dt_max: Self::Field::from_real(dt_max.clone()),
311 dt_min: Self::Field::from_real(dt_min.clone()),
312 time: Self::Field::from_real(time),
313 end: Self::Field::from_real(end),
314 tolerance: Self::Field::from_real(tolerance),
315 dt: Self::Field::from_real(dt_max + dt_min) * half,
316 state,
317 derivative,
318 data,
319 t_coefficients,
320 k_coefficients,
321 avg_coefficients,
322 error_coefficients,
323 half_steps: BMatrix::from_element_generic(
324 self.dim,
325 <Const<O> as DimName>::name(),
326 Self::Field::zero(),
327 ),
328 scratch_pad: BVector::from_element_generic(
329 self.dim,
330 U1::name(),
331 Self::Field::zero(),
332 ),
333 step: BVector::from_element_generic(self.dim, U1::name(), Self::Field::zero()),
334 one_tenth,
335 one_fourth,
336 point_eighty_four,
337 four,
338 _lifetime: PhantomData,
339 },
340 finished: false,
341 _dim: PhantomData,
342 })
343 }
344}
345
346impl<'a, N, D, const O: usize, T, F> IVPStepper<D> for RungeKuttaSolver<'a, N, D, O, T, F>
347where
348 D: Dimension,
349 N: ComplexField + Copy,
350 T: Clone,
351 F: Derivative<N, D, T> + 'a,
352 DefaultAllocator: Allocator<N, D>,
353 DefaultAllocator: Allocator<N, Const<O>>,
354 DefaultAllocator: Allocator<N, D, Const<O>>,
355{
356 type Error = IVPError;
357 type Field = N;
358 type RealField = N::RealField;
359 type UserData = T;
360
361 fn step(&mut self) -> Step<Self::RealField, Self::Field, D, Self::Error> {
362 if self.time.real() >= self.end.real() {
363 return Err(IVPStatus::Done);
364 }
365
366 if self.time.real() + self.dt.real() >= self.end.real() {
367 self.dt = self.end - self.time;
368 }
369
370 for (i, k_row) in self.k_coefficients.row_iter().enumerate() {
371 self.scratch_pad = self.state.clone();
372 for (j, &k_coeff) in k_row.iter().enumerate() {
373 self.scratch_pad += self.half_steps.column(j) * k_coeff;
374 }
375
376 let step_time = self.time + self.t_coefficients[i] * self.dt;
377 self.step = (self.derivative)(
378 step_time.real(),
379 self.scratch_pad.as_slice(),
380 &mut self.data.clone(),
381 )? * self.dt;
382
383 self.half_steps.set_column(i, &self.step);
384 }
385
386 self.scratch_pad = self.half_steps.column(0) * self.error_coefficients[0];
387 for (ind, &e_coeff) in self.error_coefficients.iter().enumerate().skip(1) {
388 self.scratch_pad += self.half_steps.column(ind) * e_coeff;
389 }
390 let error = self.scratch_pad.norm() / self.dt.real();
391
392 if error <= self.tolerance.real() {
393 self.time += self.dt;
394
395 for (ind, &avg_coeff) in self.avg_coefficients.iter().enumerate() {
396 self.state += self.half_steps.column(ind) * avg_coeff;
397 }
398 }
399
400 let delta = self.point_eighty_four.real()
401 * (self.tolerance.real() / error.clone()).powf(self.one_fourth.real());
402 if delta <= self.one_tenth.real() {
403 self.dt *= self.one_tenth;
404 } else if delta >= self.four.real() {
405 self.dt *= self.four;
406 } else {
407 self.dt *= Self::Field::from_real(delta);
408 }
409
410 if self.dt.real() > self.dt_max.real() {
411 self.dt = self.dt_max;
412 }
413
414 if self.dt.real() < self.dt_min.real() && self.time.real() < self.end.real() {
415 return Err(IVPStatus::Failure(IVPError::MinimumTimeDeltaExceeded));
416 }
417
418 if error <= self.tolerance.real() {
419 Ok((self.time.real(), self.state.clone()))
420 } else {
421 Err(IVPStatus::Redo)
422 }
423 }
424
425 fn time(&self) -> Self::RealField {
426 self.time.real()
427 }
428}
429
430pub struct RKCoefficients45<N: ComplexField>(PhantomData<N>);
431
432impl<N: ComplexField> RungeKuttaCoefficients<6> for RKCoefficients45<N> {
433 type RealField = N::RealField;
434
435 fn t_coefficients() -> Option<BSVector<Self::RealField, 6>> {
436 let one_fourth = Self::RealField::from_u8(4)?.recip();
437 let one_half = Self::RealField::from_u8(2)?.recip();
438 let three = Self::RealField::from_u8(3)?;
439 let eight = Self::RealField::from_u8(8)?;
440 let twelve = Self::RealField::from_u8(12)?;
441 let thirteen = Self::RealField::from_u8(13)?;
442
443 Some(BSVector::from_column_slice(&[
444 Self::RealField::zero(),
445 one_fourth,
446 three / eight,
447 twelve / thirteen,
448 Self::RealField::one(),
449 one_half,
450 ]))
451 }
452
453 fn k_coefficients() -> Option<BSMatrix<Self::RealField, 6, 6>> {
454 let zero = Self::RealField::zero();
455 let one_fourth = Self::RealField::from_u8(4)?.recip();
456 let thirty_two = Self::RealField::from_u8(32)?;
457 let two_one_nine_seven = Self::RealField::from_u16(2197)?;
458
459 Some(BSMatrix::from_vec(vec![
460 zero.clone(),
462 zero.clone(),
463 zero.clone(),
464 zero.clone(),
465 zero.clone(),
466 zero.clone(),
467 one_fourth,
469 zero.clone(),
470 zero.clone(),
471 zero.clone(),
472 zero.clone(),
473 zero.clone(),
474 Self::RealField::from_u8(3)? / thirty_two.clone(),
476 Self::RealField::from_u8(9)? / thirty_two.clone(),
477 zero.clone(),
478 zero.clone(),
479 zero.clone(),
480 zero.clone(),
481 Self::RealField::from_u16(1932)? / two_one_nine_seven.clone(),
483 -Self::RealField::from_u16(7200)? / two_one_nine_seven.clone(),
484 Self::RealField::from_u16(7296)? / two_one_nine_seven,
485 zero.clone(),
486 zero.clone(),
487 zero.clone(),
488 Self::RealField::from_u16(439)? / Self::RealField::from_u8(216)?,
490 -Self::RealField::from_u8(8)?,
491 Self::RealField::from_u16(3680)? / Self::RealField::from_u16(513)?,
492 -Self::RealField::from_u16(845)? / Self::RealField::from_u16(4104)?,
493 zero.clone(),
494 zero.clone(),
495 -Self::RealField::from_u8(8)? / Self::RealField::from_u8(27)?,
497 Self::RealField::from_u8(2)?,
498 -Self::RealField::from_u16(3544)? / Self::RealField::from_u16(2565)?,
499 Self::RealField::from_u16(1859)? / Self::RealField::from_u16(4014)?,
500 -Self::RealField::from_u8(11)? / Self::RealField::from_u8(40)?,
501 zero,
502 ]))
503 }
504
505 fn avg_coefficients() -> Option<BSVector<Self::RealField, 6>> {
506 Some(BSVector::from_column_slice(&[
507 Self::RealField::from_u8(25)? / Self::RealField::from_u8(216)?,
508 Self::RealField::zero(),
509 Self::RealField::from_u16(1408)? / Self::RealField::from_u16(2565)?,
510 Self::RealField::from_u16(2197)? / Self::RealField::from_u16(4104)?,
511 -Self::RealField::from_u8(5)?.recip(),
512 Self::RealField::zero(),
513 ]))
514 }
515
516 fn error_coefficients() -> Option<BSVector<Self::RealField, 6>> {
517 Some(BSVector::from_column_slice(&[
518 Self::RealField::from_u16(360)?.recip(),
519 Self::RealField::from_f64(0.0).unwrap(),
520 Self::RealField::from_f64(-128.0 / 4275.0).unwrap(),
521 Self::RealField::from_f64(-2197.0 / 75240.0).unwrap(),
522 Self::RealField::from_f64(1.0 / 50.0).unwrap(),
523 Self::RealField::from_f64(2.0 / 55.0).unwrap(),
524 ]))
525 }
526}
527
528pub type RungeKutta45<'a, N, D, T, F> = RungeKutta<'a, N, D, 6, T, F, RKCoefficients45<N>>;
562
563pub struct RK23Coefficients<N: ComplexField>(PhantomData<N>);
564
565impl<N: ComplexField> RungeKuttaCoefficients<4> for RK23Coefficients<N> {
566 type RealField = N::RealField;
567
568 fn t_coefficients() -> Option<BSVector<Self::RealField, 4>> {
569 Some(BSVector::from_column_slice(&[
570 Self::RealField::zero(),
571 Self::RealField::from_u8(2)?.recip(),
572 Self::RealField::from_u8(3)? / Self::RealField::from_u8(4)?,
573 Self::RealField::one(),
574 ]))
575 }
576
577 fn k_coefficients() -> Option<BSMatrix<Self::RealField, 4, 4>> {
578 let zero = Self::RealField::zero();
579
580 Some(BSMatrix::from_vec(vec![
581 zero.clone(),
583 zero.clone(),
584 zero.clone(),
585 zero.clone(),
586 Self::RealField::from_u8(2)?.recip(),
588 zero.clone(),
589 zero.clone(),
590 zero.clone(),
591 zero.clone(),
593 Self::RealField::from_u8(3)? / Self::RealField::from_u8(4)?,
594 zero.clone(),
595 zero.clone(),
596 Self::RealField::from_u8(2)? / Self::RealField::from_u8(9)?,
598 Self::RealField::from_u8(3)?.recip(),
599 Self::RealField::from_u8(4)? / Self::RealField::from_u8(9)?,
600 zero,
601 ]))
602 }
603
604 fn avg_coefficients() -> Option<BSVector<Self::RealField, 4>> {
605 Some(BSVector::from_column_slice(&[
606 Self::RealField::from_u8(2)? / Self::RealField::from_u8(9)?,
607 Self::RealField::from_u8(3)?.recip(),
608 Self::RealField::from_u8(4)? / Self::RealField::from_u8(9)?,
609 Self::RealField::zero(),
610 ]))
611 }
612
613 fn error_coefficients() -> Option<BSVector<Self::RealField, 4>> {
614 Some(BSVector::from_column_slice(&[
615 -Self::RealField::from_u8(5)? / Self::RealField::from_u8(72)?,
616 Self::RealField::from_u8(12)?.recip(),
617 Self::RealField::from_u8(9)?.recip(),
618 -Self::RealField::from_u8(8)?.recip(),
619 ]))
620 }
621}
622
623pub type RungeKutta23<'a, N, D, T, F> = RungeKutta<'a, N, D, 4, T, F, RK23Coefficients<N>>;
657
658#[cfg(test)]
659mod test {
660 use super::*;
661 use crate::{ivp::UserError, BSVector};
662 use rstest::rstest;
663
664 fn quadratic_deriv(t: f64, _y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
665 Ok(BSVector::from_column_slice(&[-2.0 * t]))
666 }
667
668 fn sine_deriv(t: f64, _y: &[f64], _: &mut ()) -> Result<BSVector<f64, 1>, UserError> {
669 Ok(BSVector::from_column_slice(&[t.cos()]))
670 }
671
672 type TestRK<'a, const O: usize, R> = RungeKutta<
673 'a,
674 f64,
675 Const<1>,
676 O,
677 (),
678 fn(f64, &[f64], &mut ()) -> Result<BVector<f64, Const<1>>, UserError>,
679 R,
680 >;
681
682 #[rstest]
683 #[case::rk23(RungeKutta23::new().unwrap())]
684 #[case::rk45(RungeKutta45::new().unwrap())]
685 fn rungekutta_quadratic<'a, const O: usize, R>(#[case] rk: TestRK<'a, O, R>)
686 where
687 R: RungeKuttaCoefficients<O, RealField = f64>,
688 {
689 let t_initial = 0.0;
690 let t_final = 10.0;
691
692 let solver = rk
693 .with_minimum_dt(0.0001)
694 .unwrap()
695 .with_maximum_dt(0.1)
696 .unwrap()
697 .with_initial_time(t_initial)
698 .unwrap()
699 .with_ending_time(t_final)
700 .unwrap()
701 .with_tolerance(1e-5)
702 .unwrap()
703 .with_initial_conditions_slice(&[1.0])
704 .unwrap()
705 .with_derivative(quadratic_deriv)
706 .solve(())
707 .unwrap();
708
709 let path = solver.collect_vec().unwrap();
710
711 for step in &path {
712 assert!(approx_eq!(
713 f64,
714 step.1.column(0)[0],
715 1.0 - step.0.powi(2),
716 epsilon = 0.0001
717 ));
718 }
719 }
720
721 #[rstest]
722 #[case::rk23(RungeKutta23::new().unwrap())]
723 #[case::rk45(RungeKutta45::new().unwrap())]
724 fn rungekutta_sine<'a, const O: usize, R>(#[case] rk: TestRK<'a, O, R>)
725 where
726 R: RungeKuttaCoefficients<O, RealField = f64>,
727 {
728 let t_initial = 0.0;
729 let t_final = 10.0;
730
731 let solver = rk
732 .with_minimum_dt(0.001)
733 .unwrap()
734 .with_maximum_dt(0.01)
735 .unwrap()
736 .with_tolerance(0.0001)
737 .unwrap()
738 .with_initial_time(t_initial)
739 .unwrap()
740 .with_ending_time(t_final)
741 .unwrap()
742 .with_initial_conditions_slice(&[0.0])
743 .unwrap()
744 .with_derivative(sine_deriv)
745 .solve(())
746 .unwrap();
747
748 let path = solver.collect_vec().unwrap();
749
750 for step in &path {
751 assert!(approx_eq!(
752 f64,
753 step.1.column(0)[0],
754 step.0.sin(),
755 epsilon = 0.01
756 ));
757 }
758 }
759}