use super::{AnalysisError, Event};
use hifitime::{Epoch, Unit};
pub fn brent_solver<F>(
evaluator: F,
event: &Event,
start_epoch: Epoch,
end_epoch: Epoch,
) -> Result<Epoch, AnalysisError>
where
F: Fn(Epoch) -> Result<f64, AnalysisError>,
{
let max_iter = 50;
let has_converged = |xa: f64, xb: f64| (xa - xb).abs() <= event.epoch_precision.to_seconds();
let xa_e = start_epoch;
let xb_e = end_epoch;
let mut xa = 0.0;
let mut xb = (xb_e - xa_e).to_seconds();
let mut ya = evaluator(xa_e)?;
if ya.abs() <= f64::EPSILON {
return Ok(xa_e);
}
let mut yb = evaluator(xb_e)?;
if yb.abs() <= f64::EPSILON {
return Ok(xb_e);
}
if ya * yb >= 0.0 {
return Err(AnalysisError::EventNotFound {
start: start_epoch,
end: end_epoch,
event: Box::new(event.clone()),
});
}
let (mut xc, mut yc, mut xd) = (xa, ya, xa);
let mut flag = true;
for _ in 0..max_iter {
if has_converged(xa, xb) {
return Ok(xa_e + xb * Unit::Second);
}
let mut s = if (ya - yc).abs() > f64::EPSILON && (yb - yc).abs() > f64::EPSILON {
xa * yb * yc / ((ya - yb) * (ya - yc))
+ xb * ya * yc / ((yb - ya) * (yb - yc))
+ xc * ya * yb / ((yc - ya) * (yc - yb))
} else {
xb - yb * (xb - xa) / (yb - ya)
};
let cond1 = (s - xb) * (s - (3.0 * xa + xb) / 4.0) > 0.0;
let cond2 = flag && (s - xb).abs() >= (xb - xc).abs() / 2.0;
let cond3 = !flag && (s - xb).abs() >= (xc - xd).abs() / 2.0;
let cond4 = flag && has_converged(xb, xc);
let cond5 = !flag && has_converged(xc, xd);
if cond1 || cond2 || cond3 || cond4 || cond5 {
s = (xa + xb) / 2.0;
flag = true;
} else {
flag = false;
}
let ys = evaluator(xa_e + s * Unit::Second)?;
if ys.abs() <= f64::EPSILON {
return Ok(xa_e + s * Unit::Second);
}
xd = xc;
xc = xb;
yc = yb;
if ya * ys < 0.0 {
xb = s;
yb = ys;
} else {
xa = s;
ya = ys;
}
if ya.abs() < yb.abs() {
std::mem::swap(&mut xa, &mut xb);
std::mem::swap(&mut ya, &mut yb);
}
}
Err(AnalysisError::EventNotFound {
start: start_epoch,
end: end_epoch,
event: Box::new(event.clone()),
})
}
pub fn adaptive_step_scanner<F>(
evaluator: F,
event: &Event,
start_epoch: Epoch,
end_epoch: Epoch,
) -> Result<Vec<(Epoch, Epoch)>, AnalysisError>
where
F: Fn(Epoch) -> Result<f64, AnalysisError>,
{
let min_step = event.epoch_precision;
let max_step = min_step * 10_000;
let mut brackets = Vec::new();
let mut y_prev = evaluator(start_epoch)?;
let mut t = start_epoch;
let mut step = max_step;
while t < end_epoch {
let remaining = end_epoch - t;
step = step.min(remaining);
let y_next = match evaluator(t + step) {
Ok(val) => val,
Err(_) => {
break;
}
};
let delta = (y_next - y_prev).abs();
let delta_ratio = delta / step.to_seconds();
if event.scalar.is_angle() {
if y_prev.signum() != y_next.signum() && delta < 180.0 {
brackets.push((t, t + step));
}
} else {
let next_step = (step.to_seconds() / delta_ratio) * Unit::Second;
if delta_ratio > 1.1 && step >= min_step {
step = next_step;
continue;
}
if y_prev * y_next < 0.0 {
brackets.push((t, t + step));
}
step = next_step.clamp(min_step, max_step);
}
y_prev = y_next;
t += step;
}
Ok(brackets)
}