use ndarray::{Array1, Array2};
use rand::SeedableRng;
use crate::faer_ndarray::FaerCholesky;
use crate::linalg::triangular::{
back_substitution_lower_transpose_guarded_into, forward_substitution_lower_matrix,
};
use crate::pirls::LinearInequalityConstraints;
const TRAVEL_TIME: f64 = std::f64::consts::FRAC_PI_2;
const AMPLITUDE_FLOOR: f64 = 1e-300;
const WALL_SLACK_EPS: f64 = 1e-9;
const MAX_BOUNCES_BASE: usize = 256;
pub fn sample_truncated_gaussian_posterior(
mode: &Array1<f64>,
penalized_hessian: &Array2<f64>,
sqrt_phi: f64,
constraints: &LinearInequalityConstraints,
n_samples: usize,
n_chains: usize,
seed: u64,
) -> Result<Array2<f64>, String> {
let p = mode.len();
if p == 0 {
return Err(
"truncated-Gaussian posterior: cannot sample from an empty coefficient vector"
.to_string(),
);
}
if penalized_hessian.nrows() != p || penalized_hessian.ncols() != p {
return Err(format!(
"truncated-Gaussian posterior: penalised Hessian is {}x{}, expected {p}x{p}",
penalized_hessian.nrows(),
penalized_hessian.ncols(),
));
}
let a = &constraints.a;
let b = &constraints.b;
let m = a.nrows();
if m != b.len() {
return Err(format!(
"truncated-Gaussian posterior: constraint row mismatch (A has {m} rows, b has {})",
b.len(),
));
}
if m > 0 && a.ncols() != p {
return Err(format!(
"truncated-Gaussian posterior: constraint matrix has {} columns, expected {p}",
a.ncols(),
));
}
if !sqrt_phi.is_finite() || sqrt_phi <= 0.0 {
return Err(format!(
"truncated-Gaussian posterior: non-positive or non-finite √φ ({sqrt_phi})"
));
}
let chol = penalized_hessian
.cholesky(faer::Side::Lower)
.map_err(|err| {
format!(
"truncated-Gaussian posterior: Cholesky of the penalised Hessian failed: {err:?}"
)
})?;
let l = chol.lower_triangular();
let (f_rows, g, f_sq_norm) = if m == 0 {
(
Array2::<f64>::zeros((0, p)),
Array1::<f64>::zeros(0),
Vec::new(),
)
} else {
let at = a.t().to_owned();
let mut f = forward_substitution_lower_matrix(&l, &at).reversed_axes(); f.mapv_inplace(|v| v * sqrt_phi);
let g = a.dot(mode) - b;
let f_sq_norm: Vec<f64> = (0..m).map(|i| f.row(i).dot(&f.row(i))).collect();
(f, g, f_sq_norm)
};
let n_total = n_samples.saturating_mul(n_chains);
let mut samples = Array2::<f64>::zeros((n_total, p));
let max_bounces = MAX_BOUNCES_BASE + 8 * m;
let mut z = Array1::<f64>::zeros(p);
let mut v = Array1::<f64>::zeros(p);
let mut beta = Array1::<f64>::zeros(p);
for chain in 0..n_chains {
let mut rng = rand::rngs::StdRng::seed_from_u64(
seed ^ ((chain as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)),
);
z.fill(0.0);
for draw in 0..n_samples {
for vi in v.iter_mut() {
*vi = standard_normal(&mut rng);
}
simulate_constrained_trajectory(&mut z, &mut v, &f_rows, &g, &f_sq_norm, max_bounces);
back_substitution_lower_transpose_guarded_into(&l, &z, &mut beta);
let row = chain * n_samples + draw;
for j in 0..p {
samples[(row, j)] = mode[j] + sqrt_phi * beta[j];
}
}
}
Ok(samples)
}
fn simulate_constrained_trajectory(
z: &mut Array1<f64>,
v: &mut Array1<f64>,
f_rows: &Array2<f64>,
g: &Array1<f64>,
f_sq_norm: &[f64],
max_bounces: usize,
) {
let m = f_rows.nrows();
let mut t_left = TRAVEL_TIME;
let mut bounces = 0usize;
loop {
if t_left <= 0.0 {
return;
}
let mut hit_time = t_left;
let mut hit_wall: Option<usize> = None;
for i in 0..m {
let fi = f_rows.row(i);
let u = fi.dot(z); let w = fi.dot(v); if let Some(t) = first_wall_hit(u, w, g[i], hit_time) {
if t < hit_time {
hit_time = t;
hit_wall = Some(i);
} else if hit_wall.is_none() && t <= hit_time {
hit_time = t;
hit_wall = Some(i);
}
}
}
match hit_wall {
None => {
advance(z, v, t_left);
return;
}
Some(j) => {
advance(z, v, hit_time);
t_left -= hit_time;
let fj = f_rows.row(j);
let denom = f_sq_norm[j];
if denom > 0.0 {
let coeff = 2.0 * fj.dot(v) / denom;
for k in 0..v.len() {
v[k] -= coeff * fj[k];
}
}
bounces += 1;
if bounces >= max_bounces {
return;
}
}
}
}
}
#[inline]
fn first_wall_hit(u: f64, w: f64, g: f64, t_max: f64) -> Option<f64> {
let c0 = u + g;
if c0 <= WALL_SLACK_EPS {
if w < 0.0 {
return Some(0.0);
}
return None;
}
let r = (u * u + w * w).sqrt();
if r <= AMPLITUDE_FLOOR {
return None;
}
let q = -g / r;
if q < -1.0 {
return None;
}
let q = q.clamp(-1.0, 1.0);
let psi = w.atan2(u);
let alpha = q.acos(); let two_pi = 2.0 * std::f64::consts::PI;
let mut t = (psi + alpha).rem_euclid(two_pi);
if t <= WALL_SLACK_EPS {
t += two_pi;
}
if t <= t_max { Some(t) } else { None }
}
#[inline]
fn advance(z: &mut Array1<f64>, v: &mut Array1<f64>, t: f64) {
if t == 0.0 {
return;
}
let (st, ct) = t.sin_cos();
for k in 0..z.len() {
let zk = z[k];
let vk = v[k];
z[k] = zk * ct + vk * st;
v[k] = -zk * st + vk * ct;
}
}
#[inline]
fn standard_normal<R: rand::Rng + ?Sized>(rng: &mut R) -> f64 {
use rand::RngExt as _;
let u1 = rng.random::<f64>().max(1e-16);
let u2 = rng.random::<f64>();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn constraints(a: Array2<f64>, b: Array1<f64>) -> LinearInequalityConstraints {
LinearInequalityConstraints::new(a, b).expect("valid constraints")
}
fn assert_all_feasible(samples: &Array2<f64>, c: &LinearInequalityConstraints) {
for k in 0..samples.nrows() {
let beta = samples.row(k).to_owned();
let slack = c.a.dot(&beta) - &c.b;
for (i, s) in slack.iter().enumerate() {
assert!(
*s >= -1e-8,
"draw {k} violates constraint {i}: slack {s} (β = {beta})"
);
}
}
}
#[test]
fn loose_constraint_recovers_unconstrained_gaussian() {
let h = array![[4.0, 1.0], [1.0, 3.0]];
let mode = array![0.5, -0.3];
let c = constraints(array![[1.0, 0.0]], array![-1000.0]);
let n = 60_000;
let s = sample_truncated_gaussian_posterior(&mode, &h, 1.0, &c, n, 1, 20240613)
.expect("sampler");
assert_all_feasible(&s, &c);
let mean = s.mean_axis(ndarray::Axis(0)).unwrap();
assert!((mean[0] - 0.5).abs() < 0.02, "mean0 {} ", mean[0]);
assert!((mean[1] + 0.3).abs() < 0.02, "mean1 {}", mean[1]);
let det = 4.0 * 3.0 - 1.0;
let sigma = array![[3.0 / det, -1.0 / det], [-1.0 / det, 4.0 / det]];
let mut cov = Array2::<f64>::zeros((2, 2));
for k in 0..n {
let d0 = s[(k, 0)] - mean[0];
let d1 = s[(k, 1)] - mean[1];
cov[(0, 0)] += d0 * d0;
cov[(0, 1)] += d0 * d1;
cov[(1, 1)] += d1 * d1;
}
cov.mapv_inplace(|v| v / (n as f64 - 1.0));
cov[(1, 0)] = cov[(0, 1)];
for i in 0..2 {
for j in 0..2 {
assert!(
(cov[(i, j)] - sigma[(i, j)]).abs() < 0.01,
"cov[{i},{j}] {} vs Σ {}",
cov[(i, j)],
sigma[(i, j)]
);
}
}
}
#[test]
fn active_lower_bound_is_half_normal() {
let sigma = 2.0_f64;
let h = array![[1.0 / (sigma * sigma)]];
let mode = array![0.0]; let c = constraints(array![[1.0]], array![0.0]); let n = 200_000;
let s = sample_truncated_gaussian_posterior(&mode, &h, 1.0, &c, n, 1, 7).expect("sampler");
assert_all_feasible(&s, &c);
let col = s.column(0);
let mean = col.mean().unwrap();
let var = col.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (n as f64 - 1.0);
let two_over_pi = 2.0 / std::f64::consts::PI;
let expect_mean = sigma * two_over_pi.sqrt();
let expect_var = sigma * sigma * (1.0 - two_over_pi);
assert!(
(mean - expect_mean).abs() < 0.02,
"half-normal mean {mean} vs {expect_mean}"
);
assert!(
(var - expect_var).abs() < 0.05,
"half-normal var {var} vs {expect_var}"
);
assert!(col.iter().all(|&v| v >= 0.0), "a draw escaped β ≥ 0");
}
#[test]
fn dispersion_scales_covariance() {
let h = array![[1.0]];
let mode = array![0.0];
let c = constraints(array![[1.0]], array![0.0]);
let n = 200_000;
let sqrt_phi = 2.0; let s = sample_truncated_gaussian_posterior(&mode, &h, sqrt_phi, &c, n, 1, 99)
.expect("sampler");
let mean = s.column(0).mean().unwrap();
let expect = 2.0 * (2.0 / std::f64::consts::PI).sqrt();
assert!(
(mean - expect).abs() < 0.03,
"scaled mean {mean} vs {expect}"
);
}
#[test]
fn monotone_cone_draws_stay_feasible() {
let p = 6;
let mut h = Array2::<f64>::eye(p);
for i in 0..p {
h[(i, i)] = 3.0;
if i >= 1 && i + 1 < p {
h[(i, i + 1)] = 0.7;
h[(i + 1, i)] = 0.7;
}
}
let mode = Array1::from_vec(vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
let mut a = Array2::<f64>::zeros((p - 1, p));
for r in 0..(p - 1) {
a[(r, r + 1)] = 1.0;
}
let c = constraints(a, Array1::zeros(p - 1));
let n = 40_000;
let chains = 2;
let s = sample_truncated_gaussian_posterior(&mode, &h, 1.0, &c, n, chains, 31337)
.expect("sampler");
assert_eq!(s.dim(), (n * chains, p));
assert_all_feasible(&s, &c);
assert!((s.column(0).mean().unwrap() - 1.0).abs() < 0.05);
}
#[test]
fn interior_box_keeps_draws_in_interval() {
let h = array![[44.0]]; let mode = array![0.5];
let c = constraints(array![[1.0], [-1.0]], array![0.0, -1.0]);
let n = 80_000;
let s = sample_truncated_gaussian_posterior(&mode, &h, 1.0, &c, n, 1, 5).expect("sampler");
assert_all_feasible(&s, &c);
assert!(s.column(0).iter().all(|&v| v > 0.0 && v < 1.0));
assert!((s.column(0).mean().unwrap() - 0.5).abs() < 0.01);
}
}