use crate::PolyaGamma;
use mini_mcmc::core::{ChainRunner, init_det};
use mini_mcmc::distributions::Conditional;
use mini_mcmc::gibbs::GibbsSampler;
use mini_mcmc::stats::RunStats;
use ndarray::{Array1, Array2, Array3};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use statrs::distribution::Normal;
use statrs::function::gamma::ln_gamma;
use std::error::Error;
pub struct GibbsNegativeBinomial<R = ChaCha8Rng>
where
R: SeedableRng + Rng + Clone + Send + Sync,
{
x: Array2<f64>,
y: Array1<f64>,
prior_prec: Array2<f64>,
prior_shape: f64,
prior_scale: f64,
n_chains: usize,
seed: u64,
rng: R,
}
pub struct NegativeBinomialResults {
pub posterior_means: Vec<f64>,
pub posterior_sds: Vec<f64>,
pub posterior_mean_r: f64,
pub samples: Array3<f64>,
pub true_coefficients: Option<Vec<f64>>,
pub true_r: Option<f64>,
pub run_stats: RunStats,
}
impl GibbsNegativeBinomial<ChaCha8Rng> {
pub fn new(
x: Array2<f64>,
y: Array1<f64>,
prior_variance: f64,
prior_shape: f64,
prior_scale: f64,
n_chains: usize,
seed: u64,
) -> Self {
let p = x.ncols();
let prior_prec = Array2::eye(p) * (1.0 / prior_variance);
let rng = ChaCha8Rng::seed_from_u64(seed);
Self {
x,
y,
prior_prec,
prior_shape,
prior_scale,
n_chains,
seed,
rng,
}
}
}
impl<R: SeedableRng + Rng + Clone + Send + Sync> GibbsNegativeBinomial<R> {
pub fn run(
self,
burnin: usize,
samples: usize,
true_coefficients: Option<Vec<f64>>,
true_r: Option<f64>,
) -> Result<NegativeBinomialResults, Box<dyn Error>> {
let n = self.x.nrows();
let p = self.x.ncols();
let dim = p + n + 1;
let cond = NegativeBinomialConditional::new(
self.x.clone(),
self.y.clone(),
self.prior_prec.clone(),
self.prior_shape,
self.prior_scale,
self.rng,
);
let mut init: Vec<Vec<f64>> = init_det(self.n_chains, dim);
for state in &mut init {
for w in &mut state[p..p + n] {
*w = 1.0;
}
state[p + n] = 1.0;
}
let mut gibbs = GibbsSampler::new(cond, init).set_seed(self.seed);
let (all_samples, run_stats) = gibbs.run_progress(samples, burnin)?;
let pooled = all_samples.to_shape((self.n_chains * samples, dim))?;
let posterior_means: Vec<f64> = (0..p).map(|j| pooled.column(j).mean().unwrap()).collect();
let posterior_sds: Vec<f64> = (0..p)
.map(|j| pooled.column(j).std(1.0)) .collect();
let posterior_mean_r = pooled.column(p + n).mean().unwrap();
Ok(NegativeBinomialResults {
posterior_means,
posterior_sds,
posterior_mean_r,
samples: all_samples,
true_coefficients,
true_r,
run_stats,
})
}
}
impl NegativeBinomialResults {
pub fn summary(&self) {
println!("Negative Binomial Regression Results");
println!("================================");
println!(
"{:<10} {:<15} {:<15} {:<15}",
"Parameter", "Mean", "Std. Dev.", "True Value"
);
println!("{}", "-".repeat(55));
for (i, (mean, sd)) in self
.posterior_means
.iter()
.zip(&self.posterior_sds)
.enumerate()
{
let true_val = self
.true_coefficients
.as_ref()
.map_or("N/A".to_string(), |v| format!("{:.4}", v[i]));
println!(
"{:<10} {:<15.4} {:<15.4} {:<15}",
format!("β{}", i),
mean,
sd,
true_val
);
}
let true_r = self
.true_r
.map_or("N/A".to_string(), |r| format!("{:.4}", r));
println!(
"{:<10} {:<15.4} {:<15} {:<15}",
"r", self.posterior_mean_r, "-", true_r
);
}
pub fn get_posterior_samples(&self, param_idx: usize) -> Option<Vec<f64>> {
let n_params = self.samples.shape()[2];
if param_idx >= n_params {
return None;
}
let n_chains = self.samples.shape()[0];
let n_samples = self.samples.shape()[1];
let mut samples = Vec::with_capacity(n_chains * n_samples);
for chain in 0..n_chains {
for sample in 0..n_samples {
samples.push(self.samples[[chain, sample, param_idx]]);
}
}
Some(samples)
}
}
#[derive(Clone)]
struct NegativeBinomialConditional<R>
where
R: SeedableRng + Rng + Clone + Send + Sync,
{
x: Array2<f64>,
y: Array1<f64>,
prior_prec: Array2<f64>,
prior_shape: f64,
prior_scale: f64,
pg: PolyaGamma,
rng: R,
}
impl<R> NegativeBinomialConditional<R>
where
R: SeedableRng + Rng + Clone + Send + Sync,
{
pub fn new(
x: Array2<f64>,
y: Array1<f64>,
prior_prec: Array2<f64>,
prior_shape: f64,
prior_scale: f64,
rng: R,
) -> Self {
assert_eq!(
x.nrows(),
y.len(),
"Number of rows in x must match length of y"
);
assert!(prior_shape > 0.0, "prior_shape must be positive");
assert!(prior_scale > 0.0, "prior_scale must be positive");
Self {
x,
y,
prior_prec,
prior_shape,
prior_scale,
pg: PolyaGamma::new(1.0),
rng,
}
}
}
impl<R> Conditional<f64> for NegativeBinomialConditional<R>
where
R: SeedableRng + Rng + Clone + Send + Sync,
{
fn sample(&mut self, i: usize, given: &[f64]) -> f64 {
let n = self.x.nrows();
let p = self.x.ncols();
let r = given[p + n];
if i < p {
let col_i = self.x.column(i);
let prior_ii = self.prior_prec[(i, i)];
let mut precision = prior_ii; let mut precision_mean = 0.0;
for row_idx in 0..n {
let xi = col_i[row_idx];
let wi = given[p + row_idx]; precision += wi * xi * xi;
let mut dot_minus_i = 0.0;
for (k, bj) in given.iter().enumerate().take(p) {
if k != i {
dot_minus_i += self.x[(row_idx, k)] * bj;
}
}
let yi = self.y[row_idx];
let kappa = (yi - r) / 2.0;
let resid = kappa - wi * (dot_minus_i - r.ln());
precision_mean += xi * resid;
}
let var_i = 1.0 / precision;
let mean_i = precision_mean * var_i;
let normal = Normal::standard();
let eps: f64 = self.rng.sample(normal);
mean_i + eps * var_i.sqrt()
} else if i < p + n {
let obs_idx = i - p;
let row = self.x.row(obs_idx);
let xb: f64 = row
.iter()
.zip(&given[0..p])
.map(|(xij, bj)| xij * bj)
.sum::<f64>();
let psi = xb - r.ln();
let yi = self.y[obs_idx];
self.pg.set_shape(yi + r);
self.pg.draw(&mut self.rng, psi.abs())
} else {
let current_r = r;
let log_posterior = |r_val: f64| -> f64 {
if r_val <= 0.0 {
return f64::NEG_INFINITY;
}
let mut lp = (self.prior_shape - 1.0) * r_val.ln() - self.prior_scale * r_val;
for (idx, &yi) in self.y.iter().enumerate() {
let row = self.x.row(idx);
let xb = row
.iter()
.zip(&given[..p])
.map(|(xij, bj)| xij * bj)
.sum::<f64>();
let mu = xb.exp();
lp += ln_gamma(yi + r_val) - ln_gamma(r_val) + r_val * r_val.ln()
- (yi + r_val) * (r_val + mu).ln();
}
lp
};
let proposal_sd = 1.0; let proposal_ln_r =
current_r.ln() + self.rng.sample(Normal::new(0.0, proposal_sd).unwrap());
let proposal_r = proposal_ln_r.exp();
let log_alpha = log_posterior(proposal_r) - log_posterior(current_r) + proposal_ln_r
- current_r.ln();
if log_alpha >= 0.0 || self.rng.r#gen::<f64>().ln() < log_alpha {
proposal_r
} else {
current_r
}
}
}
}