1use std::collections::VecDeque;
6use std::fmt::Display;
7
8use itertools::Itertools;
9use lox_math::roots::{BoxedError, Callback, CallbackError, FindBracketedRoot, RootFinderError};
10use lox_time::Time;
11use lox_time::deltas::TimeDelta;
12use lox_time::intervals::TimeInterval;
13use lox_time::time_scales::TimeScale;
14use thiserror::Error;
15
16#[derive(Clone, Copy, Debug, Eq, PartialEq)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub enum ZeroCrossing {
24 Up,
26 Down,
28}
29
30impl ZeroCrossing {
31 fn new(s0: f64, s1: f64) -> Option<ZeroCrossing> {
32 if s0 < 0.0 && s1 > 0.0 {
33 Some(ZeroCrossing::Up)
34 } else if s0 > 0.0 && s1 < 0.0 {
35 Some(ZeroCrossing::Down)
36 } else {
37 None
38 }
39 }
40}
41
42impl Display for ZeroCrossing {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 ZeroCrossing::Up => write!(f, "up"),
46 ZeroCrossing::Down => write!(f, "down"),
47 }
48 }
49}
50
51#[derive(Clone, Copy, Debug, Eq, PartialEq)]
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54pub struct Event<T: TimeScale> {
55 crossing: ZeroCrossing,
56 time: Time<T>,
57}
58
59impl<T: TimeScale> Event<T> {
60 pub fn new(time: Time<T>, crossing: ZeroCrossing) -> Self {
62 Self { crossing, time }
63 }
64
65 pub fn time(&self) -> Time<T>
67 where
68 T: Copy,
69 {
70 self.time
71 }
72
73 pub fn crossing(&self) -> ZeroCrossing {
75 self.crossing
76 }
77}
78
79#[derive(Debug, Error)]
85pub enum DetectError {
86 #[error(transparent)]
88 RootFinder(#[from] RootFinderError),
89 #[error(transparent)]
91 Callback(Box<dyn std::error::Error + Send + Sync>),
92}
93
94pub trait DetectFn<T: TimeScale> {
100 type Error: std::error::Error + Send + Sync + 'static;
102 fn eval(&self, time: Time<T>) -> Result<f64, Self::Error>;
104}
105
106pub trait EventDetector<T: TimeScale> {
108 fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<Event<T>>, DetectError>;
110}
111
112pub trait IntervalDetector<T: TimeScale> {
114 fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError>;
116}
117
118pub(crate) struct DetectCallback<'a, T: TimeScale, F: DetectFn<T>> {
125 func: &'a F,
126 start: Time<T>,
127}
128
129impl<T: TimeScale + Copy, F: DetectFn<T>> Clone for DetectCallback<'_, T, F> {
130 fn clone(&self) -> Self {
131 *self
132 }
133}
134
135impl<T: TimeScale + Copy, F: DetectFn<T>> Copy for DetectCallback<'_, T, F> {}
136
137impl<'a, T: TimeScale + Copy, F: DetectFn<T>> DetectCallback<'a, T, F> {
138 fn new(func: &'a F, start: Time<T>) -> Self {
139 Self { func, start }
140 }
141}
142
143impl<T: TimeScale + Copy, F: DetectFn<T>> Callback for DetectCallback<'_, T, F> {
144 fn call(&self, v: f64) -> Result<f64, CallbackError> {
145 let time = self.start + TimeDelta::from_seconds_f64(v);
146 self.func
147 .eval(time)
148 .map_err(|e| CallbackError::from(Box::new(e) as BoxedError))
149 }
150}
151
152use lox_math::roots::Brent;
157
158pub struct RootFindingDetector<F, R = Brent> {
160 pub(crate) func: F,
161 root_finder: R,
162 step: TimeDelta,
163 coarse_step: Option<TimeDelta>,
164}
165
166impl<F> RootFindingDetector<F, Brent> {
167 pub fn new(func: F, step: TimeDelta) -> Self {
169 Self {
170 func,
171 root_finder: Brent::default(),
172 step,
173 coarse_step: None,
174 }
175 }
176}
177
178impl<F, R> RootFindingDetector<F, R> {
179 pub fn with_root_finder(func: F, step: TimeDelta, root_finder: R) -> Self {
181 Self {
182 func,
183 root_finder,
184 step,
185 coarse_step: None,
186 }
187 }
188
189 pub fn with_coarse_step(mut self, coarse_step: TimeDelta) -> Self {
191 self.coarse_step = Some(coarse_step);
192 self
193 }
194}
195
196fn build_time_grid(total: f64, step: f64) -> Vec<f64> {
199 let mut grid = Vec::new();
200 let mut t = 0.0;
201 while t <= total {
202 grid.push(t);
203 t += step;
204 }
205 if grid.last().is_none_or(|&last| last < total) {
206 grid.push(total);
207 }
208 grid
209}
210
211impl<F, R> RootFindingDetector<F, R> {
212 pub(crate) fn detect_with_start_sign<T>(
218 &self,
219 interval: TimeInterval<T>,
220 ) -> Result<(Vec<Event<T>>, f64), DetectError>
221 where
222 T: TimeScale + Copy,
223 F: DetectFn<T>,
224 for<'a> R: FindBracketedRoot<DetectCallback<'a, T, F>>,
225 {
226 let start = interval.start();
227 let end = interval.end();
228 let total_seconds = (end - start).to_seconds().to_f64();
229 let step_seconds = self.step.to_seconds().to_f64();
230 let callback = DetectCallback::new(&self.func, start);
231
232 match self.coarse_step {
233 Some(coarse_step) => {
234 let coarse_seconds = coarse_step.to_seconds().to_f64();
235 self.detect_two_level(callback, start, total_seconds, step_seconds, coarse_seconds)
236 }
237 None => self.detect_single_level(callback, start, total_seconds, step_seconds),
238 }
239 }
240
241 fn detect_single_level<T>(
243 &self,
244 callback: DetectCallback<'_, T, F>,
245 start: Time<T>,
246 total_seconds: f64,
247 step_seconds: f64,
248 ) -> Result<(Vec<Event<T>>, f64), DetectError>
249 where
250 T: TimeScale + Copy,
251 F: DetectFn<T>,
252 for<'a> R: FindBracketedRoot<DetectCallback<'a, T, F>>,
253 {
254 let steps = build_time_grid(total_seconds, step_seconds);
255
256 let mut signs = Vec::with_capacity(steps.len());
257 for &t in &steps {
258 let v = callback
259 .call(t)
260 .map_err(|e| DetectError::RootFinder(RootFinderError::Callback(e)))?;
261 signs.push(v.signum());
262 }
263
264 let start_sign = signs[0];
265
266 if signs.iter().all(|&s| s < 0.0) || signs.iter().all(|&s| s > 0.0) {
267 return Ok((vec![], start_sign));
268 }
269
270 let mut events = Vec::new();
271 for ((&t0, &s0), (&t1, &s1)) in std::iter::zip(&steps, &signs).tuple_windows() {
272 if let Some(crossing) = ZeroCrossing::new(s0, s1) {
273 let t = self
274 .root_finder
275 .find_in_bracket(callback, (t0, t1))
276 .map_err(DetectError::RootFinder)?;
277 let time = start + TimeDelta::from_seconds_f64(t);
278 events.push(Event { crossing, time });
279 }
280 }
281
282 Ok((events, start_sign))
283 }
284
285 fn detect_two_level<T>(
288 &self,
289 callback: DetectCallback<'_, T, F>,
290 start: Time<T>,
291 total_seconds: f64,
292 step_seconds: f64,
293 coarse_seconds: f64,
294 ) -> Result<(Vec<Event<T>>, f64), DetectError>
295 where
296 T: TimeScale + Copy,
297 F: DetectFn<T>,
298 for<'a> R: FindBracketedRoot<DetectCallback<'a, T, F>>,
299 {
300 let coarse_grid = build_time_grid(total_seconds, coarse_seconds);
302 let mut coarse_signs = Vec::with_capacity(coarse_grid.len());
303 for &t in &coarse_grid {
304 let v = callback
305 .call(t)
306 .map_err(|e| DetectError::RootFinder(RootFinderError::Callback(e)))?;
307 coarse_signs.push(v.signum());
308 }
309
310 let start_sign = coarse_signs[0];
311
312 let mut events = Vec::new();
314 for ((&tc0, &sc0), (&tc1, &sc1)) in
315 std::iter::zip(&coarse_grid, &coarse_signs).tuple_windows()
316 {
317 if ZeroCrossing::new(sc0, sc1).is_none() {
318 continue;
319 }
320
321 let bracket_len = tc1 - tc0;
324 let fine_grid = build_time_grid(bracket_len, step_seconds);
325
326 let mut fine_times = Vec::with_capacity(fine_grid.len());
327 let mut fine_signs = Vec::with_capacity(fine_grid.len());
328
329 fine_times.push(tc0);
331 fine_signs.push(sc0);
332
333 for &ft in &fine_grid[1..] {
335 let abs_t = tc0 + ft;
336 fine_times.push(abs_t);
337 let v = callback
338 .call(abs_t)
339 .map_err(|e| DetectError::RootFinder(RootFinderError::Callback(e)))?;
340 fine_signs.push(v.signum());
341 }
342
343 for ((&t0, &s0), (&t1, &s1)) in std::iter::zip(&fine_times, &fine_signs).tuple_windows()
345 {
346 if let Some(crossing) = ZeroCrossing::new(s0, s1) {
347 let t = self
348 .root_finder
349 .find_in_bracket(callback, (t0, t1))
350 .map_err(DetectError::RootFinder)?;
351 let time = start + TimeDelta::from_seconds_f64(t);
352 events.push(Event { crossing, time });
353 }
354 }
355 }
356
357 Ok((events, start_sign))
358 }
359}
360
361impl<T, F, R> EventDetector<T> for RootFindingDetector<F, R>
362where
363 T: TimeScale + Copy,
364 F: DetectFn<T>,
365 for<'a> R: FindBracketedRoot<DetectCallback<'a, T, F>>,
366{
367 fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<Event<T>>, DetectError> {
368 self.detect_with_start_sign(interval)
369 .map(|(events, _)| events)
370 }
371}
372
373pub struct EventsToIntervals<F, R = Brent> {
384 detector: RootFindingDetector<F, R>,
385}
386
387impl<F> EventsToIntervals<F, Brent> {
388 pub fn new(detector: RootFindingDetector<F>) -> Self {
390 Self { detector }
391 }
392}
393
394impl<F, R> EventsToIntervals<F, R> {
395 pub fn with_root_finder(detector: RootFindingDetector<F, R>) -> Self {
397 Self { detector }
398 }
399}
400
401impl<T, F, R> IntervalDetector<T> for EventsToIntervals<F, R>
402where
403 T: TimeScale + Copy,
404 F: DetectFn<T>,
405 for<'a> R: FindBracketedRoot<DetectCallback<'a, T, F>>,
406{
407 fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
408 let start = interval.start();
409 let end = interval.end();
410
411 let (events, start_sign) = self.detector.detect_with_start_sign(interval)?;
412 if events.is_empty() {
413 return if start_sign >= 0.0 {
417 Ok(vec![interval])
418 } else {
419 Ok(vec![])
420 };
421 }
422
423 let mut events: VecDeque<Event<T>> = events.into();
424
425 if events.front().unwrap().crossing == ZeroCrossing::Down {
426 events.push_front(Event {
427 crossing: ZeroCrossing::Up,
428 time: start,
429 });
430 }
431
432 if events.back().unwrap().crossing == ZeroCrossing::Up {
433 events.push_back(Event {
434 crossing: ZeroCrossing::Down,
435 time: end,
436 });
437 }
438
439 let mut intervals = Vec::with_capacity(events.len() / 2);
440 for (up, down) in events.into_iter().tuples() {
441 debug_assert!(up.crossing == ZeroCrossing::Up);
442 debug_assert!(down.crossing == ZeroCrossing::Down);
443 intervals.push(TimeInterval::new(up.time, down.time));
444 }
445
446 Ok(intervals)
447 }
448}
449
450pub struct Intersection<A, B> {
456 a: A,
457 b: B,
458}
459
460impl<T, A, B> IntervalDetector<T> for Intersection<A, B>
461where
462 T: TimeScale + Ord + Copy,
463 A: IntervalDetector<T>,
464 B: IntervalDetector<T>,
465{
466 fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
467 let a = self.a.detect(interval)?;
468 let b = self.b.detect(interval)?;
469 Ok(lox_time::intervals::intersect_intervals(&a, &b))
470 }
471}
472
473pub struct Union<A, B> {
475 a: A,
476 b: B,
477}
478
479impl<T, A, B> IntervalDetector<T> for Union<A, B>
480where
481 T: TimeScale + Ord + Copy,
482 A: IntervalDetector<T>,
483 B: IntervalDetector<T>,
484{
485 fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
486 let a = self.a.detect(interval)?;
487 let b = self.b.detect(interval)?;
488 Ok(lox_time::intervals::union_intervals(&a, &b))
489 }
490}
491
492pub struct Complement<D> {
494 detector: D,
495}
496
497impl<T, D> IntervalDetector<T> for Complement<D>
498where
499 T: TimeScale + Ord + Copy,
500 D: IntervalDetector<T>,
501{
502 fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
503 let inner = self.detector.detect(interval)?;
504 Ok(lox_time::intervals::complement_intervals(&inner, interval))
505 }
506}
507
508pub struct Chain<A, B> {
510 a: A,
511 b: B,
512}
513
514impl<T, A, B> IntervalDetector<T> for Chain<A, B>
515where
516 T: TimeScale + Copy,
517 A: IntervalDetector<T>,
518 B: IntervalDetector<T>,
519{
520 fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
521 let a_intervals = self.a.detect(interval)?;
522 let mut result = Vec::new();
523 for sub in a_intervals {
524 result.extend(self.b.detect(sub)?);
525 }
526 Ok(result)
527 }
528}
529
530pub trait IntervalDetectorExt<T: TimeScale>: IntervalDetector<T> + Sized {
536 fn intersect<B>(self, other: B) -> Intersection<Self, B> {
538 Intersection { a: self, b: other }
539 }
540
541 fn union<B>(self, other: B) -> Union<Self, B> {
543 Union { a: self, b: other }
544 }
545
546 fn complement(self) -> Complement<Self> {
548 Complement { detector: self }
549 }
550
551 fn chain<B>(self, other: B) -> Chain<Self, B> {
553 Chain { a: self, b: other }
554 }
555}
556
557impl<T: TimeScale, D: IntervalDetector<T>> IntervalDetectorExt<T> for D {}
558
559impl<T: TimeScale> IntervalDetector<T> for Box<dyn IntervalDetector<T> + '_> {
564 fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
565 (**self).detect(interval)
566 }
567}
568
569impl<T: TimeScale> IntervalDetector<T> for Box<dyn IntervalDetector<T> + Send + '_> {
570 fn detect(&self, interval: TimeInterval<T>) -> Result<Vec<TimeInterval<T>>, DetectError> {
571 (**self).detect(interval)
572 }
573}
574
575pub struct FnDetect<F>(pub F);
581
582impl<T, F> DetectFn<T> for FnDetect<F>
583where
584 T: TimeScale + Copy,
585 F: Fn(Time<T>) -> f64,
586{
587 type Error = std::convert::Infallible;
588 fn eval(&self, time: Time<T>) -> Result<f64, Self::Error> {
589 Ok((self.0)(time))
590 }
591}
592
593pub struct TryFnDetect<F>(pub F);
595
596impl<T, F, E> DetectFn<T> for TryFnDetect<F>
597where
598 T: TimeScale + Copy,
599 F: Fn(Time<T>) -> Result<f64, E>,
600 E: std::error::Error + Send + Sync + 'static,
601{
602 type Error = E;
603 fn eval(&self, time: Time<T>) -> Result<f64, Self::Error> {
604 (self.0)(time)
605 }
606}
607
608pub fn find_events<T, F>(
614 func: F,
615 interval: TimeInterval<T>,
616 step: TimeDelta,
617) -> Result<Vec<Event<T>>, DetectError>
618where
619 T: TimeScale + Copy,
620 F: Fn(Time<T>) -> f64,
621{
622 RootFindingDetector::new(FnDetect(func), step).detect(interval)
623}
624
625pub fn try_find_events<T, F, E>(
627 func: F,
628 interval: TimeInterval<T>,
629 step: TimeDelta,
630) -> Result<Vec<Event<T>>, DetectError>
631where
632 T: TimeScale + Copy,
633 F: Fn(Time<T>) -> Result<f64, E>,
634 E: std::error::Error + Send + Sync + 'static,
635{
636 RootFindingDetector::new(TryFnDetect(func), step).detect(interval)
637}
638
639pub fn find_windows<T, F>(
641 func: F,
642 interval: TimeInterval<T>,
643 step: TimeDelta,
644) -> Result<Vec<TimeInterval<T>>, DetectError>
645where
646 T: TimeScale + Copy,
647 F: Fn(Time<T>) -> f64,
648{
649 let detector = RootFindingDetector::new(FnDetect(func), step);
650 EventsToIntervals::new(detector).detect(interval)
651}
652
653pub fn try_find_windows<T, F, E>(
655 func: F,
656 interval: TimeInterval<T>,
657 step: TimeDelta,
658) -> Result<Vec<TimeInterval<T>>, DetectError>
659where
660 T: TimeScale + Copy,
661 F: Fn(Time<T>) -> Result<f64, E>,
662 E: std::error::Error + Send + Sync + 'static,
663{
664 let detector = RootFindingDetector::new(TryFnDetect(func), step);
665 EventsToIntervals::new(detector).detect(interval)
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671 use lox_test_utils::assert_approx_eq;
672 use lox_time::time;
673 use lox_time::time_scales::Tai;
674 use std::f64::consts::{PI, TAU};
675 use std::sync::atomic::{AtomicUsize, Ordering};
676
677 struct CountingDetectFn<'a, F> {
679 inner: F,
680 counter: &'a AtomicUsize,
681 }
682
683 impl<'a, T, F> DetectFn<T> for CountingDetectFn<'a, F>
684 where
685 T: TimeScale + Copy,
686 F: Fn(Time<T>) -> f64,
687 {
688 type Error = std::convert::Infallible;
689 fn eval(&self, time: Time<T>) -> Result<f64, Self::Error> {
690 self.counter.fetch_add(1, Ordering::Relaxed);
691 Ok((self.inner)(time))
692 }
693 }
694
695 #[test]
696 fn test_events() {
697 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
698 let end = start + TimeDelta::from_seconds(7);
699 let interval = TimeInterval::new(start, end);
700
701 let detect_fn = FnDetect(|t: Time<Tai>| (t - start).to_seconds().to_f64().sin());
702 let detector = RootFindingDetector::new(detect_fn, TimeDelta::from_seconds(1));
703 let events = detector.detect(interval).unwrap();
704
705 assert_eq!(events.len(), 2);
706 assert_eq!(events[0].crossing, ZeroCrossing::Down);
707 assert_approx_eq!(
708 events[0].time,
709 start + TimeDelta::from_seconds_f64(PI),
710 rtol <= 1e-6
711 );
712 assert_eq!(events[1].crossing, ZeroCrossing::Up);
713 assert_approx_eq!(
714 events[1].time,
715 start + TimeDelta::from_seconds_f64(TAU),
716 rtol <= 1e-6
717 );
718 }
719
720 #[test]
721 fn test_windows() {
722 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
723 let end = start + TimeDelta::from_seconds(7);
724 let interval = TimeInterval::new(start, end);
725
726 let detect_fn = FnDetect(|t: Time<Tai>| (t - start).to_seconds().to_f64().sin());
727 let detector = RootFindingDetector::new(detect_fn, TimeDelta::from_seconds(1));
728 let intervals_detector = EventsToIntervals::new(detector);
729 let windows = intervals_detector.detect(interval).unwrap();
730
731 assert_eq!(windows.len(), 2);
732 assert_eq!(windows[0].start(), start);
733 assert_approx_eq!(
734 windows[0].end(),
735 start + TimeDelta::from_seconds_f64(PI),
736 rtol <= 1e-6
737 );
738 }
739
740 #[test]
741 fn test_windows_no_windows() {
742 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
743 let end = start + TimeDelta::from_seconds(7);
744 let interval = TimeInterval::new(start, end);
745
746 let detect_fn = FnDetect(|_t: Time<Tai>| -1.0);
747 let detector = RootFindingDetector::new(detect_fn, TimeDelta::from_seconds(1));
748 let intervals_detector = EventsToIntervals::new(detector);
749 let windows = intervals_detector.detect(interval).unwrap();
750
751 assert!(windows.is_empty());
752 }
753
754 #[test]
755 fn test_windows_full_coverage() {
756 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
757 let end = start + TimeDelta::from_seconds(7);
758 let interval = TimeInterval::new(start, end);
759
760 let detect_fn = FnDetect(|_t: Time<Tai>| 1.0);
761 let detector = RootFindingDetector::new(detect_fn, TimeDelta::from_seconds(1));
762 let intervals_detector = EventsToIntervals::new(detector);
763 let windows = intervals_detector.detect(interval).unwrap();
764
765 assert_eq!(windows.len(), 1);
766 assert_eq!(windows[0].start(), start);
767 assert_eq!(windows[0].end(), end);
768 }
769
770 #[test]
775 fn test_two_level_matches_single_level() {
776 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
779 let end = start + TimeDelta::from_seconds(7);
780 let interval = TimeInterval::new(start, end);
781
782 let single = RootFindingDetector::new(
783 FnDetect(move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin()),
784 TimeDelta::from_seconds(1),
785 )
786 .detect(interval)
787 .unwrap();
788
789 let two_level = RootFindingDetector::new(
790 FnDetect(move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin()),
791 TimeDelta::from_seconds(1),
792 )
793 .with_coarse_step(TimeDelta::from_seconds(3))
794 .detect(interval)
795 .unwrap();
796
797 assert_eq!(single.len(), two_level.len());
798 for (s, tl) in single.iter().zip(&two_level) {
799 assert_eq!(s.crossing, tl.crossing);
800 assert_approx_eq!(s.time, tl.time, rtol <= 1e-6);
801 }
802 }
803
804 #[test]
805 fn test_two_level_multiple_crossings_in_bracket() {
806 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
811 let end = start + TimeDelta::from_seconds(10);
812 let interval = TimeInterval::new(start, end);
813
814 let func = move |t: Time<Tai>| ((t - start).to_seconds().to_f64() + 0.5).sin();
815
816 let single = RootFindingDetector::new(FnDetect(func), TimeDelta::from_seconds(1))
817 .detect(interval)
818 .unwrap();
819
820 let two_level = RootFindingDetector::new(FnDetect(func), TimeDelta::from_seconds(1))
821 .with_coarse_step(TimeDelta::from_seconds(10))
822 .detect(interval)
823 .unwrap();
824
825 assert_eq!(single.len(), 3, "expected 3 crossings");
826 assert_eq!(single.len(), two_level.len());
827 for (s, tl) in single.iter().zip(&two_level) {
828 assert_eq!(s.crossing, tl.crossing);
829 assert_approx_eq!(s.time, tl.time, rtol <= 1e-6);
830 }
831 }
832
833 #[test]
834 fn test_two_level_no_events() {
835 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
837 let end = start + TimeDelta::from_seconds(10);
838 let interval = TimeInterval::new(start, end);
839
840 let det =
841 RootFindingDetector::new(FnDetect(|_t: Time<Tai>| -1.0), TimeDelta::from_seconds(1))
842 .with_coarse_step(TimeDelta::from_seconds(3));
843
844 let (events, start_sign) = det.detect_with_start_sign(interval).unwrap();
845 assert!(events.is_empty());
846 assert!(start_sign < 0.0);
847 }
848
849 #[test]
850 fn test_two_level_windows_roundtrip() {
851 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
854 let end = start + TimeDelta::from_seconds(7);
855 let interval = TimeInterval::new(start, end);
856
857 let single_windows = EventsToIntervals::new(RootFindingDetector::new(
858 FnDetect(move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin()),
859 TimeDelta::from_seconds(1),
860 ))
861 .detect(interval)
862 .unwrap();
863
864 let two_level_windows = EventsToIntervals::new(
865 RootFindingDetector::new(
866 FnDetect(move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin()),
867 TimeDelta::from_seconds(1),
868 )
869 .with_coarse_step(TimeDelta::from_seconds(3)),
870 )
871 .detect(interval)
872 .unwrap();
873
874 assert_eq!(single_windows.len(), two_level_windows.len());
875 for (s, tl) in single_windows.iter().zip(&two_level_windows) {
876 assert_approx_eq!(s.start(), tl.start(), rtol <= 1e-6);
877 assert_approx_eq!(s.end(), tl.end(), rtol <= 1e-6);
878 }
879 }
880
881 #[test]
882 fn test_two_level_eval_count_reduction() {
883 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
886 let end = start + TimeDelta::from_seconds(1000);
887 let interval = TimeInterval::new(start, end);
888
889 let func = move |t: Time<Tai>| ((t - start).to_seconds().to_f64() / 100.0).sin();
891
892 let counter_single = AtomicUsize::new(0);
893 let single = RootFindingDetector::new(
894 CountingDetectFn {
895 inner: func,
896 counter: &counter_single,
897 },
898 TimeDelta::from_seconds(1),
899 )
900 .detect(interval)
901 .unwrap();
902
903 let counter_two = AtomicUsize::new(0);
904 let two_level = RootFindingDetector::new(
906 CountingDetectFn {
907 inner: func,
908 counter: &counter_two,
909 },
910 TimeDelta::from_seconds(1),
911 )
912 .with_coarse_step(TimeDelta::from_seconds(150))
913 .detect(interval)
914 .unwrap();
915
916 assert_eq!(single.len(), two_level.len());
918
919 let single_evals = counter_single.load(Ordering::Relaxed);
920 let two_level_evals = counter_two.load(Ordering::Relaxed);
921
922 assert!(
924 two_level_evals < single_evals,
925 "two-level ({two_level_evals}) should use fewer evals than single-level ({single_evals})"
926 );
927 }
928
929 fn make_window_detector<F: Fn(Time<Tai>) -> f64>(
935 func: F,
936 step: TimeDelta,
937 ) -> EventsToIntervals<FnDetect<F>> {
938 EventsToIntervals::new(RootFindingDetector::new(FnDetect(func), step))
939 }
940
941 fn sin_detector(start: Time<Tai>) -> EventsToIntervals<FnDetect<impl Fn(Time<Tai>) -> f64>> {
943 make_window_detector(
944 move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin(),
945 TimeDelta::from_seconds(1),
946 )
947 }
948
949 fn cos_detector(start: Time<Tai>) -> EventsToIntervals<FnDetect<impl Fn(Time<Tai>) -> f64>> {
951 make_window_detector(
952 move |t: Time<Tai>| (t - start).to_seconds().to_f64().cos(),
953 TimeDelta::from_seconds(1),
954 )
955 }
956
957 fn test_interval() -> (Time<Tai>, TimeInterval<Tai>) {
958 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
959 let end = start + TimeDelta::from_seconds(7);
960 (start, TimeInterval::new(start, end))
961 }
962
963 #[test]
964 fn test_intersect() {
965 let (start, interval) = test_interval();
966 let det = sin_detector(start).intersect(cos_detector(start));
969 let windows = det.detect(interval).unwrap();
970 assert_eq!(windows.len(), 2);
971 assert_approx_eq!(windows[0].start(), start, rtol <= 1e-6);
973 assert_approx_eq!(
974 windows[0].end(),
975 start + TimeDelta::from_seconds_f64(PI / 2.0),
976 rtol <= 1e-4
977 );
978 }
979
980 #[test]
981 fn test_union() {
982 let (start, interval) = test_interval();
983 let det = sin_detector(start).union(cos_detector(start));
986 let windows = det.detect(interval).unwrap();
987 assert_eq!(windows.len(), 2);
989 assert_approx_eq!(windows[0].start(), start, rtol <= 1e-6);
991 assert_approx_eq!(
992 windows[0].end(),
993 start + TimeDelta::from_seconds_f64(PI),
994 rtol <= 1e-4
995 );
996 }
997
998 #[test]
999 fn test_complement() {
1000 let (start, interval) = test_interval();
1001 let det = sin_detector(start).complement();
1004 let windows = det.detect(interval).unwrap();
1005 assert_eq!(windows.len(), 1);
1006 assert_approx_eq!(
1007 windows[0].start(),
1008 start + TimeDelta::from_seconds_f64(PI),
1009 rtol <= 1e-4
1010 );
1011 assert_approx_eq!(
1012 windows[0].end(),
1013 start + TimeDelta::from_seconds_f64(TAU),
1014 rtol <= 1e-4
1015 );
1016 }
1017
1018 #[test]
1019 fn test_chain() {
1020 let (start, interval) = test_interval();
1021 let det = sin_detector(start).chain(cos_detector(start));
1026 let windows = det.detect(interval).unwrap();
1027 assert_eq!(windows.len(), 2);
1028 assert_approx_eq!(windows[0].start(), start, rtol <= 1e-6);
1030 assert_approx_eq!(
1031 windows[0].end(),
1032 start + TimeDelta::from_seconds_f64(PI / 2.0),
1033 rtol <= 1e-4
1034 );
1035 assert_approx_eq!(
1037 windows[1].start(),
1038 start + TimeDelta::from_seconds_f64(TAU),
1039 rtol <= 1e-4
1040 );
1041 assert_approx_eq!(
1042 windows[1].end(),
1043 start + TimeDelta::from_seconds(7),
1044 rtol <= 1e-6
1045 );
1046 }
1047
1048 #[test]
1049 fn test_chain_restricts_evaluation() {
1050 let (start, interval) = test_interval();
1053 let counter = AtomicUsize::new(0);
1054 let a = make_window_detector(|_t: Time<Tai>| -1.0, TimeDelta::from_seconds(1));
1055 let b = EventsToIntervals::new(RootFindingDetector::new(
1056 CountingDetectFn {
1057 inner: move |t: Time<Tai>| (t - start).to_seconds().to_f64().sin(),
1058 counter: &counter,
1059 },
1060 TimeDelta::from_seconds(1),
1061 ));
1062 let windows = a.chain(b).detect(interval).unwrap();
1063 assert!(windows.is_empty());
1064 assert_eq!(counter.load(Ordering::Relaxed), 0);
1065 }
1066
1067 #[test]
1072 fn test_boxed_interval_detector() {
1073 let (start, interval) = test_interval();
1074 let det: Box<dyn IntervalDetector<Tai>> = Box::new(sin_detector(start));
1075 let windows = det.detect(interval).unwrap();
1076 assert_eq!(windows.len(), 2);
1078 assert_approx_eq!(
1079 windows[0].end(),
1080 start + TimeDelta::from_seconds_f64(PI),
1081 rtol <= 1e-4
1082 );
1083 }
1084
1085 #[test]
1086 fn test_boxed_send_interval_detector() {
1087 let (start, interval) = test_interval();
1088 let det: Box<dyn IntervalDetector<Tai> + Send> = Box::new(sin_detector(start));
1089 let windows = det.detect(interval).unwrap();
1090 assert_eq!(windows.len(), 2);
1091 }
1092
1093 #[test]
1094 fn test_boxed_dynamic_fold() {
1095 let (start, interval) = test_interval();
1098
1099 let detectors: Vec<Box<dyn IntervalDetector<Tai>>> =
1103 vec![Box::new(sin_detector(start)), Box::new(cos_detector(start))];
1104
1105 let mut combined: Box<dyn IntervalDetector<Tai>> = detectors.into_iter().next().unwrap();
1106
1107 let det = Box::new(cos_detector(start)) as Box<dyn IntervalDetector<Tai>>;
1108 combined = Box::new(combined.intersect(det));
1109
1110 let windows = combined.detect(interval).unwrap();
1111 assert_eq!(windows.len(), 2);
1112 assert_approx_eq!(
1114 windows[0].end(),
1115 start + TimeDelta::from_seconds_f64(PI / 2.0),
1116 rtol <= 1e-4
1117 );
1118 }
1119
1120 #[test]
1125 fn test_find_events() {
1126 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
1127 let end = start + TimeDelta::from_seconds(7);
1128 let interval = TimeInterval::new(start, end);
1129 let events = find_events(
1130 |t: Time<Tai>| (t - start).to_seconds().to_f64().sin(),
1131 interval,
1132 TimeDelta::from_seconds(1),
1133 )
1134 .unwrap();
1135 assert_eq!(events.len(), 2);
1136 }
1137
1138 #[test]
1139 fn test_try_find_events() {
1140 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
1141 let end = start + TimeDelta::from_seconds(7);
1142 let interval = TimeInterval::new(start, end);
1143 let events = try_find_events(
1144 |t: Time<Tai>| {
1145 Ok::<f64, std::convert::Infallible>((t - start).to_seconds().to_f64().sin())
1146 },
1147 interval,
1148 TimeDelta::from_seconds(1),
1149 )
1150 .unwrap();
1151 assert_eq!(events.len(), 2);
1152 }
1153
1154 #[test]
1155 fn test_find_windows() {
1156 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
1157 let end = start + TimeDelta::from_seconds(7);
1158 let interval = TimeInterval::new(start, end);
1159 let windows = find_windows(
1160 |t: Time<Tai>| (t - start).to_seconds().to_f64().sin(),
1161 interval,
1162 TimeDelta::from_seconds(1),
1163 )
1164 .unwrap();
1165 assert_eq!(windows.len(), 2);
1166 assert_approx_eq!(
1167 windows[0].end(),
1168 start + TimeDelta::from_seconds_f64(PI),
1169 rtol <= 1e-4
1170 );
1171 }
1172
1173 #[test]
1174 fn test_try_find_windows() {
1175 let start = time!(Tai, 2000, 1, 1, 12).unwrap();
1176 let end = start + TimeDelta::from_seconds(7);
1177 let interval = TimeInterval::new(start, end);
1178 let windows = try_find_windows(
1179 |t: Time<Tai>| {
1180 Ok::<f64, std::convert::Infallible>((t - start).to_seconds().to_f64().sin())
1181 },
1182 interval,
1183 TimeDelta::from_seconds(1),
1184 )
1185 .unwrap();
1186 assert_eq!(windows.len(), 2);
1187 }
1188}