use crate::{
Interval, IntervalError, MeetError, ObjectiveSign, PBError, PosteriorDistribution,
PosteriorError, RootSide, SequentialInterval, SupportSet,
};
use confi::ConfidenceLevel;
use num_traits::{Float, FromPrimitive};
use std::ops::Range;
#[derive(thiserror::Error, Debug)]
pub enum BisectionError<T> {
#[error("Semi meet error: {0}")]
Meet(#[from] MeetError<T>),
#[error("Posterior error: {0}")]
Posterior(#[from] PosteriorError<T>),
#[error("Interval error: {0}")]
Interval(#[from] IntervalError<T>),
#[error("Empty hull...")]
EmptyHull,
}
#[derive(Clone, Debug)]
pub struct InferenceState<T> {
iter: usize,
sign_indeterminate: bool,
posterior: PosteriorDistribution<T>,
support: SupportSet<T>,
confidence: SequentialInterval<T>,
slope_sign: Option<ObjectiveSign>,
sequential_stalled: bool,
empty_meet_count: usize,
}
impl<T: std::fmt::Debug> std::fmt::Display for InferenceState<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl<T> InferenceState<T> {
pub(crate) fn new(domain: Range<T>, max_knots: usize) -> Result<Self, PBError<T>>
where
T: Float + FromPrimitive,
{
let posterior = PosteriorDistribution::new(domain.start, domain.end, max_knots)?;
let support = SupportSet::new(&posterior)?;
Ok(Self {
iter: 0,
sign_indeterminate: false,
posterior,
support,
confidence: SequentialInterval::instantiate(domain),
slope_sign: None,
sequential_stalled: false,
empty_meet_count: 0,
})
}
pub(crate) fn set_slope_sign(&mut self, sign: ObjectiveSign) {
self.slope_sign = Some(sign);
}
pub(crate) fn slope_sign(&self) -> Option<ObjectiveSign> {
self.slope_sign
}
pub(crate) fn sequential_stalled(&self) -> bool {
self.sequential_stalled
}
pub(crate) fn sign_indeterminate(&self) -> bool {
self.sign_indeterminate
}
pub(crate) fn sign_is_indeterminate(&mut self) {
self.sign_indeterminate = true;
}
pub(crate) fn posterior(&self) -> &PosteriorDistribution<T> {
&self.posterior
}
pub(crate) fn support(&self) -> &SupportSet<T> {
&self.support
}
pub(crate) fn confidence(&self) -> &SequentialInterval<T> {
&self.confidence
}
pub(crate) fn width(&self) -> T
where
T: Float,
{
self.confidence.width()
}
pub(crate) fn observe(
&mut self,
x: T,
root_side: RootSide,
conf: ConfidenceLevel<T>,
) -> Result<(), BisectionError<T>>
where
T: Float + FromPrimitive + std::iter::Sum + std::ops::AddAssign + std::fmt::Debug,
{
tracing::debug!("updating posterior");
self.posterior.observe(x, root_side, conf)?;
tracing::debug!("recomputing support");
self.support.recompute(&self.posterior)?;
let candidate = compute_snapshot(&self.posterior, conf, self.iter)?;
let _candidate_interval = SequentialInterval { current: candidate };
let (next_confidence, met) = self.confidence.clone().meet_or_keep(candidate);
self.confidence = next_confidence;
if !met {
self.sequential_stalled = true;
self.empty_meet_count += 1;
}
self.iter += 1;
Ok(())
}
}
pub fn compute_snapshot<T>(
posterior: &PosteriorDistribution<T>,
confidence: ConfidenceLevel<T>,
n: usize,
) -> Result<Interval<T>, BisectionError<T>>
where
T: Float + FromPrimitive,
{
let alpha = confidence.significance().into_inner();
let n1 = T::from_usize(n + 1).unwrap();
let c = confidence.into_inner();
let one_minus_c = T::one() - c;
let two = T::one() + T::one();
let d = c * (two * c).ln() + one_minus_c * (two * one_minus_c).ln();
let beta = (c / one_minus_c).ln();
let b = n1 * d - n1.sqrt() * (-(T::one() / two) * (alpha / two).ln()).sqrt() * beta;
let max_log_density = posterior.max_log_interval_density();
let b_shifted = b - max_log_density;
let g: Vec<usize> = (0..posterior.log_interval_mass.len())
.filter(|&i| posterior.log_interval_density(i) - max_log_density > b_shifted)
.collect();
let start = *g.first().ok_or(BisectionError::EmptyHull)?;
let end = *g.last().ok_or(BisectionError::EmptyHull)?;
let candidate = Interval::new(posterior.knots[start], posterior.knots[end + 1])?;
Ok(candidate)
}