1use crate::error::SolverError;
19use crate::problem::OdeSystem;
20use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
21use crate::t_eval::{validate_grid, TEvalEmitter};
22use numra_core::Scalar;
23
24#[derive(Clone, Debug, Default)]
30pub struct Vern6;
31
32impl Vern6 {
33 pub fn new() -> Self {
34 Self
35 }
36}
37
38#[allow(dead_code)]
42mod vern6_tableau {
43 pub const C: [f64; 10] = [
45 0.0,
46 0.06,
47 0.09593333333333333,
48 0.1439,
49 0.4973,
50 0.9725,
51 0.9995,
52 1.0,
53 1.0,
54 0.5,
55 ];
56
57 pub const A21: f64 = 0.06;
60
61 pub const A31: f64 = 0.019239962962962962;
62 pub const A32: f64 = 0.07669337037037037;
63
64 pub const A41: f64 = 0.035975;
65 pub const A42: f64 = 0.0;
66 pub const A43: f64 = 0.107925;
67
68 pub const A51: f64 = 1.3186834152331484;
69 pub const A52: f64 = 0.0;
70 pub const A53: f64 = -5.042058063628562;
71 pub const A54: f64 = 4.220674648395414;
72
73 pub const A61: f64 = -41.872591664327516;
74 pub const A62: f64 = 0.0;
75 pub const A63: f64 = 159.4325621631375;
76 pub const A64: f64 = -122.11921356501003;
77 pub const A65: f64 = 5.531743066200054;
78
79 pub const A71: f64 = -54.430156935316504;
80 pub const A72: f64 = 0.0;
81 pub const A73: f64 = 207.06725136501848;
82 pub const A74: f64 = -158.61081378459;
83 pub const A75: f64 = 6.991816585950242;
84 pub const A76: f64 = -0.018597231062203234;
85
86 pub const A81: f64 = -54.66374178728198;
87 pub const A82: f64 = 0.0;
88 pub const A83: f64 = 207.95280625538936;
89 pub const A84: f64 = -159.2889574744995;
90 pub const A85: f64 = 7.018743740796944;
91 pub const A86: f64 = -0.018338785905045722;
92 pub const A87: f64 = -0.0005119484997882099;
93
94 pub const A91: f64 = 0.03438957868357036;
95 pub const A92: f64 = 0.0;
96 pub const A93: f64 = 0.0;
97 pub const A94: f64 = 0.2582624555633503;
98 pub const A95: f64 = 0.4209371189673537;
99 pub const A96: f64 = 4.40539646966931;
100 pub const A97: f64 = -176.48311902429865;
101 pub const A98: f64 = 172.36413340141507;
102
103 pub const A101: f64 = 0.016524159013572806;
105 pub const A102: f64 = 0.0;
106 pub const A103: f64 = 0.0;
107 pub const A104: f64 = 0.3053128187514179;
108 pub const A105: f64 = 0.2071200938201979;
109 pub const A106: f64 = -1.293879140655123;
110 pub const A107: f64 = 57.11988411588149;
111 pub const A108: f64 = -55.87979207510932;
112 pub const A109: f64 = 0.024830028297766014;
113
114 pub const B: [f64; 10] = [
116 0.03438957868357036,
117 0.0,
118 0.0,
119 0.2582624555633503,
120 0.4209371189673537,
121 4.40539646966931,
122 -176.48311902429865,
123 172.36413340141507,
124 0.0,
125 0.0,
126 ];
127
128 pub const B_HAT: [f64; 10] = [
130 0.0490996764838249,
131 0.0,
132 0.0,
133 0.22511122295165242,
134 0.4694682253029562,
135 0.8065792249988868,
136 0.0,
137 -0.607119489177796,
138 0.056861139440475696,
139 0.0,
140 ];
141
142 pub const E: [f64; 10] = [
144 B[0] - B_HAT[0], 0.0,
146 0.0,
147 B[3] - B_HAT[3], B[4] - B_HAT[4], B[5] - B_HAT[5], B[6] - B_HAT[6], B[7] - B_HAT[7], B[8] - B_HAT[8], 0.0,
154 ];
155}
156
157impl<S: Scalar> Solver<S> for Vern6 {
158 fn solve<Sys: OdeSystem<S>>(
159 problem: &Sys,
160 t0: S,
161 tf: S,
162 y0: &[S],
163 options: &SolverOptions<S>,
164 ) -> Result<SolverResult<S>, SolverError> {
165 use vern6_tableau::*;
166
167 let dim = problem.dim();
168 if dim == 0 {
169 return Err(SolverError::DimensionMismatch {
170 expected: 1,
171 actual: 0,
172 });
173 }
174
175 let mut t = t0;
176 let mut y = y0.to_vec();
177 let direction = if tf >= t0 { S::ONE } else { -S::ONE };
178
179 let mut k = vec![S::ZERO; 10 * dim];
182 let mut y_stage = vec![S::ZERO; dim];
183 let mut y_new = vec![S::ZERO; dim];
184 let mut err = vec![S::ZERO; dim];
185
186 let mut stats = SolverStats::default();
188
189 if let Some(grid) = options.t_eval.as_deref() {
190 validate_grid(grid, t0, tf)?;
191 }
192 let mut grid_emitter = options
193 .t_eval
194 .as_deref()
195 .map(|g| TEvalEmitter::new(g, direction));
196 let (mut t_out, mut y_out) = if grid_emitter.is_some() {
197 (Vec::new(), Vec::new())
198 } else {
199 (vec![t0], y0.to_vec())
200 };
201
202 problem.rhs(t, &y, &mut k[0..dim]);
204 stats.n_eval += 1;
205
206 let mut f_norm = S::ZERO;
207 for i in 0..dim {
208 f_norm = f_norm + k[i] * k[i];
209 }
210 f_norm = f_norm.sqrt();
211
212 let mut h = if f_norm > S::from_f64(1e-10) {
213 S::from_f64(0.01) * (S::ONE / f_norm)
214 } else {
215 S::from_f64(0.01)
216 };
217 h = h.min(options.h_max).max(options.h_min) * direction;
218
219 let max_steps = options.max_steps;
220 let mut step_count = 0;
221
222 while (tf - t) * direction > S::from_f64(1e-14) * (tf - t0).abs() {
223 if step_count >= max_steps {
224 return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
225 }
226
227 if (t + h - tf) * direction > S::ZERO {
229 h = tf - t;
230 }
231
232 let c2 = S::from_f64(C[1]);
237 let a21 = S::from_f64(A21);
238 for i in 0..dim {
239 y_stage[i] = y[i] + h * a21 * k[i];
240 }
241 problem.rhs(t + c2 * h, &y_stage, &mut k[dim..2 * dim]);
242
243 let c3 = S::from_f64(C[2]);
245 let a31 = S::from_f64(A31);
246 let a32 = S::from_f64(A32);
247 for i in 0..dim {
248 y_stage[i] = y[i] + h * (a31 * k[i] + a32 * k[dim + i]);
249 }
250 problem.rhs(t + c3 * h, &y_stage, &mut k[2 * dim..3 * dim]);
251
252 let c4 = S::from_f64(C[3]);
254 let a41 = S::from_f64(A41);
255 let a43 = S::from_f64(A43);
256 for i in 0..dim {
257 y_stage[i] = y[i] + h * (a41 * k[i] + a43 * k[2 * dim + i]);
258 }
259 problem.rhs(t + c4 * h, &y_stage, &mut k[3 * dim..4 * dim]);
260
261 let c5 = S::from_f64(C[4]);
263 let a51 = S::from_f64(A51);
264 let a53 = S::from_f64(A53);
265 let a54 = S::from_f64(A54);
266 for i in 0..dim {
267 y_stage[i] = y[i] + h * (a51 * k[i] + a53 * k[2 * dim + i] + a54 * k[3 * dim + i]);
268 }
269 problem.rhs(t + c5 * h, &y_stage, &mut k[4 * dim..5 * dim]);
270
271 let c6 = S::from_f64(C[5]);
273 let a61 = S::from_f64(A61);
274 let a63 = S::from_f64(A63);
275 let a64 = S::from_f64(A64);
276 let a65 = S::from_f64(A65);
277 for i in 0..dim {
278 y_stage[i] = y[i]
279 + h * (a61 * k[i]
280 + a63 * k[2 * dim + i]
281 + a64 * k[3 * dim + i]
282 + a65 * k[4 * dim + i]);
283 }
284 problem.rhs(t + c6 * h, &y_stage, &mut k[5 * dim..6 * dim]);
285
286 let c7 = S::from_f64(C[6]);
288 let a71 = S::from_f64(A71);
289 let a73 = S::from_f64(A73);
290 let a74 = S::from_f64(A74);
291 let a75 = S::from_f64(A75);
292 let a76 = S::from_f64(A76);
293 for i in 0..dim {
294 y_stage[i] = y[i]
295 + h * (a71 * k[i]
296 + a73 * k[2 * dim + i]
297 + a74 * k[3 * dim + i]
298 + a75 * k[4 * dim + i]
299 + a76 * k[5 * dim + i]);
300 }
301 problem.rhs(t + c7 * h, &y_stage, &mut k[6 * dim..7 * dim]);
302
303 let c8 = S::from_f64(C[7]);
305 let a81 = S::from_f64(A81);
306 let a83 = S::from_f64(A83);
307 let a84 = S::from_f64(A84);
308 let a85 = S::from_f64(A85);
309 let a86 = S::from_f64(A86);
310 let a87 = S::from_f64(A87);
311 for i in 0..dim {
312 y_stage[i] = y[i]
313 + h * (a81 * k[i]
314 + a83 * k[2 * dim + i]
315 + a84 * k[3 * dim + i]
316 + a85 * k[4 * dim + i]
317 + a86 * k[5 * dim + i]
318 + a87 * k[6 * dim + i]);
319 }
320 problem.rhs(t + c8 * h, &y_stage, &mut k[7 * dim..8 * dim]);
321
322 let a91 = S::from_f64(A91);
324 let a94 = S::from_f64(A94);
325 let a95 = S::from_f64(A95);
326 let a96 = S::from_f64(A96);
327 let a97 = S::from_f64(A97);
328 let a98 = S::from_f64(A98);
329 for i in 0..dim {
330 y_new[i] = y[i]
331 + h * (a91 * k[i]
332 + a94 * k[3 * dim + i]
333 + a95 * k[4 * dim + i]
334 + a96 * k[5 * dim + i]
335 + a97 * k[6 * dim + i]
336 + a98 * k[7 * dim + i]);
337 }
338 problem.rhs(t + h, &y_new, &mut k[8 * dim..9 * dim]);
339
340 stats.n_eval += 8;
341
342 let e = &E;
344 for i in 0..dim {
345 err[i] = h
346 * (S::from_f64(e[0]) * k[i]
347 + S::from_f64(e[3]) * k[3 * dim + i]
348 + S::from_f64(e[4]) * k[4 * dim + i]
349 + S::from_f64(e[5]) * k[5 * dim + i]
350 + S::from_f64(e[6]) * k[6 * dim + i]
351 + S::from_f64(e[7]) * k[7 * dim + i]
352 + S::from_f64(e[8]) * k[8 * dim + i]);
353 }
354
355 let mut err_norm = S::ZERO;
357 for i in 0..dim {
358 let sc = options.atol + options.rtol * y[i].abs().max(y_new[i].abs());
359 let ratio = err[i] / sc;
360 err_norm = err_norm + ratio * ratio;
361 }
362 err_norm = (err_norm / S::from_f64(dim as f64)).sqrt();
363
364 if err_norm <= S::ONE {
365 stats.n_accept += 1;
367
368 let t_new = t + h;
369 if let Some(ref mut emitter) = grid_emitter {
370 let (dy_start, dy_end_block) = k.split_at(dim);
373 let dy_end = &dy_end_block[(8 - 1) * dim..8 * dim];
374 emitter.emit_step(
375 t, &y, dy_start, t_new, &y_new, dy_end, &mut t_out, &mut y_out,
376 );
377 } else {
378 t_out.push(t_new);
379 y_out.extend_from_slice(&y_new);
380 }
381
382 t = t_new;
383 y.copy_from_slice(&y_new);
384
385 k.copy_within(8 * dim..9 * dim, 0);
387 } else {
388 stats.n_reject += 1;
389 }
390
391 let safety = S::from_f64(0.9);
393 let min_factor = S::from_f64(0.2);
394 let max_factor = S::from_f64(10.0);
395
396 let factor = if err_norm > S::from_f64(1e-10) {
397 safety * (S::ONE / err_norm).powf(S::from_f64(1.0 / 7.0))
398 } else {
399 max_factor
400 };
401 let factor = factor.min(max_factor).max(min_factor);
402 h = h * factor;
403
404 let h_abs = h.abs();
406 let h_abs = h_abs.min(options.h_max).max(options.h_min);
407 h = h_abs * direction;
408
409 step_count += 1;
410 }
411
412 if stats.n_accept > 0 {
414 } else {
416 problem.rhs(t, &y, &mut k[0..dim]);
417 stats.n_eval += 1;
418 }
419
420 Ok(SolverResult::new(t_out, y_out, dim, stats))
421 }
422}
423
424#[derive(Clone, Debug, Default)]
430pub struct Vern7;
431
432impl Vern7 {
433 pub fn new() -> Self {
434 Self
435 }
436}
437
438mod vern7_tableau {
442 pub const C: [f64; 10] = [
443 0.0,
444 0.005,
445 0.10888888888888888888888888888888888889,
446 0.16333333333333333333333333333333333333,
447 0.4555,
448 0.60950944899783813170870044214860249496,
449 0.884,
450 0.925,
451 1.0,
452 1.0,
453 ];
454
455 pub const A: [[f64; 9]; 9] = [
460 [0.005, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
462 [
464 -1.076790123456790123456790123456790123457,
465 1.185679012345679012345679012345679012346,
466 0.0,
467 0.0,
468 0.0,
469 0.0,
470 0.0,
471 0.0,
472 0.0,
473 ],
474 [
476 0.04083333333333333333333333333333333333333,
477 0.0,
478 0.1225,
479 0.0,
480 0.0,
481 0.0,
482 0.0,
483 0.0,
484 0.0,
485 ],
486 [
488 0.6389139236255726780508121615993336109954,
489 0.0,
490 -2.455672638223656809662640566430653894211,
491 2.272258714598084131611828404831320283215,
492 0.0,
493 0.0,
494 0.0,
495 0.0,
496 0.0,
497 ],
498 [
500 -2.661577375018757131119259297861818119279,
501 0.0,
502 10.80451388645613769565396655365532838482,
503 -8.353914657396199411968048547819291691541,
504 0.8204875949566569791420417341743839209619,
505 0.0,
506 0.0,
507 0.0,
508 0.0,
509 ],
510 [
512 6.067741434696770992718360183877276714679,
513 0.0,
514 -24.71127363591108579734203485290746001803,
515 20.42751793078889394045773111748346612697,
516 -1.906157978816647150624096784352757010879,
517 1.006172249242068014790040335899474187268,
518 0.0,
519 0.0,
520 0.0,
521 ],
522 [
524 12.05467007625320299509109452892778311648,
525 0.0,
526 -49.75478495046898932807257615331444758322,
527 41.14288863860467663259698416710157354209,
528 -4.461760149974004185641911603484815375051,
529 2.042334822239174959821717077708608543738,
530 -0.09834843665406107379530801693870224403537,
531 0.0,
532 0.0,
533 ],
534 [
536 10.13814652288180787641845141981689030769,
537 0.0,
538 -42.64113603171750214622846006736635730625,
539 35.76384003992257007135021178023160054034,
540 -4.348022840392907653340370296908245943710,
541 2.009862268377035895441943593011827554771,
542 0.3487490460338272405953822853053145879140,
543 -0.2714390051048312842371587140910297407572,
544 0.0,
545 ],
546 [
548 -45.03007203429867712435322405073769635151,
549 0.0,
550 187.3272437654588840752418206154201997384,
551 -154.0288236935018690596728621034510402582,
552 18.56465306347536233859492332958439136765,
553 -7.141809679295078854925420496823551192821,
554 1.308808578161378625114762706007696696508,
555 0.0,
556 0.0,
557 ],
558 ];
559
560 pub const B: [f64; 10] = [
562 0.047155618486272221704317651088381756796,
563 0.0,
564 0.0,
565 0.25750564298434151895964361010376875810,
566 0.26216653977412620477138630957645277111,
567 0.15216092656738557403231331991651175355,
568 0.49399691700324842469071758932278768443,
569 -0.29430311714032504415572447440927034291,
570 0.081317472324951099997345994401367618925,
571 0.0,
572 ];
573
574 pub const E: [f64; 10] = [
583 0.047155618486272221704317651088381756796 - 0.044608606606341176287318175974791977814, 0.0,
585 0.0,
586 0.25750564298434151895964361010376875810 - 0.26716403785713726805091022609438378997, 0.26216653977412620477138630957645277111 - 0.22010183001772930199797157766507530963, 0.15216092656738557403231331991651175355 - 0.21884317031431568309831208335128938246, 0.49399691700324842469071758932278768443 - 0.22898717054112028833781738897635523654, -0.29430311714032504415572447440927034291 - 0.0, 0.081317472324951099997345994401367618925 - 0.0, 0.0 - 0.020295184663356282227670547938104303586, ];
594}
595
596impl<S: Scalar> Solver<S> for Vern7 {
597 fn solve<Sys: OdeSystem<S>>(
598 problem: &Sys,
599 t0: S,
600 tf: S,
601 y0: &[S],
602 options: &SolverOptions<S>,
603 ) -> Result<SolverResult<S>, SolverError> {
604 let a_slices: Vec<&[f64]> = vern7_tableau::A.iter().map(|row| row.as_slice()).collect();
606 solve_erk(
607 problem,
608 t0,
609 tf,
610 y0,
611 options,
612 &vern7_tableau::C,
613 &a_slices,
614 &vern7_tableau::B,
615 &vern7_tableau::E,
616 7, 10, )
619 }
620}
621
622#[derive(Clone, Debug, Default)]
628pub struct Vern8;
629
630impl Vern8 {
631 pub fn new() -> Self {
632 Self
633 }
634}
635
636mod vern8_tableau {
641 pub const C: [f64; 13] = [
642 0.0,
643 0.05,
644 0.1065625,
645 0.15984375,
646 0.39,
647 0.465,
648 0.155,
649 0.943,
650 0.9018020417358569582597079406783721499560,
651 0.909,
652 0.94,
653 1.0,
654 1.0,
655 ];
656
657 pub const A2: [f64; 1] = [0.05];
660 pub const A3: [f64; 2] = [-0.0069931640625, 0.1135556640625];
661 pub const A4: [f64; 3] = [0.0399609375, 0.0, 0.1198828125];
662 pub const A5: [f64; 4] = [
663 0.3613975628004575124052940721184028345129,
664 0.0,
665 -1.341524066700492771819987788202715834917,
666 1.370126503900035259414693716084313000404,
667 ];
668 pub const A6: [f64; 5] = [
669 0.04904720279720279720279720279720279720280,
670 0.0,
671 0.0,
672 0.2350972042214404739862988335493427143122,
673 0.1808555929813567288109039636534544884850,
674 ];
675 pub const A7: [f64; 6] = [
676 0.06169289044289044289044289044289044289044,
677 0.0,
678 0.0,
679 0.1123656831464027662262557035130015442303,
680 -0.03885046071451366767049048108111244567456,
681 0.01979188712522045855379188712522045855379,
682 ];
683 pub const A8: [f64; 7] = [
684 -1.767630240222326875735597119572145586714,
685 0.0,
686 0.0,
687 -62.5,
688 -6.061889377376669100821361459659331999758,
689 5.650823198222763138561298030600840174201,
690 65.62169641937623283799566054863063741227,
691 ];
692 pub const A9: [f64; 8] = [
693 -1.180945066554970799825116282628297957882,
694 0.0,
695 0.0,
696 -41.50473441114320841606641502701994225874,
697 -4.434438319103725011225169229846100211776,
698 4.260408188586133024812193710744693240761,
699 43.75364022446171584987676829438379303004,
700 0.007871425489912310687446475044226307550860,
701 ];
702 pub const A10: [f64; 9] = [
703 -1.281405999441488405459510291182054246266,
704 0.0,
705 0.0,
706 -45.04713996013986630220754257136007322267,
707 -4.731362069449576477311464265491282810943,
708 4.514967016593807841185851584597240996214,
709 47.44909557172985134869022392235929015114,
710 0.01059228297111661135687393955516542875228,
711 -0.005746842263844616254432318478286296232021,
712 ];
713 pub const A11: [f64; 10] = [
714 -1.724470134262485191756709817484481861731,
715 0.0,
716 0.0,
717 -60.92349008483054016518434619253765246063,
718 -5.951518376222392455202832767061854868290,
719 5.556523730698456235979791650843592496839,
720 63.98301198033305336837536378635995939281,
721 0.01464202825041496159275921391759452676003,
722 0.06460408772358203603621865144977650714892,
723 -0.07930323169008878984024452548693373291447,
724 ];
725 pub const A12: [f64; 11] = [
726 -3.301622667747079016353994789790983625569,
727 0.0,
728 0.0,
729 -118.0112723597525085666923303957898868510,
730 -10.14142238845611248642783916034510897595,
731 9.139311332232057923544012273556827000619,
732 123.3759428284042683684847180986501894364,
733 4.623244378874580474839807625067630924792,
734 -3.383277738068201923652550971536811240814,
735 4.527592100324618189451265339351129035325,
736 -5.828495485811622963193088019162985703755,
737 ];
738 pub const A13: [f64; 12] = [
740 -3.039515033766309030040102851821200251056,
741 0.0,
742 0.0,
743 -109.2608680894176254686444192322164623352,
744 -9.290642497400293449717665542656897549158,
745 8.430504981764911142134299253836167803454,
746 114.2010010378331313557424041095523427476,
747 -0.9637271342145479358162375658987901652762,
748 -5.034884088802189791198680336183332323118,
749 5.958130824002923177540402165388172072794,
750 0.0,
751 0.0,
752 ];
753
754 pub const B: [f64; 13] = [
756 0.04427989419007951074716746668098518862111,
757 0.0,
758 0.0,
759 0.0,
760 0.0,
761 0.3541049391724448744815552028733568354121,
762 0.2479692154956437828667629415370663023884,
763 -15.69420203883808405099207034271191213468,
764 25.08406496555856261343930031237186278518,
765 -31.73836778626027646833156112007297739997,
766 22.93828327398878395231483560344797018313,
767 -0.2361324633071542145259900641263517600737,
768 0.0,
769 ];
770
771 #[allow(dead_code)]
773 pub const B_HAT: [f64; 13] = [
774 0.04431261522908979212486436510209029764893,
775 0.0,
776 0.0,
777 0.0,
778 0.0,
779 0.3546095642343226447863179350895055038855,
780 0.2478480431366653069619986721504458660016,
781 4.448134732475784492725128317159648871312,
782 19.84688636611873369930932399297687935291,
783 -23.58162337746561841969517960870394965085,
784 0.0,
785 0.0,
786 -0.3601679437289775162124536737746202409110,
787 ];
788
789 pub const E: [f64; 13] = [
792 0.04427989419007951074716746668098518862111 - 0.04431261522908979212486436510209029764893,
793 0.0,
794 0.0,
795 0.0,
796 0.0,
797 0.3541049391724448744815552028733568354121 - 0.3546095642343226447863179350895055038855,
798 0.2479692154956437828667629415370663023884 - 0.2478480431366653069619986721504458660016,
799 -15.69420203883808405099207034271191213468 - 4.448134732475784492725128317159648871312,
800 25.08406496555856261343930031237186278518 - 19.84688636611873369930932399297687935291,
801 -31.73836778626027646833156112007297739997 - (-23.58162337746561841969517960870394965085),
802 22.93828327398878395231483560344797018313 - 0.0,
803 -0.2361324633071542145259900641263517600737 - 0.0,
804 0.0 - (-0.3601679437289775162124536737746202409110),
805 ];
806}
807
808impl<S: Scalar> Solver<S> for Vern8 {
809 fn solve<Sys: OdeSystem<S>>(
810 problem: &Sys,
811 t0: S,
812 tf: S,
813 y0: &[S],
814 options: &SolverOptions<S>,
815 ) -> Result<SolverResult<S>, SolverError> {
816 solve_vern8(problem, t0, tf, y0, options)
818 }
819}
820
821fn solve_erk<S, Sys>(
823 problem: &Sys,
824 t0: S,
825 tf: S,
826 y0: &[S],
827 options: &SolverOptions<S>,
828 c: &[f64],
829 a: &[&[f64]],
830 b: &[f64],
831 e: &[f64],
832 order: usize,
833 stages: usize,
834) -> Result<SolverResult<S>, SolverError>
835where
836 S: Scalar,
837 Sys: OdeSystem<S>,
838{
839 let dim = problem.dim();
840 if y0.len() != dim {
841 return Err(SolverError::DimensionMismatch {
842 expected: dim,
843 actual: y0.len(),
844 });
845 }
846
847 let mut t = t0;
848 let mut y = y0.to_vec();
849
850 let mut k = vec![S::ZERO; stages * dim];
853 let mut y_stage = vec![S::ZERO; dim];
854 let mut y_new = vec![S::ZERO; dim];
855 let mut err = vec![S::ZERO; dim];
856 let mut dy_old = vec![S::ZERO; dim];
860 let mut dy_new = vec![S::ZERO; dim];
861
862 let mut stats = SolverStats::default();
863
864 problem.rhs(t, &y, &mut k[0..dim]);
866 stats.n_eval += 1;
867 let mut h = initial_step_size(&y, &k[0..dim], options, dim);
868 let h_min = options.h_min;
869 let h_max = options.h_max.min((tf - t0).abs());
870
871 let direction = if tf > t0 { S::ONE } else { -S::ONE };
872 if let Some(grid) = options.t_eval.as_deref() {
873 validate_grid(grid, t0, tf)?;
874 }
875 let mut grid_emitter = options
876 .t_eval
877 .as_deref()
878 .map(|g| TEvalEmitter::new(g, direction));
879 let (mut t_out, mut y_out) = if grid_emitter.is_some() {
880 (Vec::new(), Vec::new())
881 } else {
882 (vec![t0], y0.to_vec())
883 };
884 let mut step_count = 0_usize;
885
886 while (tf - t) * direction > S::from_f64(1e-10) * (tf - t0).abs() {
887 if step_count >= options.max_steps {
888 return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
889 }
890
891 if (t + h - tf) * direction > S::ZERO {
892 h = tf - t;
893 }
894
895 h = h.abs().max(h_min) * direction;
896 if h.abs() > h_max {
897 h = h_max * direction;
898 }
899
900 for s in 1..stages {
902 for i in 0..dim {
903 let mut sum = S::ZERO;
904 for j in 0..s {
905 if s - 1 < a.len() && j < a[s - 1].len() {
906 sum = sum + S::from_f64(a[s - 1][j]) * k[j * dim + i];
907 }
908 }
909 y_stage[i] = y[i] + h * sum;
910 }
911 problem.rhs(
912 t + S::from_f64(c[s]) * h,
913 &y_stage,
914 &mut k[s * dim..(s + 1) * dim],
915 );
916 }
917 stats.n_eval += stages - 1;
918
919 for i in 0..dim {
921 let mut sum_b = S::ZERO;
922 let mut sum_e = S::ZERO;
923 for s in 0..stages {
924 sum_b = sum_b + S::from_f64(b[s]) * k[s * dim + i];
925 sum_e = sum_e + S::from_f64(e[s]) * k[s * dim + i];
926 }
927 y_new[i] = y[i] + h * sum_b;
928 err[i] = h * sum_e;
929 }
930
931 let err_norm = error_norm(&err, &y, &y_new, options, dim);
932
933 let safety = S::from_f64(0.9);
934 let fac_max = S::from_f64(4.0);
935 let fac_min = S::from_f64(0.2);
936 let order_f = S::from_usize(order + 1);
937
938 if err_norm <= S::ONE {
939 stats.n_accept += 1;
940
941 let t_new = t + h;
942 dy_old.copy_from_slice(&k[0..dim]);
947 problem.rhs(t_new, &y_new, &mut dy_new);
948 stats.n_eval += 1;
949
950 if let Some(ref mut emitter) = grid_emitter {
951 emitter.emit_step(
952 t, &y, &dy_old, t_new, &y_new, &dy_new, &mut t_out, &mut y_out,
953 );
954 } else {
955 t_out.push(t_new);
956 y_out.extend_from_slice(&y_new);
957 }
958
959 t = t_new;
960 y.copy_from_slice(&y_new);
961 k[0..dim].copy_from_slice(&dy_new);
962
963 let err_safe = err_norm.max(S::from_f64(1e-10));
964 let fac = safety * err_safe.powf(-S::ONE / order_f);
965 let fac = fac.min(fac_max).max(fac_min);
966 h = h * fac;
967 } else {
968 stats.n_reject += 1;
969
970 let err_safe = err_norm.max(S::from_f64(1e-10));
971 let fac = safety * err_safe.powf(-S::ONE / (order_f - S::ONE));
972 let fac = fac.max(fac_min);
973 h = h * fac;
974 }
975
976 if h.abs() < h_min {
977 return Err(SolverError::StepSizeTooSmall {
978 t: t.to_f64(),
979 h: h.to_f64(),
980 h_min: h_min.to_f64(),
981 });
982 }
983
984 step_count += 1;
985 }
986
987 Ok(SolverResult::new(t_out, y_out, dim, stats))
988}
989
990fn solve_vern8<S, Sys>(
992 problem: &Sys,
993 t0: S,
994 tf: S,
995 y0: &[S],
996 options: &SolverOptions<S>,
997) -> Result<SolverResult<S>, SolverError>
998where
999 S: Scalar,
1000 Sys: OdeSystem<S>,
1001{
1002 let dim = problem.dim();
1003 if y0.len() != dim {
1004 return Err(SolverError::DimensionMismatch {
1005 expected: dim,
1006 actual: y0.len(),
1007 });
1008 }
1009
1010 let mut t = t0;
1011 let mut y = y0.to_vec();
1012
1013 let mut k = vec![S::ZERO; 13 * dim];
1015 let mut y_stage = vec![S::ZERO; dim];
1016 let mut y_new = vec![S::ZERO; dim];
1017 let mut err = vec![S::ZERO; dim];
1018 let mut dy_old = vec![S::ZERO; dim];
1019 let mut dy_new = vec![S::ZERO; dim];
1020
1021 let mut stats = SolverStats::default();
1022
1023 problem.rhs(t, &y, &mut k[0..dim]);
1024 stats.n_eval += 1;
1025 let mut h = initial_step_size(&y, &k[0..dim], options, dim);
1026 let h_min = options.h_min;
1027 let h_max = options.h_max.min((tf - t0).abs());
1028
1029 let direction = if tf > t0 { S::ONE } else { -S::ONE };
1030 if let Some(grid) = options.t_eval.as_deref() {
1031 validate_grid(grid, t0, tf)?;
1032 }
1033 let mut grid_emitter = options
1034 .t_eval
1035 .as_deref()
1036 .map(|g| TEvalEmitter::new(g, direction));
1037 let (mut t_out, mut y_out) = if grid_emitter.is_some() {
1038 (Vec::new(), Vec::new())
1039 } else {
1040 (vec![t0], y0.to_vec())
1041 };
1042 let mut step_count = 0_usize;
1043
1044 while (tf - t) * direction > S::from_f64(1e-10) * (tf - t0).abs() {
1045 if step_count >= options.max_steps {
1046 return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
1047 }
1048
1049 if (t + h - tf) * direction > S::ZERO {
1050 h = tf - t;
1051 }
1052
1053 h = h.abs().max(h_min) * direction;
1054 if h.abs() > h_max {
1055 h = h_max * direction;
1056 }
1057
1058 for i in 0..dim {
1061 y_stage[i] = y[i] + h * S::from_f64(vern8_tableau::A2[0]) * k[i];
1062 }
1063 problem.rhs(
1064 t + S::from_f64(vern8_tableau::C[1]) * h,
1065 &y_stage,
1066 &mut k[dim..2 * dim],
1067 );
1068
1069 for i in 0..dim {
1071 y_stage[i] = y[i]
1072 + h * (S::from_f64(vern8_tableau::A3[0]) * k[i]
1073 + S::from_f64(vern8_tableau::A3[1]) * k[dim + i]);
1074 }
1075 problem.rhs(
1076 t + S::from_f64(vern8_tableau::C[2]) * h,
1077 &y_stage,
1078 &mut k[2 * dim..3 * dim],
1079 );
1080
1081 for i in 0..dim {
1083 y_stage[i] = y[i]
1084 + h * (S::from_f64(vern8_tableau::A4[0]) * k[i]
1085 + S::from_f64(vern8_tableau::A4[2]) * k[2 * dim + i]);
1086 }
1087 problem.rhs(
1088 t + S::from_f64(vern8_tableau::C[3]) * h,
1089 &y_stage,
1090 &mut k[3 * dim..4 * dim],
1091 );
1092
1093 for i in 0..dim {
1095 y_stage[i] = y[i]
1096 + h * (S::from_f64(vern8_tableau::A5[0]) * k[i]
1097 + S::from_f64(vern8_tableau::A5[2]) * k[2 * dim + i]
1098 + S::from_f64(vern8_tableau::A5[3]) * k[3 * dim + i]);
1099 }
1100 problem.rhs(
1101 t + S::from_f64(vern8_tableau::C[4]) * h,
1102 &y_stage,
1103 &mut k[4 * dim..5 * dim],
1104 );
1105
1106 for i in 0..dim {
1108 y_stage[i] = y[i]
1109 + h * (S::from_f64(vern8_tableau::A6[0]) * k[i]
1110 + S::from_f64(vern8_tableau::A6[3]) * k[3 * dim + i]
1111 + S::from_f64(vern8_tableau::A6[4]) * k[4 * dim + i]);
1112 }
1113 problem.rhs(
1114 t + S::from_f64(vern8_tableau::C[5]) * h,
1115 &y_stage,
1116 &mut k[5 * dim..6 * dim],
1117 );
1118
1119 for i in 0..dim {
1121 y_stage[i] = y[i]
1122 + h * (S::from_f64(vern8_tableau::A7[0]) * k[i]
1123 + S::from_f64(vern8_tableau::A7[3]) * k[3 * dim + i]
1124 + S::from_f64(vern8_tableau::A7[4]) * k[4 * dim + i]
1125 + S::from_f64(vern8_tableau::A7[5]) * k[5 * dim + i]);
1126 }
1127 problem.rhs(
1128 t + S::from_f64(vern8_tableau::C[6]) * h,
1129 &y_stage,
1130 &mut k[6 * dim..7 * dim],
1131 );
1132
1133 for i in 0..dim {
1135 y_stage[i] = y[i]
1136 + h * (S::from_f64(vern8_tableau::A8[0]) * k[i]
1137 + S::from_f64(vern8_tableau::A8[3]) * k[3 * dim + i]
1138 + S::from_f64(vern8_tableau::A8[4]) * k[4 * dim + i]
1139 + S::from_f64(vern8_tableau::A8[5]) * k[5 * dim + i]
1140 + S::from_f64(vern8_tableau::A8[6]) * k[6 * dim + i]);
1141 }
1142 problem.rhs(
1143 t + S::from_f64(vern8_tableau::C[7]) * h,
1144 &y_stage,
1145 &mut k[7 * dim..8 * dim],
1146 );
1147
1148 for i in 0..dim {
1150 y_stage[i] = y[i]
1151 + h * (S::from_f64(vern8_tableau::A9[0]) * k[i]
1152 + S::from_f64(vern8_tableau::A9[3]) * k[3 * dim + i]
1153 + S::from_f64(vern8_tableau::A9[4]) * k[4 * dim + i]
1154 + S::from_f64(vern8_tableau::A9[5]) * k[5 * dim + i]
1155 + S::from_f64(vern8_tableau::A9[6]) * k[6 * dim + i]
1156 + S::from_f64(vern8_tableau::A9[7]) * k[7 * dim + i]);
1157 }
1158 problem.rhs(
1159 t + S::from_f64(vern8_tableau::C[8]) * h,
1160 &y_stage,
1161 &mut k[8 * dim..9 * dim],
1162 );
1163
1164 for i in 0..dim {
1166 y_stage[i] = y[i]
1167 + h * (S::from_f64(vern8_tableau::A10[0]) * k[i]
1168 + S::from_f64(vern8_tableau::A10[3]) * k[3 * dim + i]
1169 + S::from_f64(vern8_tableau::A10[4]) * k[4 * dim + i]
1170 + S::from_f64(vern8_tableau::A10[5]) * k[5 * dim + i]
1171 + S::from_f64(vern8_tableau::A10[6]) * k[6 * dim + i]
1172 + S::from_f64(vern8_tableau::A10[7]) * k[7 * dim + i]
1173 + S::from_f64(vern8_tableau::A10[8]) * k[8 * dim + i]);
1174 }
1175 problem.rhs(
1176 t + S::from_f64(vern8_tableau::C[9]) * h,
1177 &y_stage,
1178 &mut k[9 * dim..10 * dim],
1179 );
1180
1181 for i in 0..dim {
1183 y_stage[i] = y[i]
1184 + h * (S::from_f64(vern8_tableau::A11[0]) * k[i]
1185 + S::from_f64(vern8_tableau::A11[3]) * k[3 * dim + i]
1186 + S::from_f64(vern8_tableau::A11[4]) * k[4 * dim + i]
1187 + S::from_f64(vern8_tableau::A11[5]) * k[5 * dim + i]
1188 + S::from_f64(vern8_tableau::A11[6]) * k[6 * dim + i]
1189 + S::from_f64(vern8_tableau::A11[7]) * k[7 * dim + i]
1190 + S::from_f64(vern8_tableau::A11[8]) * k[8 * dim + i]
1191 + S::from_f64(vern8_tableau::A11[9]) * k[9 * dim + i]);
1192 }
1193 problem.rhs(
1194 t + S::from_f64(vern8_tableau::C[10]) * h,
1195 &y_stage,
1196 &mut k[10 * dim..11 * dim],
1197 );
1198
1199 for i in 0..dim {
1201 y_stage[i] = y[i]
1202 + h * (S::from_f64(vern8_tableau::A12[0]) * k[i]
1203 + S::from_f64(vern8_tableau::A12[3]) * k[3 * dim + i]
1204 + S::from_f64(vern8_tableau::A12[4]) * k[4 * dim + i]
1205 + S::from_f64(vern8_tableau::A12[5]) * k[5 * dim + i]
1206 + S::from_f64(vern8_tableau::A12[6]) * k[6 * dim + i]
1207 + S::from_f64(vern8_tableau::A12[7]) * k[7 * dim + i]
1208 + S::from_f64(vern8_tableau::A12[8]) * k[8 * dim + i]
1209 + S::from_f64(vern8_tableau::A12[9]) * k[9 * dim + i]
1210 + S::from_f64(vern8_tableau::A12[10]) * k[10 * dim + i]);
1211 }
1212 problem.rhs(
1213 t + S::from_f64(vern8_tableau::C[11]) * h,
1214 &y_stage,
1215 &mut k[11 * dim..12 * dim],
1216 );
1217
1218 for i in 0..dim {
1220 y_stage[i] = y[i]
1221 + h * (S::from_f64(vern8_tableau::A13[0]) * k[i]
1222 + S::from_f64(vern8_tableau::A13[3]) * k[3 * dim + i]
1223 + S::from_f64(vern8_tableau::A13[4]) * k[4 * dim + i]
1224 + S::from_f64(vern8_tableau::A13[5]) * k[5 * dim + i]
1225 + S::from_f64(vern8_tableau::A13[6]) * k[6 * dim + i]
1226 + S::from_f64(vern8_tableau::A13[7]) * k[7 * dim + i]
1227 + S::from_f64(vern8_tableau::A13[8]) * k[8 * dim + i]
1228 + S::from_f64(vern8_tableau::A13[9]) * k[9 * dim + i]);
1229 }
1230 problem.rhs(
1231 t + S::from_f64(vern8_tableau::C[12]) * h,
1232 &y_stage,
1233 &mut k[12 * dim..13 * dim],
1234 );
1235
1236 stats.n_eval += 12;
1237
1238 for i in 0..dim {
1240 let mut sum_b = S::ZERO;
1241 let mut sum_e = S::ZERO;
1242 for s in 0..13 {
1243 sum_b = sum_b + S::from_f64(vern8_tableau::B[s]) * k[s * dim + i];
1244 sum_e = sum_e + S::from_f64(vern8_tableau::E[s]) * k[s * dim + i];
1245 }
1246 y_new[i] = y[i] + h * sum_b;
1247 err[i] = h * sum_e;
1248 }
1249
1250 let err_norm = error_norm(&err, &y, &y_new, options, dim);
1251
1252 let safety = S::from_f64(0.9);
1253 let fac_max = S::from_f64(3.0);
1254 let fac_min = S::from_f64(0.2);
1255
1256 if err_norm <= S::ONE {
1257 stats.n_accept += 1;
1258
1259 let t_new = t + h;
1260 dy_old.copy_from_slice(&k[0..dim]);
1261 problem.rhs(t_new, &y_new, &mut dy_new);
1262 stats.n_eval += 1;
1263
1264 if let Some(ref mut emitter) = grid_emitter {
1265 emitter.emit_step(
1266 t, &y, &dy_old, t_new, &y_new, &dy_new, &mut t_out, &mut y_out,
1267 );
1268 } else {
1269 t_out.push(t_new);
1270 y_out.extend_from_slice(&y_new);
1271 }
1272
1273 t = t_new;
1274 y.copy_from_slice(&y_new);
1275 k[0..dim].copy_from_slice(&dy_new);
1276
1277 let err_safe = err_norm.max(S::from_f64(1e-10));
1278 let fac = safety * err_safe.powf(S::from_f64(-1.0 / 9.0));
1279 let fac = fac.min(fac_max).max(fac_min);
1280 h = h * fac;
1281 } else {
1282 stats.n_reject += 1;
1283
1284 let err_safe = err_norm.max(S::from_f64(1e-10));
1285 let fac = safety * err_safe.powf(S::from_f64(-1.0 / 8.0));
1286 let fac = fac.max(fac_min);
1287 h = h * fac;
1288 }
1289
1290 if h.abs() < h_min {
1291 return Err(SolverError::StepSizeTooSmall {
1292 t: t.to_f64(),
1293 h: h.to_f64(),
1294 h_min: h_min.to_f64(),
1295 });
1296 }
1297
1298 step_count += 1;
1299 }
1300
1301 Ok(SolverResult::new(t_out, y_out, dim, stats))
1302}
1303
1304fn initial_step_size<S: Scalar>(y0: &[S], f0: &[S], options: &SolverOptions<S>, dim: usize) -> S {
1305 if let Some(h0) = options.h0 {
1306 return h0;
1307 }
1308
1309 let mut y_norm = S::ZERO;
1310 let mut f_norm = S::ZERO;
1311 for i in 0..dim {
1312 let sc = options.atol + options.rtol * y0[i].abs();
1313 y_norm = y_norm + (y0[i] / sc) * (y0[i] / sc);
1314 f_norm = f_norm + (f0[i] / sc) * (f0[i] / sc);
1315 }
1316 y_norm = (y_norm / S::from_usize(dim)).sqrt();
1317 f_norm = (f_norm / S::from_usize(dim)).sqrt();
1318
1319 if y_norm < S::from_f64(1e-5) || f_norm < S::from_f64(1e-5) {
1320 S::from_f64(1e-6)
1321 } else {
1322 (S::from_f64(0.01) * y_norm / f_norm).min(options.h_max)
1323 }
1324}
1325
1326fn error_norm<S: Scalar>(
1327 err: &[S],
1328 y: &[S],
1329 y_new: &[S],
1330 options: &SolverOptions<S>,
1331 dim: usize,
1332) -> S {
1333 let mut err_norm = S::ZERO;
1334 for i in 0..dim {
1335 let sc = options.atol + options.rtol * y[i].abs().max(y_new[i].abs());
1336 let sc = sc.max(S::from_f64(1e-15));
1337 let scaled_err = err[i] / sc;
1338 err_norm = err_norm + scaled_err * scaled_err;
1339 }
1340 (err_norm / S::from_usize(dim)).sqrt()
1341}
1342
1343#[cfg(test)]
1344mod tests {
1345 use super::*;
1346 use crate::problem::OdeProblem;
1347
1348 #[test]
1349 fn test_vern6_exponential() {
1350 let problem = OdeProblem::new(
1351 |_t, y: &[f64], dydt: &mut [f64]| {
1352 dydt[0] = -y[0];
1353 },
1354 0.0,
1355 5.0,
1356 vec![1.0],
1357 );
1358 let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
1360 let result = Vern6::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
1361
1362 assert!(result.success);
1363 let y_final = result.y_final().unwrap();
1364 let expected = (-5.0_f64).exp();
1365 assert!(
1366 (y_final[0] - expected).abs() < 1e-5,
1367 "Vern6 exponential: got {}, expected {}",
1368 y_final[0],
1369 expected
1370 );
1371 }
1372
1373 #[test]
1374 fn test_vern7_harmonic() {
1375 let problem = OdeProblem::new(
1376 |_t, y: &[f64], dydt: &mut [f64]| {
1377 dydt[0] = y[1];
1378 dydt[1] = -y[0];
1379 },
1380 0.0,
1381 10.0,
1382 vec![1.0, 0.0],
1383 );
1384 let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1386 let result = Vern7::solve(&problem, 0.0, 10.0, &[1.0, 0.0], &options).unwrap();
1387
1388 assert!(result.success);
1389 let y_final = result.y_final().unwrap();
1390 assert!(
1392 (y_final[0] - 10.0_f64.cos()).abs() < 0.01,
1393 "Vern7 harmonic: got {}, expected {}",
1394 y_final[0],
1395 10.0_f64.cos()
1396 );
1397 }
1398
1399 #[test]
1400 fn test_vern8_high_accuracy() {
1401 let problem = OdeProblem::new(
1402 |_t, y: &[f64], dydt: &mut [f64]| {
1403 dydt[0] = -y[0];
1404 },
1405 0.0,
1406 5.0,
1407 vec![1.0],
1408 );
1409 let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1411 let result = Vern8::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
1412
1413 assert!(result.success);
1414 let y_final = result.y_final().unwrap();
1415 let expected = (-5.0_f64).exp();
1416 assert!(
1418 (y_final[0] - expected).abs() < expected * 0.05,
1419 "Vern8 exponential: got {}, expected {}",
1420 y_final[0],
1421 expected
1422 );
1423 }
1424
1425 #[test]
1426 fn test_vern_methods_agree() {
1427 let problem = OdeProblem::new(
1428 |_t, y: &[f64], dydt: &mut [f64]| {
1429 dydt[0] = -y[0];
1430 },
1431 0.0,
1432 2.0,
1433 vec![1.0],
1434 );
1435 let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1437
1438 let r6 = Vern6::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
1439 let r7 = Vern7::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
1440 let r8 = Vern8::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
1441
1442 let y6 = r6.y_final().unwrap()[0];
1443 let y7 = r7.y_final().unwrap()[0];
1444 let y8 = r8.y_final().unwrap()[0];
1445 let expected = (-2.0_f64).exp();
1446
1447 assert!(
1449 (y6 - expected).abs() < expected * 0.01,
1450 "Vern6 disagrees: {}",
1451 y6
1452 );
1453 assert!(
1454 (y7 - expected).abs() < expected * 0.01,
1455 "Vern7 disagrees: {}",
1456 y7
1457 );
1458 assert!(
1459 (y8 - expected).abs() < expected * 0.05,
1460 "Vern8 disagrees: {}",
1461 y8
1462 );
1463 }
1464}