use fastrand::Rng;
use ganesh::{
algorithms::mcmc::{
ess::{ESSConfig, ESSInit},
AutocorrelationTerminator, ESSMove, ESS,
},
core::{utils::SampleFloat, Callbacks, MaxSteps},
traits::{Algorithm, LogDensity},
DMatrix, DVector, Float,
};
use std::{convert::Infallible, error::Error, fs::File, io::BufWriter, path::Path};
fn main() -> Result<(), Box<dyn Error>> {
struct Problem;
impl LogDensity<DMatrix<Float>> for Problem {
fn log_density(
&self,
x: &DVector<Float>,
args: &DMatrix<Float>,
) -> Result<Float, Infallible> {
Ok(-0.5 * x.dot(&(args * x)))
}
}
let problem = Problem;
let mut rng = Rng::new();
rng.seed(0);
let x0: Vec<DVector<Float>> = (0..100)
.map(|_| DVector::from_fn(5, |_, _| rng.normal(0.0, 4.0)))
.collect();
let cov_inv = DMatrix::from_fn(5, 5, |i, j| if i == j { 1.0 } else { 0.1 } / rng.float());
println!("Σ⁻¹ = \n{}", cov_inv);
let aco = AutocorrelationTerminator::default()
.with_verbose(true)
.build();
let mut sampler = ESS::default();
let init = ESSInit::new(x0.clone()).unwrap();
let config = ESSConfig::default().with_moves([
ESSMove::gaussian(0.1),
ESSMove::custom_global(0.7, None, Some(0.5), Some(4))?,
ESSMove::differential(0.2),
])?;
let result = sampler.process(
&problem,
&cov_inv,
init,
config,
Callbacks::empty()
.with_terminator(aco.clone())
.with_terminator(MaxSteps(1000)),
)?;
let chains = result.chain;
let taus = aco.lock().taus.clone();
let mut writer = BufWriter::new(File::create(Path::new("data.pkl"))?);
serde_pickle::to_writer(&mut writer, &(chains, taus), Default::default())?;
Ok(())
}