use std::{
collections::HashMap,
f64,
time::{Duration, Instant},
};
use anyhow::Result;
use nuts_rs::{
CpuLogpFunc, CpuMath, CpuMathError, CsvConfig, DiagNutsSettings, LogpError, Model, Sampler,
SamplerWaitResult, Storable,
};
use nuts_storable::{HasDims, Value};
use rand::{Rng, RngExt};
use thiserror::Error;
#[derive(Clone, Debug)]
struct MultivariateNormal {
mean: Vec<f64>,
precision: Vec<Vec<f64>>, }
impl MultivariateNormal {
fn new(mean: Vec<f64>, precision: Vec<Vec<f64>>) -> Self {
Self { mean, precision }
}
}
#[allow(dead_code)]
#[derive(Debug, Error)]
enum MyLogpError {
#[error("Recoverable error in logp calculation: {0}")]
Recoverable(String),
#[error("Non-recoverable error in logp calculation: {0}")]
NonRecoverable(String),
}
impl LogpError for MyLogpError {
fn is_recoverable(&self) -> bool {
matches!(self, MyLogpError::Recoverable(_))
}
}
#[derive(Clone)]
struct MvnLogp {
model: MultivariateNormal,
}
impl HasDims for MvnLogp {
fn dim_sizes(&self) -> HashMap<String, u64> {
HashMap::from([
("param".to_string(), self.model.mean.len() as u64),
])
}
fn coords(&self) -> HashMap<String, nuts_storable::Value> {
HashMap::from([(
"param".to_string(),
Value::Strings(vec!["mu1".to_string(), "mu2".to_string()]),
)])
}
}
#[derive(Storable)]
struct ExpandedDraw {
#[storable(dims("param"))]
parameters: Vec<f64>,
}
impl CpuLogpFunc for MvnLogp {
type LogpError = MyLogpError;
type FlowParameters = (); type ExpandedVector = ExpandedDraw;
fn dim(&self) -> usize {
self.model.mean.len()
}
fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
let n = x.len();
let mut diff = vec![0.0; n];
for i in 0..n {
diff[i] = x[i] - self.model.mean[i];
}
let mut quad = 0.0;
for i in 0..n {
let mut pdot = 0.0;
for j in 0..n {
let pij = self.model.precision[i][j];
pdot += pij * diff[j];
quad += diff[i] * pij * diff[j];
}
grad[i] = -pdot;
}
Ok(-0.5 * quad)
}
fn expand_vector<R: Rng + ?Sized>(
&mut self,
_rng: &mut R,
array: &[f64],
) -> Result<Self::ExpandedVector, CpuMathError> {
Ok(ExpandedDraw {
parameters: array.to_vec(),
})
}
fn vector_coord(&self) -> Option<Value> {
Some(Value::Strings(vec!["mu1".to_string(), "mu2".to_string()]))
}
}
struct MvnModel {
math: CpuMath<MvnLogp>,
}
impl Model for MvnModel {
type Math<'model>
= CpuMath<MvnLogp>
where
Self: 'model;
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
Ok(self.math.clone())
}
fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
for p in position.iter_mut() {
*p = rng.random_range(-2.0..2.0);
}
Ok(())
}
}
fn main() -> Result<()> {
println!("=== Multivariate Normal MCMC with CSV Storage ===\n");
let mean = vec![0.0, 0.0];
let precision = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let mvn = MultivariateNormal::new(mean, precision);
println!("Model: 2D Multivariate Normal");
println!("Mean: {:?}", mvn.mean);
println!("Precision matrix: {:?}\n", mvn.precision);
let output_path = "csv_output";
println!("Output will be saved to: {}/\n", output_path);
let num_chains = 4; let num_tune = 500; let num_draws = 500;
println!("Sampling configuration:");
println!(" Chains: {}", num_chains);
println!(" Warmup samples: {}", num_tune);
println!(" Sampling draws: {}", num_draws);
let mut settings = DiagNutsSettings::default();
settings.num_chains = num_chains as _;
settings.num_tune = num_tune;
settings.num_draws = num_draws as _;
settings.seed = 54;
let csv_config = CsvConfig::new(output_path)
.with_precision(6) .store_warmup(true);
let model = MvnModel {
math: CpuMath::new(MvnLogp { model: mvn }),
};
println!("\nStarting MCMC sampling...\n");
let start = Instant::now();
let mut sampler = Some(Sampler::new(model, settings, csv_config, 4, None)?);
let mut num_progress_updates = 0;
while let Some(sampler_) = sampler.take() {
match sampler_.wait_timeout(Duration::from_millis(50)) {
SamplerWaitResult::Trace(_) => {
println!("✓ Sampling completed in {:?}", start.elapsed());
println!("✓ Traces written to CSV format in '{}'", output_path);
if let Ok(entries) = std::fs::read_dir(output_path) {
println!("\nOutput files:");
for entry in entries.flatten() {
if let Some(name) = entry.file_name().to_str() {
if name.ends_with(".csv") {
println!(" - {}", name);
}
}
}
}
println!("\n=== Next Steps ===");
println!("The CSV files are compatible with CmdStan format and can be read by:");
println!(" - R: posterior package, bayesplot, etc.");
println!(" - Python: arviz.from_cmdstanpy() or pandas.read_csv()");
println!(" - Stan ecosystem tools");
println!("\nExample usage in Python:");
println!(" import pandas as pd");
println!(" import arviz as az");
println!(" # Read individual chains");
println!(" chain0 = pd.read_csv('{}/chain_0.csv')", output_path);
println!(" # Or use arviz to read all chains");
println!(" # (Note: arviz.from_cmdstanpy might need adaptation)");
println!("\nExample usage in R:");
println!(" library(posterior)");
println!(" draws <- read_cmdstan_csv(c(");
for i in 0..num_chains {
let comma = if i == num_chains - 1 { "" } else { "," };
println!(" '{}/chain_{}.csv'{}", output_path, i, comma);
}
println!(" ))");
println!(" summarise_draws(draws)");
break;
}
SamplerWaitResult::Timeout(mut sampler_) => {
num_progress_updates += 1;
println!("Progress update {}:", num_progress_updates);
let progress = sampler_.progress()?;
for (i, chain) in progress.iter().enumerate() {
let phase = if chain.tuning { "warmup" } else { "sampling" };
println!(
" Chain {}: {} samples ({} divergences), step size: {:.6} [{}]",
i, chain.finished_draws, chain.divergences, chain.step_size, phase
);
}
println!();
sampler = Some(sampler_);
}
SamplerWaitResult::Err(err, _) => {
eprintln!("✗ Sampling failed: {}", err);
return Err(err);
}
}
}
Ok(())
}