use std::sync::OnceLock;
use gam_linalg::matrix::DesignMatrix;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
#[derive(Clone, Debug)]
pub struct LaplaceTrustworthiness {
pub directional_skewness: Array1<f64>,
pub untrustworthy_directions: Vec<usize>,
pub threshold: f64,
pub max_abs_skewness: f64,
}
impl LaplaceTrustworthiness {
pub fn fallback_required(&self) -> bool {
!self.untrustworthy_directions.is_empty()
}
}
#[derive(Clone, Debug)]
pub struct BlockSampledMoments {
pub e_t: Array1<f64>,
pub e_tt: Array2<f64>,
pub e_neg_score: Array1<f64>,
pub e_t_neg_score: Array2<f64>,
}
#[derive(Clone, Debug)]
pub struct BlockSampledMarginal {
pub value: f64,
pub rho_gradient: Array1<f64>,
pub importance_ess: f64,
pub n_draws: usize,
pub moments: Option<BlockSampledMoments>,
}
#[derive(Clone, Debug)]
pub struct GaussianModePosterior {
pub samples: Array2<f64>,
pub posterior_mean: Array1<f64>,
pub posterior_std: Array1<f64>,
pub rhat: f64,
pub ess: f64,
}
pub fn laplace_skewness_threshold(n_eff: f64) -> f64 {
if !(n_eff > 0.0) {
return f64::INFINITY;
}
((24.0 / 5.0) / n_eff).sqrt()
}
pub fn laplace_trustworthiness_from_skewness(
directional_skewness: &Array1<f64>,
n_eff: f64,
) -> LaplaceTrustworthiness {
let threshold = laplace_skewness_threshold(n_eff);
let mut untrustworthy_directions = Vec::new();
let mut max_abs_skewness = 0.0_f64;
for (r, &gamma) in directional_skewness.iter().enumerate() {
let abs_gamma = if gamma.is_finite() { gamma.abs() } else { 0.0 };
max_abs_skewness = max_abs_skewness.max(abs_gamma);
if abs_gamma > threshold {
untrustworthy_directions.push(r);
}
}
LaplaceTrustworthiness {
directional_skewness: directional_skewness.clone(),
untrustworthy_directions,
threshold,
max_abs_skewness,
}
}
pub trait BlockExcessTarget {
fn block_dim(&self) -> usize;
fn rho_dim(&self) -> usize;
fn block_curvatures(&self) -> &Array1<f64>;
fn excess(&self, t: &Array1<f64>) -> f64;
fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64>;
fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64>;
fn base_neg_score(&self) -> Array1<f64>;
fn excess_with_displaced_neg_score(&self, t: &Array1<f64>) -> (f64, Option<Array1<f64>>) {
let excess = self.excess(t);
if excess.is_finite() {
(excess, Some(self.displaced_neg_score(t)))
} else {
(excess, None)
}
}
fn excess_with_displaced_neg_score_batch(
&self,
draws: &Array2<f64>,
) -> Vec<(f64, Option<Array1<f64>>)> {
let n_draws = draws.ncols();
let mut out = Vec::with_capacity(n_draws);
let mut t = Array1::<f64>::zeros(draws.nrows());
for s in 0..n_draws {
t.assign(&draws.column(s));
out.push(self.excess_with_displaced_neg_score(&t));
}
out
}
}
pub trait LaplaceMarginalSampler: Send + Sync {
fn directional_cubic_diagnostic(
&self,
hessian: &Array2<f64>,
design: &DesignMatrix,
c_weights: &Array1<f64>,
refine_supremum: bool,
) -> Result<(f64, Array1<f64>), String>;
fn block_sampled_marginal_correction(
&self,
target: &dyn BlockExcessTarget,
) -> Result<BlockSampledMarginal, String>;
}
pub trait GaussianModePosteriorSampler: Send + Sync {
fn sample_gaussian_mode_posterior(
&self,
mode: ArrayView1<f64>,
precision: ArrayView2<f64>,
) -> Result<GaussianModePosterior, String>;
}
static LAPLACE_MARGINAL_SAMPLER: OnceLock<Box<dyn LaplaceMarginalSampler>> = OnceLock::new();
static GAUSSIAN_MODE_POSTERIOR_SAMPLER: OnceLock<Box<dyn GaussianModePosteriorSampler>> =
OnceLock::new();
pub fn set_laplace_marginal_sampler(
sampler: Box<dyn LaplaceMarginalSampler>,
) -> Result<(), Box<dyn LaplaceMarginalSampler>> {
LAPLACE_MARGINAL_SAMPLER.set(sampler)
}
pub fn laplace_marginal_sampler() -> Option<&'static dyn LaplaceMarginalSampler> {
LAPLACE_MARGINAL_SAMPLER.get().map(|b| b.as_ref())
}
pub fn set_gaussian_mode_posterior_sampler(
sampler: Box<dyn GaussianModePosteriorSampler>,
) -> Result<(), Box<dyn GaussianModePosteriorSampler>> {
GAUSSIAN_MODE_POSTERIOR_SAMPLER.set(sampler)
}
pub fn gaussian_mode_posterior_sampler() -> Option<&'static dyn GaussianModePosteriorSampler> {
GAUSSIAN_MODE_POSTERIOR_SAMPLER.get().map(|b| b.as_ref())
}