1use crate::error::SolverError;
20use crate::problem::OdeSystem;
21use crate::solver::{Solver, SolverOptions, SolverResult, SolverStats};
22use crate::t_eval::{validate_grid, TEvalEmitter};
23use numra_core::Scalar;
24
25#[derive(Clone, Debug, Default)]
27pub struct Tsit5;
28
29impl Tsit5 {
30 pub fn new() -> Self {
32 Self
33 }
34}
35
36#[allow(dead_code)]
38mod tableau {
39 pub const C2: f64 = 0.161;
41 pub const C3: f64 = 0.327;
42 pub const C4: f64 = 0.9;
43 pub const C5: f64 = 0.9800255409045097;
44 pub const C6: f64 = 1.0;
45 pub const C7: f64 = 1.0;
46
47 pub const A21: f64 = 0.161;
49
50 pub const A31: f64 = -0.008480655492356989;
51 pub const A32: f64 = 0.335480655492357;
52
53 pub const A41: f64 = 2.8971530571054935;
54 pub const A42: f64 = -6.359448489975075;
55 pub const A43: f64 = 4.3622954328695815;
56
57 pub const A51: f64 = 5.325864828439257;
58 pub const A52: f64 = -11.748883564062828;
59 pub const A53: f64 = 7.4955393428898365;
60 pub const A54: f64 = -0.09249506636175525;
61
62 pub const A61: f64 = 5.86145544294642;
63 pub const A62: f64 = -12.92096931784711;
64 pub const A63: f64 = 8.159367898576159;
65 pub const A64: f64 = -0.071584973281401;
66 pub const A65: f64 = -0.028269050394068383;
67
68 pub const A71: f64 = 0.09646076681806523;
69 pub const A72: f64 = 0.01;
70 pub const A73: f64 = 0.4798896504144996;
71 pub const A74: f64 = 1.379008574103742;
72 pub const A75: f64 = -3.290069515436081;
73 pub const A76: f64 = 2.324710524099774;
74
75 pub const B1: f64 = 0.09646076681806523;
77 pub const B2: f64 = 0.01;
78 pub const B3: f64 = 0.4798896504144996;
79 pub const B4: f64 = 1.379008574103742;
80 pub const B5: f64 = -3.290069515436081;
81 pub const B6: f64 = 2.324710524099774;
82 pub const B7: f64 = 0.0;
83
84 pub const E1: f64 = 0.001780011052226;
89 pub const E2: f64 = 0.000816434459657;
90 pub const E3: f64 = -0.007880878010262;
91 pub const E4: f64 = 0.144711007173263;
92 pub const E5: f64 = -0.582357165452555;
93 pub const E6: f64 = 0.458082105929187;
94 pub const E7: f64 = -1.0 / 66.0; }
96
97impl<S: Scalar> Solver<S> for Tsit5 {
98 fn solve<Sys: OdeSystem<S>>(
99 problem: &Sys,
100 t0: S,
101 tf: S,
102 y0: &[S],
103 options: &SolverOptions<S>,
104 ) -> Result<SolverResult<S>, SolverError> {
105 let dim = problem.dim();
106 if y0.len() != dim {
107 return Err(SolverError::DimensionMismatch {
108 expected: dim,
109 actual: y0.len(),
110 });
111 }
112
113 let mut t = t0;
114 let mut y = y0.to_vec();
115
116 let direction = if tf > t0 { S::ONE } else { -S::ONE };
117 if let Some(grid) = options.t_eval.as_deref() {
118 validate_grid(grid, t0, tf)?;
119 }
120 let mut grid_emitter = options
121 .t_eval
122 .as_deref()
123 .map(|g| TEvalEmitter::new(g, direction));
124 let (mut t_out, mut y_out) = if grid_emitter.is_some() {
125 (Vec::new(), Vec::new())
126 } else {
127 (vec![t0], y0.to_vec())
128 };
129
130 let mut k1 = vec![S::ZERO; dim];
132 let mut k2 = vec![S::ZERO; dim];
133 let mut k3 = vec![S::ZERO; dim];
134 let mut k4 = vec![S::ZERO; dim];
135 let mut k5 = vec![S::ZERO; dim];
136 let mut k6 = vec![S::ZERO; dim];
137 let mut k7 = vec![S::ZERO; dim];
138 let mut y_stage = vec![S::ZERO; dim];
139 let mut y_new = vec![S::ZERO; dim];
140 let mut err = vec![S::ZERO; dim];
141
142 let mut stats = SolverStats::default();
143
144 problem.rhs(t, &y, &mut k1);
146 stats.n_eval += 1;
147 let mut h = initial_step_size(&y, &k1, options, dim);
148 let h_min = options.h_min;
149 let h_max = options.h_max.min((tf - t0).abs());
150
151 let mut step_count = 0_usize;
152
153 while (tf - t) * direction > S::from_f64(1e-10) * (tf - t0).abs() {
154 if step_count >= options.max_steps {
155 return Err(SolverError::MaxIterationsExceeded { t: t.to_f64() });
156 }
157
158 if (t + h - tf) * direction > S::ZERO {
160 h = tf - t;
161 }
162
163 h = h.abs().max(h_min) * direction;
165 if h.abs() > h_max {
166 h = h_max * direction;
167 }
168
169 for i in 0..dim {
174 y_stage[i] = y[i] + h * S::from_f64(tableau::A21) * k1[i];
175 }
176 problem.rhs(t + S::from_f64(tableau::C2) * h, &y_stage, &mut k2);
177
178 for i in 0..dim {
180 y_stage[i] = y[i]
181 + h * (S::from_f64(tableau::A31) * k1[i] + S::from_f64(tableau::A32) * k2[i]);
182 }
183 problem.rhs(t + S::from_f64(tableau::C3) * h, &y_stage, &mut k3);
184
185 for i in 0..dim {
187 y_stage[i] = y[i]
188 + h * (S::from_f64(tableau::A41) * k1[i]
189 + S::from_f64(tableau::A42) * k2[i]
190 + S::from_f64(tableau::A43) * k3[i]);
191 }
192 problem.rhs(t + S::from_f64(tableau::C4) * h, &y_stage, &mut k4);
193
194 for i in 0..dim {
196 y_stage[i] = y[i]
197 + h * (S::from_f64(tableau::A51) * k1[i]
198 + S::from_f64(tableau::A52) * k2[i]
199 + S::from_f64(tableau::A53) * k3[i]
200 + S::from_f64(tableau::A54) * k4[i]);
201 }
202 problem.rhs(t + S::from_f64(tableau::C5) * h, &y_stage, &mut k5);
203
204 for i in 0..dim {
206 y_stage[i] = y[i]
207 + h * (S::from_f64(tableau::A61) * k1[i]
208 + S::from_f64(tableau::A62) * k2[i]
209 + S::from_f64(tableau::A63) * k3[i]
210 + S::from_f64(tableau::A64) * k4[i]
211 + S::from_f64(tableau::A65) * k5[i]);
212 }
213 problem.rhs(t + S::from_f64(tableau::C6) * h, &y_stage, &mut k6);
214
215 for i in 0..dim {
217 y_new[i] = y[i]
218 + h * (S::from_f64(tableau::B1) * k1[i]
219 + S::from_f64(tableau::B2) * k2[i]
220 + S::from_f64(tableau::B3) * k3[i]
221 + S::from_f64(tableau::B4) * k4[i]
222 + S::from_f64(tableau::B5) * k5[i]
223 + S::from_f64(tableau::B6) * k6[i]);
224 }
225 problem.rhs(t + h, &y_new, &mut k7);
226 stats.n_eval += 6;
227
228 for i in 0..dim {
230 err[i] = h
231 * (S::from_f64(tableau::E1) * k1[i]
232 + S::from_f64(tableau::E2) * k2[i]
233 + S::from_f64(tableau::E3) * k3[i]
234 + S::from_f64(tableau::E4) * k4[i]
235 + S::from_f64(tableau::E5) * k5[i]
236 + S::from_f64(tableau::E6) * k6[i]
237 + S::from_f64(tableau::E7) * k7[i]);
238 }
239
240 let err_norm = error_norm(&err, &y, &y_new, options, dim);
241
242 let safety = S::from_f64(0.9);
244 let fac_max = S::from_f64(5.0);
245 let fac_min = S::from_f64(0.2);
246
247 if err_norm <= S::ONE {
248 stats.n_accept += 1;
250
251 let t_new = t + h;
252 if let Some(ref mut emitter) = grid_emitter {
253 emitter.emit_step(t, &y, &k1, t_new, &y_new, &k7, &mut t_out, &mut y_out);
254 } else {
255 t_out.push(t_new);
256 y_out.extend_from_slice(&y_new);
257 }
258
259 t = t_new;
260 y.copy_from_slice(&y_new);
261 k1.copy_from_slice(&k7); let err_safe = err_norm.max(S::from_f64(1e-10));
265 let fac = safety * err_safe.powf(S::from_f64(-1.0 / 6.0));
266 let fac = fac.min(fac_max).max(fac_min);
267 h = h * fac;
268 } else {
269 stats.n_reject += 1;
271
272 let err_safe = err_norm.max(S::from_f64(1e-10));
273 let fac = safety * err_safe.powf(S::from_f64(-1.0 / 5.0));
274 let fac = fac.max(fac_min);
275 h = h * fac;
276 }
277
278 if h.abs() < h_min {
279 return Err(SolverError::StepSizeTooSmall {
280 t: t.to_f64(),
281 h: h.to_f64(),
282 h_min: h_min.to_f64(),
283 });
284 }
285
286 step_count += 1;
287 }
288
289 Ok(SolverResult::new(t_out, y_out, dim, stats))
290 }
291}
292
293fn initial_step_size<S: Scalar>(y0: &[S], f0: &[S], options: &SolverOptions<S>, dim: usize) -> S {
294 if let Some(h0) = options.h0 {
295 return h0;
296 }
297
298 let mut y_norm = S::ZERO;
299 let mut f_norm = S::ZERO;
300 for i in 0..dim {
301 let sc = options.atol + options.rtol * y0[i].abs();
302 y_norm = y_norm + (y0[i] / sc) * (y0[i] / sc);
303 f_norm = f_norm + (f0[i] / sc) * (f0[i] / sc);
304 }
305 y_norm = (y_norm / S::from_usize(dim)).sqrt();
306 f_norm = (f_norm / S::from_usize(dim)).sqrt();
307
308 if y_norm < S::from_f64(1e-5) || f_norm < S::from_f64(1e-5) {
309 S::from_f64(1e-6)
310 } else {
311 (S::from_f64(0.01) * y_norm / f_norm).min(options.h_max)
312 }
313}
314
315fn error_norm<S: Scalar>(
316 err: &[S],
317 y: &[S],
318 y_new: &[S],
319 options: &SolverOptions<S>,
320 dim: usize,
321) -> S {
322 let mut err_norm = S::ZERO;
323 for i in 0..dim {
324 let sc = options.atol + options.rtol * y[i].abs().max(y_new[i].abs());
325 let sc = sc.max(S::from_f64(1e-15));
326 let scaled_err = err[i] / sc;
327 err_norm = err_norm + scaled_err * scaled_err;
328 }
329 (err_norm / S::from_usize(dim)).sqrt()
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use crate::problem::OdeProblem;
336
337 #[test]
338 fn test_tsit5_exponential_decay() {
339 let problem = OdeProblem::new(
340 |_t, y: &[f64], dydt: &mut [f64]| {
341 dydt[0] = -y[0];
342 },
343 0.0,
344 5.0,
345 vec![1.0],
346 );
347 let options = SolverOptions::default().rtol(1e-6).atol(1e-8);
348 let result = Tsit5::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
349
350 assert!(result.success);
351 let y_final = result.y_final().unwrap();
352 let expected = (-5.0_f64).exp();
353 assert!(
354 (y_final[0] - expected).abs() < 1e-5,
355 "Tsit5 exponential: got {}, expected {}",
356 y_final[0],
357 expected
358 );
359 }
360
361 #[test]
362 fn test_tsit5_harmonic_oscillator() {
363 let problem = OdeProblem::new(
365 |_t, y: &[f64], dydt: &mut [f64]| {
366 dydt[0] = y[1];
367 dydt[1] = -y[0];
368 },
369 0.0,
370 10.0,
371 vec![1.0, 0.0],
372 );
373 let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
375 let result = Tsit5::solve(&problem, 0.0, 10.0, &[1.0, 0.0], &options).unwrap();
376
377 assert!(result.success);
378 let y_final = result.y_final().unwrap();
379 let expected_y1 = 10.0_f64.cos();
381 let expected_y2 = -10.0_f64.sin();
382 assert!(
383 (y_final[0] - expected_y1).abs() < 1e-3,
384 "Tsit5 harmonic y[0]: got {}, expected {}",
385 y_final[0],
386 expected_y1
387 );
388 assert!(
389 (y_final[1] - expected_y2).abs() < 1e-3,
390 "Tsit5 harmonic y[1]: got {}, expected {}",
391 y_final[1],
392 expected_y2
393 );
394 }
395
396 #[test]
397 fn test_tsit5_lorenz() {
398 let sigma = 10.0;
399 let rho = 28.0;
400 let beta = 8.0 / 3.0;
401
402 let problem = OdeProblem::new(
403 move |_t, y: &[f64], dydt: &mut [f64]| {
404 dydt[0] = sigma * (y[1] - y[0]);
405 dydt[1] = y[0] * (rho - y[2]) - y[1];
406 dydt[2] = y[0] * y[1] - beta * y[2];
407 },
408 0.0,
409 10.0,
410 vec![1.0, 1.0, 1.0],
411 );
412 let options = SolverOptions::default().rtol(1e-4).atol(1e-6);
414 let result = Tsit5::solve(&problem, 0.0, 10.0, &[1.0, 1.0, 1.0], &options);
415
416 assert!(result.is_ok());
417 }
418
419 #[test]
420 fn test_tsit5_efficiency() {
421 let problem = OdeProblem::new(
423 |_t, y: &[f64], dydt: &mut [f64]| {
424 dydt[0] = -y[0];
425 },
426 0.0,
427 5.0,
428 vec![1.0],
429 );
430 let options = SolverOptions::default().rtol(1e-3).atol(1e-5);
432 let result = Tsit5::solve(&problem, 0.0, 5.0, &[1.0], &options).unwrap();
433
434 assert!(
437 result.stats.n_eval < 500,
438 "Tsit5 used {} evaluations, expected < 500",
439 result.stats.n_eval
440 );
441 }
442}