use std::{
collections::HashMap,
f64,
time::{Duration, Instant},
};
use anyhow::Result;
use nuts_rs::{
ArrowConfig, CpuLogpFunc, CpuMath, CpuMathError, DiagNutsSettings, LogpError, Model, Sampler,
SamplerWaitResult, Storable,
};
use nuts_storable::{HasDims, Value};
use rand::{Rng, RngExt};
use thiserror::Error;
#[derive(Clone, Debug)]
struct MultivariateNormal {
mean: Vec<f64>,
precision: Vec<Vec<f64>>, }
impl MultivariateNormal {
fn new(mean: Vec<f64>, precision: Vec<Vec<f64>>) -> Self {
Self { mean, precision }
}
}
#[allow(dead_code)]
#[derive(Debug, Error)]
enum MyLogpError {
#[error("Recoverable error in logp calculation: {0}")]
Recoverable(String),
#[error("Non-recoverable error in logp calculation: {0}")]
NonRecoverable(String),
}
impl LogpError for MyLogpError {
fn is_recoverable(&self) -> bool {
matches!(self, MyLogpError::Recoverable(_))
}
}
#[derive(Clone)]
struct MvnLogp {
model: MultivariateNormal,
buffer: Vec<f64>, }
impl HasDims for MvnLogp {
fn dim_sizes(&self) -> HashMap<String, u64> {
HashMap::from([
("x".to_string(), self.model.mean.len() as u64),
])
}
fn coords(&self) -> HashMap<String, nuts_storable::Value> {
HashMap::from([(
"x".to_string(),
Value::Strings(vec!["x1".to_string(), "x2".to_string()]),
)])
}
}
#[derive(Storable)]
struct ExpandedDraw {
#[storable(dims("x"))]
prec: Vec<f64>,
diff: f64,
}
impl CpuLogpFunc for MvnLogp {
type LogpError = MyLogpError;
type FlowParameters = (); type ExpandedVector = ExpandedDraw;
fn dim(&self) -> usize {
self.model.mean.len()
}
fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
let n = x.len();
let diff = &mut self.buffer;
for i in 0..n {
diff[i] = x[i] - self.model.mean[i];
}
let mut quad = 0.0;
for i in 0..n {
let mut pdot = 0.0;
for j in 0..n {
let pij = self.model.precision[i][j];
pdot += pij * diff[j];
quad += diff[i] * pij * diff[j];
}
grad[i] = -pdot;
}
Ok(-0.5 * quad)
}
fn expand_vector<R: Rng + ?Sized>(
&mut self,
_rng: &mut R,
array: &[f64],
) -> Result<Self::ExpandedVector, CpuMathError> {
Ok(ExpandedDraw {
prec: array.to_vec(),
diff: array[1] - array[0], })
}
fn vector_coord(&self) -> Option<Value> {
Some(Value::Strings(vec!["x1".to_string(), "x2".to_string()]))
}
}
struct MvnModel {
math: CpuMath<MvnLogp>,
}
impl Model for MvnModel {
type Math<'model>
= CpuMath<MvnLogp>
where
Self: 'model;
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
Ok(self.math.clone())
}
fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
for p in position.iter_mut() {
*p = rng.random_range(-2.0..2.0);
}
Ok(())
}
}
fn main() -> Result<()> {
println!("=== Multivariate Normal MCMC with Arrow Storage ===\n");
let mean = vec![0.0, 0.0];
let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let mvn = MultivariateNormal::new(mean, precision);
println!("Model: 2D Multivariate Normal");
println!("Mean: {:?}", mvn.mean);
println!("Precision matrix: {:?}\n", mvn.precision);
let output_path = "mcmc_output";
println!("Output will be saved to: {}/\n", output_path);
let num_chains = 4; let num_tune = 500; let num_draws = 500;
println!("Sampling configuration:");
println!(" Chains: {}", num_chains);
println!(" Warmup samples: {}", num_tune);
println!(" Sampling draws: {}", num_draws);
let mut settings = DiagNutsSettings::default();
settings.num_chains = num_chains as _;
settings.num_tune = num_tune;
settings.num_draws = num_draws as _;
settings.seed = 54;
let model = MvnModel {
math: CpuMath::new(MvnLogp {
model: mvn,
buffer: vec![0.0; 2],
}),
};
println!("\nStarting MCMC sampling...\n");
let start = Instant::now();
let arrow_config = ArrowConfig::new();
let mut sampler = Some(Sampler::new(model, settings, arrow_config, 4, None)?);
let mut num_progress_updates = 0;
while let Some(sampler_) = sampler.take() {
match sampler_.wait_timeout(Duration::from_millis(50)) {
SamplerWaitResult::Trace(traces) => {
println!("✓ Sampling completed in {:?}", start.elapsed());
println!("✓ MCMC traces stored in Arrow format");
println!("\nTrace summary:");
println!(" Number of chains: {}", traces.len());
if let Some(first_trace) = traces.first() {
println!(
" Posterior samples: {} rows, {} columns",
first_trace.posterior.num_rows(),
first_trace.posterior.num_columns()
);
println!(
" Sample stats: {} rows, {} columns",
first_trace.sample_stats.num_rows(),
first_trace.sample_stats.num_columns()
);
println!("\n Posterior columns:");
for field in first_trace.posterior.schema().fields() {
println!(
" {} ({} {:?})",
field.name(),
field.data_type(),
field.metadata(),
);
}
println!("\n Sample stats columns:");
for field in first_trace.sample_stats.schema().fields() {
println!(
" {} ({} {:?})",
field.name(),
field.data_type(),
field.metadata(),
);
}
}
break;
}
SamplerWaitResult::Timeout(mut sampler_) => {
num_progress_updates += 1;
println!("Progress update {}:", num_progress_updates);
let progress = sampler_.progress()?;
for (i, chain) in progress.iter().enumerate() {
let phase = if chain.tuning { "warmup" } else { "sampling" };
println!(
" Chain {}: {} samples ({} divergences), step size: {:.6} [{}]",
i, chain.finished_draws, chain.divergences, chain.step_size, phase
);
}
println!();
sampler = Some(sampler_);
}
SamplerWaitResult::Err(err, _) => {
eprintln!("✗ Sampling failed: {}", err);
return Err(err);
}
}
}
Ok(())
}