numerical_integration/
runge_kutta.rs

1use super::*;
2
3use std::fmt::{Debug, Formatter};
4
5#[derive(Clone, Copy, PartialEq)]
6pub enum RKError {
7    EmptyTableau,
8    JaggedTableau,
9    TooManyColumns(usize, usize),
10    NonSquareTableau(usize, usize),
11    UnsupportedImplicit
12}
13
14impl Debug for RKError {
15    fn fmt(&self, f: &mut Formatter) -> ::std::fmt::Result {
16        match self {
17            RKError::EmptyTableau =>
18                write!(f, "Zero-length Runge-Kutta matrix"),
19            RKError::JaggedTableau =>
20                write!(f, "Tableau is non-rectangular"),
21            RKError::TooManyColumns(r, c) =>
22                write!(f, "Tableau has {} rows but {} columns", r, c),
23            RKError::NonSquareTableau(r, c) =>
24                write!(f, "Non-square tableau; number of rows is {} but there is a row of length {}", r, c),
25            RKError::UnsupportedImplicit =>
26                write!(f, "Implicit Runge-Kutta not supported")
27        }
28    }
29}
30
31#[derive(Clone, Copy, PartialEq, Debug)]
32pub enum ButcherTableau<'a> {
33    Fixed(&'a[&'a[f64]]),
34    Adaptive(&'a[&'a[f64]]),
35    Implicit(&'a[&'a[f64]]),
36    AdaptiveImplicit(&'a[&'a[f64]])
37}
38
39impl<'a> ButcherTableau<'a> {
40    fn new(table: &'a[&'a[f64]]) -> Result<Self, RKError> {
41        use ButcherTableau::*;
42        use RKError::*;
43
44        //make sure the tableau is non-empty
45        if table.len()==0 {
46            Err(EmptyTableau)
47        } else {
48            let rows = table.len();
49            let columns = table[0].len();
50
51            //make sure we have enough rows
52            if columns>rows { return Err(TooManyColumns(rows, columns)); }
53
54            //check if the tableau is of an implict method and make sure we have a non-jagged array
55            let mut implicit = false;
56            for i in 0..rows {
57                if table[i].len()!=columns { return Err(JaggedTableau); }
58                for j in i..columns {
59                    if table[i][j] != 0.0 {
60                        implicit = true;
61                        break;
62                    }
63                }
64            }
65
66            Ok(match (rows>columns, implicit) {
67                (false, false) => Fixed(table),
68                (true, false) => Adaptive(table),
69                (false, true) => Implicit(table),
70                (true, true) => AdaptiveImplicit(table),
71            })
72
73        }
74    }
75}
76
77
78#[derive(Clone, Copy, PartialEq, Debug)]
79pub struct RungeKutta<'a>(&'a[&'a[f64]]);
80
81#[derive(Clone, Copy, PartialEq, Debug)]
82pub struct AdaptiveRungeKutta<'a>(&'a[&'a[f64]]);
83
84pub const EULER: RungeKutta = RK1;
85pub const MIDPOINT: RungeKutta = RK2;
86pub const RK1: RungeKutta = RungeKutta(
87    &[&[0.0,0.0],
88      &[0.0,1.0]]
89);
90pub const RK2: RungeKutta = RungeKutta(
91    &[&[0.0,0.0,0.0],
92      &[0.5,0.5,0.0],
93      &[0.0,0.0,1.0]]
94);
95pub const HEUN2: RungeKutta = RungeKutta(
96    &[&[0.0,0.0,0.0],
97      &[1.0,1.0,0.0],
98      &[0.0,0.5,0.5]]
99);
100pub const RALSTON: RungeKutta = RungeKutta(
101    &[&[0.0,    0.0,    0.0 ],
102      &[2.0/3.0,2.0/3.0,0.0 ],
103      &[0.0,    0.25,   0.75]]
104);
105pub const RK3: RungeKutta = RungeKutta(
106    &[&[0.0, 0.0,     0.0,     0.0],
107      &[0.5, 0.5,     0.0,     0.0],
108      &[1.0,-1.0,     2.0,     0.0],
109      &[0.0, 1.0/6.0, 2.0/3.0, 1.0/6.0]]
110);
111pub const HEUN3: RungeKutta = RungeKutta(
112    &[&[0.0,    0.0,    0.0,    0.0],
113      &[1.0/3.0,1.0/3.0,0.0,    0.0],
114      &[2.0/3.0,0.0,    2.0/3.0,0.0],
115      &[0.0,    0.25,   0.0,    0.75]]
116);
117pub const RK4: RungeKutta = RungeKutta(
118    &[&[0.0, 0.0,     0.0,     0.0,     0.0],
119      &[0.5, 0.5,     0.0,     0.0,     0.0],
120      &[0.5, 0.0,     0.5,     0.0,     0.0],
121      &[1.0, 0.0,     0.0,     1.0,     0.0],
122      &[0.0, 1.0/6.0, 1.0/3.0, 1.0/3.0, 1.0/6.0]]
123);
124pub const RK_3_8: RungeKutta = RungeKutta(
125    &[&[0.0,     0.0,      0.0,   0.0,   0.0],
126      &[1.0/3.0, 1.0/3.0,  0.0,   0.0,   0.0],
127      &[2.0/3.0, -1.0/3.0, 1.0,   0.0,   0.0],
128      &[1.0,     1.0,      -1.0,  1.0,   0.0],
129      &[0.0,     0.125,    0.375, 0.375, 0.125]]
130);
131
132pub const EULER_HEUN: AdaptiveRungeKutta = AdaptiveRungeKutta(
133    &[&[0.0, 0.0, 0.0],
134      &[1.0, 1.0, 0.0],
135      &[0.0, 0.5, 0.5],
136      &[0.0, 1.0, 0.0]]
137);
138
139pub const BOGACKI_SHAMPINE: AdaptiveRungeKutta = AdaptiveRungeKutta(
140    &[&[0.0,  0.0,      0.0,     0.0,     0.0],
141      &[0.5,  0.5,      0.0,     0.0,     0.0],
142      &[0.75, 0.0,      0.75,    0.0,     0.0],
143      &[1.0,  2.0/9.0,  1.0/3.0, 4.0/9.0, 0.0],
144      &[0.0,  2.0/9.0,  1.0/3.0, 4.0/9.0, 0.0],
145      &[0.0,  7.0/24.0, 0.25,    1.0/3.0, 0.125]]
146);
147
148pub const RK_FELBERG: AdaptiveRungeKutta = AdaptiveRungeKutta(
149    &[&[0.0,       0.0,            0.0,            0.0,            0.0,              0.0,       0.0],
150      &[0.25,      0.25,           0.0,            0.0,            0.0,              0.0,       0.0],
151      &[0.375,     3.0/32.0,       9.0/32.0,       0.0,            0.0,              0.0,       0.0],
152      &[12.0/13.0, 1932.0/2197.0, -7200.0/2197.0,  7296.0/2197.0,  0.0,              0.0,       0.0],
153      &[1.0,       439.0/216.0,   -8.0,            3680.0/513.0,  -845.0/4104.0,     0.0,       0.0],
154      &[0.5,      -8.0/27.0,       2.0,           -3544.0/2565.0,  1859.0/4104.0,   -11.0/40.0, 0.0],
155      &[0.0,       16.0/135.0,     0.0,            6656.0/12825.0, 28561.0/56430.0, -9.0/50.0,  2.0/55.0],
156      &[0.0,       25.0/216.0,     0.0,            1408.0/2565.0,  2197.0/4104.0,   -1.0/5.0,   0.0]]
157);
158
159pub const DORMAND_PRINCE: AdaptiveRungeKutta = AdaptiveRungeKutta(
160    &[&[0.0,     0.0,             0.0,            0.0,             0.0,          0.0,              0.0,          0.0],
161      &[0.2,     0.2,             0.0,            0.0,             0.0,          0.0,              0.0,          0.0],
162      &[0.3,     3.0/40.0,        9.0/40.0,       0.0,             0.0,          0.0,              0.0,          0.0],
163      &[0.4,     44.0/45.0,      -56.0/15.0,      32.0/9.0,        0.0,          0.0,              0.0,          0.0],
164      &[8.0/9.0, 19372.0/6561.0, -25360.0/2187.0, 64448.0/6561.0, -212.0/729.0,  0.0,              0.0,          0.0],
165      &[1.0,     9017.0/3168.0,  -355.0/33.0,     46732.0/5247.0,  49.0/176.0,  -5103.0/18656.0,   0.0,          0.0],
166      &[1.0,     35.0/384.0,      0.0,            500.0/1113.0,    125.0/192.0, -2187.0/6784.0,    11.0/84.0,    0.0],
167      &[0.0,     35.0/384.0,      0.0,            500.0/1113.0,    125.0/192.0, -2187.0/6784.0,    11.0/84.0,    0.0],
168      &[0.0,     5179.0/57600.0,  0.0,            7571.0/16695.0,  393.0/640.0, -92097.0/339200.0, 187.0/2100.0, 1.0/4.0]]
169);
170
171
172impl<'a> RungeKutta<'a> {
173    pub fn order(&self) -> usize {self.0.len()-1}
174    pub fn from_matrix(rk_matrix: &'a[&'a[f64]]) -> Result<Self, RKError> {
175        match ButcherTableau::new(rk_matrix)? {
176            ButcherTableau::Fixed(t) => Ok(RungeKutta(t)),
177            ButcherTableau::Implicit(_) => Err(RKError::UnsupportedImplicit),
178            _ => Err(RKError::NonSquareTableau(rk_matrix.len(), rk_matrix[0].len()))
179        }
180    }
181}
182
183impl<'a> AdaptiveRungeKutta<'a> {
184    pub fn order(&self) -> usize {self.0[0].len()-1}
185    pub fn from_matrix(rk_matrix: &'a[&'a[f64]]) -> Result<Self, RKError> {
186        match ButcherTableau::new(rk_matrix)? {
187            ButcherTableau::Adaptive(t) => Ok(AdaptiveRungeKutta(t)),
188            ButcherTableau::Fixed(t) => Err(RKError::TooManyColumns(t.len(), t[0].len())),
189            _ => Err(RKError::UnsupportedImplicit),
190        }
191    }
192}
193
194impl<'a> VelIntegrator for RungeKutta<'a> {
195    fn step_with_vel<R:Real, S:VectorSpace<R>, V:Fn(R,S)->S, F:Fn(R,S)->S>(&self, time:R, state: &mut [S], dt:R, _:V, force:F) -> S {
196        Integrator::step(self, time, state, dt, force)
197    }
198}
199
200fn compute_k<R:Real, S:VectorSpace<R>, F:Fn(R, S) -> S>(tableau: &[&[f64]], time:R, state:&S, dt:R, force: F) -> Vec<S> {
201    let order = tableau[0].len()-1;
202    let mut k:Vec<S> = Vec::with_capacity(order);
203
204    for i in 0..order {
205        let t = time.clone() + dt.clone() * R::repr(tableau[i][0]);
206        let mut y_i = state.clone();
207        for j in 1..=i {
208            if tableau[i][j]!=0.0 {
209                y_i += k[j-1].clone() * (dt.clone() * R::repr(tableau[i][j]));
210            }
211        }
212        k.push(force(t, y_i));
213    }
214
215    k
216}
217
218impl<'a> Integrator for RungeKutta<'a> {
219    fn step<R:Real, S:VectorSpace<R>, F:Fn(R, S) -> S>(&self, time:R, state: &mut [S], dt:R, force: F) -> S {
220
221        let order = self.order();
222        let k:Vec<S> = compute_k(self.0, time, &state[0], dt.clone(), force);
223
224        let mut j = 1;
225        for k_j in k {
226            if self.0[order][j]!=0.0 { state[0] += k_j * (dt.clone()*R::repr(self.0[order][j]));}
227            j += 1;
228        }
229
230        state[0].clone()
231    }
232}
233
234impl<'a> AdaptiveIntegrator for AdaptiveRungeKutta<'a> {
235    fn adaptive_init<R:Real, S:VectorSpace<R>, M:Metric<S,R>, F:Fn(R, S) -> S>(&self, t0:R, state: S, ds:R, _force:F, _d:M) -> Box<[(R,S)]>{
236        Box::new([(t0, state.clone()), (ds, state.clone())])
237    }
238
239    fn adaptive_step<R:Real, S:VectorSpace<R>, M:Metric<S,R>, F:Fn(R, S) -> S>(&self, state: &mut [(R,S)], ds:R, force:F, d:M) -> (R,S) {
240        let order = self.order();
241        let mut dt = state[1].0.clone();
242        let time = state[0].0.clone();
243
244        loop {
245            let k:Vec<S> = compute_k(self.0, time.clone(), &state[0].1, dt.clone(), &force);
246
247            let mut est1 = state[0].1.clone();
248            let mut est2 = state[0].1.clone();
249
250            let mut j = 1;
251            for k_j in k {
252                if self.0[order][j]!=0.0 { est1 += k_j.clone() * (dt.clone()*R::repr(self.0[order][j]));}
253                if self.0[order+1][j]!=0.0 { est2 += k_j * (dt.clone()*R::repr(self.0[order+1][j]));}
254                j += 1;
255            }
256
257            let err = d.distance(est1.clone(), est2.clone());
258
259            if err < ds {
260                let next_dt = dt.clone() * R::repr(1.5);
261                state[0].0 += dt;
262                state[0].1 = est1;
263                state[1] = (next_dt, est2);
264                return state[0].clone();
265            } else {
266                dt *= R::repr(0.5);
267            }
268        }
269
270    }
271}