differential_equations/methods/erk/fixed/
ordinary.rs1use crate::{
3 error::Error,
4 interpolate::{Interpolation, cubic_hermite_interpolate},
5 methods::{ExplicitRungeKutta, Fixed, Ordinary},
6 ode::{ODE, OrdinaryNumericalMethod},
7 stats::Evals,
8 status::Status,
9 traits::{Real, State},
10 utils::validate_step_size_parameters,
11};
12
13impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize>
14 OrdinaryNumericalMethod<T, Y> for ExplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
15{
16 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
17 where
18 F: ODE<T, Y> + ?Sized,
19 {
20 let mut evals = Evals::new();
21
22 if self.h0 == T::zero() {
24 let duration = (tf - t0).abs();
26 let default_steps = T::from_usize(100).unwrap();
27 self.h0 = duration / default_steps;
28 }
29
30 match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
32 Ok(h0) => self.h = h0,
33 Err(status) => return Err(status),
34 } self.t = t0;
38 self.y = y0.clone();
39 self.dydt = y0.zeros_like();
40 self.y_prev = y0.clone();
41 self.dydt_prev = y0.zeros_like();
42 self.k = core::array::from_fn(|_| y0.zeros_like());
43 self.cont = core::array::from_fn(|_| y0.zeros_like());
44 ode.diff(self.t, &self.y, &mut self.dydt);
45 evals.function += 1;
46
47 self.t_prev = self.t;
49 self.y_prev = self.y.clone();
50 self.dydt_prev = self.dydt.clone();
51
52 self.status = Status::Initialized;
54
55 Ok(evals)
56 }
57
58 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, Y>>
59 where
60 F: ODE<T, Y> + ?Sized,
61 {
62 let mut evals = Evals::new();
63
64 if self.steps >= self.max_steps {
66 self.status = Status::Error(Error::MaxSteps {
67 t: self.t,
68 y: self.y.clone(),
69 });
70 return Err(Error::MaxSteps {
71 t: self.t,
72 y: self.y.clone(),
73 });
74 }
75 self.steps += 1;
76
77 self.k[0] = self.dydt.clone();
79
80 for i in 1..self.stages {
82 let mut y_stage = self.y.clone();
83
84 for j in 0..i {
85 y_stage.add_scaled(self.a[i][j] * self.h, &self.k[j]);
86 }
87
88 ode.diff(self.t + self.c[i] * self.h, &y_stage, &mut self.k[i]);
89 }
90 evals.function += self.stages - 1; self.t_prev = self.t;
94 self.y_prev = self.y.clone();
95 self.dydt_prev = self.k[0].clone();
96 self.h_prev = self.h;
97
98 let mut y_next = self.y.clone();
100 for i in 0..self.stages {
101 y_next.add_scaled(self.b[i] * self.h, &self.k[i]);
102 }
103
104 if self.bi.is_some() {
106 for i in 0..(I - S) {
108 let mut y_stage = self.y.clone();
109 for j in 0..self.stages + i {
110 y_stage.add_scaled(self.a[self.stages + i][j] * self.h, &self.k[j]);
111 }
112
113 ode.diff(
114 self.t + self.c[self.stages + i] * self.h,
115 &y_stage,
116 &mut self.k[self.stages + i],
117 );
118 }
119 evals.function += I - S;
120 }
121
122 self.t += self.h;
124 self.y = y_next;
125
126 if self.fsal {
128 self.dydt = self.k[S - 1].clone();
130 } else {
131 ode.diff(self.t, &self.y, &mut self.dydt);
133 evals.function += 1;
134 }
135
136 self.status = Status::Solving;
137 Ok(evals)
138 }
139
140 fn t(&self) -> T {
141 self.t
142 }
143 fn y(&self) -> &Y {
144 &self.y
145 }
146 fn t_prev(&self) -> T {
147 self.t_prev
148 }
149 fn y_prev(&self) -> &Y {
150 &self.y_prev
151 }
152 fn h(&self) -> T {
153 self.h
154 }
155 fn set_h(&mut self, h: T) {
156 self.h = h;
157 }
158 fn status(&self) -> &Status<T, Y> {
159 &self.status
160 }
161 fn set_status(&mut self, status: Status<T, Y>) {
162 self.status = status;
163 }
164}
165
166impl<T: Real, Y: State<T>, const O: usize, const S: usize, const I: usize> Interpolation<T, Y>
167 for ExplicitRungeKutta<Ordinary, Fixed, T, Y, O, S, I>
168{
169 fn interpolate(&mut self, t_interp: T) -> Result<Y, Error<T, Y>> {
170 if t_interp < self.t_prev || t_interp > self.t {
172 return Err(Error::OutOfBounds {
173 t_interp,
174 t_prev: self.t_prev,
175 t_curr: self.t,
176 });
177 }
178
179 if let Some(bi) = self.bi.as_ref() {
181 let s = (t_interp - self.t_prev) / self.h_prev;
183
184 let mut cont = [T::zero(); I];
185 for i in 0..self.dense_stages {
187 cont[i] = bi[i][self.order - 1];
189
190 for j in (0..self.order - 1).rev() {
192 cont[i] = cont[i] * s + bi[i][j];
193 }
194
195 cont[i] *= s;
197 }
198
199 let mut y_interp = self.y_prev.clone();
201 for i in 0..I {
202 y_interp.add_scaled(cont[i] * self.h_prev, &self.k[i]);
203 }
204
205 Ok(y_interp)
206 } else {
207 let y_interp = cubic_hermite_interpolate(
209 self.t_prev,
210 self.t,
211 &self.y_prev,
212 &self.y,
213 &self.dydt_prev,
214 &self.dydt,
215 t_interp,
216 );
217
218 Ok(y_interp)
219 }
220 }
221}