use std::sync::OnceLock;
use gam_linalg::matrix::DesignMatrix;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
#[derive(Clone, Debug)]
pub struct LaplaceTrustworthiness {
pub directional_skewness: Array1<f64>,
pub untrustworthy_directions: Vec<usize>,
pub threshold: f64,
pub max_abs_skewness: f64,
}
impl LaplaceTrustworthiness {
pub fn fallback_required(&self) -> bool {
!self.untrustworthy_directions.is_empty()
}
}
#[derive(Clone, Debug)]
pub struct BlockSampledMoments {
pub e_t: Array1<f64>,
pub e_tt: Array2<f64>,
pub e_neg_score: Array1<f64>,
pub e_t_neg_score: Array2<f64>,
}
#[derive(Clone, Debug)]
pub struct BlockSampledMarginal {
pub value: f64,
pub rho_gradient: Array1<f64>,
pub importance_ess: f64,
pub n_draws: usize,
pub moments: Option<BlockSampledMoments>,
}
#[derive(Clone, Debug)]
pub struct GaussianModePosterior {
pub samples: Array2<f64>,
pub posterior_mean: Array1<f64>,
pub posterior_std: Array1<f64>,
pub rhat: f64,
pub ess: f64,
}
pub fn laplace_skewness_threshold(n_eff: f64) -> f64 {
if !(n_eff > 0.0) {
return f64::INFINITY;
}
((24.0 / 5.0) / n_eff).sqrt()
}
pub fn laplace_trustworthiness_from_skewness(
directional_skewness: &Array1<f64>,
n_eff: f64,
) -> LaplaceTrustworthiness {
let threshold = laplace_skewness_threshold(n_eff);
let mut untrustworthy_directions = Vec::new();
let mut max_abs_skewness = 0.0_f64;
for (r, &gamma) in directional_skewness.iter().enumerate() {
let abs_gamma = if gamma.is_finite() { gamma.abs() } else { 0.0 };
max_abs_skewness = max_abs_skewness.max(abs_gamma);
if abs_gamma > threshold {
untrustworthy_directions.push(r);
}
}
LaplaceTrustworthiness {
directional_skewness: directional_skewness.clone(),
untrustworthy_directions,
threshold,
max_abs_skewness,
}
}
pub trait BlockExcessTarget {
fn block_dim(&self) -> usize;
fn rho_dim(&self) -> usize;
fn block_curvatures(&self) -> &Array1<f64>;
fn excess(&self, t: &Array1<f64>) -> f64;
fn excess_rho_gradient(&self, t: &Array1<f64>) -> Array1<f64>;
fn displaced_neg_score(&self, t: &Array1<f64>) -> Array1<f64>;
fn base_neg_score(&self) -> Array1<f64>;
fn excess_with_displaced_neg_score(&self, t: &Array1<f64>) -> (f64, Option<Array1<f64>>) {
let excess = self.excess(t);
if excess.is_finite() {
(excess, Some(self.displaced_neg_score(t)))
} else {
(excess, None)
}
}
fn excess_with_displaced_neg_score_batch(
&self,
draws: &Array2<f64>,
) -> Vec<(f64, Option<Array1<f64>>)> {
let n_draws = draws.ncols();
let mut out = Vec::with_capacity(n_draws);
let mut t = Array1::<f64>::zeros(draws.nrows());
for s in 0..n_draws {
t.assign(&draws.column(s));
out.push(self.excess_with_displaced_neg_score(&t));
}
out
}
}
pub trait LaplaceMarginalSampler: Send + Sync {
fn directional_cubic_diagnostic(
&self,
hessian: &Array2<f64>,
design: &DesignMatrix,
c_weights: &Array1<f64>,
refine_supremum: bool,
) -> Result<(f64, Array1<f64>), String>;
fn block_sampled_marginal_correction(
&self,
target: &dyn BlockExcessTarget,
) -> Result<BlockSampledMarginal, String>;
}
pub trait GaussianModePosteriorSampler: Send + Sync {
fn sample_gaussian_mode_posterior(
&self,
mode: ArrayView1<f64>,
precision: ArrayView2<f64>,
) -> Result<GaussianModePosterior, String>;
}
static LAPLACE_MARGINAL_SAMPLER: OnceLock<Box<dyn LaplaceMarginalSampler>> = OnceLock::new();
static GAUSSIAN_MODE_POSTERIOR_SAMPLER: OnceLock<Box<dyn GaussianModePosteriorSampler>> =
OnceLock::new();
pub fn set_laplace_marginal_sampler(
sampler: Box<dyn LaplaceMarginalSampler>,
) -> Result<(), Box<dyn LaplaceMarginalSampler>> {
LAPLACE_MARGINAL_SAMPLER.set(sampler)
}
pub fn laplace_marginal_sampler() -> Option<&'static dyn LaplaceMarginalSampler> {
LAPLACE_MARGINAL_SAMPLER.get().map(|b| b.as_ref())
}
pub fn set_gaussian_mode_posterior_sampler(
sampler: Box<dyn GaussianModePosteriorSampler>,
) -> Result<(), Box<dyn GaussianModePosteriorSampler>> {
GAUSSIAN_MODE_POSTERIOR_SAMPLER.set(sampler)
}
pub fn gaussian_mode_posterior_sampler() -> Option<&'static dyn GaussianModePosteriorSampler> {
GAUSSIAN_MODE_POSTERIOR_SAMPLER.get().map(|b| b.as_ref())
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn threshold_is_infinity_for_zero_n_eff() {
assert_eq!(laplace_skewness_threshold(0.0), f64::INFINITY);
}
#[test]
fn threshold_is_infinity_for_negative_n_eff() {
assert_eq!(laplace_skewness_threshold(-5.0), f64::INFINITY);
}
#[test]
fn threshold_known_value() {
let n_eff = 24.0 / 5.0;
let t = laplace_skewness_threshold(n_eff);
assert!((t - 1.0).abs() < 1e-14, "threshold={t}");
}
#[test]
fn threshold_decreases_as_n_eff_increases() {
let t_small = laplace_skewness_threshold(10.0);
let t_large = laplace_skewness_threshold(1000.0);
assert!(t_large < t_small, "threshold should decrease with more data");
}
#[test]
fn all_small_skewness_gives_no_untrustworthy_directions() {
let skewness = array![0.01_f64, -0.02, 0.005];
let result = laplace_trustworthiness_from_skewness(&skewness, 1000.0);
assert!(result.untrustworthy_directions.is_empty());
assert!(!result.fallback_required());
}
#[test]
fn large_skewness_flagged_as_untrustworthy() {
let skewness = array![0.1_f64, 2.0];
let result = laplace_trustworthiness_from_skewness(&skewness, 10.0);
assert!(result.untrustworthy_directions.contains(&1));
assert!(!result.untrustworthy_directions.contains(&0));
assert!(result.fallback_required());
}
#[test]
fn max_abs_skewness_is_largest_abs_value() {
let skewness = array![1.5_f64, -3.0, 2.0];
let result = laplace_trustworthiness_from_skewness(&skewness, 1.0);
assert!((result.max_abs_skewness - 3.0).abs() < 1e-14);
}
#[test]
fn non_finite_skewness_treated_as_zero_for_max_abs() {
let skewness = array![f64::NAN, 1.0];
let result = laplace_trustworthiness_from_skewness(&skewness, 1.0);
assert!((result.max_abs_skewness - 1.0).abs() < 1e-14);
}
#[test]
fn fallback_required_true_when_directions_nonempty() {
let lt = LaplaceTrustworthiness {
directional_skewness: array![1.0_f64],
untrustworthy_directions: vec![0],
threshold: 0.5,
max_abs_skewness: 1.0,
};
assert!(lt.fallback_required());
}
#[test]
fn fallback_required_false_when_directions_empty() {
let lt = LaplaceTrustworthiness {
directional_skewness: array![0.1_f64],
untrustworthy_directions: vec![],
threshold: 0.5,
max_abs_skewness: 0.1,
};
assert!(!lt.fallback_required());
}
}