differential_equations/methods/erk/fixed/
ordinary.rs1use crate::{
4 error::Error,
5 interpolate::{Interpolation, cubic_hermite_interpolate},
6 methods::{ExplicitRungeKutta, Fixed, Ordinary},
7 ode::{ODE, OrdinaryNumericalMethod},
8 stats::Evals,
9 status::Status,
10 traits::{CallBackData, Real, State},
11 utils::validate_step_size_parameters,
12};
13
14impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
15 OrdinaryNumericalMethod<T, Y, D> for ExplicitRungeKutta<Ordinary, Fixed, T, Y, D, O, S, I>
16{
17 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
18 where
19 F: ODE<T, Y, D>,
20 {
21 let mut evals = Evals::new();
22
23 if self.h0 == T::zero() {
25 let duration = (tf - t0).abs();
27 let default_steps = T::from_usize(100).unwrap();
28 self.h0 = duration / default_steps;
29 }
30
31 match validate_step_size_parameters::<T, Y, D>(self.h0, self.h_min, self.h_max, t0, tf) {
33 Ok(h0) => self.h = h0,
34 Err(status) => return Err(status),
35 } self.t = t0;
39 self.y = *y0;
40 ode.diff(self.t, &self.y, &mut self.dydt);
41 evals.function += 1;
42
43 self.t_prev = self.t;
45 self.y_prev = self.y;
46 self.dydt_prev = self.dydt;
47
48 self.status = Status::Initialized;
50
51 Ok(evals)
52 }
53
54 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
55 where
56 F: ODE<T, Y, D>,
57 {
58 let mut evals = Evals::new();
59
60 if self.steps >= self.max_steps {
62 self.status = Status::Error(Error::MaxSteps {
63 t: self.t,
64 y: self.y,
65 });
66 return Err(Error::MaxSteps {
67 t: self.t,
68 y: self.y,
69 });
70 }
71 self.steps += 1;
72
73 self.k[0] = self.dydt;
75
76 for i in 1..self.stages {
78 let mut y_stage = self.y;
79
80 for j in 0..i {
81 y_stage += self.k[j] * (self.a[i][j] * self.h);
82 }
83
84 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
85 }
86 evals.function += self.stages - 1; self.t_prev = self.t;
90 self.y_prev = self.y;
91 self.dydt_prev = self.k[0];
92
93 let mut y_next = self.y;
95 for i in 0..self.stages {
96 y_next += self.k[i] * (self.b[i] * self.h);
97 }
98
99 if self.bi.is_some() {
101 for i in 0..(I - S) {
103 let mut y_stage = self.y;
104 for j in 0..self.stages + i {
105 y_stage += self.k[j] * (self.a[self.stages + i][j] * self.h);
106 }
107
108 ode.diff(
109 self.t + self.c[self.stages + i] * self.h,
110 &y_stage,
111 &mut self.k[self.stages + i],
112 );
113 }
114 evals.function += I - S;
115 }
116
117 self.t += self.h;
119 self.y = y_next;
120
121 if self.fsal {
123 self.dydt = self.k[S - 1];
125 } else {
126 ode.diff(self.t, &self.y, &mut self.dydt);
128 evals.function += 1;
129 }
130
131 self.status = Status::Solving;
132 Ok(evals)
133 }
134
135 fn t(&self) -> T {
136 self.t
137 }
138 fn y(&self) -> &Y {
139 &self.y
140 }
141 fn t_prev(&self) -> T {
142 self.t_prev
143 }
144 fn y_prev(&self) -> &Y {
145 &self.y_prev
146 }
147 fn h(&self) -> T {
148 self.h
149 }
150 fn set_h(&mut self, h: T) {
151 self.h = h;
152 }
153 fn status(&self) -> &Status<T, Y, D> {
154 &self.status
155 }
156 fn set_status(&mut self, status: Status<T, Y, D>) {
157 self.status = status;
158 }
159}
160
161impl<T: Real, Y: State<T>, D: CallBackData, const O: usize, const S: usize, const I: usize>
162 Interpolation<T, Y> for ExplicitRungeKutta<Ordinary, Fixed, T, Y, D, O, S, I>
163{
164 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
165 if t_interp < self.t_prev || t_interp > self.t {
167 return Err(Error::OutOfBounds {
168 t_interp,
169 t_prev: self.t_prev,
170 t_curr: self.t,
171 });
172 }
173
174 if self.bi.is_some() {
176 let s = (t_interp - self.t_prev) / self.h_prev;
178
179 let bi = self.bi.as_ref().unwrap();
181
182 let mut cont = [T::zero(); I];
183 for i in 0..self.dense_stages {
185 cont[i] = bi[i][self.order - 1];
187
188 for j in (0..self.order - 1).rev() {
190 cont[i] = cont[i] * s + bi[i][j];
191 }
192
193 cont[i] *= s;
195 }
196
197 let mut y_interp = self.y_prev;
199 for i in 0..I {
200 y_interp += self.k[i] * cont[i] * self.h_prev;
201 }
202
203 Ok(y_interp)
204 } else {
205 let y_interp = cubic_hermite_interpolate(
207 self.t_prev,
208 self.t,
209 &self.y_prev,
210 &self.y,
211 &self.dydt_prev,
212 &self.dydt,
213 t_interp,
214 );
215
216 Ok(y_interp)
217 }
218 }
219}