1use super::details::{EventArc, EventDetails, EventEdge};
20use crate::errors::{EventError, EventTrajSnafu};
21use crate::linalg::allocator::Allocator;
22use crate::linalg::DefaultAllocator;
23use crate::md::prelude::{Interpolatable, Traj};
24use crate::md::EventEvaluator;
25use crate::time::{Duration, Epoch, TimeSeries, Unit};
26use anise::almanac::Almanac;
27use log::{debug, error, info, warn};
28use rayon::prelude::*;
29use snafu::ResultExt;
30use std::iter::Iterator;
31use std::sync::mpsc::channel;
32use std::sync::Arc;
33
34impl<S: Interpolatable> Traj<S>
35where
36 DefaultAllocator: Allocator<S::VecLength> + Allocator<S::Size> + Allocator<S::Size, S::Size>,
37{
38 #[allow(clippy::identity_op)]
40 pub fn find_bracketed<E>(
41 &self,
42 start: Epoch,
43 end: Epoch,
44 event: &E,
45 almanac: Arc<Almanac>,
46 ) -> Result<EventDetails<S>, EventError>
47 where
48 E: EventEvaluator<S>,
49 {
50 let max_iter = 50;
51
52 let has_converged =
54 |xa: f64, xb: f64| (xa - xb).abs() <= event.epoch_precision().to_seconds();
55 let arrange = |a: f64, ya: f64, b: f64, yb: f64| {
56 if ya.abs() > yb.abs() {
57 (a, ya, b, yb)
58 } else {
59 (b, yb, a, ya)
60 }
61 };
62
63 let xa_e = start;
64 let xb_e = end;
65
66 let mut xa = 0.0;
68 let mut xb = (xb_e - xa_e).to_seconds();
69 let ya_state = self.at(xa_e).context(EventTrajSnafu {})?;
71 let yb_state = self.at(xb_e).context(EventTrajSnafu {})?;
72 let mut ya = event.eval(&ya_state, almanac.clone())?;
73 let mut yb = event.eval(&yb_state, almanac.clone())?;
74
75 if ya.abs() <= event.value_precision().abs() {
77 debug!(
78 "{event} -- found with |{ya}| < {} @ {xa_e}",
79 event.value_precision().abs()
80 );
81 return EventDetails::new(ya_state, ya, event, self, almanac.clone());
82 } else if yb.abs() <= event.value_precision().abs() {
83 debug!(
84 "{event} -- found with |{yb}| < {} @ {xb_e}",
85 event.value_precision().abs()
86 );
87 return EventDetails::new(yb_state, yb, event, self, almanac.clone());
88 }
89
90 let (mut xc, mut yc, mut xd) = (xa, ya, xa);
94 let mut flag = true;
95
96 for _ in 0..max_iter {
97 if ya.abs() < event.value_precision().abs() {
98 let state = self.at(xa_e + xa * Unit::Second).unwrap();
100 debug!(
101 "{event} -- found with |{ya}| < {} @ {}",
102 event.value_precision().abs(),
103 state.epoch(),
104 );
105 return EventDetails::new(state, ya, event, self, almanac.clone());
106 }
107 if yb.abs() < event.value_precision().abs() {
108 let state = self.at(xa_e + xb * Unit::Second).unwrap();
110 debug!(
111 "{event} -- found with |{yb}| < {} @ {}",
112 event.value_precision().abs(),
113 state.epoch()
114 );
115 return EventDetails::new(state, yb, event, self, almanac.clone());
116 }
117 if has_converged(xa, xb) {
118 return Err(EventError::NotFound {
120 start,
121 end,
122 event: format!("{event}"),
123 });
124 }
125 let mut s = if (ya - yc).abs() > f64::EPSILON && (yb - yc).abs() > f64::EPSILON {
126 xa * yb * yc / ((ya - yb) * (ya - yc))
127 + xb * ya * yc / ((yb - ya) * (yb - yc))
128 + xc * ya * yb / ((yc - ya) * (yc - yb))
129 } else {
130 xb - yb * (xb - xa) / (yb - ya)
131 };
132 let cond1 = (s - xb) * (s - (3.0 * xa + xb) / 4.0) > 0.0;
133 let cond2 = flag && (s - xb).abs() >= (xb - xc).abs() / 2.0;
134 let cond3 = !flag && (s - xb).abs() >= (xc - xd).abs() / 2.0;
135 let cond4 = flag && has_converged(xb, xc);
136 let cond5 = !flag && has_converged(xc, xd);
137 if cond1 || cond2 || cond3 || cond4 || cond5 {
138 s = (xa + xb) / 2.0;
139 flag = true;
140 } else {
141 flag = false;
142 }
143 let next_try = self
144 .at(xa_e + s * Unit::Second)
145 .context(EventTrajSnafu {})?;
146 let ys = event.eval(&next_try, almanac.clone())?;
147 xd = xc;
148 xc = xb;
149 yc = yb;
150 if ya * ys < 0.0 {
151 let next_try = self
153 .at(xa_e + xa * Unit::Second)
154 .context(EventTrajSnafu {})?;
155 let ya_p = event.eval(&next_try, almanac.clone())?;
156 let (_a, _ya, _b, _yb) = arrange(xa, ya_p, s, ys);
157 {
158 xa = _a;
159 ya = _ya;
160 xb = _b;
161 yb = _yb;
162 }
163 } else {
164 let next_try = self
166 .at(xa_e + xb * Unit::Second)
167 .context(EventTrajSnafu {})?;
168 let yb_p = event.eval(&next_try, almanac.clone())?;
169 let (_a, _ya, _b, _yb) = arrange(s, ys, xb, yb_p);
170 {
171 xa = _a;
172 ya = _ya;
173 xb = _b;
174 yb = _yb;
175 }
176 }
177 }
178 error!("Brent solver failed after {max_iter} iterations");
179 Err(EventError::NotFound {
180 start,
181 end,
182 event: format!("{event}"),
183 })
184 }
185
186 #[allow(clippy::identity_op)]
200 pub fn find<E>(
201 &self,
202 event: &E,
203 heuristic: Option<Duration>,
204 almanac: Arc<Almanac>,
205 ) -> Result<Vec<EventDetails<S>>, EventError>
206 where
207 E: EventEvaluator<S>,
208 {
209 let start_epoch = self.first().epoch();
210 let end_epoch = self.last().epoch();
211 if start_epoch == end_epoch {
212 return Err(EventError::NotFound {
213 start: start_epoch,
214 end: end_epoch,
215 event: format!("{event}"),
216 });
217 }
218 let heuristic = heuristic.unwrap_or((end_epoch - start_epoch) / 100);
219 info!("Searching for {event} with initial heuristic of {heuristic}");
220
221 let (sender, receiver) = channel();
222
223 let epochs: Vec<Epoch> = TimeSeries::inclusive(start_epoch, end_epoch, heuristic).collect();
224 epochs.into_par_iter().for_each_with(sender, |s, epoch| {
225 if let Ok(event_state) =
226 self.find_bracketed(epoch, epoch + heuristic, event, almanac.clone())
227 {
228 s.send(event_state).unwrap()
229 };
230 });
231
232 let mut states: Vec<_> = receiver.iter().collect();
233
234 if states.is_empty() {
235 warn!("Heuristic failed to find any {event} event, using slower approach");
236 match self.find_minmax(event, Unit::Second, almanac.clone()) {
239 Ok((min_event, max_event)) => {
240 let lower_min_epoch =
241 if min_event.epoch() - 1 * Unit::Millisecond < self.first().epoch() {
242 self.first().epoch()
243 } else {
244 min_event.epoch() - 1 * Unit::Millisecond
245 };
246
247 let lower_max_epoch =
248 if min_event.epoch() + 1 * Unit::Millisecond > self.last().epoch() {
249 self.last().epoch()
250 } else {
251 min_event.epoch() + 1 * Unit::Millisecond
252 };
253
254 let upper_min_epoch =
255 if max_event.epoch() - 1 * Unit::Millisecond < self.first().epoch() {
256 self.first().epoch()
257 } else {
258 max_event.epoch() - 1 * Unit::Millisecond
259 };
260
261 let upper_max_epoch =
262 if max_event.epoch() + 1 * Unit::Millisecond > self.last().epoch() {
263 self.last().epoch()
264 } else {
265 max_event.epoch() + 1 * Unit::Millisecond
266 };
267
268 if let Ok(event_state) = self.find_bracketed(
270 lower_min_epoch,
271 lower_max_epoch,
272 event,
273 almanac.clone(),
274 ) {
275 states.push(event_state);
276 };
277
278 if let Ok(event_state) = self.find_bracketed(
280 upper_min_epoch,
281 upper_max_epoch,
282 event,
283 almanac.clone(),
284 ) {
285 states.push(event_state);
286 };
287
288 if states.is_empty() {
290 return Err(EventError::NotFound {
291 start: start_epoch,
292 end: end_epoch,
293 event: format!("{event}"),
294 });
295 }
296 }
297 Err(_) => {
298 return Err(EventError::NotFound {
299 start: start_epoch,
300 end: end_epoch,
301 event: format!("{event}"),
302 });
303 }
304 };
305 }
306 states.sort_by(|s1, s2| s1.state.epoch().partial_cmp(&s2.state.epoch()).unwrap());
308 states.dedup();
309
310 match states.len() {
311 0 => info!("Event {event} not found"),
312 1 => info!("Event {event} found once on {}", states[0].state.epoch()),
313 _ => {
314 info!(
315 "Event {event} found {} times from {} until {}",
316 states.len(),
317 states.first().unwrap().state.epoch(),
318 states.last().unwrap().state.epoch()
319 )
320 }
321 };
322
323 Ok(states)
324 }
325
326 #[allow(clippy::identity_op)]
328 pub fn find_minmax<E>(
329 &self,
330 event: &E,
331 precision: Unit,
332 almanac: Arc<Almanac>,
333 ) -> Result<(S, S), EventError>
334 where
335 E: EventEvaluator<S>,
336 {
337 let step: Duration = 1 * precision;
338 let mut min_val = f64::INFINITY;
339 let mut max_val = f64::NEG_INFINITY;
340 let mut min_state = S::zeros();
341 let mut max_state = S::zeros();
342
343 let (sender, receiver) = channel();
344
345 let epochs: Vec<Epoch> =
346 TimeSeries::inclusive(self.first().epoch(), self.last().epoch(), step).collect();
347
348 epochs.into_par_iter().for_each_with(sender, |s, epoch| {
349 let state = self.at(epoch).unwrap();
351 if let Ok(this_eval) = event.eval(&state, almanac.clone()) {
352 s.send((this_eval, state)).unwrap();
353 }
354 });
355
356 let evald_states: Vec<_> = receiver.iter().collect();
357 for (this_eval, state) in evald_states {
358 if this_eval < min_val {
359 min_val = this_eval;
360 min_state = state;
361 }
362 if this_eval > max_val {
363 max_val = this_eval;
364 max_state = state;
365 }
366 }
367
368 Ok((min_state, max_state))
369 }
370
371 pub fn find_arcs<E>(
396 &self,
397 event: &E,
398 heuristic: Option<Duration>,
399 almanac: Arc<Almanac>,
400 ) -> Result<Vec<EventArc<S>>, EventError>
401 where
402 E: EventEvaluator<S>,
403 {
404 let mut events = match self.find(event, heuristic, almanac.clone()) {
405 Ok(events) => events,
406 Err(_) => {
407 let first_eval = event.eval(self.first(), almanac.clone())?;
410 let last_eval = event.eval(self.last(), almanac.clone())?;
411 if first_eval > 0.0 && last_eval > 0.0 {
412 vec![
416 EventDetails::new(*self.first(), first_eval, event, self, almanac.clone())?,
417 EventDetails::new(*self.last(), last_eval, event, self, almanac.clone())?,
418 ]
419 } else {
420 return Err(EventError::NotFound {
421 start: self.first().epoch(),
422 end: self.last().epoch(),
423 event: format!("{event}"),
424 });
425 }
426 }
427 };
428 events.sort_by_key(|event| event.state.epoch());
429
430 let mut arcs = Vec::new();
432
433 if events.is_empty() {
434 return Ok(arcs);
435 }
436
437 let mut prev_rise = if events[0].edge != EventEdge::Rising {
439 let value = event.eval(self.first(), almanac.clone())?;
440 Some(EventDetails::new(
441 *self.first(),
442 value,
443 event,
444 self,
445 almanac.clone(),
446 )?)
447 } else {
448 Some(events[0].clone())
449 };
450
451 let mut prev_fall = if events[0].edge == EventEdge::Falling {
452 Some(events[0].clone())
453 } else {
454 None
455 };
456
457 for event in events {
458 if event.edge == EventEdge::Rising {
459 if prev_rise.is_none() && prev_fall.is_none() {
460 prev_rise = Some(event.clone());
462 } else if prev_fall.is_some() {
463 if prev_rise.is_some() {
465 let arc = EventArc {
466 rise: prev_rise.clone().unwrap(),
467 fall: prev_fall.clone().unwrap(),
468 };
469 arcs.push(arc);
470 } else {
471 let arc = EventArc {
472 rise: event.clone(),
473 fall: prev_fall.clone().unwrap(),
474 };
475 arcs.push(arc);
476 }
477 prev_fall = None;
478 prev_rise = Some(event.clone());
480 }
481 } else if event.edge == EventEdge::Falling {
482 prev_fall = Some(event.clone());
483 }
484 }
485
486 if prev_rise.is_some() {
488 if prev_fall.is_some() {
489 let arc = EventArc {
490 rise: prev_rise.clone().unwrap(),
491 fall: prev_fall.clone().unwrap(),
492 };
493 arcs.push(arc);
494 } else {
495 let value = event.eval(self.last(), almanac.clone())?;
497 let fall = EventDetails::new(*self.last(), value, event, self, almanac.clone())?;
498 let arc = EventArc {
499 rise: prev_rise.clone().unwrap(),
500 fall,
501 };
502 arcs.push(arc);
503 }
504 }
505
506 Ok(arcs)
507 }
508}