use ndarray::{Array1, Array2};
use std::sync::OnceLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RhoCertificate {
PlugInCertified,
ImportanceCorrect,
Escalate,
}
impl RhoCertificate {
pub fn from_k_hat(k_hat: f64) -> Self {
if !k_hat.is_finite() || k_hat > 0.7 {
RhoCertificate::Escalate
} else if k_hat < 0.5 {
RhoCertificate::PlugInCertified
} else {
RhoCertificate::ImportanceCorrect
}
}
}
#[derive(Debug, Clone)]
pub struct RhoPosteriorCertificate {
pub k_hat: f64,
pub certificate: RhoCertificate,
pub n_samples: usize,
pub weights: Array1<f64>,
pub effective_sample_size: f64,
}
#[derive(Debug, Clone)]
pub struct RhoMixtureNode {
pub rho: Array1<f64>,
pub weight: f64,
pub log_weight: f64,
pub cost: f64,
}
#[derive(Debug, Clone)]
pub struct RhoPosteriorMixture {
pub nodes: Vec<RhoMixtureNode>,
pub mean: Array1<f64>,
pub covariance: Array2<f64>,
pub effective_sample_size: f64,
}
#[derive(Debug, Clone)]
pub struct RhoPosteriorSamples {
pub samples: Array2<f64>,
pub mean: Array1<f64>,
pub covariance: Array2<f64>,
pub rhat: f64,
pub ess: f64,
pub converged: bool,
}
#[derive(Debug, Clone)]
pub enum RhoPosteriorEscalation {
Quadrature(RhoPosteriorMixture),
Nuts(RhoPosteriorSamples),
Unavailable { n_params: usize, reason: String },
}
pub trait RhoPosteriorEscalator: Send + Sync {
fn rho_posterior_certificate(
&self,
rho_hat: &Array1<f64>,
outer_hessian: &Array2<f64>,
criterion: &dyn Fn(&Array1<f64>) -> Option<f64>,
n_samples: Option<usize>,
) -> Option<RhoPosteriorCertificate>;
fn escalate_rho_posterior(
&self,
rho_hat: &Array1<f64>,
outer_hessian: &Array2<f64>,
criterion: &mut dyn FnMut(&Array1<f64>) -> Option<f64>,
criterion_and_grad: &mut (dyn FnMut(&Array1<f64>) -> Option<(f64, Array1<f64>)> + Send),
) -> RhoPosteriorEscalation;
}
static RHO_POSTERIOR_ESCALATOR: OnceLock<Box<dyn RhoPosteriorEscalator>> = OnceLock::new();
pub fn set_rho_posterior_escalator(
escalator: Box<dyn RhoPosteriorEscalator>,
) -> Result<(), Box<dyn RhoPosteriorEscalator>> {
RHO_POSTERIOR_ESCALATOR.set(escalator)
}
pub fn rho_posterior_escalator() -> Option<&'static dyn RhoPosteriorEscalator> {
RHO_POSTERIOR_ESCALATOR.get().map(|b| b.as_ref())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_k_hat_below_half_is_plug_in_certified() {
assert_eq!(
RhoCertificate::from_k_hat(0.0),
RhoCertificate::PlugInCertified
);
assert_eq!(
RhoCertificate::from_k_hat(0.499),
RhoCertificate::PlugInCertified
);
}
#[test]
fn from_k_hat_between_half_and_point_seven_is_importance_correct() {
assert_eq!(
RhoCertificate::from_k_hat(0.5),
RhoCertificate::ImportanceCorrect
);
assert_eq!(
RhoCertificate::from_k_hat(0.7),
RhoCertificate::ImportanceCorrect
);
assert_eq!(
RhoCertificate::from_k_hat(0.65),
RhoCertificate::ImportanceCorrect
);
}
#[test]
fn from_k_hat_above_point_seven_is_escalate() {
assert_eq!(
RhoCertificate::from_k_hat(0.701),
RhoCertificate::Escalate
);
assert_eq!(
RhoCertificate::from_k_hat(10.0),
RhoCertificate::Escalate
);
}
#[test]
fn from_k_hat_nan_is_escalate() {
assert_eq!(RhoCertificate::from_k_hat(f64::NAN), RhoCertificate::Escalate);
}
#[test]
fn from_k_hat_infinity_is_escalate() {
assert_eq!(
RhoCertificate::from_k_hat(f64::INFINITY),
RhoCertificate::Escalate
);
}
#[test]
fn rho_posterior_escalator_returns_none_when_unregistered() {
assert!(rho_posterior_escalator().is_none());
}
}