Skip to main content

numra_dde/
method_of_steps.rs

1//! Method of Steps DDE solver.
2//!
3//! Solves DDEs by treating them as ODEs with interpolated history.
4//!
5//! Author: Moussa Leblouba
6//! Date: 4 February 2026
7//! Modified: 2 May 2026
8
9use crate::history::{History, HistoryStep};
10use crate::system::{DdeOptions, DdeResult, DdeSolver, DdeStats, DdeSystem};
11use numra_core::Scalar;
12
13/// Method of Steps DDE solver.
14///
15/// Uses an embedded Runge-Kutta method (similar to DoPri5) with
16/// history interpolation for the delayed terms.
17pub struct MethodOfSteps;
18
19/// Propagate discontinuities from the initial time.
20///
21/// Discontinuities in DDEs propagate: if there's a discontinuity at t_d,
22/// then derivative discontinuity occurs at t_d + tau for each delay tau.
23/// This function computes all discontinuity points up to a given order.
24/// Maximum number of discontinuity points to track.
25/// Prevents combinatorial explosion when there are many delays and high order.
26const MAX_DISCONTINUITIES: usize = 1000;
27
28fn propagate_discontinuities<S: Scalar>(t0: S, delays: &[S], tf: S, order: usize) -> Vec<S> {
29    let mut discs = vec![t0]; // Initial discontinuity at t0
30
31    for _ in 0..order {
32        let mut new_discs = Vec::new();
33        for &d in &discs {
34            for &tau in delays {
35                let t_new = d + tau;
36                if t_new <= tf && t_new > t0 {
37                    new_discs.push(t_new);
38                }
39            }
40        }
41        discs.extend(new_discs);
42
43        // Cap total discontinuities to prevent combinatorial explosion
44        if discs.len() > MAX_DISCONTINUITIES {
45            break;
46        }
47    }
48
49    // Sort and remove duplicates
50    discs.sort_by(|a, b| a.partial_cmp(b).unwrap());
51    discs.dedup_by(|a, b| (*a - *b).abs() < S::from_f64(1e-14));
52
53    // Filter to only return points > t0 and <= tf, capped
54    let mut result: Vec<S> = discs.into_iter().filter(|&d| d > t0 && d <= tf).collect();
55    result.truncate(MAX_DISCONTINUITIES);
56    result
57}
58
59impl<S: Scalar> DdeSolver<S> for MethodOfSteps {
60    fn solve<Sys, H>(
61        system: &Sys,
62        t0: S,
63        tf: S,
64        history: &H,
65        options: &DdeOptions<S>,
66    ) -> Result<DdeResult<S>, String>
67    where
68        Sys: DdeSystem<S>,
69        H: Fn(S) -> Vec<S>,
70    {
71        let dim = system.dim();
72        let n_delays = system.n_delays();
73
74        // Initialize history
75        let mut hist = History::new(history, t0, dim);
76
77        // Get initial state from history at t0
78        let y0 = history(t0);
79        if y0.len() != dim {
80            return Err(format!(
81                "History function returned dimension {} but system has dimension {}",
82                y0.len(),
83                dim
84            ));
85        }
86
87        // Initialize state
88        let mut t = t0;
89        let mut y = y0.clone();
90
91        // Initial step size
92        let mut h = options.h0.unwrap_or_else(|| {
93            let span = tf - t0;
94            (span * S::from_f64(0.001)).min(S::from_f64(0.1))
95        });
96
97        // Working arrays
98        let mut f = vec![S::ZERO; dim];
99        let mut y_delayed: Vec<Vec<S>> = vec![vec![S::ZERO; dim]; n_delays];
100        let has_state_dependent_delays = system.has_state_dependent_delays();
101        let constant_delays = system.delays();
102
103        // Helper to get delays at a given state, with positivity validation
104        let get_delays = |sys: &Sys, t: S, y: &[S]| -> Result<Vec<S>, String> {
105            let delays = if has_state_dependent_delays {
106                sys.delays_at(t, y)
107            } else {
108                constant_delays.clone()
109            };
110            for (i, &tau) in delays.iter().enumerate() {
111                if tau < S::ZERO {
112                    return Err(format!(
113                        "Delay {} is negative ({}) at t = {}. Delays must be non-negative.",
114                        i,
115                        tau.to_f64(),
116                        t.to_f64()
117                    ));
118                }
119            }
120            Ok(delays)
121        };
122
123        // Evaluate initial derivative
124        let delays = get_delays(system, t, &y)?;
125        for (i, &tau) in delays.iter().enumerate() {
126            y_delayed[i] = hist.evaluate(t - tau);
127        }
128        let y_delayed_refs: Vec<&[S]> = y_delayed.iter().map(|v| v.as_slice()).collect();
129        system.rhs(t, &y, &y_delayed_refs, &mut f);
130
131        // Output storage
132        let mut t_out = vec![t];
133        let mut y_out = y.clone();
134        let mut stats = DdeStats::default();
135        stats.n_eval += 1;
136
137        // Dormand-Prince 5(4) coefficients
138        let a21 = S::from_f64(1.0 / 5.0);
139        let a31 = S::from_f64(3.0 / 40.0);
140        let a32 = S::from_f64(9.0 / 40.0);
141        let a41 = S::from_f64(44.0 / 45.0);
142        let a42 = S::from_f64(-56.0 / 15.0);
143        let a43 = S::from_f64(32.0 / 9.0);
144        let a51 = S::from_f64(19372.0 / 6561.0);
145        let a52 = S::from_f64(-25360.0 / 2187.0);
146        let a53 = S::from_f64(64448.0 / 6561.0);
147        let a54 = S::from_f64(-212.0 / 729.0);
148        let a61 = S::from_f64(9017.0 / 3168.0);
149        let a62 = S::from_f64(-355.0 / 33.0);
150        let a63 = S::from_f64(46732.0 / 5247.0);
151        let a64 = S::from_f64(49.0 / 176.0);
152        let a65 = S::from_f64(-5103.0 / 18656.0);
153        let a71 = S::from_f64(35.0 / 384.0);
154        let a73 = S::from_f64(500.0 / 1113.0);
155        let a74 = S::from_f64(125.0 / 192.0);
156        let a75 = S::from_f64(-2187.0 / 6784.0);
157        let a76 = S::from_f64(11.0 / 84.0);
158
159        let c2 = S::from_f64(1.0 / 5.0);
160        let c3 = S::from_f64(3.0 / 10.0);
161        let c4 = S::from_f64(4.0 / 5.0);
162        let c5 = S::from_f64(8.0 / 9.0);
163        let c6 = S::ONE;
164        let c7 = S::ONE;
165
166        // Error estimation coefficients
167        let e1 = S::from_f64(71.0 / 57600.0);
168        let e3 = S::from_f64(-71.0 / 16695.0);
169        let e4 = S::from_f64(71.0 / 1920.0);
170        let e5 = S::from_f64(-17253.0 / 339200.0);
171        let e6 = S::from_f64(22.0 / 525.0);
172        let e7 = S::from_f64(-1.0 / 40.0);
173
174        // Stage values
175        let mut k1 = f.clone();
176        let mut k2 = vec![S::ZERO; dim];
177        let mut k3 = vec![S::ZERO; dim];
178        let mut k4 = vec![S::ZERO; dim];
179        let mut k5 = vec![S::ZERO; dim];
180        let mut k6 = vec![S::ZERO; dim];
181        let mut k7 = vec![S::ZERO; dim];
182        let mut y_stage = vec![S::ZERO; dim];
183        let mut y_new = vec![S::ZERO; dim];
184        let mut y_err = vec![S::ZERO; dim];
185
186        // Step size control parameters
187        let safety = S::from_f64(0.9);
188        let fac_min = S::from_f64(0.2);
189        let fac_max = S::from_f64(10.0);
190        let order = S::from_f64(5.0);
191
192        // Compute discontinuity points if tracking is enabled
193        let discontinuities = if options.track_discontinuities && options.discontinuity_order > 0 {
194            propagate_discontinuities(t0, &delays, tf, options.discontinuity_order)
195        } else {
196            Vec::new()
197        };
198        let mut disc_idx = 0; // Index into discontinuities vector
199
200        let mut step = 0;
201        while t < tf && step < options.max_steps {
202            // Limit step size
203            h = h.min(tf - t).min(options.h_max).max(options.h_min);
204
205            // Check if we would cross a discontinuity
206            if options.track_discontinuities && disc_idx < discontinuities.len() {
207                let next_disc = discontinuities[disc_idx];
208                let t_end = t + h;
209                // If we would cross or overshoot the discontinuity, reduce step to land exactly on it
210                if next_disc > t && next_disc <= t_end {
211                    let h_to_disc = next_disc - t;
212                    if h_to_disc < options.h_min {
213                        // h_min is too large and would skip this discontinuity
214                        return Err(format!(
215                            "h_min ({}) is larger than distance to discontinuity at t = {} (distance = {}). \
216                             Reduce h_min or disable discontinuity tracking.",
217                            options.h_min.to_f64(), next_disc.to_f64(), h_to_disc.to_f64()
218                        ));
219                    }
220                    h = h_to_disc;
221                }
222            }
223
224            // k1 is already computed (FSAL)
225
226            // k2
227            let t2 = t + c2 * h;
228            for j in 0..dim {
229                y_stage[j] = y[j] + h * a21 * k1[j];
230            }
231            let delays_k2 = get_delays(system, t2, &y_stage)?;
232            for (i, &tau) in delays_k2.iter().enumerate() {
233                y_delayed[i] = hist.evaluate(t2 - tau);
234            }
235            let y_delayed_refs: Vec<&[S]> = y_delayed.iter().map(|v| v.as_slice()).collect();
236            system.rhs(t2, &y_stage, &y_delayed_refs, &mut k2);
237
238            // k3
239            let t3 = t + c3 * h;
240            for j in 0..dim {
241                y_stage[j] = y[j] + h * (a31 * k1[j] + a32 * k2[j]);
242            }
243            let delays_k3 = get_delays(system, t3, &y_stage)?;
244            for (i, &tau) in delays_k3.iter().enumerate() {
245                y_delayed[i] = hist.evaluate(t3 - tau);
246            }
247            let y_delayed_refs: Vec<&[S]> = y_delayed.iter().map(|v| v.as_slice()).collect();
248            system.rhs(t3, &y_stage, &y_delayed_refs, &mut k3);
249
250            // k4
251            let t4 = t + c4 * h;
252            for j in 0..dim {
253                y_stage[j] = y[j] + h * (a41 * k1[j] + a42 * k2[j] + a43 * k3[j]);
254            }
255            let delays_k4 = get_delays(system, t4, &y_stage)?;
256            for (i, &tau) in delays_k4.iter().enumerate() {
257                y_delayed[i] = hist.evaluate(t4 - tau);
258            }
259            let y_delayed_refs: Vec<&[S]> = y_delayed.iter().map(|v| v.as_slice()).collect();
260            system.rhs(t4, &y_stage, &y_delayed_refs, &mut k4);
261
262            // k5
263            let t5 = t + c5 * h;
264            for j in 0..dim {
265                y_stage[j] = y[j] + h * (a51 * k1[j] + a52 * k2[j] + a53 * k3[j] + a54 * k4[j]);
266            }
267            let delays_k5 = get_delays(system, t5, &y_stage)?;
268            for (i, &tau) in delays_k5.iter().enumerate() {
269                y_delayed[i] = hist.evaluate(t5 - tau);
270            }
271            let y_delayed_refs: Vec<&[S]> = y_delayed.iter().map(|v| v.as_slice()).collect();
272            system.rhs(t5, &y_stage, &y_delayed_refs, &mut k5);
273
274            // k6
275            let t6 = t + c6 * h;
276            for j in 0..dim {
277                y_stage[j] = y[j]
278                    + h * (a61 * k1[j] + a62 * k2[j] + a63 * k3[j] + a64 * k4[j] + a65 * k5[j]);
279            }
280            let delays_k6 = get_delays(system, t6, &y_stage)?;
281            for (i, &tau) in delays_k6.iter().enumerate() {
282                y_delayed[i] = hist.evaluate(t6 - tau);
283            }
284            let y_delayed_refs: Vec<&[S]> = y_delayed.iter().map(|v| v.as_slice()).collect();
285            system.rhs(t6, &y_stage, &y_delayed_refs, &mut k6);
286
287            // Compute 5th order solution
288            for j in 0..dim {
289                y_new[j] = y[j]
290                    + h * (a71 * k1[j] + a73 * k3[j] + a74 * k4[j] + a75 * k5[j] + a76 * k6[j]);
291            }
292
293            // k7 (at new point, for FSAL and error)
294            let t7 = t + c7 * h;
295            let delays_k7 = get_delays(system, t7, &y_new)?;
296            for (i, &tau) in delays_k7.iter().enumerate() {
297                y_delayed[i] = hist.evaluate(t7 - tau);
298            }
299            let y_delayed_refs: Vec<&[S]> = y_delayed.iter().map(|v| v.as_slice()).collect();
300            system.rhs(t7, &y_new, &y_delayed_refs, &mut k7);
301
302            stats.n_eval += 6;
303
304            // Error estimate
305            for j in 0..dim {
306                y_err[j] = h
307                    * (e1 * k1[j] + e3 * k3[j] + e4 * k4[j] + e5 * k5[j] + e6 * k6[j] + e7 * k7[j]);
308            }
309
310            // Compute error norm
311            let mut err_sq = S::ZERO;
312            for j in 0..dim {
313                let scale = options.atol + options.rtol * y[j].abs().max(y_new[j].abs());
314                let ratio = y_err[j] / scale;
315                err_sq = err_sq + ratio * ratio;
316            }
317            let err = (err_sq / S::from_usize(dim)).sqrt();
318
319            // Accept or reject
320            if err <= S::ONE {
321                // Accept step
322                let t_new = t + h;
323
324                // Add step to history
325                hist.add_step(HistoryStep::new(
326                    t,
327                    y.clone(),
328                    k1.clone(),
329                    t_new,
330                    y_new.clone(),
331                    k7.clone(),
332                ));
333
334                t = t_new;
335                y.clone_from(&y_new);
336                k1.clone_from(&k7); // FSAL
337
338                // Save output
339                t_out.push(t);
340                y_out.extend_from_slice(&y);
341                stats.n_accept += 1;
342                step += 1;
343
344                // Check if we hit a discontinuity point
345                if options.track_discontinuities && disc_idx < discontinuities.len() {
346                    let next_disc = discontinuities[disc_idx];
347                    if (t - next_disc).abs() < S::from_f64(1e-10) {
348                        // We've hit this discontinuity, move to the next one
349                        stats.n_discontinuities += 1;
350                        disc_idx += 1;
351                    }
352                }
353            } else {
354                stats.n_reject += 1;
355            }
356
357            // Compute new step size
358            let err_safe = err.max(S::from_f64(1e-10));
359            let fac = safety * err_safe.powf(-S::ONE / (order + S::ONE));
360            h = h * fac.max(fac_min).min(fac_max);
361        }
362
363        if step >= options.max_steps && t < tf {
364            return Err(format!(
365                "Maximum steps ({}) exceeded at t = {}",
366                options.max_steps,
367                t.to_f64()
368            ));
369        }
370
371        Ok(DdeResult::new(t_out, y_out, dim, stats))
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use crate::DdeSystem;
379
380    /// Simple test DDE: y'(t) = -y(t-1), y(t) = 1 for t <= 0
381    struct SimpleDelay;
382
383    impl DdeSystem<f64> for SimpleDelay {
384        fn dim(&self) -> usize {
385            1
386        }
387        fn delays(&self) -> Vec<f64> {
388            vec![1.0]
389        }
390        fn rhs(&self, _t: f64, _y: &[f64], y_delayed: &[&[f64]], dydt: &mut [f64]) {
391            dydt[0] = -y_delayed[0][0];
392        }
393    }
394
395    #[test]
396    fn test_simple_delay() {
397        let system = SimpleDelay;
398        let history = |_t: f64| vec![1.0];
399        let options = DdeOptions::default().rtol(1e-6).atol(1e-9);
400
401        let result =
402            MethodOfSteps::solve(&system, 0.0, 5.0, &history, &options).expect("Solve failed");
403
404        assert!(result.success);
405        assert!(!result.t.is_empty());
406
407        // The solution should decay (oscillatory decay for this equation)
408        let y_final = result.y_final().unwrap()[0];
409        assert!(y_final.abs() < 2.0); // Should be bounded
410    }
411
412    /// Mackey-Glass equation: y'(t) = β * y(t-τ) / (1 + y(t-τ)^n) - γ * y(t)
413    struct MackeyGlass {
414        beta: f64,
415        gamma: f64,
416        n: f64,
417        tau: f64,
418    }
419
420    impl DdeSystem<f64> for MackeyGlass {
421        fn dim(&self) -> usize {
422            1
423        }
424        fn delays(&self) -> Vec<f64> {
425            vec![self.tau]
426        }
427        fn rhs(&self, _t: f64, y: &[f64], y_delayed: &[&[f64]], dydt: &mut [f64]) {
428            let y_tau = y_delayed[0][0];
429            dydt[0] = self.beta * y_tau / (1.0 + y_tau.powf(self.n)) - self.gamma * y[0];
430        }
431    }
432
433    #[test]
434    fn test_mackey_glass() {
435        let system = MackeyGlass {
436            beta: 2.0,
437            gamma: 1.0,
438            n: 9.65,
439            tau: 2.0,
440        };
441        let history = |_t: f64| vec![0.5];
442        let options = DdeOptions::default().rtol(1e-4).atol(1e-6).max_steps(50000);
443
444        let result =
445            MethodOfSteps::solve(&system, 0.0, 50.0, &history, &options).expect("Solve failed");
446
447        assert!(result.success);
448        // Mackey-Glass with these parameters is chaotic
449        // Just verify the solution stays bounded and positive
450        for &y in result.y.iter() {
451            assert!(y > 0.0, "Solution should stay positive");
452            assert!(y < 3.0, "Solution should stay bounded");
453        }
454    }
455
456    #[test]
457    fn test_two_delays() {
458        struct TwoDelays;
459        impl DdeSystem<f64> for TwoDelays {
460            fn dim(&self) -> usize {
461                1
462            }
463            fn delays(&self) -> Vec<f64> {
464                vec![0.5, 1.0]
465            }
466            fn rhs(&self, _t: f64, y: &[f64], y_delayed: &[&[f64]], dydt: &mut [f64]) {
467                dydt[0] = -y[0] + 0.5 * y_delayed[0][0] + 0.3 * y_delayed[1][0];
468            }
469        }
470
471        let system = TwoDelays;
472        let history = |_t: f64| vec![1.0];
473        let options = DdeOptions::default();
474
475        let result =
476            MethodOfSteps::solve(&system, 0.0, 10.0, &history, &options).expect("Solve failed");
477
478        assert!(result.success);
479        assert!(result.stats.n_accept > 0);
480    }
481
482    #[test]
483    fn test_2d_system() {
484        struct TwoD;
485        impl DdeSystem<f64> for TwoD {
486            fn dim(&self) -> usize {
487                2
488            }
489            fn delays(&self) -> Vec<f64> {
490                vec![1.0]
491            }
492            fn rhs(&self, _t: f64, y: &[f64], y_delayed: &[&[f64]], dydt: &mut [f64]) {
493                dydt[0] = -y[0] + y_delayed[0][1];
494                dydt[1] = -y[1] + y_delayed[0][0];
495            }
496        }
497
498        let system = TwoD;
499        let history = |_t: f64| vec![1.0, 0.5];
500        let options = DdeOptions::default();
501
502        let result =
503            MethodOfSteps::solve(&system, 0.0, 5.0, &history, &options).expect("Solve failed");
504
505        assert!(result.success);
506        assert_eq!(result.dim, 2);
507    }
508
509    #[test]
510    fn test_dde_discontinuity_tracking() {
511        // DDE with discontinuous history
512        // y'(t) = -y(t-1) for t > 0
513        // y(t) = 1 for t <= 0 (constant history)
514        //
515        // The solution has derivative discontinuity at t=1, t=2, ...
516        // Solver should hit these points exactly
517
518        struct SimpleDelaySystem;
519
520        impl DdeSystem<f64> for SimpleDelaySystem {
521            fn dim(&self) -> usize {
522                1
523            }
524            fn delays(&self) -> Vec<f64> {
525                vec![1.0]
526            }
527            fn rhs(&self, _t: f64, _y: &[f64], y_delayed: &[&[f64]], dydt: &mut [f64]) {
528                dydt[0] = -y_delayed[0][0];
529            }
530        }
531
532        let system = SimpleDelaySystem;
533        let history = |_t: f64| vec![1.0];
534        let options = DdeOptions::default()
535            .track_discontinuities(true)
536            .discontinuity_order(3);
537
538        let result =
539            MethodOfSteps::solve(&system, 0.0, 3.5, &history, &options).expect("Solve failed");
540
541        // Check that t=1.0, t=2.0, t=3.0 are in the output times
542        let has_t1 = result.t.iter().any(|&t| (t - 1.0).abs() < 1e-10);
543        let has_t2 = result.t.iter().any(|&t| (t - 2.0).abs() < 1e-10);
544        let has_t3 = result.t.iter().any(|&t| (t - 3.0).abs() < 1e-10);
545
546        assert!(has_t1, "Discontinuity at t=1 not tracked");
547        assert!(has_t2, "Discontinuity at t=2 not tracked");
548        assert!(has_t3, "Discontinuity at t=3 not tracked");
549
550        // Verify that discontinuity count is correct
551        assert_eq!(
552            result.stats.n_discontinuities, 3,
553            "Should have tracked 3 discontinuities"
554        );
555    }
556
557    #[test]
558    fn test_propagate_discontinuities() {
559        // Test the discontinuity propagation function directly
560        let delays = vec![1.0];
561        let discs = propagate_discontinuities(0.0, &delays, 5.0, 3);
562
563        // Should have discontinuities at 1, 2, 3
564        assert_eq!(discs.len(), 3);
565        assert!((discs[0] - 1.0).abs() < 1e-10);
566        assert!((discs[1] - 2.0).abs() < 1e-10);
567        assert!((discs[2] - 3.0).abs() < 1e-10);
568    }
569
570    #[test]
571    fn test_propagate_discontinuities_multiple_delays() {
572        // With two delays, discontinuities propagate differently
573        let delays = vec![0.5, 1.0];
574        let discs = propagate_discontinuities(0.0, &delays, 3.0, 2);
575
576        // Should have discontinuities at 0.5, 1.0, 1.5, 2.0
577        // (0.5+0.5=1.0 is duplicate, 0.5+1.0=1.5, 1.0+0.5=1.5 is duplicate, 1.0+1.0=2.0)
578        assert!(
579            discs.iter().any(|&d| (d - 0.5).abs() < 1e-10),
580            "Should have 0.5"
581        );
582        assert!(
583            discs.iter().any(|&d| (d - 1.0).abs() < 1e-10),
584            "Should have 1.0"
585        );
586        assert!(
587            discs.iter().any(|&d| (d - 1.5).abs() < 1e-10),
588            "Should have 1.5"
589        );
590        assert!(
591            discs.iter().any(|&d| (d - 2.0).abs() < 1e-10),
592            "Should have 2.0"
593        );
594    }
595
596    #[test]
597    fn test_state_dependent_delay() {
598        // y'(t) = -y(t - tau(y))
599        // where tau(y) = 0.5 + 0.1 * y
600        //
601        // This is a simple state-dependent delay where the delay
602        // increases with y. For y(0) = 1, initial delay is 0.6.
603
604        struct StateDependentSystem;
605
606        impl DdeSystem<f64> for StateDependentSystem {
607            fn dim(&self) -> usize {
608                1
609            }
610
611            fn delays(&self) -> Vec<f64> {
612                // Return a nominal delay for n_delays() to work
613                vec![0.5]
614            }
615
616            fn delays_at(&self, _t: f64, y: &[f64]) -> Vec<f64> {
617                vec![0.5 + 0.1 * y[0]]
618            }
619
620            fn has_state_dependent_delays(&self) -> bool {
621                true
622            }
623
624            fn rhs(&self, _t: f64, _y: &[f64], y_delayed: &[&[f64]], dydt: &mut [f64]) {
625                dydt[0] = -y_delayed[0][0];
626            }
627        }
628
629        let system = StateDependentSystem;
630        let history = |_t: f64| vec![1.0];
631        let options = DdeOptions::default().rtol(1e-6).atol(1e-8);
632
633        let result = MethodOfSteps::solve(&system, 0.0, 2.0, &history, &options);
634
635        assert!(
636            result.is_ok(),
637            "State-dependent delay solve failed: {:?}",
638            result.err()
639        );
640        let sol = result.unwrap();
641
642        // The solution should decay (since RHS is negative of delayed value)
643        // and remain bounded
644        let y_final = sol.y_final().unwrap()[0];
645        assert!(y_final < 1.0, "Solution should decay from initial y=1");
646        assert!(y_final > -10.0, "Solution should remain bounded");
647
648        // Verify the delay varies during integration
649        // At t=0, y=1, delay=0.6
650        // As y decreases, delay decreases toward 0.5
651    }
652
653    #[test]
654    fn test_state_dependent_delay_vs_constant() {
655        // Compare state-dependent delay vs constant delay to verify
656        // the state-dependent delay is actually being used
657        //
658        // y'(t) = -y(t - tau)
659        // For state-dependent: tau(y) = 0.2 + 0.5 * y^2
660        // For constant: tau = 0.5
661        //
662        // With y(0) = 1, initial state-dependent delay is 0.7
663        // As solution decays, delay decreases toward 0.2
664
665        struct StateDependentDelay;
666        impl DdeSystem<f64> for StateDependentDelay {
667            fn dim(&self) -> usize {
668                1
669            }
670            fn delays(&self) -> Vec<f64> {
671                vec![0.5]
672            }
673            fn delays_at(&self, _t: f64, y: &[f64]) -> Vec<f64> {
674                vec![0.2 + 0.5 * y[0] * y[0]]
675            }
676            fn has_state_dependent_delays(&self) -> bool {
677                true
678            }
679            fn rhs(&self, _t: f64, _y: &[f64], y_delayed: &[&[f64]], dydt: &mut [f64]) {
680                dydt[0] = -y_delayed[0][0];
681            }
682        }
683
684        struct ConstantDelay;
685        impl DdeSystem<f64> for ConstantDelay {
686            fn dim(&self) -> usize {
687                1
688            }
689            fn delays(&self) -> Vec<f64> {
690                vec![0.5]
691            }
692            fn rhs(&self, _t: f64, _y: &[f64], y_delayed: &[&[f64]], dydt: &mut [f64]) {
693                dydt[0] = -y_delayed[0][0];
694            }
695        }
696
697        let history = |_t: f64| vec![1.0];
698        let options = DdeOptions::default().rtol(1e-8).atol(1e-10);
699
700        let result_sd = MethodOfSteps::solve(&StateDependentDelay, 0.0, 3.0, &history, &options)
701            .expect("State-dependent solve failed");
702        let result_const = MethodOfSteps::solve(&ConstantDelay, 0.0, 3.0, &history, &options)
703            .expect("Constant delay solve failed");
704
705        // Solutions should be different because delays are different
706        let y_sd = result_sd.y_final().unwrap()[0];
707        let y_const = result_const.y_final().unwrap()[0];
708
709        // They should differ by a meaningful amount (not just floating point noise)
710        let diff = (y_sd - y_const).abs();
711        assert!(diff > 0.01,
712            "State-dependent and constant delay solutions should differ significantly, but diff = {}",
713            diff);
714    }
715
716    #[test]
717    fn test_dde_negative_delay_rejected() {
718        // State-dependent delay that returns a negative value should be rejected
719        struct NegativeDelaySystem;
720
721        impl DdeSystem<f64> for NegativeDelaySystem {
722            fn dim(&self) -> usize {
723                1
724            }
725            fn delays(&self) -> Vec<f64> {
726                vec![1.0]
727            }
728            fn delays_at(&self, _t: f64, _y: &[f64]) -> Vec<f64> {
729                vec![-0.5] // Negative delay - invalid
730            }
731            fn has_state_dependent_delays(&self) -> bool {
732                true
733            }
734            fn rhs(&self, _t: f64, _y: &[f64], y_delayed: &[&[f64]], dydt: &mut [f64]) {
735                dydt[0] = -y_delayed[0][0];
736            }
737        }
738
739        let system = NegativeDelaySystem;
740        let history = |_t: f64| vec![1.0];
741        let options = DdeOptions::default();
742
743        let result = MethodOfSteps::solve(&system, 0.0, 1.0, &history, &options);
744        assert!(result.is_err(), "Should reject negative delay");
745        assert!(
746            result.unwrap_err().contains("negative"),
747            "Error should mention negative delay"
748        );
749    }
750
751    #[test]
752    fn test_dde_discontinuity_hmin_too_large() {
753        // When h_min is larger than the distance to a discontinuity,
754        // the solver should return an error instead of skipping it
755        struct SimpleDelaySystem;
756
757        impl DdeSystem<f64> for SimpleDelaySystem {
758            fn dim(&self) -> usize {
759                1
760            }
761            fn delays(&self) -> Vec<f64> {
762                vec![0.001]
763            } // Very small delay
764            fn rhs(&self, _t: f64, _y: &[f64], y_delayed: &[&[f64]], dydt: &mut [f64]) {
765                dydt[0] = -y_delayed[0][0];
766            }
767        }
768
769        let system = SimpleDelaySystem;
770        let history = |_t: f64| vec![1.0];
771        // h_min is larger than the delay, so we can't step to the discontinuity
772        let options = DdeOptions::default()
773            .track_discontinuities(true)
774            .discontinuity_order(1)
775            .h_max(0.1);
776
777        // Set h_min very large
778        let mut opts = options;
779        opts.h_min = 0.01; // Larger than delay of 0.001
780
781        let result = MethodOfSteps::solve(&system, 0.0, 0.1, &history, &opts);
782        assert!(
783            result.is_err(),
784            "Should error when h_min would skip discontinuity"
785        );
786        assert!(
787            result.unwrap_err().contains("h_min"),
788            "Error should mention h_min"
789        );
790    }
791
792    #[test]
793    fn test_dde_discontinuity_cap() {
794        // Many delays with high order should be capped
795        let delays = vec![0.01, 0.02, 0.03, 0.04, 0.05];
796        let discs = propagate_discontinuities(0.0, &delays, 100.0, 10);
797
798        // Should be capped at MAX_DISCONTINUITIES
799        assert!(
800            discs.len() <= MAX_DISCONTINUITIES,
801            "Discontinuities should be capped, got {}",
802            discs.len()
803        );
804    }
805
806    #[test]
807    fn test_dde_solution_accuracy() {
808        // y'(t) = -y(t-1), y(t) = 1 for t <= 0
809        // Exact solution for 0 <= t <= 1: y(t) = 1 - t
810        // Exact solution for 1 <= t <= 2: y(t) = 1 - t + (t-1)^2/2
811        let system = SimpleDelay;
812        let history = |_t: f64| vec![1.0];
813        let options = DdeOptions::default()
814            .rtol(1e-8)
815            .atol(1e-10)
816            .track_discontinuities(true)
817            .discontinuity_order(3);
818
819        let result =
820            MethodOfSteps::solve(&system, 0.0, 2.0, &history, &options).expect("Solve failed");
821
822        // Helper: find the closest time point to target
823        let find_nearest = |target: f64| -> usize {
824            result
825                .t
826                .iter()
827                .enumerate()
828                .min_by(|(_, a), (_, b)| {
829                    (a.to_f64() - target)
830                        .abs()
831                        .partial_cmp(&(b.to_f64() - target).abs())
832                        .unwrap()
833                })
834                .unwrap()
835                .0
836        };
837
838        // Check at t ~ 0.5: y(0.5) = 1 - 0.5 = 0.5
839        let idx_half = find_nearest(0.5);
840        let t_half = result.t[idx_half];
841        let y_half = result.y_at(idx_half)[0];
842        let exact_half = 1.0 - t_half;
843        assert!(
844            (y_half - exact_half).abs() < 1e-4,
845            "At t={}: expected ~{}, got {}",
846            t_half,
847            exact_half,
848            y_half
849        );
850
851        // Check at t = 1.0 (discontinuity tracking should land exactly here)
852        let idx_one = find_nearest(1.0);
853        let y_one = result.y_at(idx_one)[0];
854        assert!(y_one.abs() < 1e-5, "At t=1.0: expected ~0, got {}", y_one);
855
856        // Check at t ~ 1.5: y(t) = 1 - t + (t-1)^2/2
857        let idx_1_5 = find_nearest(1.5);
858        let t_1_5 = result.t[idx_1_5];
859        let y_1_5 = result.y_at(idx_1_5)[0];
860        let exact_1_5 = 1.0 - t_1_5 + (t_1_5 - 1.0) * (t_1_5 - 1.0) / 2.0;
861        assert!(
862            (y_1_5 - exact_1_5).abs() < 1e-4,
863            "At t={}: expected ~{}, got {}",
864            t_1_5,
865            exact_1_5,
866            y_1_5
867        );
868    }
869}