use crate::error::PramanaError;
use crate::rng::Rng;
pub use crate::rng::SimpleRng;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EstimateResult {
pub value: f64,
pub std_error: f64,
}
#[must_use = "returns the integration estimate"]
pub fn monte_carlo_integrate(
f: impl Fn(f64) -> f64,
a: f64,
b: f64,
n_samples: usize,
rng: &mut impl Rng,
) -> Result<EstimateResult, PramanaError> {
if n_samples == 0 {
return Err(PramanaError::InvalidParameter(
"n_samples must be positive".into(),
));
}
if a >= b {
return Err(PramanaError::InvalidParameter(
"lower bound must be less than upper bound".into(),
));
}
let width = b - a;
let mut sum = 0.0;
let mut sum_sq = 0.0;
for _ in 0..n_samples {
let x = a + rng.next_f64() * width;
let val = f(x);
sum += val;
sum_sq += val * val;
}
let n = n_samples as f64;
let mean = sum / n;
let value = width * mean;
let var = sum_sq / n - mean * mean;
let std_error = width * (var / n).sqrt();
Ok(EstimateResult { value, std_error })
}
#[must_use = "returns the pi estimate"]
pub fn monte_carlo_pi(n_samples: usize, rng: &mut impl Rng) -> Result<f64, PramanaError> {
if n_samples == 0 {
return Err(PramanaError::InvalidParameter(
"n_samples must be positive".into(),
));
}
let mut inside = 0u64;
for _ in 0..n_samples {
let x = rng.next_f64() * 2.0 - 1.0;
let y = rng.next_f64() * 2.0 - 1.0;
if x * x + y * y <= 1.0 {
inside += 1;
}
}
Ok(4.0 * inside as f64 / n_samples as f64)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McmcResult {
pub samples: Vec<Vec<f64>>,
pub acceptance_rate: f64,
}
pub fn metropolis_hastings(
log_target: impl Fn(&[f64]) -> f64,
initial: &[f64],
proposal_std: f64,
n_samples: usize,
burn_in: usize,
rng: &mut impl Rng,
) -> Result<McmcResult, PramanaError> {
let dim = initial.len();
if dim == 0 {
return Err(PramanaError::InvalidParameter(
"initial point must be non-empty".into(),
));
}
if proposal_std <= 0.0 {
return Err(PramanaError::InvalidParameter(
"proposal_std must be positive".into(),
));
}
if n_samples == 0 {
return Err(PramanaError::InvalidParameter(
"n_samples must be positive".into(),
));
}
let total = n_samples + burn_in;
let mut current = initial.to_vec();
let mut log_current = log_target(¤t);
let mut samples = Vec::with_capacity(n_samples);
let mut accepted: u64 = 0;
for step in 0..total {
let proposal: Vec<f64> = current
.iter()
.map(|&xi| {
let u1 = rng.next_f64().max(f64::MIN_POSITIVE);
let u2 = rng.next_f64();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
xi + proposal_std * z
})
.collect();
let log_proposal = log_target(&proposal);
let log_alpha = log_proposal - log_current;
let accept = if log_alpha >= 0.0 {
true
} else {
rng.next_f64() < log_alpha.exp()
};
if accept {
current = proposal;
log_current = log_proposal;
accepted += 1;
}
if step >= burn_in {
samples.push(current.clone());
}
}
let acceptance_rate = accepted as f64 / total as f64;
Ok(McmcResult {
samples,
acceptance_rate,
})
}
pub fn gibbs_sampling<F>(
conditionals: &[F],
initial: &[f64],
n_samples: usize,
burn_in: usize,
rng: &mut impl Rng,
) -> Result<McmcResult, PramanaError>
where
F: Fn(&[f64], &mut dyn Rng) -> f64,
{
let dim = conditionals.len();
if dim == 0 {
return Err(PramanaError::InvalidParameter(
"need at least one conditional sampler".into(),
));
}
if initial.len() != dim {
return Err(PramanaError::DimensionMismatch(format!(
"initial has length {}, expected {dim}",
initial.len()
)));
}
if n_samples == 0 {
return Err(PramanaError::InvalidParameter(
"n_samples must be positive".into(),
));
}
let total = n_samples + burn_in;
let mut state = initial.to_vec();
let mut samples = Vec::with_capacity(n_samples);
for step in 0..total {
for (j, cond) in conditionals.iter().enumerate() {
state[j] = cond(&state, rng);
}
if step >= burn_in {
samples.push(state.clone());
}
}
Ok(McmcResult {
samples,
acceptance_rate: 1.0, })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_monte_carlo_pi() {
let mut rng = SimpleRng::new(42);
let pi = monte_carlo_pi(100_000, &mut rng).unwrap();
assert!(
(pi - std::f64::consts::PI).abs() < 0.05,
"pi estimate {pi} too far from actual"
);
}
#[test]
fn test_monte_carlo_integrate_x_squared() {
let mut rng = SimpleRng::new(42);
let result = monte_carlo_integrate(|x| x * x, 0.0, 1.0, 100_000, &mut rng).unwrap();
assert!(
(result.value - 1.0 / 3.0).abs() < 0.01,
"integral estimate {} too far from 1/3",
result.value
);
assert!(result.std_error > 0.0);
}
#[test]
fn test_zero_samples() {
let mut rng = SimpleRng::new(1);
assert!(monte_carlo_pi(0, &mut rng).is_err());
assert!(monte_carlo_integrate(|x| x, 0.0, 1.0, 0, &mut rng).is_err());
}
#[test]
fn serde_roundtrip() {
let est = EstimateResult {
value: 3.15,
std_error: 0.01,
};
let json = serde_json::to_string(&est).unwrap();
let est2: EstimateResult = serde_json::from_str(&json).unwrap();
assert_eq!(est.value, est2.value);
}
#[test]
fn mh_samples_standard_normal() {
let mut rng = SimpleRng::new(42);
let result =
metropolis_hastings(|x| -0.5 * x[0] * x[0], &[0.0], 1.0, 50_000, 5_000, &mut rng)
.unwrap();
assert_eq!(result.samples.len(), 50_000);
assert!(result.acceptance_rate > 0.1 && result.acceptance_rate < 0.9);
let mean: f64 =
result.samples.iter().map(|s| s[0]).sum::<f64>() / result.samples.len() as f64;
assert!(mean.abs() < 0.1, "sample mean = {mean}");
let var: f64 = result
.samples
.iter()
.map(|s| (s[0] - mean).powi(2))
.sum::<f64>()
/ result.samples.len() as f64;
assert!((var - 1.0).abs() < 0.2, "sample variance = {var}");
}
#[test]
fn mh_2d_target() {
let mut rng = SimpleRng::new(123);
let result = metropolis_hastings(
|x| -0.5 * (x[0] * x[0] + x[1] * x[1]),
&[0.0, 0.0],
0.5,
20_000,
2_000,
&mut rng,
)
.unwrap();
assert_eq!(result.samples[0].len(), 2);
let mean_x: f64 =
result.samples.iter().map(|s| s[0]).sum::<f64>() / result.samples.len() as f64;
let mean_y: f64 =
result.samples.iter().map(|s| s[1]).sum::<f64>() / result.samples.len() as f64;
assert!(mean_x.abs() < 0.15, "mean_x = {mean_x}");
assert!(mean_y.abs() < 0.15, "mean_y = {mean_y}");
}
#[test]
fn mh_invalid_params() {
let mut rng = SimpleRng::new(1);
assert!(metropolis_hastings(|_| 0.0, &[], 1.0, 100, 0, &mut rng).is_err());
assert!(metropolis_hastings(|_| 0.0, &[0.0], 0.0, 100, 0, &mut rng).is_err());
assert!(metropolis_hastings(|_| 0.0, &[0.0], 1.0, 0, 0, &mut rng).is_err());
}
#[test]
fn mh_serde_roundtrip() {
let result = McmcResult {
samples: vec![vec![1.0, 2.0], vec![3.0, 4.0]],
acceptance_rate: 0.5,
};
let json = serde_json::to_string(&result).unwrap();
let r2: McmcResult = serde_json::from_str(&json).unwrap();
assert_eq!(result.samples, r2.samples);
assert_eq!(result.acceptance_rate, r2.acceptance_rate);
}
type GibbsCond = Box<dyn Fn(&[f64], &mut dyn Rng) -> f64>;
#[test]
fn gibbs_independent_normals() {
use std::f64::consts::PI;
fn sample_normal(rng: &mut dyn Rng) -> f64 {
let u1 = rng.next_f64().max(f64::MIN_POSITIVE);
let u2 = rng.next_f64();
(-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
}
let conditionals: Vec<GibbsCond> = vec![
Box::new(|_, rng: &mut dyn Rng| sample_normal(rng)),
Box::new(|_, rng: &mut dyn Rng| sample_normal(rng)),
];
let mut rng = SimpleRng::new(42);
let result = gibbs_sampling(&conditionals, &[0.0, 0.0], 20_000, 1_000, &mut rng).unwrap();
assert_eq!(result.samples.len(), 20_000);
assert!((result.acceptance_rate - 1.0).abs() < 1e-10);
let mean_x: f64 =
result.samples.iter().map(|s| s[0]).sum::<f64>() / result.samples.len() as f64;
assert!(mean_x.abs() < 0.1, "mean_x = {mean_x}");
}
#[test]
fn gibbs_invalid_params() {
let mut rng = SimpleRng::new(1);
let conds: Vec<GibbsCond> = vec![];
assert!(gibbs_sampling(&conds, &[], 100, 0, &mut rng).is_err());
let conds: Vec<GibbsCond> = vec![Box::new(|_, _: &mut dyn Rng| 0.0)];
assert!(gibbs_sampling(&conds, &[0.0, 0.0], 100, 0, &mut rng).is_err());
assert!(gibbs_sampling(&conds, &[0.0], 0, 0, &mut rng).is_err());
}
}