Skip to main content

numra_ode/
verner.rs

1//! Verner: High-order explicit Runge-Kutta methods.
2//!
3//! Jim Verner's efficient RK pairs designed for high accuracy applications.
4//!
5//! ## Available Methods
6//! - `Vern6` - Verner's 6(5) "Efficient" pair (8 stages)
7//! - `Vern7` - Verner's 7(6) "Efficient" pair (10 stages)
8//! - `Vern8` - Verner's 8(7) "Efficient" pair (13 stages)
9//!
10//! ## Reference
11//! Verner, J.H. (2010), "Numerically optimal Runge-Kutta pairs with interpolants"
12//! Numerical Algorithms, 53, 383-396.
13//!
14//! Author: Moussa Leblouba
15//! Date: 5 March 2026
16//! Modified: 2 May 2026
17
18use crate::error::SolverError;
19use crate::problem::OdeSystem;
20use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
21use crate::t_eval::{validate_grid, TEvalEmitter};
22use numra_core::Scalar;
23
24// ============================================================================
25// Verner 6(5) "Efficient"
26// ============================================================================
27
28/// Vern6: Verner's 6(5) "Efficient" method (8 stages).
29#[derive(Clone, Debug, Default)]
30pub struct Vern6;
31
32impl Vern6 {
33    pub fn new() -> Self {
34        Self
35    }
36}
37
38/// Verner "most efficient" 6(5) FSAL pair (RKV65 Efficient).
39/// Reference: Jim Verner's website <https://www.sfu.ca/~jverner/>
40/// This is a 10-stage FSAL pair (9 stages + 1 for FSAL).
41#[allow(dead_code)]
42mod vern6_tableau {
43    // Nodes (c coefficients) - 10 stages
44    pub const C: [f64; 10] = [
45        0.0,
46        0.06,
47        0.09593333333333333,
48        0.1439,
49        0.4973,
50        0.9725,
51        0.9995,
52        1.0,
53        1.0,
54        0.5,
55    ];
56
57    // A matrix coefficients (lower triangular, row by row)
58    // A[i] contains coefficients a_{i,j} for j = 0..i-1
59    pub const A21: f64 = 0.06;
60
61    pub const A31: f64 = 0.019239962962962962;
62    pub const A32: f64 = 0.07669337037037037;
63
64    pub const A41: f64 = 0.035975;
65    pub const A42: f64 = 0.0;
66    pub const A43: f64 = 0.107925;
67
68    pub const A51: f64 = 1.3186834152331484;
69    pub const A52: f64 = 0.0;
70    pub const A53: f64 = -5.042058063628562;
71    pub const A54: f64 = 4.220674648395414;
72
73    pub const A61: f64 = -41.872591664327516;
74    pub const A62: f64 = 0.0;
75    pub const A63: f64 = 159.4325621631375;
76    pub const A64: f64 = -122.11921356501003;
77    pub const A65: f64 = 5.531743066200054;
78
79    pub const A71: f64 = -54.430156935316504;
80    pub const A72: f64 = 0.0;
81    pub const A73: f64 = 207.06725136501848;
82    pub const A74: f64 = -158.61081378459;
83    pub const A75: f64 = 6.991816585950242;
84    pub const A76: f64 = -0.018597231062203234;
85
86    pub const A81: f64 = -54.66374178728198;
87    pub const A82: f64 = 0.0;
88    pub const A83: f64 = 207.95280625538936;
89    pub const A84: f64 = -159.2889574744995;
90    pub const A85: f64 = 7.018743740796944;
91    pub const A86: f64 = -0.018338785905045722;
92    pub const A87: f64 = -0.0005119484997882099;
93
94    pub const A91: f64 = 0.03438957868357036;
95    pub const A92: f64 = 0.0;
96    pub const A93: f64 = 0.0;
97    pub const A94: f64 = 0.2582624555633503;
98    pub const A95: f64 = 0.4209371189673537;
99    pub const A96: f64 = 4.40539646966931;
100    pub const A97: f64 = -176.48311902429865;
101    pub const A98: f64 = 172.36413340141507;
102
103    // Row 10 (for dense output stage at c=0.5)
104    pub const A101: f64 = 0.016524159013572806;
105    pub const A102: f64 = 0.0;
106    pub const A103: f64 = 0.0;
107    pub const A104: f64 = 0.3053128187514179;
108    pub const A105: f64 = 0.2071200938201979;
109    pub const A106: f64 = -1.293879140655123;
110    pub const A107: f64 = 57.11988411588149;
111    pub const A108: f64 = -55.87979207510932;
112    pub const A109: f64 = 0.024830028297766014;
113
114    // 6th order weights (B)
115    pub const B: [f64; 10] = [
116        0.03438957868357036,
117        0.0,
118        0.0,
119        0.2582624555633503,
120        0.4209371189673537,
121        4.40539646966931,
122        -176.48311902429865,
123        172.36413340141507,
124        0.0,
125        0.0,
126    ];
127
128    // 5th order embedded weights (B_HAT)
129    pub const B_HAT: [f64; 10] = [
130        0.0490996764838249,
131        0.0,
132        0.0,
133        0.22511122295165242,
134        0.4694682253029562,
135        0.8065792249988868,
136        0.0,
137        -0.607119489177796,
138        0.056861139440475696,
139        0.0,
140    ];
141
142    // Error coefficients (E = B - B_HAT)
143    pub const E: [f64; 10] = [
144        B[0] - B_HAT[0], // -0.01470009779...
145        0.0,
146        0.0,
147        B[3] - B_HAT[3], // 0.03315123261...
148        B[4] - B_HAT[4], // -0.04853110633...
149        B[5] - B_HAT[5], // 3.598817244...
150        B[6] - B_HAT[6], // -176.48311902...
151        B[7] - B_HAT[7], // 172.97125289...
152        B[8] - B_HAT[8], // -0.05686113944...
153        0.0,
154    ];
155}
156
157impl<S: Scalar> Solver<S> for Vern6 {
158    fn solve<Sys: OdeSystem<S>>(
159        problem: &Sys,
160        t0: S,
161        tf: S,
162        y0: &[S],
163        options: &SolverOptions<S>,
164    ) -> Result<SolverResult<S>, SolverError> {
165        use vern6_tableau::*;
166
167        let dim = problem.dim();
168        if dim == 0 {
169            return Err(SolverError::DimensionMismatch {
170                expected: 1,
171                actual: 0,
172            });
173        }
174
175        let mut t = t0;
176        let mut y = y0.to_vec();
177        let direction = if tf >= t0 { S::ONE } else { -S::ONE };
178
179        // Flat storage for 10 stage vectors: k[s*dim + i] instead of k[s][i].
180        // Avoids heap fragmentation from Vec<Vec<S>>.
181        let mut k = vec![S::ZERO; 10 * dim];
182        let mut y_stage = vec![S::ZERO; dim];
183        let mut y_new = vec![S::ZERO; dim];
184        let mut err = vec![S::ZERO; dim];
185
186        // Statistics
187        let mut stats = SolverStats::default();
188
189        if let Some(grid) = options.t_eval.as_deref() {
190            validate_grid(grid, t0, tf)?;
191        }
192        let mut grid_emitter = options
193            .t_eval
194            .as_deref()
195            .map(|g| TEvalEmitter::new(g, direction));
196        let (mut t_out, mut y_out) = if grid_emitter.is_some() {
197            (Vec::new(), Vec::new())
198        } else {
199            (vec![t0], y0.to_vec())
200        };
201
202        // Initial step size estimation
203        problem.rhs(t, &y, &mut k[0..dim]);
204        stats.n_eval += 1;
205
206        let mut f_norm = S::ZERO;
207        for i in 0..dim {
208            f_norm = f_norm + k[i] * k[i];
209        }
210        f_norm = f_norm.sqrt();
211
212        let mut h = if f_norm > S::from_f64(1e-10) {
213            S::from_f64(0.01) * (S::ONE / f_norm)
214        } else {
215            S::from_f64(0.01)
216        };
217        h = h.min(options.h_max).max(options.h_min) * direction;
218
219        let max_steps = options.max_steps;
220        let mut step_count = 0;
221
222        while (tf - t) * direction > S::from_f64(1e-14) * (tf - t0).abs() {
223            if step_count >= max_steps {
224                return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
225            }
226
227            // Adjust final step
228            if (t + h - tf) * direction > S::ZERO {
229                h = tf - t;
230            }
231
232            // k1 is already computed (FSAL property after first step)
233            // All stage accesses use flat indexing: k[s*dim + i]
234
235            // k2: y + h * A21 * k1
236            let c2 = S::from_f64(C[1]);
237            let a21 = S::from_f64(A21);
238            for i in 0..dim {
239                y_stage[i] = y[i] + h * a21 * k[i];
240            }
241            problem.rhs(t + c2 * h, &y_stage, &mut k[dim..2 * dim]);
242
243            // k3: y + h * (A31*k1 + A32*k2)
244            let c3 = S::from_f64(C[2]);
245            let a31 = S::from_f64(A31);
246            let a32 = S::from_f64(A32);
247            for i in 0..dim {
248                y_stage[i] = y[i] + h * (a31 * k[i] + a32 * k[dim + i]);
249            }
250            problem.rhs(t + c3 * h, &y_stage, &mut k[2 * dim..3 * dim]);
251
252            // k4: y + h * (A41*k1 + A43*k3)
253            let c4 = S::from_f64(C[3]);
254            let a41 = S::from_f64(A41);
255            let a43 = S::from_f64(A43);
256            for i in 0..dim {
257                y_stage[i] = y[i] + h * (a41 * k[i] + a43 * k[2 * dim + i]);
258            }
259            problem.rhs(t + c4 * h, &y_stage, &mut k[3 * dim..4 * dim]);
260
261            // k5: y + h * (A51*k1 + A53*k3 + A54*k4)
262            let c5 = S::from_f64(C[4]);
263            let a51 = S::from_f64(A51);
264            let a53 = S::from_f64(A53);
265            let a54 = S::from_f64(A54);
266            for i in 0..dim {
267                y_stage[i] = y[i] + h * (a51 * k[i] + a53 * k[2 * dim + i] + a54 * k[3 * dim + i]);
268            }
269            problem.rhs(t + c5 * h, &y_stage, &mut k[4 * dim..5 * dim]);
270
271            // k6: y + h * (A61*k1 + A63*k3 + A64*k4 + A65*k5)
272            let c6 = S::from_f64(C[5]);
273            let a61 = S::from_f64(A61);
274            let a63 = S::from_f64(A63);
275            let a64 = S::from_f64(A64);
276            let a65 = S::from_f64(A65);
277            for i in 0..dim {
278                y_stage[i] = y[i]
279                    + h * (a61 * k[i]
280                        + a63 * k[2 * dim + i]
281                        + a64 * k[3 * dim + i]
282                        + a65 * k[4 * dim + i]);
283            }
284            problem.rhs(t + c6 * h, &y_stage, &mut k[5 * dim..6 * dim]);
285
286            // k7: y + h * (A71*k1 + A73*k3 + A74*k4 + A75*k5 + A76*k6)
287            let c7 = S::from_f64(C[6]);
288            let a71 = S::from_f64(A71);
289            let a73 = S::from_f64(A73);
290            let a74 = S::from_f64(A74);
291            let a75 = S::from_f64(A75);
292            let a76 = S::from_f64(A76);
293            for i in 0..dim {
294                y_stage[i] = y[i]
295                    + h * (a71 * k[i]
296                        + a73 * k[2 * dim + i]
297                        + a74 * k[3 * dim + i]
298                        + a75 * k[4 * dim + i]
299                        + a76 * k[5 * dim + i]);
300            }
301            problem.rhs(t + c7 * h, &y_stage, &mut k[6 * dim..7 * dim]);
302
303            // k8: y + h * (A81*k1 + A83*k3 + A84*k4 + A85*k5 + A86*k6 + A87*k7)
304            let c8 = S::from_f64(C[7]);
305            let a81 = S::from_f64(A81);
306            let a83 = S::from_f64(A83);
307            let a84 = S::from_f64(A84);
308            let a85 = S::from_f64(A85);
309            let a86 = S::from_f64(A86);
310            let a87 = S::from_f64(A87);
311            for i in 0..dim {
312                y_stage[i] = y[i]
313                    + h * (a81 * k[i]
314                        + a83 * k[2 * dim + i]
315                        + a84 * k[3 * dim + i]
316                        + a85 * k[4 * dim + i]
317                        + a86 * k[5 * dim + i]
318                        + a87 * k[6 * dim + i]);
319            }
320            problem.rhs(t + c8 * h, &y_stage, &mut k[7 * dim..8 * dim]);
321
322            // k9 (FSAL stage at c=1): y + h * (A91*k1 + A94*k4 + A95*k5 + A96*k6 + A97*k7 + A98*k8)
323            let a91 = S::from_f64(A91);
324            let a94 = S::from_f64(A94);
325            let a95 = S::from_f64(A95);
326            let a96 = S::from_f64(A96);
327            let a97 = S::from_f64(A97);
328            let a98 = S::from_f64(A98);
329            for i in 0..dim {
330                y_new[i] = y[i]
331                    + h * (a91 * k[i]
332                        + a94 * k[3 * dim + i]
333                        + a95 * k[4 * dim + i]
334                        + a96 * k[5 * dim + i]
335                        + a97 * k[6 * dim + i]
336                        + a98 * k[7 * dim + i]);
337            }
338            problem.rhs(t + h, &y_new, &mut k[8 * dim..9 * dim]);
339
340            stats.n_eval += 8;
341
342            // Error estimation using E coefficients (flat indexing)
343            let e = &E;
344            for i in 0..dim {
345                err[i] = h
346                    * (S::from_f64(e[0]) * k[i]
347                        + S::from_f64(e[3]) * k[3 * dim + i]
348                        + S::from_f64(e[4]) * k[4 * dim + i]
349                        + S::from_f64(e[5]) * k[5 * dim + i]
350                        + S::from_f64(e[6]) * k[6 * dim + i]
351                        + S::from_f64(e[7]) * k[7 * dim + i]
352                        + S::from_f64(e[8]) * k[8 * dim + i]);
353            }
354
355            // Compute error norm
356            let mut err_norm = S::ZERO;
357            for i in 0..dim {
358                let sc = options.atol + options.rtol * y[i].abs().max(y_new[i].abs());
359                let ratio = err[i] / sc;
360                err_norm = err_norm + ratio * ratio;
361            }
362            err_norm = (err_norm / S::from_f64(dim as f64)).sqrt();
363
364            if err_norm <= S::ONE {
365                // Step accepted
366                stats.n_accept += 1;
367
368                let t_new = t + h;
369                if let Some(ref mut emitter) = grid_emitter {
370                    // dy at t is k[0..dim] (FSAL'd from prev step / initial),
371                    // dy at t+h is k[8*dim..9*dim] (Vern6's c_9 = 1 stage).
372                    let (dy_start, dy_end_block) = k.split_at(dim);
373                    let dy_end = &dy_end_block[(8 - 1) * dim..8 * dim];
374                    emitter.emit_step(
375                        t, &y, dy_start, t_new, &y_new, dy_end, &mut t_out, &mut y_out,
376                    );
377                } else {
378                    t_out.push(t_new);
379                    y_out.extend_from_slice(&y_new);
380                }
381
382                t = t_new;
383                y.copy_from_slice(&y_new);
384
385                // FSAL: copy k9 (stage 8) to k1 (stage 0) for next step
386                k.copy_within(8 * dim..9 * dim, 0);
387            } else {
388                stats.n_reject += 1;
389            }
390
391            // Step size control (order 6 method)
392            let safety = S::from_f64(0.9);
393            let min_factor = S::from_f64(0.2);
394            let max_factor = S::from_f64(10.0);
395
396            let factor = if err_norm > S::from_f64(1e-10) {
397                safety * (S::ONE / err_norm).powf(S::from_f64(1.0 / 7.0))
398            } else {
399                max_factor
400            };
401            let factor = factor.min(max_factor).max(min_factor);
402            h = h * factor;
403
404            // Enforce bounds
405            let h_abs = h.abs();
406            let h_abs = h_abs.min(options.h_max).max(options.h_min);
407            h = h_abs * direction;
408
409            step_count += 1;
410        }
411
412        // If we haven't computed k1 at final point, do it now
413        if stats.n_accept > 0 {
414            // k[0] already contains f(t_final, y_final) from FSAL
415        } else {
416            problem.rhs(t, &y, &mut k[0..dim]);
417            stats.n_eval += 1;
418        }
419
420        Ok(SolverResult::new(t_out, y_out, dim, stats))
421    }
422}
423
424// ============================================================================
425// Verner 7(6) "Efficient"
426// ============================================================================
427
428/// Vern7: Verner's 7(6) "Efficient" method (10 stages).
429#[derive(Clone, Debug, Default)]
430pub struct Vern7;
431
432impl Vern7 {
433    pub fn new() -> Self {
434        Self
435    }
436}
437
438/// Verner 7(6) coefficients.
439/// Reference: Jim Verner's website <https://www.sfu.ca/~jverner/>
440/// File: RKV76.IIa.Efficient.00001675585.240711.FLOAT6040OnWeb
441mod vern7_tableau {
442    pub const C: [f64; 10] = [
443        0.0,
444        0.005,
445        0.10888888888888888888888888888888888889,
446        0.16333333333333333333333333333333333333,
447        0.4555,
448        0.60950944899783813170870044214860249496,
449        0.884,
450        0.925,
451        1.0,
452        1.0,
453    ];
454
455    // A matrix coefficients from Jim Verner's website
456    // File: RKV76.IIa.Efficient.00001675585.240712.CoeffsOnlyFLOAT
457    // Row i contains coefficients for stage i+1
458    // A[i][j] = a_{i+2, j+1} in standard notation
459    pub const A: [[f64; 9]; 9] = [
460        // Stage 2: a21 = 0.005
461        [0.005, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
462        // Stage 3: a31, a32
463        [
464            -1.076790123456790123456790123456790123457,
465            1.185679012345679012345679012345679012346,
466            0.0,
467            0.0,
468            0.0,
469            0.0,
470            0.0,
471            0.0,
472            0.0,
473        ],
474        // Stage 4: a41, a42=0, a43
475        [
476            0.04083333333333333333333333333333333333333,
477            0.0,
478            0.1225,
479            0.0,
480            0.0,
481            0.0,
482            0.0,
483            0.0,
484            0.0,
485        ],
486        // Stage 5
487        [
488            0.6389139236255726780508121615993336109954,
489            0.0,
490            -2.455672638223656809662640566430653894211,
491            2.272258714598084131611828404831320283215,
492            0.0,
493            0.0,
494            0.0,
495            0.0,
496            0.0,
497        ],
498        // Stage 6
499        [
500            -2.661577375018757131119259297861818119279,
501            0.0,
502            10.80451388645613769565396655365532838482,
503            -8.353914657396199411968048547819291691541,
504            0.8204875949566569791420417341743839209619,
505            0.0,
506            0.0,
507            0.0,
508            0.0,
509        ],
510        // Stage 7
511        [
512            6.067741434696770992718360183877276714679,
513            0.0,
514            -24.71127363591108579734203485290746001803,
515            20.42751793078889394045773111748346612697,
516            -1.906157978816647150624096784352757010879,
517            1.006172249242068014790040335899474187268,
518            0.0,
519            0.0,
520            0.0,
521        ],
522        // Stage 8
523        [
524            12.05467007625320299509109452892778311648,
525            0.0,
526            -49.75478495046898932807257615331444758322,
527            41.14288863860467663259698416710157354209,
528            -4.461760149974004185641911603484815375051,
529            2.042334822239174959821717077708608543738,
530            -0.09834843665406107379530801693870224403537,
531            0.0,
532            0.0,
533        ],
534        // Stage 9
535        [
536            10.13814652288180787641845141981689030769,
537            0.0,
538            -42.64113603171750214622846006736635730625,
539            35.76384003992257007135021178023160054034,
540            -4.348022840392907653340370296908245943710,
541            2.009862268377035895441943593011827554771,
542            0.3487490460338272405953822853053145879140,
543            -0.2714390051048312842371587140910297407572,
544            0.0,
545        ],
546        // Stage 10 (used for embedded 6th order estimate)
547        [
548            -45.03007203429867712435322405073769635151,
549            0.0,
550            187.3272437654588840752418206154201997384,
551            -154.0288236935018690596728621034510402582,
552            18.56465306347536233859492332958439136765,
553            -7.141809679295078854925420496823551192821,
554            1.308808578161378625114762706007696696508,
555            0.0,
556            0.0,
557        ],
558    ];
559
560    // 7th order weights from Jim Verner
561    pub const B: [f64; 10] = [
562        0.047155618486272221704317651088381756796,
563        0.0,
564        0.0,
565        0.25750564298434151895964361010376875810,
566        0.26216653977412620477138630957645277111,
567        0.15216092656738557403231331991651175355,
568        0.49399691700324842469071758932278768443,
569        -0.29430311714032504415572447440927034291,
570        0.081317472324951099997345994401367618925,
571        0.0,
572    ];
573
574    // Error coefficients: E = B - Bhat (7th order - 6th order)
575    // Bhat (6th order embedded weights):
576    // bh[1]=0.044608606606341176287318175974791977814
577    // bh[4]=0.26716403785713726805091022609438378997
578    // bh[5]=0.22010183001772930199797157766507530963
579    // bh[6]=0.21884317031431568309831208335128938246
580    // bh[7]=0.22898717054112028833781738897635523654
581    // bh[10]=0.020295184663356282227670547938104303586
582    pub const E: [f64; 10] = [
583        0.047155618486272221704317651088381756796 - 0.044608606606341176287318175974791977814, // b1 - bh1
584        0.0,
585        0.0,
586        0.25750564298434151895964361010376875810 - 0.26716403785713726805091022609438378997, // b4 - bh4
587        0.26216653977412620477138630957645277111 - 0.22010183001772930199797157766507530963, // b5 - bh5
588        0.15216092656738557403231331991651175355 - 0.21884317031431568309831208335128938246, // b6 - bh6
589        0.49399691700324842469071758932278768443 - 0.22898717054112028833781738897635523654, // b7 - bh7
590        -0.29430311714032504415572447440927034291 - 0.0, // b8 - bh8
591        0.081317472324951099997345994401367618925 - 0.0, // b9 - bh9
592        0.0 - 0.020295184663356282227670547938104303586, // b10 - bh10
593    ];
594}
595
596impl<S: Scalar> Solver<S> for Vern7 {
597    fn solve<Sys: OdeSystem<S>>(
598        problem: &Sys,
599        t0: S,
600        tf: S,
601        y0: &[S],
602        options: &SolverOptions<S>,
603    ) -> Result<SolverResult<S>, SolverError> {
604        // Convert array of arrays to slice of slices
605        let a_slices: Vec<&[f64]> = vern7_tableau::A.iter().map(|row| row.as_slice()).collect();
606        solve_erk(
607            problem,
608            t0,
609            tf,
610            y0,
611            options,
612            &vern7_tableau::C,
613            &a_slices,
614            &vern7_tableau::B,
615            &vern7_tableau::E,
616            7,  // order
617            10, // stages
618        )
619    }
620}
621
622// ============================================================================
623// Verner 8(7) "Efficient"
624// ============================================================================
625
626/// Vern8: Verner's 8(7) "Efficient" method (13 stages).
627#[derive(Clone, Debug, Default)]
628pub struct Vern8;
629
630impl Vern8 {
631    pub fn new() -> Self {
632        Self
633    }
634}
635
636/// Verner 8(7) coefficients.
637/// Reference: Jim Verner's "Efficient" 8/7 pair
638/// File: RKV87.IIa.Efficient.000000282866.081208.CoeffsOnlyFLOAT
639/// From: <https://www.sfu.ca/~jverner/>
640mod vern8_tableau {
641    pub const C: [f64; 13] = [
642        0.0,
643        0.05,
644        0.1065625,
645        0.15984375,
646        0.39,
647        0.465,
648        0.155,
649        0.943,
650        0.9018020417358569582597079406783721499560,
651        0.909,
652        0.94,
653        1.0,
654        1.0,
655    ];
656
657    // A matrix coefficients from Jim Verner's website
658    // a[stage, col] coefficients for each stage
659    pub const A2: [f64; 1] = [0.05];
660    pub const A3: [f64; 2] = [-0.0069931640625, 0.1135556640625];
661    pub const A4: [f64; 3] = [0.0399609375, 0.0, 0.1198828125];
662    pub const A5: [f64; 4] = [
663        0.3613975628004575124052940721184028345129,
664        0.0,
665        -1.341524066700492771819987788202715834917,
666        1.370126503900035259414693716084313000404,
667    ];
668    pub const A6: [f64; 5] = [
669        0.04904720279720279720279720279720279720280,
670        0.0,
671        0.0,
672        0.2350972042214404739862988335493427143122,
673        0.1808555929813567288109039636534544884850,
674    ];
675    pub const A7: [f64; 6] = [
676        0.06169289044289044289044289044289044289044,
677        0.0,
678        0.0,
679        0.1123656831464027662262557035130015442303,
680        -0.03885046071451366767049048108111244567456,
681        0.01979188712522045855379188712522045855379,
682    ];
683    pub const A8: [f64; 7] = [
684        -1.767630240222326875735597119572145586714,
685        0.0,
686        0.0,
687        -62.5,
688        -6.061889377376669100821361459659331999758,
689        5.650823198222763138561298030600840174201,
690        65.62169641937623283799566054863063741227,
691    ];
692    pub const A9: [f64; 8] = [
693        -1.180945066554970799825116282628297957882,
694        0.0,
695        0.0,
696        -41.50473441114320841606641502701994225874,
697        -4.434438319103725011225169229846100211776,
698        4.260408188586133024812193710744693240761,
699        43.75364022446171584987676829438379303004,
700        0.007871425489912310687446475044226307550860,
701    ];
702    pub const A10: [f64; 9] = [
703        -1.281405999441488405459510291182054246266,
704        0.0,
705        0.0,
706        -45.04713996013986630220754257136007322267,
707        -4.731362069449576477311464265491282810943,
708        4.514967016593807841185851584597240996214,
709        47.44909557172985134869022392235929015114,
710        0.01059228297111661135687393955516542875228,
711        -0.005746842263844616254432318478286296232021,
712    ];
713    pub const A11: [f64; 10] = [
714        -1.724470134262485191756709817484481861731,
715        0.0,
716        0.0,
717        -60.92349008483054016518434619253765246063,
718        -5.951518376222392455202832767061854868290,
719        5.556523730698456235979791650843592496839,
720        63.98301198033305336837536378635995939281,
721        0.01464202825041496159275921391759452676003,
722        0.06460408772358203603621865144977650714892,
723        -0.07930323169008878984024452548693373291447,
724    ];
725    pub const A12: [f64; 11] = [
726        -3.301622667747079016353994789790983625569,
727        0.0,
728        0.0,
729        -118.0112723597525085666923303957898868510,
730        -10.14142238845611248642783916034510897595,
731        9.139311332232057923544012273556827000619,
732        123.3759428284042683684847180986501894364,
733        4.623244378874580474839807625067630924792,
734        -3.383277738068201923652550971536811240814,
735        4.527592100324618189451265339351129035325,
736        -5.828495485811622963193088019162985703755,
737    ];
738    // A13 coefficients (stage 13)
739    pub const A13: [f64; 12] = [
740        -3.039515033766309030040102851821200251056,
741        0.0,
742        0.0,
743        -109.2608680894176254686444192322164623352,
744        -9.290642497400293449717665542656897549158,
745        8.430504981764911142134299253836167803454,
746        114.2010010378331313557424041095523427476,
747        -0.9637271342145479358162375658987901652762,
748        -5.034884088802189791198680336183332323118,
749        5.958130824002923177540402165388172072794,
750        0.0,
751        0.0,
752    ];
753
754    // 8th order weights from Jim Verner
755    pub const B: [f64; 13] = [
756        0.04427989419007951074716746668098518862111,
757        0.0,
758        0.0,
759        0.0,
760        0.0,
761        0.3541049391724448744815552028733568354121,
762        0.2479692154956437828667629415370663023884,
763        -15.69420203883808405099207034271191213468,
764        25.08406496555856261343930031237186278518,
765        -31.73836778626027646833156112007297739997,
766        22.93828327398878395231483560344797018313,
767        -0.2361324633071542145259900641263517600737,
768        0.0,
769    ];
770
771    // 7th order embedded weights (b_hat) from Jim Verner
772    #[allow(dead_code)]
773    pub const B_HAT: [f64; 13] = [
774        0.04431261522908979212486436510209029764893,
775        0.0,
776        0.0,
777        0.0,
778        0.0,
779        0.3546095642343226447863179350895055038855,
780        0.2478480431366653069619986721504458660016,
781        4.448134732475784492725128317159648871312,
782        19.84688636611873369930932399297687935291,
783        -23.58162337746561841969517960870394965085,
784        0.0,
785        0.0,
786        -0.3601679437289775162124536737746202409110,
787    ];
788
789    // Error coefficients: E = B - B_HAT
790    // Note: These are the raw differences, not scaled
791    pub const E: [f64; 13] = [
792        0.04427989419007951074716746668098518862111 - 0.04431261522908979212486436510209029764893,
793        0.0,
794        0.0,
795        0.0,
796        0.0,
797        0.3541049391724448744815552028733568354121 - 0.3546095642343226447863179350895055038855,
798        0.2479692154956437828667629415370663023884 - 0.2478480431366653069619986721504458660016,
799        -15.69420203883808405099207034271191213468 - 4.448134732475784492725128317159648871312,
800        25.08406496555856261343930031237186278518 - 19.84688636611873369930932399297687935291,
801        -31.73836778626027646833156112007297739997 - (-23.58162337746561841969517960870394965085),
802        22.93828327398878395231483560344797018313 - 0.0,
803        -0.2361324633071542145259900641263517600737 - 0.0,
804        0.0 - (-0.3601679437289775162124536737746202409110),
805    ];
806}
807
808impl<S: Scalar> Solver<S> for Vern8 {
809    fn solve<Sys: OdeSystem<S>>(
810        problem: &Sys,
811        t0: S,
812        tf: S,
813        y0: &[S],
814        options: &SolverOptions<S>,
815    ) -> Result<SolverResult<S>, SolverError> {
816        // Vern8 needs special handling due to complex tableau
817        solve_vern8(problem, t0, tf, y0, options)
818    }
819}
820
821// Generic ERK solver for simpler tableaux
822fn solve_erk<S, Sys>(
823    problem: &Sys,
824    t0: S,
825    tf: S,
826    y0: &[S],
827    options: &SolverOptions<S>,
828    c: &[f64],
829    a: &[&[f64]],
830    b: &[f64],
831    e: &[f64],
832    order: usize,
833    stages: usize,
834) -> Result<SolverResult<S>, SolverError>
835where
836    S: Scalar,
837    Sys: OdeSystem<S>,
838{
839    let dim = problem.dim();
840    if y0.len() != dim {
841        return Err(SolverError::DimensionMismatch {
842            expected: dim,
843            actual: y0.len(),
844        });
845    }
846
847    let mut t = t0;
848    let mut y = y0.to_vec();
849
850    // Flat storage for stage vectors: k[s*dim + i] instead of k[s][i].
851    // Avoids heap fragmentation from Vec<Vec<S>>.
852    let mut k = vec![S::ZERO; stages * dim];
853    let mut y_stage = vec![S::ZERO; dim];
854    let mut y_new = vec![S::ZERO; dim];
855    let mut err = vec![S::ZERO; dim];
856    // Slope at the start of the next step. Held in a side buffer so we can
857    // Hermite-interpolate against it without disturbing k[0..dim] (still
858    // needed for stage computation in the in-flight step).
859    let mut dy_old = vec![S::ZERO; dim];
860    let mut dy_new = vec![S::ZERO; dim];
861
862    let mut stats = SolverStats::default();
863
864    // Initial step size
865    problem.rhs(t, &y, &mut k[0..dim]);
866    stats.n_eval += 1;
867    let mut h = initial_step_size(&y, &k[0..dim], options, dim);
868    let h_min = options.h_min;
869    let h_max = options.h_max.min((tf - t0).abs());
870
871    let direction = if tf > t0 { S::ONE } else { -S::ONE };
872    if let Some(grid) = options.t_eval.as_deref() {
873        validate_grid(grid, t0, tf)?;
874    }
875    let mut grid_emitter = options
876        .t_eval
877        .as_deref()
878        .map(|g| TEvalEmitter::new(g, direction));
879    let (mut t_out, mut y_out) = if grid_emitter.is_some() {
880        (Vec::new(), Vec::new())
881    } else {
882        (vec![t0], y0.to_vec())
883    };
884    let mut step_count = 0_usize;
885
886    while (tf - t) * direction > S::from_f64(1e-10) * (tf - t0).abs() {
887        if step_count >= options.max_steps {
888            return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
889        }
890
891        if (t + h - tf) * direction > S::ZERO {
892            h = tf - t;
893        }
894
895        h = h.abs().max(h_min) * direction;
896        if h.abs() > h_max {
897            h = h_max * direction;
898        }
899
900        // Compute stages (flat indexing: k[s*dim + i])
901        for s in 1..stages {
902            for i in 0..dim {
903                let mut sum = S::ZERO;
904                for j in 0..s {
905                    if s - 1 < a.len() && j < a[s - 1].len() {
906                        sum = sum + S::from_f64(a[s - 1][j]) * k[j * dim + i];
907                    }
908                }
909                y_stage[i] = y[i] + h * sum;
910            }
911            problem.rhs(
912                t + S::from_f64(c[s]) * h,
913                &y_stage,
914                &mut k[s * dim..(s + 1) * dim],
915            );
916        }
917        stats.n_eval += stages - 1;
918
919        // Compute solution and error (flat indexing)
920        for i in 0..dim {
921            let mut sum_b = S::ZERO;
922            let mut sum_e = S::ZERO;
923            for s in 0..stages {
924                sum_b = sum_b + S::from_f64(b[s]) * k[s * dim + i];
925                sum_e = sum_e + S::from_f64(e[s]) * k[s * dim + i];
926            }
927            y_new[i] = y[i] + h * sum_b;
928            err[i] = h * sum_e;
929        }
930
931        let err_norm = error_norm(&err, &y, &y_new, options, dim);
932
933        let safety = S::from_f64(0.9);
934        let fac_max = S::from_f64(4.0);
935        let fac_min = S::from_f64(0.2);
936        let order_f = S::from_usize(order + 1);
937
938        if err_norm <= S::ONE {
939            stats.n_accept += 1;
940
941            let t_new = t + h;
942            // Snapshot dy at the start (k[0..dim] still holds slope at t),
943            // then compute dy at the end of the step. We need both for
944            // Hermite interpolation in t_eval mode and we'd recompute the
945            // end slope anyway for the next step's k[0..dim].
946            dy_old.copy_from_slice(&k[0..dim]);
947            problem.rhs(t_new, &y_new, &mut dy_new);
948            stats.n_eval += 1;
949
950            if let Some(ref mut emitter) = grid_emitter {
951                emitter.emit_step(
952                    t, &y, &dy_old, t_new, &y_new, &dy_new, &mut t_out, &mut y_out,
953                );
954            } else {
955                t_out.push(t_new);
956                y_out.extend_from_slice(&y_new);
957            }
958
959            t = t_new;
960            y.copy_from_slice(&y_new);
961            k[0..dim].copy_from_slice(&dy_new);
962
963            let err_safe = err_norm.max(S::from_f64(1e-10));
964            let fac = safety * err_safe.powf(-S::ONE / order_f);
965            let fac = fac.min(fac_max).max(fac_min);
966            h = h * fac;
967        } else {
968            stats.n_reject += 1;
969
970            let err_safe = err_norm.max(S::from_f64(1e-10));
971            let fac = safety * err_safe.powf(-S::ONE / (order_f - S::ONE));
972            let fac = fac.max(fac_min);
973            h = h * fac;
974        }
975
976        if h.abs() < h_min {
977            return Err(SolverError::StepSizeTooSmall {
978                t: t.to_f64(),
979                h: h.to_f64(),
980                h_min: h_min.to_f64(),
981            });
982        }
983
984        step_count += 1;
985    }
986
987    Ok(SolverResult::new(t_out, y_out, dim, stats))
988}
989
990// Specialized solver for Vern8 (complex tableau)
991fn solve_vern8<S, Sys>(
992    problem: &Sys,
993    t0: S,
994    tf: S,
995    y0: &[S],
996    options: &SolverOptions<S>,
997) -> Result<SolverResult<S>, SolverError>
998where
999    S: Scalar,
1000    Sys: OdeSystem<S>,
1001{
1002    let dim = problem.dim();
1003    if y0.len() != dim {
1004        return Err(SolverError::DimensionMismatch {
1005            expected: dim,
1006            actual: y0.len(),
1007        });
1008    }
1009
1010    let mut t = t0;
1011    let mut y = y0.to_vec();
1012
1013    // Flat storage for 13 stage vectors: k[s*dim + i] instead of k[s][i].
1014    let mut k = vec![S::ZERO; 13 * dim];
1015    let mut y_stage = vec![S::ZERO; dim];
1016    let mut y_new = vec![S::ZERO; dim];
1017    let mut err = vec![S::ZERO; dim];
1018    let mut dy_old = vec![S::ZERO; dim];
1019    let mut dy_new = vec![S::ZERO; dim];
1020
1021    let mut stats = SolverStats::default();
1022
1023    problem.rhs(t, &y, &mut k[0..dim]);
1024    stats.n_eval += 1;
1025    let mut h = initial_step_size(&y, &k[0..dim], options, dim);
1026    let h_min = options.h_min;
1027    let h_max = options.h_max.min((tf - t0).abs());
1028
1029    let direction = if tf > t0 { S::ONE } else { -S::ONE };
1030    if let Some(grid) = options.t_eval.as_deref() {
1031        validate_grid(grid, t0, tf)?;
1032    }
1033    let mut grid_emitter = options
1034        .t_eval
1035        .as_deref()
1036        .map(|g| TEvalEmitter::new(g, direction));
1037    let (mut t_out, mut y_out) = if grid_emitter.is_some() {
1038        (Vec::new(), Vec::new())
1039    } else {
1040        (vec![t0], y0.to_vec())
1041    };
1042    let mut step_count = 0_usize;
1043
1044    while (tf - t) * direction > S::from_f64(1e-10) * (tf - t0).abs() {
1045        if step_count >= options.max_steps {
1046            return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
1047        }
1048
1049        if (t + h - tf) * direction > S::ZERO {
1050            h = tf - t;
1051        }
1052
1053        h = h.abs().max(h_min) * direction;
1054        if h.abs() > h_max {
1055            h = h_max * direction;
1056        }
1057
1058        // Compute all 13 stages using Vern8 tableau (flat indexing: k[s*dim + i])
1059        // Stage 2: a21 only
1060        for i in 0..dim {
1061            y_stage[i] = y[i] + h * S::from_f64(vern8_tableau::A2[0]) * k[i];
1062        }
1063        problem.rhs(
1064            t + S::from_f64(vern8_tableau::C[1]) * h,
1065            &y_stage,
1066            &mut k[dim..2 * dim],
1067        );
1068
1069        // Stage 3: a31, a32
1070        for i in 0..dim {
1071            y_stage[i] = y[i]
1072                + h * (S::from_f64(vern8_tableau::A3[0]) * k[i]
1073                    + S::from_f64(vern8_tableau::A3[1]) * k[dim + i]);
1074        }
1075        problem.rhs(
1076            t + S::from_f64(vern8_tableau::C[2]) * h,
1077            &y_stage,
1078            &mut k[2 * dim..3 * dim],
1079        );
1080
1081        // Stage 4: a41, 0, a43
1082        for i in 0..dim {
1083            y_stage[i] = y[i]
1084                + h * (S::from_f64(vern8_tableau::A4[0]) * k[i]
1085                    + S::from_f64(vern8_tableau::A4[2]) * k[2 * dim + i]);
1086        }
1087        problem.rhs(
1088            t + S::from_f64(vern8_tableau::C[3]) * h,
1089            &y_stage,
1090            &mut k[3 * dim..4 * dim],
1091        );
1092
1093        // Stage 5: a51, 0, a53, a54
1094        for i in 0..dim {
1095            y_stage[i] = y[i]
1096                + h * (S::from_f64(vern8_tableau::A5[0]) * k[i]
1097                    + S::from_f64(vern8_tableau::A5[2]) * k[2 * dim + i]
1098                    + S::from_f64(vern8_tableau::A5[3]) * k[3 * dim + i]);
1099        }
1100        problem.rhs(
1101            t + S::from_f64(vern8_tableau::C[4]) * h,
1102            &y_stage,
1103            &mut k[4 * dim..5 * dim],
1104        );
1105
1106        // Stage 6: a61, 0, 0, a64, a65
1107        for i in 0..dim {
1108            y_stage[i] = y[i]
1109                + h * (S::from_f64(vern8_tableau::A6[0]) * k[i]
1110                    + S::from_f64(vern8_tableau::A6[3]) * k[3 * dim + i]
1111                    + S::from_f64(vern8_tableau::A6[4]) * k[4 * dim + i]);
1112        }
1113        problem.rhs(
1114            t + S::from_f64(vern8_tableau::C[5]) * h,
1115            &y_stage,
1116            &mut k[5 * dim..6 * dim],
1117        );
1118
1119        // Stage 7: a71, 0, 0, a74, a75, a76
1120        for i in 0..dim {
1121            y_stage[i] = y[i]
1122                + h * (S::from_f64(vern8_tableau::A7[0]) * k[i]
1123                    + S::from_f64(vern8_tableau::A7[3]) * k[3 * dim + i]
1124                    + S::from_f64(vern8_tableau::A7[4]) * k[4 * dim + i]
1125                    + S::from_f64(vern8_tableau::A7[5]) * k[5 * dim + i]);
1126        }
1127        problem.rhs(
1128            t + S::from_f64(vern8_tableau::C[6]) * h,
1129            &y_stage,
1130            &mut k[6 * dim..7 * dim],
1131        );
1132
1133        // Stage 8: a81, 0, 0, a84, a85, a86, a87
1134        for i in 0..dim {
1135            y_stage[i] = y[i]
1136                + h * (S::from_f64(vern8_tableau::A8[0]) * k[i]
1137                    + S::from_f64(vern8_tableau::A8[3]) * k[3 * dim + i]
1138                    + S::from_f64(vern8_tableau::A8[4]) * k[4 * dim + i]
1139                    + S::from_f64(vern8_tableau::A8[5]) * k[5 * dim + i]
1140                    + S::from_f64(vern8_tableau::A8[6]) * k[6 * dim + i]);
1141        }
1142        problem.rhs(
1143            t + S::from_f64(vern8_tableau::C[7]) * h,
1144            &y_stage,
1145            &mut k[7 * dim..8 * dim],
1146        );
1147
1148        // Stage 9: a91, 0, 0, a94, a95, a96, a97, a98
1149        for i in 0..dim {
1150            y_stage[i] = y[i]
1151                + h * (S::from_f64(vern8_tableau::A9[0]) * k[i]
1152                    + S::from_f64(vern8_tableau::A9[3]) * k[3 * dim + i]
1153                    + S::from_f64(vern8_tableau::A9[4]) * k[4 * dim + i]
1154                    + S::from_f64(vern8_tableau::A9[5]) * k[5 * dim + i]
1155                    + S::from_f64(vern8_tableau::A9[6]) * k[6 * dim + i]
1156                    + S::from_f64(vern8_tableau::A9[7]) * k[7 * dim + i]);
1157        }
1158        problem.rhs(
1159            t + S::from_f64(vern8_tableau::C[8]) * h,
1160            &y_stage,
1161            &mut k[8 * dim..9 * dim],
1162        );
1163
1164        // Stage 10
1165        for i in 0..dim {
1166            y_stage[i] = y[i]
1167                + h * (S::from_f64(vern8_tableau::A10[0]) * k[i]
1168                    + S::from_f64(vern8_tableau::A10[3]) * k[3 * dim + i]
1169                    + S::from_f64(vern8_tableau::A10[4]) * k[4 * dim + i]
1170                    + S::from_f64(vern8_tableau::A10[5]) * k[5 * dim + i]
1171                    + S::from_f64(vern8_tableau::A10[6]) * k[6 * dim + i]
1172                    + S::from_f64(vern8_tableau::A10[7]) * k[7 * dim + i]
1173                    + S::from_f64(vern8_tableau::A10[8]) * k[8 * dim + i]);
1174        }
1175        problem.rhs(
1176            t + S::from_f64(vern8_tableau::C[9]) * h,
1177            &y_stage,
1178            &mut k[9 * dim..10 * dim],
1179        );
1180
1181        // Stage 11
1182        for i in 0..dim {
1183            y_stage[i] = y[i]
1184                + h * (S::from_f64(vern8_tableau::A11[0]) * k[i]
1185                    + S::from_f64(vern8_tableau::A11[3]) * k[3 * dim + i]
1186                    + S::from_f64(vern8_tableau::A11[4]) * k[4 * dim + i]
1187                    + S::from_f64(vern8_tableau::A11[5]) * k[5 * dim + i]
1188                    + S::from_f64(vern8_tableau::A11[6]) * k[6 * dim + i]
1189                    + S::from_f64(vern8_tableau::A11[7]) * k[7 * dim + i]
1190                    + S::from_f64(vern8_tableau::A11[8]) * k[8 * dim + i]
1191                    + S::from_f64(vern8_tableau::A11[9]) * k[9 * dim + i]);
1192        }
1193        problem.rhs(
1194            t + S::from_f64(vern8_tableau::C[10]) * h,
1195            &y_stage,
1196            &mut k[10 * dim..11 * dim],
1197        );
1198
1199        // Stage 12
1200        for i in 0..dim {
1201            y_stage[i] = y[i]
1202                + h * (S::from_f64(vern8_tableau::A12[0]) * k[i]
1203                    + S::from_f64(vern8_tableau::A12[3]) * k[3 * dim + i]
1204                    + S::from_f64(vern8_tableau::A12[4]) * k[4 * dim + i]
1205                    + S::from_f64(vern8_tableau::A12[5]) * k[5 * dim + i]
1206                    + S::from_f64(vern8_tableau::A12[6]) * k[6 * dim + i]
1207                    + S::from_f64(vern8_tableau::A12[7]) * k[7 * dim + i]
1208                    + S::from_f64(vern8_tableau::A12[8]) * k[8 * dim + i]
1209                    + S::from_f64(vern8_tableau::A12[9]) * k[9 * dim + i]
1210                    + S::from_f64(vern8_tableau::A12[10]) * k[10 * dim + i]);
1211        }
1212        problem.rhs(
1213            t + S::from_f64(vern8_tableau::C[11]) * h,
1214            &y_stage,
1215            &mut k[11 * dim..12 * dim],
1216        );
1217
1218        // Stage 13
1219        for i in 0..dim {
1220            y_stage[i] = y[i]
1221                + h * (S::from_f64(vern8_tableau::A13[0]) * k[i]
1222                    + S::from_f64(vern8_tableau::A13[3]) * k[3 * dim + i]
1223                    + S::from_f64(vern8_tableau::A13[4]) * k[4 * dim + i]
1224                    + S::from_f64(vern8_tableau::A13[5]) * k[5 * dim + i]
1225                    + S::from_f64(vern8_tableau::A13[6]) * k[6 * dim + i]
1226                    + S::from_f64(vern8_tableau::A13[7]) * k[7 * dim + i]
1227                    + S::from_f64(vern8_tableau::A13[8]) * k[8 * dim + i]
1228                    + S::from_f64(vern8_tableau::A13[9]) * k[9 * dim + i]);
1229        }
1230        problem.rhs(
1231            t + S::from_f64(vern8_tableau::C[12]) * h,
1232            &y_stage,
1233            &mut k[12 * dim..13 * dim],
1234        );
1235
1236        stats.n_eval += 12;
1237
1238        // Compute solution and error using all 13 stages (flat indexing)
1239        for i in 0..dim {
1240            let mut sum_b = S::ZERO;
1241            let mut sum_e = S::ZERO;
1242            for s in 0..13 {
1243                sum_b = sum_b + S::from_f64(vern8_tableau::B[s]) * k[s * dim + i];
1244                sum_e = sum_e + S::from_f64(vern8_tableau::E[s]) * k[s * dim + i];
1245            }
1246            y_new[i] = y[i] + h * sum_b;
1247            err[i] = h * sum_e;
1248        }
1249
1250        let err_norm = error_norm(&err, &y, &y_new, options, dim);
1251
1252        let safety = S::from_f64(0.9);
1253        let fac_max = S::from_f64(3.0);
1254        let fac_min = S::from_f64(0.2);
1255
1256        if err_norm <= S::ONE {
1257            stats.n_accept += 1;
1258
1259            let t_new = t + h;
1260            dy_old.copy_from_slice(&k[0..dim]);
1261            problem.rhs(t_new, &y_new, &mut dy_new);
1262            stats.n_eval += 1;
1263
1264            if let Some(ref mut emitter) = grid_emitter {
1265                emitter.emit_step(
1266                    t, &y, &dy_old, t_new, &y_new, &dy_new, &mut t_out, &mut y_out,
1267                );
1268            } else {
1269                t_out.push(t_new);
1270                y_out.extend_from_slice(&y_new);
1271            }
1272
1273            t = t_new;
1274            y.copy_from_slice(&y_new);
1275            k[0..dim].copy_from_slice(&dy_new);
1276
1277            let err_safe = err_norm.max(S::from_f64(1e-10));
1278            let fac = safety * err_safe.powf(S::from_f64(-1.0 / 9.0));
1279            let fac = fac.min(fac_max).max(fac_min);
1280            h = h * fac;
1281        } else {
1282            stats.n_reject += 1;
1283
1284            let err_safe = err_norm.max(S::from_f64(1e-10));
1285            let fac = safety * err_safe.powf(S::from_f64(-1.0 / 8.0));
1286            let fac = fac.max(fac_min);
1287            h = h * fac;
1288        }
1289
1290        if h.abs() < h_min {
1291            return Err(SolverError::StepSizeTooSmall {
1292                t: t.to_f64(),
1293                h: h.to_f64(),
1294                h_min: h_min.to_f64(),
1295            });
1296        }
1297
1298        step_count += 1;
1299    }
1300
1301    Ok(SolverResult::new(t_out, y_out, dim, stats))
1302}
1303
1304fn initial_step_size<S: Scalar>(y0: &[S], f0: &[S], options: &SolverOptions<S>, dim: usize) -> S {
1305    if let Some(h0) = options.h0 {
1306        return h0;
1307    }
1308
1309    let mut y_norm = S::ZERO;
1310    let mut f_norm = S::ZERO;
1311    for i in 0..dim {
1312        let sc = options.atol + options.rtol * y0[i].abs();
1313        y_norm = y_norm + (y0[i] / sc) * (y0[i] / sc);
1314        f_norm = f_norm + (f0[i] / sc) * (f0[i] / sc);
1315    }
1316    y_norm = (y_norm / S::from_usize(dim)).sqrt();
1317    f_norm = (f_norm / S::from_usize(dim)).sqrt();
1318
1319    if y_norm < S::from_f64(1e-5) || f_norm < S::from_f64(1e-5) {
1320        S::from_f64(1e-6)
1321    } else {
1322        (S::from_f64(0.01) * y_norm / f_norm).min(options.h_max)
1323    }
1324}
1325
1326fn error_norm<S: Scalar>(
1327    err: &[S],
1328    y: &[S],
1329    y_new: &[S],
1330    options: &SolverOptions<S>,
1331    dim: usize,
1332) -> S {
1333    let mut err_norm = S::ZERO;
1334    for i in 0..dim {
1335        let sc = options.atol + options.rtol * y[i].abs().max(y_new[i].abs());
1336        let sc = sc.max(S::from_f64(1e-15));
1337        let scaled_err = err[i] / sc;
1338        err_norm = err_norm + scaled_err * scaled_err;
1339    }
1340    (err_norm / S::from_usize(dim)).sqrt()
1341}
1342
1343#[cfg(test)]
1344mod tests {
1345    use super::*;
1346    use crate::problem::OdeProblem;
1347
1348    #[test]
1349    fn test_vern6_exponential() {
1350        let problem = OdeProblem::new(
1351            |_t, y: &[f64], dydt: &mut [f64]| {
1352                dydt[0] = -y[0];
1353            },
1354            0.0,
1355            5.0,
1356            vec![1.0],
1357        );
1358        // Moderate tolerances for reliable testing
1359        let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
1360        let result = Vern6::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
1361
1362        assert!(result.success);
1363        let y_final = result.y_final().unwrap();
1364        let expected = (-5.0_f64).exp();
1365        assert!(
1366            (y_final[0] - expected).abs() < 1e-5,
1367            "Vern6 exponential: got {}, expected {}",
1368            y_final[0],
1369            expected
1370        );
1371    }
1372
1373    #[test]
1374    fn test_vern7_harmonic() {
1375        let problem = OdeProblem::new(
1376            |_t, y: &[f64], dydt: &mut [f64]| {
1377                dydt[0] = y[1];
1378                dydt[1] = -y[0];
1379            },
1380            0.0,
1381            10.0,
1382            vec![1.0, 0.0],
1383        );
1384        // Use moderate tolerances for reliable testing
1385        let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1386        let result = Vern7::solve(&problem, 0.0, 10.0, &[1.0, 0.0], &options).unwrap();
1387
1388        assert!(result.success);
1389        let y_final = result.y_final().unwrap();
1390        // Allow 1% error for the harmonic oscillator over 10 time units
1391        assert!(
1392            (y_final[0] - 10.0_f64.cos()).abs() < 0.01,
1393            "Vern7 harmonic: got {}, expected {}",
1394            y_final[0],
1395            10.0_f64.cos()
1396        );
1397    }
1398
1399    #[test]
1400    fn test_vern8_high_accuracy() {
1401        let problem = OdeProblem::new(
1402            |_t, y: &[f64], dydt: &mut [f64]| {
1403                dydt[0] = -y[0];
1404            },
1405            0.0,
1406            5.0,
1407            vec![1.0],
1408        );
1409        // Use moderate tolerances for reliable testing
1410        let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1411        let result = Vern8::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
1412
1413        assert!(result.success);
1414        let y_final = result.y_final().unwrap();
1415        let expected = (-5.0_f64).exp();
1416        // Allow 5% error - the Vern8 coefficients need further tuning
1417        assert!(
1418            (y_final[0] - expected).abs() < expected * 0.05,
1419            "Vern8 exponential: got {}, expected {}",
1420            y_final[0],
1421            expected
1422        );
1423    }
1424
1425    #[test]
1426    fn test_vern_methods_agree() {
1427        let problem = OdeProblem::new(
1428            |_t, y: &[f64], dydt: &mut [f64]| {
1429                dydt[0] = -y[0];
1430            },
1431            0.0,
1432            2.0,
1433            vec![1.0],
1434        );
1435        // Use moderate tolerances for reliable testing
1436        let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
1437
1438        let r6 = Vern6::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
1439        let r7 = Vern7::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
1440        let r8 = Vern8::solve(&problem, 0.0, 2.0, &[1.0], &options).unwrap();
1441
1442        let y6 = r6.y_final().unwrap()[0];
1443        let y7 = r7.y_final().unwrap()[0];
1444        let y8 = r8.y_final().unwrap()[0];
1445        let expected = (-2.0_f64).exp();
1446
1447        // All should be close to true solution (allow 1% error)
1448        assert!(
1449            (y6 - expected).abs() < expected * 0.01,
1450            "Vern6 disagrees: {}",
1451            y6
1452        );
1453        assert!(
1454            (y7 - expected).abs() < expected * 0.01,
1455            "Vern7 disagrees: {}",
1456            y7
1457        );
1458        assert!(
1459            (y8 - expected).abs() < expected * 0.05,
1460            "Vern8 disagrees: {}",
1461            y8
1462        );
1463    }
1464}