1use super::*;
4use crate::{linalg::norm, ode::methods::h_init};
5
6pub struct APCV4<T: Real, V: State<T>, D: CallBackData> {
58 pub h0: T,
60
61 h: T,
63
64 t: T,
66 y: V,
67 dydt: V,
68
69 tf: T,
71
72 t_prev: [T; 4],
74 y_prev: [V; 4],
75
76 t_old: T,
78 y_old: V,
79 dydt_old: V,
80
81 k1: V,
83 k2: V,
84 k3: V,
85 k4: V,
86
87 evals: usize,
89 steps: usize,
90
91 status: Status<T, V, D>,
93
94 pub tol: T,
96 pub h_max: T,
97 pub h_min: T,
98 pub max_steps: usize,
99}
100
101impl<T: Real, V: State<T>, D: CallBackData> ODENumericalMethod<T, V, D> for APCV4<T, V, D> {
103 fn init<F>(&mut self, ode: &F, t0: T, tf: T, y0: &V) -> Result<Evals, Error<T, V>>
104 where
105 F: ODE<T, V, D>,
106 {
107 let mut evals = Evals::new();
108
109 self.tf = tf;
110
111 if self.h0 == T::zero() {
113 self.h0 = h_init(
114 ode, t0, tf, y0, 4, self.tol, self.tol, self.h_min, self.h_max,
115 );
116 }
117
118 match validate_step_size_parameters::<T, V, D>(self.h0, T::zero(), T::infinity(), t0, tf) {
120 Ok(h0) => self.h = h0,
121 Err(status) => return Err(status),
122 }
123
124 self.t = t0;
126 self.y = *y0;
127 self.t_prev[0] = t0;
128 self.y_prev[0] = *y0;
129
130 self.t_old = t0;
132 self.y_old = *y0;
133
134 let two = T::from_f64(2.0).unwrap();
136 let six = T::from_f64(6.0).unwrap();
137 for i in 1..=3 {
138 ode.diff(self.t, &self.y, &mut self.k1);
140 ode.diff(
141 self.t + self.h / two,
142 &(self.y + self.k1 * (self.h / two)),
143 &mut self.k2,
144 );
145 ode.diff(
146 self.t + self.h / two,
147 &(self.y + self.k2 * (self.h / two)),
148 &mut self.k3,
149 );
150 ode.diff(self.t + self.h, &(self.y + self.k3 * self.h), &mut self.k4);
151
152 self.y += (self.k1 + self.k2 * two + self.k3 * two + self.k4) * (self.h / six);
154 self.t += self.h;
155 self.t_prev[i] = self.t;
156 self.y_prev[i] = self.y;
157 evals.fcn += 4; if i == 1 {
160 self.dydt = self.k1;
161 self.dydt_old = self.k1;
162 }
163 }
164
165 self.status = Status::Initialized;
166 Ok(evals)
167 }
168
169 fn step<F>(&mut self, ode: &F) -> Result<Evals, Error<T, V>>
170 where
171 F: ODE<T, V, D>,
172 {
173 let mut evals = Evals::new();
174
175 if self.steps >= self.max_steps {
177 self.status = Status::Error(Error::MaxSteps {
178 t: self.t,
179 y: self.y,
180 });
181 return Err(Error::MaxSteps {
182 t: self.t,
183 y: self.y,
184 });
185 }
186 self.steps += 1;
187
188 if self.h != self.t_prev[0] - self.t_prev[1] && self.t + self.h == self.tf {
190 let two = T::from_f64(2.0).unwrap();
191 let six = T::from_f64(6.0).unwrap();
192
193 ode.diff(self.t, &self.y, &mut self.k1);
195 ode.diff(
196 self.t + self.h / two,
197 &(self.y + self.k1 * (self.h / two)),
198 &mut self.k2,
199 );
200 ode.diff(
201 self.t + self.h / two,
202 &(self.y + self.k2 * (self.h / two)),
203 &mut self.k3,
204 );
205 ode.diff(self.t + self.h, &(self.y + self.k3 * self.h), &mut self.k4);
206 evals.fcn += 4; self.y += (self.k1 + self.k2 * two + self.k3 * two + self.k4) * (self.h / six);
210 self.t += self.h;
211 return Ok(evals);
212 }
213
214 ode.diff(self.t_prev[3], &self.y_prev[3], &mut self.k1);
216 ode.diff(self.t_prev[2], &self.y_prev[2], &mut self.k2);
217 ode.diff(self.t_prev[1], &self.y_prev[1], &mut self.k3);
218 ode.diff(self.t_prev[0], &self.y_prev[0], &mut self.k4);
219
220 let predictor = self.y_prev[3]
221 + (self.k1 * T::from_f64(55.0).unwrap() - self.k2 * T::from_f64(59.0).unwrap()
222 + self.k3 * T::from_f64(37.0).unwrap()
223 - self.k4 * T::from_f64(9.0).unwrap())
224 * self.h
225 / T::from_f64(24.0).unwrap();
226
227 ode.diff(self.t + self.h, &predictor, &mut self.k4);
229 let corrector = self.y_prev[3]
230 + (self.k4 * T::from_f64(9.0).unwrap() + self.k1 * T::from_f64(19.0).unwrap()
231 - self.k2 * T::from_f64(5.0).unwrap()
232 + self.k3 * T::from_f64(1.0).unwrap())
233 * self.h
234 / T::from_f64(24.0).unwrap();
235
236 evals.fcn += 5;
238
239 let sigma = T::from_f64(19.0).unwrap() * norm(corrector - predictor)
241 / (T::from_f64(270.0).unwrap() * self.h.abs());
242
243 if sigma <= self.tol {
245 self.t_old = self.t;
247 self.y_old = self.y;
248 self.dydt_old = self.dydt;
249
250 self.t += self.h;
252 self.y = corrector;
253
254 if let Status::RejectedStep = self.status {
256 self.status = Status::Solving;
257 }
258
259 let two = T::from_f64(2.0).unwrap();
261 let four = T::from_f64(4.0).unwrap();
262 let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
263 self.h = if q > four { four * self.h } else { q * self.h };
264
265 let tf_t_abs = (self.tf - self.t).abs();
267 let four_div = tf_t_abs / four;
268 let h_max_effective = if self.h_max < four_div {
269 self.h_max
270 } else {
271 four_div
272 };
273
274 self.h = constrain_step_size(self.h, self.h_min, h_max_effective);
275
276 self.t_prev[0] = self.t;
278 self.y_prev[0] = self.y;
279 let two = T::from_f64(2.0).unwrap();
280 let six = T::from_f64(6.0).unwrap();
281 for i in 1..=3 {
282 ode.diff(self.t, &self.y, &mut self.k1);
284 ode.diff(
285 self.t + self.h / two,
286 &(self.y + self.k1 * (self.h / two)),
287 &mut self.k2,
288 );
289 ode.diff(
290 self.t + self.h / two,
291 &(self.y + self.k2 * (self.h / two)),
292 &mut self.k3,
293 );
294 ode.diff(self.t + self.h, &(self.y + self.k3 * self.h), &mut self.k4);
295
296 self.y += (self.k1 + self.k2 * two + self.k3 * two + self.k4) * (self.h / six);
298 self.t += self.h;
299 self.t_prev[i] = self.t;
300 self.y_prev[i] = self.y;
301 self.evals += 4; if i == 1 {
304 self.dydt = self.k1;
305 }
306 }
307 } else {
308 self.status = Status::RejectedStep;
310
311 let two = T::from_f64(2.0).unwrap();
313 let tenth = T::from_f64(0.1).unwrap();
314 let q = (self.tol / (two * sigma)).powf(T::from_f64(0.25).unwrap());
315 self.h = if q < tenth {
316 tenth * self.h
317 } else {
318 q * self.h
319 };
320
321 self.t_prev[0] = self.t;
323 self.y_prev[0] = self.y;
324 let two = T::from_f64(2.0).unwrap();
325 let six = T::from_f64(6.0).unwrap();
326 for i in 1..=3 {
327 ode.diff(self.t, &self.y, &mut self.k1);
329 ode.diff(
330 self.t + self.h / two,
331 &(self.y + self.k1 * (self.h / two)),
332 &mut self.k2,
333 );
334 ode.diff(
335 self.t + self.h / two,
336 &(self.y + self.k2 * (self.h / two)),
337 &mut self.k3,
338 );
339 ode.diff(self.t + self.h, &(self.y + self.k3 * self.h), &mut self.k4);
340
341 self.y += (self.k1 + self.k2 * two + self.k3 * two + self.k4) * (self.h / six);
343 self.t += self.h;
344 self.t_prev[i] = self.t;
345 self.y_prev[i] = self.y;
346 self.evals += 4; }
348 }
349 Ok(evals)
350 }
351
352 fn t(&self) -> T {
353 self.t
354 }
355
356 fn y(&self) -> &V {
357 &self.y
358 }
359
360 fn t_prev(&self) -> T {
361 self.t_old
362 }
363
364 fn y_prev(&self) -> &V {
365 &self.y_old
366 }
367
368 fn h(&self) -> T {
369 self.h * T::from_f64(4.0).unwrap()
373 }
374
375 fn set_h(&mut self, h: T) {
376 self.h = h;
377 }
378
379 fn status(&self) -> &Status<T, V, D> {
380 &self.status
381 }
382
383 fn set_status(&mut self, status: Status<T, V, D>) {
384 self.status = status;
385 }
386}
387
388impl<T: Real, V: State<T>, D: CallBackData> Interpolation<T, V> for APCV4<T, V, D> {
390 fn interpolate(&mut self, t_interp: T) -> Result<V, Error<T, V>> {
391 if t_interp < self.t_old || t_interp > self.t {
393 return Err(Error::OutOfBounds {
394 t_interp,
395 t_prev: self.t_old,
396 t_curr: self.t,
397 });
398 }
399
400 let y_interp = cubic_hermite_interpolate(
402 self.t_old,
403 self.t,
404 &self.y_old,
405 &self.y,
406 &self.dydt_old,
407 &self.dydt,
408 t_interp,
409 );
410
411 Ok(y_interp)
412 }
413}
414
415impl<T: Real, V: State<T>, D: CallBackData> APCV4<T, V, D> {
417 pub fn new() -> Self {
418 APCV4 {
419 ..Default::default()
420 }
421 }
422
423 pub fn h0(mut self, h0: T) -> Self {
424 self.h0 = h0;
425 self
426 }
427
428 pub fn tol(mut self, tol: T) -> Self {
429 self.tol = tol;
430 self
431 }
432
433 pub fn h_min(mut self, h_min: T) -> Self {
434 self.h_min = h_min;
435 self
436 }
437
438 pub fn h_max(mut self, h_max: T) -> Self {
439 self.h_max = h_max;
440 self
441 }
442
443 pub fn max_steps(mut self, max_steps: usize) -> Self {
444 self.max_steps = max_steps;
445 self
446 }
447}
448
449impl<T: Real, V: State<T>, D: CallBackData> Default for APCV4<T, V, D> {
450 fn default() -> Self {
451 APCV4 {
452 h0: T::zero(),
453 h: T::zero(),
454 t: T::zero(),
455 y: V::zeros(),
456 dydt: V::zeros(),
457 t_prev: [T::zero(); 4],
458 y_prev: [V::zeros(), V::zeros(), V::zeros(), V::zeros()],
459 t_old: T::zero(),
460 y_old: V::zeros(),
461 dydt_old: V::zeros(),
462 k1: V::zeros(),
463 k2: V::zeros(),
464 k3: V::zeros(),
465 k4: V::zeros(),
466 tf: T::zero(),
467 evals: 0,
468 steps: 0,
469 status: Status::Uninitialized,
470 tol: T::from_f64(1.0e-6).unwrap(),
471 h_max: T::infinity(),
472 h_min: T::zero(),
473 max_steps: 1_000_000,
474 }
475 }
476}