Skip to main content

numra_ode/
tsit5.rs

1//! Tsit5: Tsitouras 5(4) explicit Runge-Kutta method.
2//!
3//! A modern 5th order explicit RK method with optimized coefficients
4//! for efficient error estimation.
5//!
6//! ## Features
7//! - 5th order accuracy with embedded 4th order for error estimation
8//! - 7 stages (FSAL - First Same As Last)
9//! - Optimized coefficients for typical ODE problems
10//!
11//! ## Reference
12//! Tsitouras, Ch. (2011), "Runge-Kutta pairs of order 5(4) satisfying only the
13//! first column simplifying assumption", Computers & Mathematics with Applications.
14//!
15//! Author: Moussa Leblouba
16//! Date: 4 February 2026
17//! Modified: 2 May 2026
18
19use crate::error::SolverError;
20use crate::problem::OdeSystem;
21use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
22use crate::t_eval::{validate_grid, TEvalEmitter};
23use numra_core::Scalar;
24
25/// Tsit5 solver: Tsitouras 5(4) method.
26#[derive(Clone, Debug, Default)]
27pub struct Tsit5;
28
29impl Tsit5 {
30    /// Create a new Tsit5 solver.
31    pub fn new() -> Self {
32        Self
33    }
34}
35
36/// Tsitouras 5(4) coefficients.
37#[allow(dead_code)]
38mod tableau {
39    // Nodes (c_i)
40    pub const C2: f64 = 0.161;
41    pub const C3: f64 = 0.327;
42    pub const C4: f64 = 0.9;
43    pub const C5: f64 = 0.9800255409045097;
44    pub const C6: f64 = 1.0;
45    pub const C7: f64 = 1.0;
46
47    // A matrix (lower triangular, row by row)
48    pub const A21: f64 = 0.161;
49
50    pub const A31: f64 = -0.008480655492356989;
51    pub const A32: f64 = 0.335480655492357;
52
53    pub const A41: f64 = 2.8971530571054935;
54    pub const A42: f64 = -6.359448489975075;
55    pub const A43: f64 = 4.3622954328695815;
56
57    pub const A51: f64 = 5.325864828439257;
58    pub const A52: f64 = -11.748883564062828;
59    pub const A53: f64 = 7.4955393428898365;
60    pub const A54: f64 = -0.09249506636175525;
61
62    pub const A61: f64 = 5.86145544294642;
63    pub const A62: f64 = -12.92096931784711;
64    pub const A63: f64 = 8.159367898576159;
65    pub const A64: f64 = -0.071584973281401;
66    pub const A65: f64 = -0.028269050394068383;
67
68    pub const A71: f64 = 0.09646076681806523;
69    pub const A72: f64 = 0.01;
70    pub const A73: f64 = 0.4798896504144996;
71    pub const A74: f64 = 1.379008574103742;
72    pub const A75: f64 = -3.290069515436081;
73    pub const A76: f64 = 2.324710524099774;
74
75    // 5th order weights (b_i)
76    pub const B1: f64 = 0.09646076681806523;
77    pub const B2: f64 = 0.01;
78    pub const B3: f64 = 0.4798896504144996;
79    pub const B4: f64 = 1.379008574103742;
80    pub const B5: f64 = -3.290069515436081;
81    pub const B6: f64 = 2.324710524099774;
82    pub const B7: f64 = 0.0;
83
84    // Error coefficients (E_i = b_i - b_hat_i) from Tsitouras 2011 paper, Table 1
85    // These are the differences between 5th and 4th order solutions
86    // Note: b_hat_7 = 1/66 in the paper, so E7 = b7 - b_hat_7 = 0 - 1/66 = -1/66
87    // The sum of all E coefficients must be 0 for proper error estimation
88    pub const E1: f64 = 0.001780011052226;
89    pub const E2: f64 = 0.000816434459657;
90    pub const E3: f64 = -0.007880878010262;
91    pub const E4: f64 = 0.144711007173263;
92    pub const E5: f64 = -0.582357165452555;
93    pub const E6: f64 = 0.458082105929187;
94    pub const E7: f64 = -1.0 / 66.0; // -0.015151515151515...
95}
96
97impl<S: Scalar> Solver<S> for Tsit5 {
98    fn solve<Sys: OdeSystem<S>>(
99        problem: &Sys,
100        t0: S,
101        tf: S,
102        y0: &[S],
103        options: &SolverOptions<S>,
104    ) -> Result<SolverResult<S>, SolverError> {
105        let dim = problem.dim();
106        if y0.len() != dim {
107            return Err(SolverError::DimensionMismatch {
108                expected: dim,
109                actual: y0.len(),
110            });
111        }
112
113        let mut t = t0;
114        let mut y = y0.to_vec();
115
116        let direction = if tf > t0 { S::ONE } else { -S::ONE };
117        if let Some(grid) = options.t_eval.as_deref() {
118            validate_grid(grid, t0, tf)?;
119        }
120        let mut grid_emitter = options
121            .t_eval
122            .as_deref()
123            .map(|g| TEvalEmitter::new(g, direction));
124        let (mut t_out, mut y_out) = if grid_emitter.is_some() {
125            (Vec::new(), Vec::new())
126        } else {
127            (vec![t0], y0.to_vec())
128        };
129
130        // Stage derivatives
131        let mut k1 = vec![S::ZERO; dim];
132        let mut k2 = vec![S::ZERO; dim];
133        let mut k3 = vec![S::ZERO; dim];
134        let mut k4 = vec![S::ZERO; dim];
135        let mut k5 = vec![S::ZERO; dim];
136        let mut k6 = vec![S::ZERO; dim];
137        let mut k7 = vec![S::ZERO; dim];
138        let mut y_stage = vec![S::ZERO; dim];
139        let mut y_new = vec![S::ZERO; dim];
140        let mut err = vec![S::ZERO; dim];
141
142        let mut stats = SolverStats::default();
143
144        // Initial step size
145        problem.rhs(t, &y, &mut k1);
146        stats.n_eval += 1;
147        let mut h = initial_step_size(&y, &k1, options, dim);
148        let h_min = options.h_min;
149        let h_max = options.h_max.min((tf - t0).abs());
150
151        let mut step_count = 0_usize;
152
153        while (tf - t) * direction > S::from_f64(1e-10) * (tf - t0).abs() {
154            if step_count >= options.max_steps {
155                return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
156            }
157
158            // Adjust final step
159            if (t + h - tf) * direction > S::ZERO {
160                h = tf - t;
161            }
162
163            // Clamp step size
164            h = h.abs().max(h_min) * direction;
165            if h.abs() > h_max {
166                h = h_max * direction;
167            }
168
169            // Compute stages
170            // k1 is already computed (FSAL)
171
172            // k2
173            for i in 0..dim {
174                y_stage[i] = y[i] + h * S::from_f64(tableau::A21) * k1[i];
175            }
176            problem.rhs(t + S::from_f64(tableau::C2) * h, &y_stage, &mut k2);
177
178            // k3
179            for i in 0..dim {
180                y_stage[i] = y[i]
181                    + h * (S::from_f64(tableau::A31) * k1[i] + S::from_f64(tableau::A32) * k2[i]);
182            }
183            problem.rhs(t + S::from_f64(tableau::C3) * h, &y_stage, &mut k3);
184
185            // k4
186            for i in 0..dim {
187                y_stage[i] = y[i]
188                    + h * (S::from_f64(tableau::A41) * k1[i]
189                        + S::from_f64(tableau::A42) * k2[i]
190                        + S::from_f64(tableau::A43) * k3[i]);
191            }
192            problem.rhs(t + S::from_f64(tableau::C4) * h, &y_stage, &mut k4);
193
194            // k5
195            for i in 0..dim {
196                y_stage[i] = y[i]
197                    + h * (S::from_f64(tableau::A51) * k1[i]
198                        + S::from_f64(tableau::A52) * k2[i]
199                        + S::from_f64(tableau::A53) * k3[i]
200                        + S::from_f64(tableau::A54) * k4[i]);
201            }
202            problem.rhs(t + S::from_f64(tableau::C5) * h, &y_stage, &mut k5);
203
204            // k6
205            for i in 0..dim {
206                y_stage[i] = y[i]
207                    + h * (S::from_f64(tableau::A61) * k1[i]
208                        + S::from_f64(tableau::A62) * k2[i]
209                        + S::from_f64(tableau::A63) * k3[i]
210                        + S::from_f64(tableau::A64) * k4[i]
211                        + S::from_f64(tableau::A65) * k5[i]);
212            }
213            problem.rhs(t + S::from_f64(tableau::C6) * h, &y_stage, &mut k6);
214
215            // k7 and y_new (5th order solution)
216            for i in 0..dim {
217                y_new[i] = y[i]
218                    + h * (S::from_f64(tableau::B1) * k1[i]
219                        + S::from_f64(tableau::B2) * k2[i]
220                        + S::from_f64(tableau::B3) * k3[i]
221                        + S::from_f64(tableau::B4) * k4[i]
222                        + S::from_f64(tableau::B5) * k5[i]
223                        + S::from_f64(tableau::B6) * k6[i]);
224            }
225            problem.rhs(t + h, &y_new, &mut k7);
226            stats.n_eval += 6;
227
228            // Error estimate
229            for i in 0..dim {
230                err[i] = h
231                    * (S::from_f64(tableau::E1) * k1[i]
232                        + S::from_f64(tableau::E2) * k2[i]
233                        + S::from_f64(tableau::E3) * k3[i]
234                        + S::from_f64(tableau::E4) * k4[i]
235                        + S::from_f64(tableau::E5) * k5[i]
236                        + S::from_f64(tableau::E6) * k6[i]
237                        + S::from_f64(tableau::E7) * k7[i]);
238            }
239
240            let err_norm = error_norm(&err, &y, &y_new, options, dim);
241
242            // Step size control
243            let safety = S::from_f64(0.9);
244            let fac_max = S::from_f64(5.0);
245            let fac_min = S::from_f64(0.2);
246
247            if err_norm <= S::ONE {
248                // Accept step
249                stats.n_accept += 1;
250
251                let t_new = t + h;
252                if let Some(ref mut emitter) = grid_emitter {
253                    emitter.emit_step(t, &y, &k1, t_new, &y_new, &k7, &mut t_out, &mut y_out);
254                } else {
255                    t_out.push(t_new);
256                    y_out.extend_from_slice(&y_new);
257                }
258
259                t = t_new;
260                y.copy_from_slice(&y_new);
261                k1.copy_from_slice(&k7); // FSAL
262
263                // New step size for 5th order method: exponent = -1/(p+1) = -1/6
264                let err_safe = err_norm.max(S::from_f64(1e-10));
265                let fac = safety * err_safe.powf(S::from_f64(-1.0 / 6.0));
266                let fac = fac.min(fac_max).max(fac_min);
267                h = h * fac;
268            } else {
269                // Reject step - use slightly more aggressive reduction
270                stats.n_reject += 1;
271
272                let err_safe = err_norm.max(S::from_f64(1e-10));
273                let fac = safety * err_safe.powf(S::from_f64(-1.0 / 5.0));
274                let fac = fac.max(fac_min);
275                h = h * fac;
276            }
277
278            if h.abs() < h_min {
279                return Err(SolverError::StepSizeTooSmall {
280                    t: t.to_f64(),
281                    h: h.to_f64(),
282                    h_min: h_min.to_f64(),
283                });
284            }
285
286            step_count += 1;
287        }
288
289        Ok(SolverResult::new(t_out, y_out, dim, stats))
290    }
291}
292
293fn initial_step_size<S: Scalar>(y0: &[S], f0: &[S], options: &SolverOptions<S>, dim: usize) -> S {
294    if let Some(h0) = options.h0 {
295        return h0;
296    }
297
298    let mut y_norm = S::ZERO;
299    let mut f_norm = S::ZERO;
300    for i in 0..dim {
301        let sc = options.atol + options.rtol * y0[i].abs();
302        y_norm = y_norm + (y0[i] / sc) * (y0[i] / sc);
303        f_norm = f_norm + (f0[i] / sc) * (f0[i] / sc);
304    }
305    y_norm = (y_norm / S::from_usize(dim)).sqrt();
306    f_norm = (f_norm / S::from_usize(dim)).sqrt();
307
308    if y_norm < S::from_f64(1e-5) || f_norm < S::from_f64(1e-5) {
309        S::from_f64(1e-6)
310    } else {
311        (S::from_f64(0.01) * y_norm / f_norm).min(options.h_max)
312    }
313}
314
315fn error_norm<S: Scalar>(
316    err: &[S],
317    y: &[S],
318    y_new: &[S],
319    options: &SolverOptions<S>,
320    dim: usize,
321) -> S {
322    let mut err_norm = S::ZERO;
323    for i in 0..dim {
324        let sc = options.atol + options.rtol * y[i].abs().max(y_new[i].abs());
325        let sc = sc.max(S::from_f64(1e-15));
326        let scaled_err = err[i] / sc;
327        err_norm = err_norm + scaled_err * scaled_err;
328    }
329    (err_norm / S::from_usize(dim)).sqrt()
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use crate::problem::OdeProblem;
336
337    #[test]
338    fn test_tsit5_exponential_decay() {
339        let problem = OdeProblem::new(
340            |_t, y: &[f64], dydt: &mut [f64]| {
341                dydt[0] = -y[0];
342            },
343            0.0,
344            5.0,
345            vec![1.0],
346        );
347        let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
348        let result = Tsit5::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
349
350        assert!(result.success);
351        let y_final = result.y_final().unwrap();
352        let expected = (-5.0_f64).exp();
353        assert!(
354            (y_final[0] - expected).abs() < 1e-5,
355            "Tsit5 exponential: got {}, expected {}",
356            y_final[0],
357            expected
358        );
359    }
360
361    #[test]
362    fn test_tsit5_harmonic_oscillator() {
363        // y'' + y = 0, or y1' = y2, y2' = -y1
364        let problem = OdeProblem::new(
365            |_t, y: &[f64], dydt: &mut [f64]| {
366                dydt[0] = y[1];
367                dydt[1] = -y[0];
368            },
369            0.0,
370            10.0,
371            vec![1.0, 0.0],
372        );
373        // Use moderate tolerances (1e-4 is reasonable for unit tests)
374        let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
375        let result = Tsit5::solve(&problem, 0.0, 10.0, &[1.0, 0.0], &options).unwrap();
376
377        assert!(result.success);
378        let y_final = result.y_final().unwrap();
379        // y1 = cos(t), y2 = -sin(t)
380        let expected_y1 = 10.0_f64.cos();
381        let expected_y2 = -10.0_f64.sin();
382        assert!(
383            (y_final[0] - expected_y1).abs() < 1e-3,
384            "Tsit5 harmonic y[0]: got {}, expected {}",
385            y_final[0],
386            expected_y1
387        );
388        assert!(
389            (y_final[1] - expected_y2).abs() < 1e-3,
390            "Tsit5 harmonic y[1]: got {}, expected {}",
391            y_final[1],
392            expected_y2
393        );
394    }
395
396    #[test]
397    fn test_tsit5_lorenz() {
398        let sigma = 10.0;
399        let rho = 28.0;
400        let beta = 8.0 / 3.0;
401
402        let problem = OdeProblem::new(
403            move |_t, y: &[f64], dydt: &mut [f64]| {
404                dydt[0] = sigma * (y[1] - y[0]);
405                dydt[1] = y[0] * (rho - y[2]) - y[1];
406                dydt[2] = y[0] * y[1] - beta * y[2];
407            },
408            0.0,
409            10.0,
410            vec![1.0, 1.0, 1.0],
411        );
412        // Use moderate tolerances for chaotic system
413        let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
414        let result = Tsit5::solve(&problem, 0.0, 10.0, &[1.0, 1.0, 1.0], &options);
415
416        assert!(result.is_ok());
417    }
418
419    #[test]
420    fn test_tsit5_efficiency() {
421        // Compare function evaluations with DoPri5 on same problem
422        let problem = OdeProblem::new(
423            |_t, y: &[f64], dydt: &mut [f64]| {
424                dydt[0] = -y[0];
425            },
426            0.0,
427            5.0,
428            vec![1.0],
429        );
430        // Use looser tolerances for efficiency test
431        let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
432        let result = Tsit5::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
433
434        // Tsit5 should be reasonably efficient - allow more evaluations
435        // since we're using simpler step control than DoPri5's PI controller
436        assert!(
437            result.stats.n_eval < 500,
438            "Tsit5 used {} evaluations, expected < 500",
439            result.stats.n_eval
440        );
441    }
442}