use crate::PolyaGamma;
use mini_mcmc::{
core::{ChainRunner, init_det},
distributions::Conditional,
gibbs::GibbsSampler,
stats::RunStats,
};
use ndarray::{Array1, Array2, Array3};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use statrs::distribution::Normal;
use std::error::Error;
pub struct GibbsLogit<R = ChaCha8Rng>
where
R: SeedableRng + Rng + Clone + Send + Sync,
{
x: Array2<f64>,
y: Array1<f64>,
prior_prec: Array2<f64>,
n_chains: usize,
seed: u64,
rng: R,
}
pub struct LogisticRegressionResults {
pub posterior_means: Vec<f64>,
pub posterior_sds: Vec<f64>,
pub samples: Array3<f64>,
pub true_coefficients: Option<Vec<f64>>,
pub run_stats: RunStats,
}
impl GibbsLogit<ChaCha8Rng> {
pub fn new(
x: Array2<f64>,
y: Array1<f64>,
prior_variance: f64,
n_chains: usize,
seed: u64,
) -> Self {
let p = x.ncols();
let prior_prec = Array2::eye(p) * (1.0 / prior_variance);
Self {
x,
y,
prior_prec,
n_chains,
seed,
rng: ChaCha8Rng::seed_from_u64(seed),
}
}
}
impl<R: SeedableRng + Rng + Clone + Send + Sync> GibbsLogit<R> {
pub fn from_rng(
rng: R,
x: Array2<f64>,
y: Array1<f64>,
prior_variance: f64,
n_chains: usize,
seed: u64,
) -> Self {
let p = x.ncols();
let prior_prec = Array2::eye(p) * (1.0 / prior_variance);
Self {
x,
y,
prior_prec,
n_chains,
seed,
rng,
}
}
}
impl<R: SeedableRng + Rng + Clone + Send + Sync> GibbsLogit<R> {
pub fn run(
self,
burn_in: usize,
samples: usize,
true_coefficients: Option<Vec<f64>>,
) -> Result<LogisticRegressionResults, Box<dyn Error>> {
let n = self.x.nrows();
let p = self.x.ncols();
let dim = p + n;
let cond = LogitConditional::new(
self.x.clone(),
self.y.clone(),
self.prior_prec.clone(),
self.rng,
);
let mut init: Vec<Vec<f64>> = init_det(self.n_chains, dim);
for state in &mut init {
for w in &mut state[p..] {
*w = 1.0; }
}
let mut gibbs = GibbsSampler::new(cond, init).set_seed(self.seed);
let (all_samples, run_stats) = gibbs.run_progress(samples, burn_in)?;
let pooled = all_samples.to_shape((self.n_chains * samples, dim))?;
let posterior_means: Vec<f64> = (0..p).map(|j| pooled.column(j).mean().unwrap()).collect();
let posterior_sds: Vec<f64> = (0..p)
.map(|j| pooled.column(j).std(1.0)) .collect();
Ok(LogisticRegressionResults {
posterior_means,
posterior_sds,
samples: all_samples,
true_coefficients,
run_stats,
})
}
}
impl LogisticRegressionResults {
pub fn summary(&self) {
println!(
"{:<10} {:<15} {:<15} {:<15}",
"Parameter", "Mean", "Std. Dev.", "True Value"
);
println!("{}", "-".repeat(55));
for (i, (mean, sd)) in self
.posterior_means
.iter()
.zip(&self.posterior_sds)
.enumerate()
{
let true_val = self
.true_coefficients
.as_ref()
.map_or("N/A".to_string(), |v| format!("{:.4}", v[i]));
println!(
"{:<10} {:<15.4} {:<15.4} {:<15}",
format!("β{}", i),
mean,
sd,
true_val
);
}
}
pub fn get_posterior_samples(&self, param_idx: usize) -> Option<Vec<f64>> {
if param_idx >= self.samples.shape()[2] {
return None;
}
let n_chains = self.samples.shape()[0];
let n_samples = self.samples.shape()[1];
let mut samples = Vec::with_capacity(n_chains * n_samples);
for chain in 0..n_chains {
for sample in 0..n_samples {
samples.push(self.samples[[chain, sample, param_idx]]);
}
}
Some(samples)
}
}
#[derive(Clone)]
struct LogitConditional<R>
where
R: SeedableRng + Rng + Clone + Send + Sync,
{
x: Array2<f64>,
y: Array1<f64>,
prior_prec: Array2<f64>,
pg: PolyaGamma,
rng: R,
}
impl<R> LogitConditional<R>
where
R: SeedableRng + Rng + Clone + Send + Sync,
{
pub fn new(x: Array2<f64>, y: Array1<f64>, prior_prec: Array2<f64>, rng: R) -> Self {
Self {
x,
y,
prior_prec,
pg: PolyaGamma::new(1.0),
rng,
}
}
}
impl<R> Conditional<f64> for LogitConditional<R>
where
R: SeedableRng + Rng + Clone + Send + Sync,
{
fn sample(&mut self, i: usize, given: &[f64]) -> f64 {
let n = self.x.nrows();
let p = self.x.ncols();
if i < p {
let col_i = self.x.column(i);
let prior_ii = self.prior_prec[(i, i)];
let mut precision = prior_ii;
let mut precision_mean = 0.0;
for row_idx in 0..n {
let xi = col_i[row_idx];
let wi = given[p + row_idx]; precision += wi * xi * xi;
let mut dot_minus_i = 0.0;
for (k, bj) in given.iter().enumerate().take(p) {
if k != i {
dot_minus_i += self.x[(row_idx, k)] * bj;
}
}
let yi = self.y[row_idx];
let resid = (yi - 0.5) - wi * dot_minus_i;
precision_mean += xi * resid;
}
let var_i = 1.0 / precision;
let mean_i = precision_mean * var_i;
let eps: f64 = self.rng.sample(Normal::standard());
mean_i + eps * var_i.sqrt()
} else {
let row = self.x.row(i - p);
let xb: f64 = row
.iter()
.zip(&given[0..p])
.map(|(xij, bj)| (xij * bj))
.sum::<f64>()
.abs();
self.pg.draw(&mut self.rng, xb)
}
}
}