Skip to main content

lox_orbits/
events.rs

1// SPDX-FileCopyrightText: 2024 Helge Eichhorn <git@helgeeichhorn.de>
2//
3// SPDX-License-Identifier: MPL-2.0
4
5use std::collections::VecDeque;
6use std::fmt::Display;
7
8use itertools::Itertools;
9use lox_math::roots::{BoxedError, Callback, CallbackError, FindBracketedRoot, RootFinderError};
10use lox_time::Time;
11use lox_time::deltas::TimeDelta;
12use lox_time::intervals::TimeInterval;
13use lox_time::time_scales::TimeScale;
14use thiserror::Error;
15
16// ---------------------------------------------------------------------------
17// Core event types
18// ---------------------------------------------------------------------------
19
20/// Direction of a zero-crossing event.
21#[derive(Clone, Copy, Debug, Eq, PartialEq)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub enum ZeroCrossing {
24    /// Signal crosses from negative to positive.
25    Up,
26    /// Signal crosses from positive to negative.
27    Down,
28}
29
30impl ZeroCrossing {
31    fn new(s0: f64, s1: f64) -> Option<ZeroCrossing> {
32        if s0 < 0.0 && s1 > 0.0 {
33            Some(ZeroCrossing::Up)
34        } else if s0 > 0.0 && s1 < 0.0 {
35            Some(ZeroCrossing::Down)
36        } else {
37            None
38        }
39    }
40}
41
42impl Display for ZeroCrossing {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        match self {
45            ZeroCrossing::Up => write!(f, "up"),
46            ZeroCrossing::Down => write!(f, "down"),
47        }
48    }
49}
50
51/// A zero-crossing event at a specific time.
52#[derive(Clone, Copy, Debug, Eq, PartialEq)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54pub struct Event<T: TimeScale> {
55    crossing: ZeroCrossing,
56    time: Time<T>,
57}
58
59impl<T: TimeScale> Event<T> {
60    /// Creates a new event at the given time with the specified crossing direction.
61    pub fn new(time: Time<T>, crossing: ZeroCrossing) -> Self {
62        Self { crossing, time }
63    }
64
65    /// Returns the time of the event.
66    pub fn time(&self) -> Time<T>
67    where
68        T: Copy,
69    {
70        self.time
71    }
72
73    /// Returns the crossing direction.
74    pub fn crossing(&self) -> ZeroCrossing {
75        self.crossing
76    }
77}
78
79// ---------------------------------------------------------------------------
80// Error types
81// ---------------------------------------------------------------------------
82
83/// Errors that can occur during event detection.
84#[derive(Debug, Error)]
85pub enum DetectError {
86    /// The root-finding algorithm failed.
87    #[error(transparent)]
88    RootFinder(#[from] RootFinderError),
89    /// The user-provided callback returned an error.
90    #[error(transparent)]
91    Callback(Box<dyn std::error::Error + Send + Sync>),
92}
93
94// ---------------------------------------------------------------------------
95// Core traits
96// ---------------------------------------------------------------------------
97
98/// Scalar function whose zero-crossings define events.
99pub trait DetectFn<T: TimeScale> {
100    /// The error type returned by [`eval`](Self::eval).
101    type Error: std::error::Error + Send + Sync + 'static;
102    /// Evaluates the detection function at the given time.
103    fn eval(&self, time: Time<T>) -> Result<f64, Self::Error>;
104}
105
106/// Detects instantaneous events (zero-crossings) within a time interval.
107pub trait EventDetector<T: TimeScale> {
108    /// Detects all zero-crossing events within the given time interval.
109    fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<Event<T>>, DetectError>;
110}
111
112/// Detects intervals where a condition holds within a time interval.
113pub trait IntervalDetector<T: TimeScale> {
114    /// Detects all sub-intervals where the condition holds.
115    fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError>;
116}
117
118// ---------------------------------------------------------------------------
119// Callback wrapper for DetectFn → root finder bridge
120// ---------------------------------------------------------------------------
121
122/// A `Callback`-compatible wrapper that bridges `DetectFn` to the root-finder
123/// interface.
124pub(crate) struct DetectCallback<'a, T: TimeScale, F: DetectFn<T>> {
125    func: &'a F,
126    start: Time<T>,
127}
128
129impl<T: TimeScale + Copy, F: DetectFn<T>> Clone for DetectCallback<'_, T, F> {
130    fn clone(&self) -> Self {
131        *self
132    }
133}
134
135impl<T: TimeScale + Copy, F: DetectFn<T>> Copy for DetectCallback<'_, T, F> {}
136
137impl<'a, T: TimeScale + Copy, F: DetectFn<T>> DetectCallback<'a, T, F> {
138    fn new(func: &'a F, start: Time<T>) -> Self {
139        Self { func, start }
140    }
141}
142
143impl<T: TimeScale + Copy, F: DetectFn<T>> Callback for DetectCallback<'_, T, F> {
144    fn call(&self, v: f64) -> Result<f64, CallbackError> {
145        let time = self.start + TimeDelta::from_seconds_f64(v);
146        self.func
147            .eval(time)
148            .map_err(|e| CallbackError::from(Box::new(e) as BoxedError))
149    }
150}
151
152// ---------------------------------------------------------------------------
153// RootFindingDetector — wraps DetectFn + root finder → EventDetector
154// ---------------------------------------------------------------------------
155
156use lox_math::roots::Brent;
157
158/// Wraps a `DetectFn` with a root finder to produce an `EventDetector`.
159pub struct RootFindingDetector<F, R = Brent> {
160    pub(crate) func: F,
161    root_finder: R,
162    step: TimeDelta,
163    coarse_step: Option<TimeDelta>,
164}
165
166impl<F> RootFindingDetector<F, Brent> {
167    /// Creates a new detector with Brent's root-finding method and the given step size.
168    pub fn new(func: F, step: TimeDelta) -> Self {
169        Self {
170            func,
171            root_finder: Brent::default(),
172            step,
173            coarse_step: None,
174        }
175    }
176}
177
178impl<F, R> RootFindingDetector<F, R> {
179    /// Creates a new detector with a custom root-finding algorithm.
180    pub fn with_root_finder(func: F, step: TimeDelta, root_finder: R) -> Self {
181        Self {
182            func,
183            root_finder,
184            step,
185            coarse_step: None,
186        }
187    }
188
189    /// Enables two-level detection with the given coarse step size.
190    pub fn with_coarse_step(mut self, coarse_step: TimeDelta) -> Self {
191        self.coarse_step = Some(coarse_step);
192        self
193    }
194}
195
196/// Build a uniform time grid from 0 to `total` with the given `step`,
197/// always including the endpoint.
198fn build_time_grid(total: f64, step: f64) -> Vec<f64> {
199    let mut grid = Vec::new();
200    let mut t = 0.0;
201    while t <= total {
202        grid.push(t);
203        t += step;
204    }
205    if grid.last().is_none_or(|&last| last < total) {
206        grid.push(total);
207    }
208    grid
209}
210
211impl<F, R> RootFindingDetector<F, R> {
212    /// Core detection returning events and the sign at the interval start.
213    ///
214    /// The start sign is needed by [`EventsToIntervals`] to determine whether
215    /// the condition holds throughout when no zero-crossings are found.
216    /// Returning it here avoids a redundant function evaluation.
217    pub(crate) fn detect_with_start_sign<T>(
218        &self,
219        interval: TimeInterval<T>,
220    ) -> Result<(Vec<Event<T>>, f64), DetectError>
221    where
222        T: TimeScale + Copy,
223        F: DetectFn<T>,
224        for<'a> R: FindBracketedRoot<DetectCallback<'a, T, F>>,
225    {
226        let start = interval.start();
227        let end = interval.end();
228        let total_seconds = (end - start).to_seconds().to_f64();
229        let step_seconds = self.step.to_seconds().to_f64();
230        let callback = DetectCallback::new(&self.func, start);
231
232        match self.coarse_step {
233            Some(coarse_step) => {
234                let coarse_seconds = coarse_step.to_seconds().to_f64();
235                self.detect_two_level(callback, start, total_seconds, step_seconds, coarse_seconds)
236            }
237            None => self.detect_single_level(callback, start, total_seconds, step_seconds),
238        }
239    }
240
241    /// Single-level detection: evaluate at every fine step then root-find.
242    fn detect_single_level<T>(
243        &self,
244        callback: DetectCallback<'_, T, F>,
245        start: Time<T>,
246        total_seconds: f64,
247        step_seconds: f64,
248    ) -> Result<(Vec<Event<T>>, f64), DetectError>
249    where
250        T: TimeScale + Copy,
251        F: DetectFn<T>,
252        for<'a> R: FindBracketedRoot<DetectCallback<'a, T, F>>,
253    {
254        let steps = build_time_grid(total_seconds, step_seconds);
255
256        let mut signs = Vec::with_capacity(steps.len());
257        for &t in &steps {
258            let v = callback
259                .call(t)
260                .map_err(|e| DetectError::RootFinder(RootFinderError::Callback(e)))?;
261            signs.push(v.signum());
262        }
263
264        let start_sign = signs[0];
265
266        if signs.iter().all(|&s| s < 0.0) || signs.iter().all(|&s| s > 0.0) {
267            return Ok((vec![], start_sign));
268        }
269
270        let mut events = Vec::new();
271        for ((&t0, &s0), (&t1, &s1)) in std::iter::zip(&steps, &signs).tuple_windows() {
272            if let Some(crossing) = ZeroCrossing::new(s0, s1) {
273                let t = self
274                    .root_finder
275                    .find_in_bracket(callback, (t0, t1))
276                    .map_err(DetectError::RootFinder)?;
277                let time = start + TimeDelta::from_seconds_f64(t);
278                events.push(Event { crossing, time });
279            }
280        }
281
282        Ok((events, start_sign))
283    }
284
285    /// Two-level detection: coarse grid to find sign-change brackets, then
286    /// fine grid within each bracket to locate precise crossings.
287    fn detect_two_level<T>(
288        &self,
289        callback: DetectCallback<'_, T, F>,
290        start: Time<T>,
291        total_seconds: f64,
292        step_seconds: f64,
293        coarse_seconds: f64,
294    ) -> Result<(Vec<Event<T>>, f64), DetectError>
295    where
296        T: TimeScale + Copy,
297        F: DetectFn<T>,
298        for<'a> R: FindBracketedRoot<DetectCallback<'a, T, F>>,
299    {
300        // 1. Build coarse grid and evaluate signs.
301        let coarse_grid = build_time_grid(total_seconds, coarse_seconds);
302        let mut coarse_signs = Vec::with_capacity(coarse_grid.len());
303        for &t in &coarse_grid {
304            let v = callback
305                .call(t)
306                .map_err(|e| DetectError::RootFinder(RootFinderError::Callback(e)))?;
307            coarse_signs.push(v.signum());
308        }
309
310        let start_sign = coarse_signs[0];
311
312        // 2. For each coarse bracket with a sign change, subdivide with fine steps.
313        let mut events = Vec::new();
314        for ((&tc0, &sc0), (&tc1, &sc1)) in
315            std::iter::zip(&coarse_grid, &coarse_signs).tuple_windows()
316        {
317            if ZeroCrossing::new(sc0, sc1).is_none() {
318                continue;
319            }
320
321            // Build fine grid within this coarse bracket.
322            // Reuse the known sign at tc0 to avoid a redundant evaluation.
323            let bracket_len = tc1 - tc0;
324            let fine_grid = build_time_grid(bracket_len, step_seconds);
325
326            let mut fine_times = Vec::with_capacity(fine_grid.len());
327            let mut fine_signs = Vec::with_capacity(fine_grid.len());
328
329            // First point: reuse coarse sign.
330            fine_times.push(tc0);
331            fine_signs.push(sc0);
332
333            // Interior and last points: evaluate.
334            for &ft in &fine_grid[1..] {
335                let abs_t = tc0 + ft;
336                fine_times.push(abs_t);
337                let v = callback
338                    .call(abs_t)
339                    .map_err(|e| DetectError::RootFinder(RootFinderError::Callback(e)))?;
340                fine_signs.push(v.signum());
341            }
342
343            // Root-find on fine-level sign changes.
344            for ((&t0, &s0), (&t1, &s1)) in std::iter::zip(&fine_times, &fine_signs).tuple_windows()
345            {
346                if let Some(crossing) = ZeroCrossing::new(s0, s1) {
347                    let t = self
348                        .root_finder
349                        .find_in_bracket(callback, (t0, t1))
350                        .map_err(DetectError::RootFinder)?;
351                    let time = start + TimeDelta::from_seconds_f64(t);
352                    events.push(Event { crossing, time });
353                }
354            }
355        }
356
357        Ok((events, start_sign))
358    }
359}
360
361impl<T, F, R> EventDetector<T> for RootFindingDetector<F, R>
362where
363    T: TimeScale + Copy,
364    F: DetectFn<T>,
365    for<'a> R: FindBracketedRoot<DetectCallback<'a, T, F>>,
366{
367    fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<Event<T>>, DetectError> {
368        self.detect_with_start_sign(interval)
369            .map(|(events, _)| events)
370    }
371}
372
373// ---------------------------------------------------------------------------
374// EventsToIntervals — converts EventDetector → IntervalDetector
375// ---------------------------------------------------------------------------
376
377/// Converts a [`RootFindingDetector`] into an [`IntervalDetector`] by pairing
378/// Up/Down crossings into intervals.
379///
380/// When no events are found, the sign of the detect function at the interval
381/// start is checked: if positive, the entire interval is returned; if
382/// negative, an empty list is returned.
383pub struct EventsToIntervals<F, R = Brent> {
384    detector: RootFindingDetector<F, R>,
385}
386
387impl<F> EventsToIntervals<F, Brent> {
388    /// Creates a new converter from a Brent-based root-finding detector.
389    pub fn new(detector: RootFindingDetector<F>) -> Self {
390        Self { detector }
391    }
392}
393
394impl<F, R> EventsToIntervals<F, R> {
395    /// Creates a new converter from a detector with a custom root finder.
396    pub fn with_root_finder(detector: RootFindingDetector<F, R>) -> Self {
397        Self { detector }
398    }
399}
400
401impl<T, F, R> IntervalDetector<T> for EventsToIntervals<F, R>
402where
403    T: TimeScale + Copy,
404    F: DetectFn<T>,
405    for<'a> R: FindBracketedRoot<DetectCallback<'a, T, F>>,
406{
407    fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
408        let start = interval.start();
409        let end = interval.end();
410
411        let (events, start_sign) = self.detector.detect_with_start_sign(interval)?;
412        if events.is_empty() {
413            // No zero crossings — use the sign at the start (already computed
414            // during step evaluation) to determine if the condition holds
415            // throughout or not at all.
416            return if start_sign >= 0.0 {
417                Ok(vec![interval])
418            } else {
419                Ok(vec![])
420            };
421        }
422
423        let mut events: VecDeque<Event<T>> = events.into();
424
425        if events.front().unwrap().crossing == ZeroCrossing::Down {
426            events.push_front(Event {
427                crossing: ZeroCrossing::Up,
428                time: start,
429            });
430        }
431
432        if events.back().unwrap().crossing == ZeroCrossing::Up {
433            events.push_back(Event {
434                crossing: ZeroCrossing::Down,
435                time: end,
436            });
437        }
438
439        let mut intervals = Vec::with_capacity(events.len() / 2);
440        for (up, down) in events.into_iter().tuples() {
441            debug_assert!(up.crossing == ZeroCrossing::Up);
442            debug_assert!(down.crossing == ZeroCrossing::Down);
443            intervals.push(TimeInterval::new(up.time, down.time));
444        }
445
446        Ok(intervals)
447    }
448}
449
450// ---------------------------------------------------------------------------
451// Combinators
452// ---------------------------------------------------------------------------
453
454/// Intervals where BOTH A and B are active (intersection).
455pub struct Intersection<A, B> {
456    a: A,
457    b: B,
458}
459
460impl<T, A, B> IntervalDetector<T> for Intersection<A, B>
461where
462    T: TimeScale + Ord + Copy,
463    A: IntervalDetector<T>,
464    B: IntervalDetector<T>,
465{
466    fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
467        let a = self.a.detect(interval)?;
468        let b = self.b.detect(interval)?;
469        Ok(lox_time::intervals::intersect_intervals(&a, &b))
470    }
471}
472
473/// Intervals where EITHER A or B is active (union).
474pub struct Union<A, B> {
475    a: A,
476    b: B,
477}
478
479impl<T, A, B> IntervalDetector<T> for Union<A, B>
480where
481    T: TimeScale + Ord + Copy,
482    A: IntervalDetector<T>,
483    B: IntervalDetector<T>,
484{
485    fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
486        let a = self.a.detect(interval)?;
487        let b = self.b.detect(interval)?;
488        Ok(lox_time::intervals::union_intervals(&a, &b))
489    }
490}
491
492/// Intervals where D is NOT active (complement within the search interval).
493pub struct Complement<D> {
494    detector: D,
495}
496
497impl<T, D> IntervalDetector<T> for Complement<D>
498where
499    T: TimeScale + Ord + Copy,
500    D: IntervalDetector<T>,
501{
502    fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
503        let inner = self.detector.detect(interval)?;
504        Ok(lox_time::intervals::complement_intervals(&inner, interval))
505    }
506}
507
508/// Optimization: B only evaluates within A's detected intervals.
509pub struct Chain<A, B> {
510    a: A,
511    b: B,
512}
513
514impl<T, A, B> IntervalDetector<T> for Chain<A, B>
515where
516    T: TimeScale + Copy,
517    A: IntervalDetector<T>,
518    B: IntervalDetector<T>,
519{
520    fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
521        let a_intervals = self.a.detect(interval)?;
522        let mut result = Vec::new();
523        for sub in a_intervals {
524            result.extend(self.b.detect(sub)?);
525        }
526        Ok(result)
527    }
528}
529
530// ---------------------------------------------------------------------------
531// Extension trait for IntervalDetector combinators
532// ---------------------------------------------------------------------------
533
534/// Extension trait providing combinator methods for [`IntervalDetector`] implementations.
535pub trait IntervalDetectorExt<T: TimeScale>: IntervalDetector<T> + Sized {
536    /// Returns intervals where both `self` and `other` are active (intersection).
537    fn intersect<B>(self, other: B) -> Intersection<Self, B> {
538        Intersection { a: self, b: other }
539    }
540
541    /// Returns intervals where either `self` or `other` is active (union).
542    fn union<B>(self, other: B) -> Union<Self, B> {
543        Union { a: self, b: other }
544    }
545
546    /// Returns intervals where `self` is NOT active (complement).
547    fn complement(self) -> Complement<Self> {
548        Complement { detector: self }
549    }
550
551    /// Evaluates `other` only within intervals detected by `self`.
552    fn chain<B>(self, other: B) -> Chain<Self, B> {
553        Chain { a: self, b: other }
554    }
555}
556
557impl<T: TimeScale, D: IntervalDetector<T>> IntervalDetectorExt<T> for D {}
558
559// ---------------------------------------------------------------------------
560// IntervalDetector impls for boxed trait objects
561// ---------------------------------------------------------------------------
562
563impl<T: TimeScale> IntervalDetector<T> for Box<dyn IntervalDetector<T> + '_> {
564    fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
565        (**self).detect(interval)
566    }
567}
568
569impl<T: TimeScale> IntervalDetector<T> for Box<dyn IntervalDetector<T> + Send + '_> {
570    fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
571        (**self).detect(interval)
572    }
573}
574
575// ---------------------------------------------------------------------------
576// Closure-based DetectFn adapters
577// ---------------------------------------------------------------------------
578
579/// Wraps an infallible closure into a [`DetectFn`].
580pub struct FnDetect<F>(pub F);
581
582impl<T, F> DetectFn<T> for FnDetect<F>
583where
584    T: TimeScale + Copy,
585    F: Fn(Time<T>) -> f64,
586{
587    type Error = std::convert::Infallible;
588    fn eval(&self, time: Time<T>) -> Result<f64, Self::Error> {
589        Ok((self.0)(time))
590    }
591}
592
593/// Wraps a fallible closure into a [`DetectFn`].
594pub struct TryFnDetect<F>(pub F);
595
596impl<T, F, E> DetectFn<T> for TryFnDetect<F>
597where
598    T: TimeScale + Copy,
599    F: Fn(Time<T>) -> Result<f64, E>,
600    E: std::error::Error + Send + Sync + 'static,
601{
602    type Error = E;
603    fn eval(&self, time: Time<T>) -> Result<f64, Self::Error> {
604        (self.0)(time)
605    }
606}
607
608// ---------------------------------------------------------------------------
609// Convenience functions
610// ---------------------------------------------------------------------------
611
612/// Find zero-crossing events for an infallible closure over a time interval.
613pub fn find_events<T, F>(
614    func: F,
615    interval: TimeInterval<T>,
616    step: TimeDelta,
617) -> Result<Vec<Event<T>>, DetectError>
618where
619    T: TimeScale + Copy,
620    F: Fn(Time<T>) -> f64,
621{
622    RootFindingDetector::new(FnDetect(func), step).detect(interval)
623}
624
625/// Find zero-crossing events for a fallible closure over a time interval.
626pub fn try_find_events<T, F, E>(
627    func: F,
628    interval: TimeInterval<T>,
629    step: TimeDelta,
630) -> Result<Vec<Event<T>>, DetectError>
631where
632    T: TimeScale + Copy,
633    F: Fn(Time<T>) -> Result<f64, E>,
634    E: std::error::Error + Send + Sync + 'static,
635{
636    RootFindingDetector::new(TryFnDetect(func), step).detect(interval)
637}
638
639/// Find intervals where an infallible closure is positive.
640pub fn find_windows<T, F>(
641    func: F,
642    interval: TimeInterval<T>,
643    step: TimeDelta,
644) -> Result<Vec<TimeInterval<T>>, DetectError>
645where
646    T: TimeScale + Copy,
647    F: Fn(Time<T>) -> f64,
648{
649    let detector = RootFindingDetector::new(FnDetect(func), step);
650    EventsToIntervals::new(detector).detect(interval)
651}
652
653/// Find intervals where a fallible closure is positive.
654pub fn try_find_windows<T, F, E>(
655    func: F,
656    interval: TimeInterval<T>,
657    step: TimeDelta,
658) -> Result<Vec<TimeInterval<T>>, DetectError>
659where
660    T: TimeScale + Copy,
661    F: Fn(Time<T>) -> Result<f64, E>,
662    E: std::error::Error + Send + Sync + 'static,
663{
664    let detector = RootFindingDetector::new(TryFnDetect(func), step);
665    EventsToIntervals::new(detector).detect(interval)
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671    use lox_test_utils::assert_approx_eq;
672    use lox_time::time;
673    use lox_time::time_scales::Tai;
674    use std::f64::consts::{PI, TAU};
675    use std::sync::atomic::{AtomicUsize, Ordering};
676
677    /// A `DetectFn` wrapper that counts evaluations via an `AtomicUsize`.
678    struct CountingDetectFn<'a, F> {
679        inner: F,
680        counter: &'a AtomicUsize,
681    }
682
683    impl<'a, T, F> DetectFn<T> for CountingDetectFn<'a, F>
684    where
685        T: TimeScale + Copy,
686        F: Fn(Time<T>) -> f64,
687    {
688        type Error = std::convert::Infallible;
689        fn eval(&self, time: Time<T>) -> Result<f64, Self::Error> {
690            self.counter.fetch_add(1, Ordering::Relaxed);
691            Ok((self.inner)(time))
692        }
693    }
694
695    #[test]
696    fn test_events() {
697        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
698        let end = start + TimeDelta::from_seconds(7);
699        let interval = TimeInterval::new(start, end);
700
701        let detect_fn = FnDetect(|t: Time<Tai>| (t - start).to_seconds().to_f64().sin());
702        let detector = RootFindingDetector::new(detect_fn, TimeDelta::from_seconds(1));
703        let events = detector.detect(interval).unwrap();
704
705        assert_eq!(events.len(), 2);
706        assert_eq!(events[0].crossing, ZeroCrossing::Down);
707        assert_approx_eq!(
708            events[0].time,
709            start + TimeDelta::from_seconds_f64(PI),
710            rtol <= 1e-6
711        );
712        assert_eq!(events[1].crossing, ZeroCrossing::Up);
713        assert_approx_eq!(
714            events[1].time,
715            start + TimeDelta::from_seconds_f64(TAU),
716            rtol <= 1e-6
717        );
718    }
719
720    #[test]
721    fn test_windows() {
722        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
723        let end = start + TimeDelta::from_seconds(7);
724        let interval = TimeInterval::new(start, end);
725
726        let detect_fn = FnDetect(|t: Time<Tai>| (t - start).to_seconds().to_f64().sin());
727        let detector = RootFindingDetector::new(detect_fn, TimeDelta::from_seconds(1));
728        let intervals_detector = EventsToIntervals::new(detector);
729        let windows = intervals_detector.detect(interval).unwrap();
730
731        assert_eq!(windows.len(), 2);
732        assert_eq!(windows[0].start(), start);
733        assert_approx_eq!(
734            windows[0].end(),
735            start + TimeDelta::from_seconds_f64(PI),
736            rtol <= 1e-6
737        );
738    }
739
740    #[test]
741    fn test_windows_no_windows() {
742        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
743        let end = start + TimeDelta::from_seconds(7);
744        let interval = TimeInterval::new(start, end);
745
746        let detect_fn = FnDetect(|_t: Time<Tai>| -1.0);
747        let detector = RootFindingDetector::new(detect_fn, TimeDelta::from_seconds(1));
748        let intervals_detector = EventsToIntervals::new(detector);
749        let windows = intervals_detector.detect(interval).unwrap();
750
751        assert!(windows.is_empty());
752    }
753
754    #[test]
755    fn test_windows_full_coverage() {
756        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
757        let end = start + TimeDelta::from_seconds(7);
758        let interval = TimeInterval::new(start, end);
759
760        let detect_fn = FnDetect(|_t: Time<Tai>| 1.0);
761        let detector = RootFindingDetector::new(detect_fn, TimeDelta::from_seconds(1));
762        let intervals_detector = EventsToIntervals::new(detector);
763        let windows = intervals_detector.detect(interval).unwrap();
764
765        assert_eq!(windows.len(), 1);
766        assert_eq!(windows[0].start(), start);
767        assert_eq!(windows[0].end(), end);
768    }
769
770    // -----------------------------------------------------------------------
771    // Two-level stepping tests
772    // -----------------------------------------------------------------------
773
774    #[test]
775    fn test_two_level_matches_single_level() {
776        // sin(t) over [0, 7]: zero crossings at PI and TAU.
777        // Two-level with coarse_step=3s, fine step=1s should find the same events.
778        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
779        let end = start + TimeDelta::from_seconds(7);
780        let interval = TimeInterval::new(start, end);
781
782        let single = RootFindingDetector::new(
783            FnDetect(move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin()),
784            TimeDelta::from_seconds(1),
785        )
786        .detect(interval)
787        .unwrap();
788
789        let two_level = RootFindingDetector::new(
790            FnDetect(move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin()),
791            TimeDelta::from_seconds(1),
792        )
793        .with_coarse_step(TimeDelta::from_seconds(3))
794        .detect(interval)
795        .unwrap();
796
797        assert_eq!(single.len(), two_level.len());
798        for (s, tl) in single.iter().zip(&two_level) {
799            assert_eq!(s.crossing, tl.crossing);
800            assert_approx_eq!(s.time, tl.time, rtol <= 1e-6);
801        }
802    }
803
804    #[test]
805    fn test_two_level_multiple_crossings_in_bracket() {
806        // sin(t + 0.5) over [0, 10]: one coarse bracket [0, 10] contains 3
807        // zero crossings (at t ≈ 2.64, 5.78, 8.92). The bracket has a sign
808        // change (sin(0.5) > 0, sin(10.5) < 0) so the fine grid is applied
809        // and all 3 crossings are found.
810        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
811        let end = start + TimeDelta::from_seconds(10);
812        let interval = TimeInterval::new(start, end);
813
814        let func = move |t: Time<Tai>| ((t - start).to_seconds().to_f64() + 0.5).sin();
815
816        let single = RootFindingDetector::new(FnDetect(func), TimeDelta::from_seconds(1))
817            .detect(interval)
818            .unwrap();
819
820        let two_level = RootFindingDetector::new(FnDetect(func), TimeDelta::from_seconds(1))
821            .with_coarse_step(TimeDelta::from_seconds(10))
822            .detect(interval)
823            .unwrap();
824
825        assert_eq!(single.len(), 3, "expected 3 crossings");
826        assert_eq!(single.len(), two_level.len());
827        for (s, tl) in single.iter().zip(&two_level) {
828            assert_eq!(s.crossing, tl.crossing);
829            assert_approx_eq!(s.time, tl.time, rtol <= 1e-6);
830        }
831    }
832
833    #[test]
834    fn test_two_level_no_events() {
835        // Constant negative function — no events, correct start_sign.
836        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
837        let end = start + TimeDelta::from_seconds(10);
838        let interval = TimeInterval::new(start, end);
839
840        let det =
841            RootFindingDetector::new(FnDetect(|_t: Time<Tai>| -1.0), TimeDelta::from_seconds(1))
842                .with_coarse_step(TimeDelta::from_seconds(3));
843
844        let (events, start_sign) = det.detect_with_start_sign(interval).unwrap();
845        assert!(events.is_empty());
846        assert!(start_sign < 0.0);
847    }
848
849    #[test]
850    fn test_two_level_windows_roundtrip() {
851        // EventsToIntervals with a coarse-stepped detector produces the same
852        // windows as without for sin(t) over [0, 7].
853        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
854        let end = start + TimeDelta::from_seconds(7);
855        let interval = TimeInterval::new(start, end);
856
857        let single_windows = EventsToIntervals::new(RootFindingDetector::new(
858            FnDetect(move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin()),
859            TimeDelta::from_seconds(1),
860        ))
861        .detect(interval)
862        .unwrap();
863
864        let two_level_windows = EventsToIntervals::new(
865            RootFindingDetector::new(
866                FnDetect(move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin()),
867                TimeDelta::from_seconds(1),
868            )
869            .with_coarse_step(TimeDelta::from_seconds(3)),
870        )
871        .detect(interval)
872        .unwrap();
873
874        assert_eq!(single_windows.len(), two_level_windows.len());
875        for (s, tl) in single_windows.iter().zip(&two_level_windows) {
876            assert_approx_eq!(s.start(), tl.start(), rtol <= 1e-6);
877            assert_approx_eq!(s.end(), tl.end(), rtol <= 1e-6);
878        }
879    }
880
881    #[test]
882    fn test_two_level_eval_count_reduction() {
883        // Verify that two-level stepping uses fewer evaluations than single-level
884        // for a long interval with sparse events.
885        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
886        let end = start + TimeDelta::from_seconds(1000);
887        let interval = TimeInterval::new(start, end);
888
889        // sin(t/100) — two crossings in [0, 1000] at ~314s and ~628s.
890        let func = move |t: Time<Tai>| ((t - start).to_seconds().to_f64() / 100.0).sin();
891
892        let counter_single = AtomicUsize::new(0);
893        let single = RootFindingDetector::new(
894            CountingDetectFn {
895                inner: func,
896                counter: &counter_single,
897            },
898            TimeDelta::from_seconds(1),
899        )
900        .detect(interval)
901        .unwrap();
902
903        let counter_two = AtomicUsize::new(0);
904        // min_pass_duration = 300s → coarse_step = 150s
905        let two_level = RootFindingDetector::new(
906            CountingDetectFn {
907                inner: func,
908                counter: &counter_two,
909            },
910            TimeDelta::from_seconds(1),
911        )
912        .with_coarse_step(TimeDelta::from_seconds(150))
913        .detect(interval)
914        .unwrap();
915
916        // Both should find the same events.
917        assert_eq!(single.len(), two_level.len());
918
919        let single_evals = counter_single.load(Ordering::Relaxed);
920        let two_level_evals = counter_two.load(Ordering::Relaxed);
921
922        // Two-level should use significantly fewer evals.
923        assert!(
924            two_level_evals < single_evals,
925            "two-level ({two_level_evals}) should use fewer evals than single-level ({single_evals})"
926        );
927    }
928
929    // -----------------------------------------------------------------------
930    // Combinator and extension-trait tests
931    // -----------------------------------------------------------------------
932
933    /// Helper: build an `EventsToIntervals` detector from an infallible closure.
934    fn make_window_detector<F: Fn(Time<Tai>) -> f64>(
935        func: F,
936        step: TimeDelta,
937    ) -> EventsToIntervals<FnDetect<F>> {
938        EventsToIntervals::new(RootFindingDetector::new(FnDetect(func), step))
939    }
940
941    /// sin(t) is positive on (0, PI) within [0, 7].
942    fn sin_detector(start: Time<Tai>) -> EventsToIntervals<FnDetect<impl Fn(Time<Tai>) -> f64>> {
943        make_window_detector(
944            move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin(),
945            TimeDelta::from_seconds(1),
946        )
947    }
948
949    /// cos(t) is positive on [0, PI/2) and (3PI/2, 7] within [0, 7].
950    fn cos_detector(start: Time<Tai>) -> EventsToIntervals<FnDetect<impl Fn(Time<Tai>) -> f64>> {
951        make_window_detector(
952            move |t: Time<Tai>| (t - start).to_seconds().to_f64().cos(),
953            TimeDelta::from_seconds(1),
954        )
955    }
956
957    fn test_interval() -> (Time<Tai>, TimeInterval<Tai>) {
958        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
959        let end = start + TimeDelta::from_seconds(7);
960        (start, TimeInterval::new(start, end))
961    }
962
963    #[test]
964    fn test_intersect() {
965        let (start, interval) = test_interval();
966        // sin > 0 on (0, PI), cos > 0 on [0, PI/2) ∪ (3PI/2, 7]
967        // intersection: (0, PI/2) — both positive only here (within [0, PI])
968        let det = sin_detector(start).intersect(cos_detector(start));
969        let windows = det.detect(interval).unwrap();
970        assert_eq!(windows.len(), 2);
971        // First window: start..PI/2
972        assert_approx_eq!(windows[0].start(), start, rtol <= 1e-6);
973        assert_approx_eq!(
974            windows[0].end(),
975            start + TimeDelta::from_seconds_f64(PI / 2.0),
976            rtol <= 1e-4
977        );
978    }
979
980    #[test]
981    fn test_union() {
982        let (start, interval) = test_interval();
983        // sin > 0 on (0, PI), cos > 0 on [0, PI/2) ∪ (3PI/2, 7]
984        // union covers most of the interval
985        let det = sin_detector(start).union(cos_detector(start));
986        let windows = det.detect(interval).unwrap();
987        // The union should cover [0, PI] ∪ [3PI/2, 7]
988        assert_eq!(windows.len(), 2);
989        // First window spans from start to PI (sin covers 0..PI, cos covers 0..PI/2)
990        assert_approx_eq!(windows[0].start(), start, rtol <= 1e-6);
991        assert_approx_eq!(
992            windows[0].end(),
993            start + TimeDelta::from_seconds_f64(PI),
994            rtol <= 1e-4
995        );
996    }
997
998    #[test]
999    fn test_complement() {
1000        let (start, interval) = test_interval();
1001        // sin > 0 on [start, PI] ∪ [TAU, end] within [0, 7]
1002        // complement: [PI, TAU]
1003        let det = sin_detector(start).complement();
1004        let windows = det.detect(interval).unwrap();
1005        assert_eq!(windows.len(), 1);
1006        assert_approx_eq!(
1007            windows[0].start(),
1008            start + TimeDelta::from_seconds_f64(PI),
1009            rtol <= 1e-4
1010        );
1011        assert_approx_eq!(
1012            windows[0].end(),
1013            start + TimeDelta::from_seconds_f64(TAU),
1014            rtol <= 1e-4
1015        );
1016    }
1017
1018    #[test]
1019    fn test_chain() {
1020        let (start, interval) = test_interval();
1021        // Chain: first find sin > 0 windows, then within those evaluate cos > 0.
1022        // sin > 0 on [start, PI] ∪ [TAU, end].
1023        // Within [start, PI]: cos > 0 on [start, PI/2].
1024        // Within [TAU, end]: cos(TAU..7) > 0 throughout, so [TAU, end].
1025        let det = sin_detector(start).chain(cos_detector(start));
1026        let windows = det.detect(interval).unwrap();
1027        assert_eq!(windows.len(), 2);
1028        // First window: [start, PI/2]
1029        assert_approx_eq!(windows[0].start(), start, rtol <= 1e-6);
1030        assert_approx_eq!(
1031            windows[0].end(),
1032            start + TimeDelta::from_seconds_f64(PI / 2.0),
1033            rtol <= 1e-4
1034        );
1035        // Second window: [TAU, end]
1036        assert_approx_eq!(
1037            windows[1].start(),
1038            start + TimeDelta::from_seconds_f64(TAU),
1039            rtol <= 1e-4
1040        );
1041        assert_approx_eq!(
1042            windows[1].end(),
1043            start + TimeDelta::from_seconds(7),
1044            rtol <= 1e-6
1045        );
1046    }
1047
1048    #[test]
1049    fn test_chain_restricts_evaluation() {
1050        // Chain should only evaluate B within A's windows.
1051        // Use a constant-negative A to prove B is never called.
1052        let (start, interval) = test_interval();
1053        let counter = AtomicUsize::new(0);
1054        let a = make_window_detector(|_t: Time<Tai>| -1.0, TimeDelta::from_seconds(1));
1055        let b = EventsToIntervals::new(RootFindingDetector::new(
1056            CountingDetectFn {
1057                inner: move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin(),
1058                counter: &counter,
1059            },
1060            TimeDelta::from_seconds(1),
1061        ));
1062        let windows = a.chain(b).detect(interval).unwrap();
1063        assert!(windows.is_empty());
1064        assert_eq!(counter.load(Ordering::Relaxed), 0);
1065    }
1066
1067    // -----------------------------------------------------------------------
1068    // Boxed IntervalDetector tests
1069    // -----------------------------------------------------------------------
1070
1071    #[test]
1072    fn test_boxed_interval_detector() {
1073        let (start, interval) = test_interval();
1074        let det: Box<dyn IntervalDetector<Tai>> = Box::new(sin_detector(start));
1075        let windows = det.detect(interval).unwrap();
1076        // sin > 0 on (0, PI) and (TAU, 7)
1077        assert_eq!(windows.len(), 2);
1078        assert_approx_eq!(
1079            windows[0].end(),
1080            start + TimeDelta::from_seconds_f64(PI),
1081            rtol <= 1e-4
1082        );
1083    }
1084
1085    #[test]
1086    fn test_boxed_send_interval_detector() {
1087        let (start, interval) = test_interval();
1088        let det: Box<dyn IntervalDetector<Tai> + Send> = Box::new(sin_detector(start));
1089        let windows = det.detect(interval).unwrap();
1090        assert_eq!(windows.len(), 2);
1091    }
1092
1093    #[test]
1094    fn test_boxed_dynamic_fold() {
1095        // Fold multiple detectors via Box<dyn IntervalDetector>, simulating
1096        // the pattern used in VisibilityAnalysis for occulting bodies.
1097        let (start, interval) = test_interval();
1098
1099        // sin > 0 on (0, PI) ∪ (TAU, 7)
1100        // cos > 0 on [0, PI/2) ∪ (3PI/2, 7]
1101        // intersection: [0, PI/2) ∪ (TAU, 7] (approximately)
1102        let detectors: Vec<Box<dyn IntervalDetector<Tai>>> =
1103            vec![Box::new(sin_detector(start)), Box::new(cos_detector(start))];
1104
1105        let mut combined: Box<dyn IntervalDetector<Tai>> = detectors.into_iter().next().unwrap();
1106
1107        let det = Box::new(cos_detector(start)) as Box<dyn IntervalDetector<Tai>>;
1108        combined = Box::new(combined.intersect(det));
1109
1110        let windows = combined.detect(interval).unwrap();
1111        assert_eq!(windows.len(), 2);
1112        // First window should end around PI/2
1113        assert_approx_eq!(
1114            windows[0].end(),
1115            start + TimeDelta::from_seconds_f64(PI / 2.0),
1116            rtol <= 1e-4
1117        );
1118    }
1119
1120    // -----------------------------------------------------------------------
1121    // Convenience function tests
1122    // -----------------------------------------------------------------------
1123
1124    #[test]
1125    fn test_find_events() {
1126        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
1127        let end = start + TimeDelta::from_seconds(7);
1128        let interval = TimeInterval::new(start, end);
1129        let events = find_events(
1130            |t: Time<Tai>| (t - start).to_seconds().to_f64().sin(),
1131            interval,
1132            TimeDelta::from_seconds(1),
1133        )
1134        .unwrap();
1135        assert_eq!(events.len(), 2);
1136    }
1137
1138    #[test]
1139    fn test_try_find_events() {
1140        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
1141        let end = start + TimeDelta::from_seconds(7);
1142        let interval = TimeInterval::new(start, end);
1143        let events = try_find_events(
1144            |t: Time<Tai>| {
1145                Ok::<f64, std::convert::Infallible>((t - start).to_seconds().to_f64().sin())
1146            },
1147            interval,
1148            TimeDelta::from_seconds(1),
1149        )
1150        .unwrap();
1151        assert_eq!(events.len(), 2);
1152    }
1153
1154    #[test]
1155    fn test_find_windows() {
1156        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
1157        let end = start + TimeDelta::from_seconds(7);
1158        let interval = TimeInterval::new(start, end);
1159        let windows = find_windows(
1160            |t: Time<Tai>| (t - start).to_seconds().to_f64().sin(),
1161            interval,
1162            TimeDelta::from_seconds(1),
1163        )
1164        .unwrap();
1165        assert_eq!(windows.len(), 2);
1166        assert_approx_eq!(
1167            windows[0].end(),
1168            start + TimeDelta::from_seconds_f64(PI),
1169            rtol <= 1e-4
1170        );
1171    }
1172
1173    #[test]
1174    fn test_try_find_windows() {
1175        let start = time!(Tai, 2000, 1, 1, 12).unwrap();
1176        let end = start + TimeDelta::from_seconds(7);
1177        let interval = TimeInterval::new(start, end);
1178        let windows = try_find_windows(
1179            |t: Time<Tai>| {
1180                Ok::<f64, std::convert::Infallible>((t - start).to_seconds().to_f64().sin())
1181            },
1182            interval,
1183            TimeDelta::from_seconds(1),
1184        )
1185        .unwrap();
1186        assert_eq!(windows.len(), 2);
1187    }
1188}