use crate::types::RhoPrior;
use ndarray::{Array1, Array2};
pub(crate) fn pc_prior_rate(upper: f64, tail_prob: f64) -> f64 {
-tail_prob.ln() / upper
}
pub(crate) fn pc_prior_terms(theta: f64, r: f64) -> (f64, f64, f64) {
let e = (-0.5 * r).exp();
(0.5 * r + theta * e, 0.5 - 0.5 * theta * e, 0.25 * theta * e)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum InvalidPriorPolicy {
HardError,
Saturate,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum RhoPriorError {
DimensionMismatch { reason: String },
ConstraintViolation { reason: String },
}
impl RhoPriorError {
fn dimension_mismatch(reason: String) -> Self {
RhoPriorError::DimensionMismatch { reason }
}
fn constraint_violation(reason: String) -> Self {
RhoPriorError::ConstraintViolation { reason }
}
}
#[derive(Debug, Clone)]
pub(crate) struct RhoPriorEval {
pub cost: f64,
pub gradient: Array1<f64>,
pub hessian: Option<Array2<f64>>,
}
fn scalar_terms(prior: &RhoPrior, r: f64, context: &str) -> Result<(f64, f64, f64), RhoPriorError> {
match prior {
RhoPrior::Flat => Ok((0.0, 0.0, 0.0)),
RhoPrior::Normal { mean, sd } => {
if !mean.is_finite() || !sd.is_finite() || *sd <= 0.0 {
return Err(RhoPriorError::constraint_violation(format!(
"{context} Normal log-precision prior requires finite mean and sd > 0"
)));
}
let inv_var = 1.0 / (*sd * *sd);
let delta = r - *mean;
Ok((0.5 * delta * delta * inv_var, delta * inv_var, inv_var))
}
RhoPrior::GammaPrecision { shape, rate } => {
if !shape.is_finite() || *shape <= 0.0 || !rate.is_finite() || *rate < 0.0 {
return Err(RhoPriorError::constraint_violation(format!(
"{context} Gamma precision prior requires shape > 0 and rate >= 0"
)));
}
let lambda = r.exp();
Ok((
*rate * lambda - (*shape - 1.0) * r,
*rate * lambda - (*shape - 1.0),
*rate * lambda,
))
}
RhoPrior::PenalizedComplexity { upper, tail_prob } => {
if !upper.is_finite() || *upper <= 0.0 {
return Err(RhoPriorError::constraint_violation(format!(
"{context} penalized-complexity prior requires a finite upper > 0"
)));
}
if !tail_prob.is_finite() || *tail_prob <= 0.0 || *tail_prob >= 1.0 {
return Err(RhoPriorError::constraint_violation(format!(
"{context} penalized-complexity prior requires tail probability in (0, 1)"
)));
}
let theta = pc_prior_rate(*upper, *tail_prob);
Ok(pc_prior_terms(theta, r))
}
RhoPrior::Independent(_) => Err(RhoPriorError::constraint_violation(format!(
"{context} must be a scalar rho prior, not a nested Independent prior"
))),
}
}
fn saturated(len: usize) -> RhoPriorEval {
RhoPriorEval {
cost: f64::INFINITY,
gradient: Array1::from_elem(len, f64::NAN),
hessian: Some(Array2::from_elem((len, len), f64::NAN)),
}
}
pub(crate) fn evaluate(
prior: &RhoPrior,
rho: &Array1<f64>,
policy: InvalidPriorPolicy,
) -> Result<RhoPriorEval, RhoPriorError> {
match evaluate_strict(prior, rho) {
Ok(eval) => Ok(eval),
Err(err) => match policy {
InvalidPriorPolicy::HardError => Err(err),
InvalidPriorPolicy::Saturate => Ok(saturated(rho.len())),
},
}
}
fn evaluate_strict(prior: &RhoPrior, rho: &Array1<f64>) -> Result<RhoPriorEval, RhoPriorError> {
let len = rho.len();
match prior {
RhoPrior::Flat => Ok(RhoPriorEval {
cost: 0.0,
gradient: Array1::zeros(len),
hessian: None,
}),
RhoPrior::Normal { .. }
| RhoPrior::GammaPrecision { .. }
| RhoPrior::PenalizedComplexity { .. } => {
let mut cost = 0.0;
let mut gradient = Array1::<f64>::zeros(len);
let mut hessian = Array2::<f64>::zeros((len, len));
let mut any_hessian = false;
for (idx, &r) in rho.iter().enumerate() {
let (c, g, h) = scalar_terms(prior, r, "rho prior")?;
cost += c;
gradient[idx] = g;
hessian[[idx, idx]] = h;
any_hessian |= h != 0.0;
}
Ok(RhoPriorEval {
cost,
gradient,
hessian: any_hessian.then_some(hessian),
})
}
RhoPrior::Independent(priors) => {
if priors.len() != len {
return Err(RhoPriorError::dimension_mismatch(format!(
"Independent rho prior length mismatch: got {}, expected {}",
priors.len(),
len
)));
}
let mut cost = 0.0;
let mut gradient = Array1::<f64>::zeros(len);
let mut hessian = Array2::<f64>::zeros((len, len));
let mut any_hessian = false;
for (idx, (prior, &r)) in priors.iter().zip(rho.iter()).enumerate() {
let (c, g, h) = scalar_terms(prior, r, &format!("rho prior coordinate {idx}"))?;
cost += c;
gradient[idx] = g;
hessian[[idx, idx]] = h;
any_hessian |= h != 0.0;
}
Ok(RhoPriorEval {
cost,
gradient,
hessian: any_hessian.then_some(hessian),
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx(a: f64, b: f64) {
assert!((a - b).abs() <= 1e-12, "expected {a} ~= {b}");
}
#[test]
fn cost_grad_hess_parity_across_valid_priors() {
let rho = Array1::from_vec(vec![-0.5, 0.25, 1.5, 0.7]);
let priors = vec![
RhoPrior::Flat,
RhoPrior::Normal { mean: 0.2, sd: 0.8 },
RhoPrior::GammaPrecision {
shape: 2.0,
rate: 0.5,
},
RhoPrior::PenalizedComplexity {
upper: 0.5,
tail_prob: 0.05,
},
RhoPrior::Independent(vec![
RhoPrior::Flat,
RhoPrior::Normal {
mean: -0.1,
sd: 1.3,
},
RhoPrior::GammaPrecision {
shape: 1.5,
rate: 0.0,
},
RhoPrior::PenalizedComplexity {
upper: 1.2,
tail_prob: 0.01,
},
]),
];
for prior in &priors {
let hard = evaluate(prior, &rho, InvalidPriorPolicy::HardError)
.expect("valid prior must not error under HardError");
let sat =
evaluate(prior, &rho, InvalidPriorPolicy::Saturate).expect("Saturate never errors");
approx(hard.cost, sat.cost);
assert_eq!(hard.gradient, sat.gradient);
assert_eq!(hard.hessian, sat.hessian);
let base = evaluate(prior, &rho, InvalidPriorPolicy::HardError).unwrap();
let cost_at = |k: usize, delta: f64| -> f64 {
let mut r = rho.clone();
r[k] += delta;
evaluate(prior, &r, InvalidPriorPolicy::HardError)
.unwrap()
.cost
};
let (h_grad, h_hess) = (1e-6, 1e-4);
for k in 0..rho.len() {
let fd_grad = (cost_at(k, h_grad) - cost_at(k, -h_grad)) / (2.0 * h_grad);
assert!(
(fd_grad - base.gradient[k]).abs() <= 1e-5,
"gradient mismatch at {k}: fd {fd_grad} vs {}",
base.gradient[k]
);
let fd_hess = (cost_at(k, h_hess) - 2.0 * base.cost + cost_at(k, -h_hess))
/ (h_hess * h_hess);
let analytic_hess = base.hessian.as_ref().map_or(0.0, |h| h[[k, k]]);
assert!(
(fd_hess - analytic_hess).abs() <= 1e-4,
"hessian mismatch at {k}: fd {fd_hess} vs {analytic_hess}"
);
}
}
}
#[test]
fn invalid_prior_policy_branches() {
let rho = Array1::from_vec(vec![0.0, 0.0]);
let bad_normal = RhoPrior::Normal {
mean: 0.0,
sd: -1.0,
};
assert!(matches!(
evaluate(&bad_normal, &rho, InvalidPriorPolicy::HardError),
Err(RhoPriorError::ConstraintViolation { .. })
));
let sat = evaluate(&bad_normal, &rho, InvalidPriorPolicy::Saturate).unwrap();
assert!(sat.cost.is_infinite() && sat.cost > 0.0);
assert!(sat.gradient.iter().all(|v| v.is_nan()));
assert!(sat.hessian.unwrap().iter().all(|v| v.is_nan()));
let bad_len = RhoPrior::Independent(vec![RhoPrior::Flat]);
assert!(matches!(
evaluate(&bad_len, &rho, InvalidPriorPolicy::HardError),
Err(RhoPriorError::DimensionMismatch { .. })
));
let nested = RhoPrior::Independent(vec![
RhoPrior::Independent(vec![RhoPrior::Flat]),
RhoPrior::Flat,
]);
assert!(matches!(
evaluate(&nested, &rho, InvalidPriorPolicy::HardError),
Err(RhoPriorError::ConstraintViolation { .. })
));
}
fn pc_log_pdf(upper: f64, tail_prob: f64, r: f64) -> f64 {
let theta = pc_prior_rate(upper, tail_prob);
(0.5 * theta).ln() - 0.5 * r - theta * (-0.5 * r).exp()
}
#[test]
fn pc_rate_calibrates_to_tail_statement() {
for &(upper, alpha) in &[(0.5_f64, 0.05_f64), (1.2, 0.01), (3.0, 0.25)] {
let theta = pc_prior_rate(upper, alpha);
let tail = (-theta * upper).exp();
assert!(
(tail - alpha).abs() < 1e-12,
"P(d>U)={tail} vs α={alpha} (U={upper})"
);
}
}
#[test]
fn pc_density_integrates_to_one_and_matches_tail() {
let upper = 0.5_f64;
let alpha = 0.05_f64;
let (lo, hi, n) = (-60.0_f64, 80.0_f64, 2_000_000usize);
let h = (hi - lo) / n as f64;
let tail_boundary = -2.0 * upper.ln();
let mut total = 0.0;
let mut tail = 0.0;
for i in 0..=n {
let r = lo + i as f64 * h;
let w = if i == 0 || i == n { 0.5 } else { 1.0 };
let p = pc_log_pdf(upper, alpha, r).exp();
total += w * p;
if r <= tail_boundary {
tail += w * p;
}
}
total *= h;
tail *= h;
assert!((total - 1.0).abs() < 1e-4, "∫ p(ρ) dρ = {total}");
assert!(
(tail - alpha).abs() < 1e-3,
"P(d>U) = {tail} vs α = {alpha}"
);
}
#[test]
fn pc_terms_are_negative_log_density_derivatives() {
let (upper, alpha) = (0.8_f64, 0.02_f64);
let theta = pc_prior_rate(upper, alpha);
let (h1, h2) = (1e-6, 1e-4);
for &r in &[-2.0_f64, -0.3, 0.0, 1.7, 4.0] {
let (cost, grad, hess) = pc_prior_terms(theta, r);
approx(cost + pc_log_pdf(upper, alpha, r), (0.5 * theta).ln());
let dlp =
(pc_log_pdf(upper, alpha, r + h1) - pc_log_pdf(upper, alpha, r - h1)) / (2.0 * h1);
let neg_dlp = -dlp;
assert!(
(grad - neg_dlp).abs() < 1e-5,
"grad {grad} vs {neg_dlp} at r={r}"
);
let d2lp = (pc_log_pdf(upper, alpha, r + h2) - 2.0 * pc_log_pdf(upper, alpha, r)
+ pc_log_pdf(upper, alpha, r - h2))
/ (h2 * h2);
let neg_d2lp = -d2lp;
assert!(
(hess - neg_d2lp).abs() < 1e-4,
"hess {hess} vs {neg_d2lp} at r={r}"
);
assert!(hess > 0.0, "PC curvature must be positive, got {hess}");
}
}
#[test]
fn pc_prior_pulls_toward_simpler_model() {
let prior = RhoPrior::PenalizedComplexity {
upper: 1.0,
tail_prob: 0.05,
};
let cost = |r: f64| {
evaluate(
&prior,
&Array1::from_vec(vec![r]),
InvalidPriorPolicy::HardError,
)
.unwrap()
.cost
};
assert!(
cost(-4.0) > cost(4.0),
"under-smoothing must cost more: {} vs {}",
cost(-4.0),
cost(4.0)
);
let g_far = evaluate(
&prior,
&Array1::from_vec(vec![25.0]),
InvalidPriorPolicy::HardError,
)
.unwrap()
.gradient[0];
assert!(
(g_far - 0.5).abs() < 1e-3,
"far over-smoothing slope {g_far}"
);
}
#[test]
fn pc_prior_rejects_invalid_hyperparameters() {
let rho = Array1::from_vec(vec![0.0]);
for bad in [
RhoPrior::PenalizedComplexity {
upper: 0.0,
tail_prob: 0.05,
},
RhoPrior::PenalizedComplexity {
upper: -1.0,
tail_prob: 0.05,
},
RhoPrior::PenalizedComplexity {
upper: 1.0,
tail_prob: 0.0,
},
RhoPrior::PenalizedComplexity {
upper: 1.0,
tail_prob: 1.0,
},
RhoPrior::PenalizedComplexity {
upper: 1.0,
tail_prob: f64::NAN,
},
] {
assert!(matches!(
evaluate(&bad, &rho, InvalidPriorPolicy::HardError),
Err(RhoPriorError::ConstraintViolation { .. })
));
let sat = evaluate(&bad, &rho, InvalidPriorPolicy::Saturate).unwrap();
assert!(sat.cost.is_infinite() && sat.cost > 0.0);
}
}
}