use crate::{Interval, IntervalError, PosteriorDistribution};
use num_traits::{Float, FromPrimitive};
#[derive(Clone, Debug)]
pub(crate) struct SupportSet<T> {
active_intervals: Vec<Interval<T>>,
}
impl<T> SupportSet<T>
where
T: Float + FromPrimitive,
{
pub(crate) fn new(posterior: &PosteriorDistribution<T>) -> Result<Self, IntervalError<T>> {
let mut support_set = SupportSet {
active_intervals: vec![],
};
support_set.recompute(posterior)?;
Ok(support_set)
}
pub fn contains(&self, x: T) -> bool {
self.active_intervals
.iter()
.any(|interval| interval.lower() <= x && x <= interval.upper())
}
pub fn widest_interval_midpoint(&self) -> Option<T> {
let two = T::one() + T::one();
self.active_intervals
.iter()
.max_by(|a, b| {
let wa = a.upper() - a.lower();
let wb = b.upper() - b.lower();
wa.partial_cmp(&wb).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|interval| (interval.lower() + interval.upper()) / two)
}
pub(crate) fn recompute(
&mut self,
posterior: &PosteriorDistribution<T>,
) -> Result<(), IntervalError<T>> {
let eps = T::from_f64(1e-12).unwrap();
let log_eps = eps.ln();
self.active_intervals.clear();
let mut i = 0;
while i < posterior.log_interval_mass.len() {
if posterior.log_interval_mass[i] > log_eps {
let start = i;
while i < posterior.log_interval_mass.len()
&& posterior.log_interval_mass[i] > log_eps
{
i += 1;
}
let end = i - 1;
self.active_intervals.push(Interval::new(
posterior.knots[start],
posterior.knots[end + 1],
)?);
} else {
i += 1;
}
}
Ok(())
}
}