use super::random;
use crate::marshal::{register_typed_fn_1, register_typed_fn_2, register_typed_fn_3};
use crate::module_exports::ModuleExports;
use crate::typed_module_exports::{ConcreteReturn, ConcreteType, TypedReturn};
use rand::Rng;
use rand_chacha::ChaCha8Rng;
use std::sync::Arc;
pub fn create_distributions_intrinsics_module() -> ModuleExports {
let mut module = ModuleExports::new("std::core::intrinsics::distributions");
module.description =
"Statistical distribution sampling intrinsics (uniform, lognormal, exponential, poisson, sample_n)"
.to_string();
register_typed_fn_2::<_, f64, f64>(
&mut module,
"__intrinsic_dist_uniform",
"Sample from a uniform distribution [lo, hi)",
[("lo", "number"), ("hi", "number")],
ConcreteType::Number,
|lo, hi, _ctx| {
if lo >= hi {
return Err(format!(
"__intrinsic_dist_uniform: lo ({}) must be < hi ({})",
lo, hi
));
}
let value = random::with_rng(|rng| sample_uniform(rng, lo, hi));
Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
},
);
register_typed_fn_2::<_, f64, f64>(
&mut module,
"__intrinsic_dist_lognormal",
"Sample from a lognormal distribution",
[("mean", "number"), ("std", "number")],
ConcreteType::Number,
|mean, std, _ctx| {
if std < 0.0 {
return Err("__intrinsic_dist_lognormal: std must be non-negative".to_string());
}
let value = random::with_rng(|rng| sample_lognormal(rng, mean, std));
Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
},
);
register_typed_fn_1::<_, f64>(
&mut module,
"__intrinsic_dist_exponential",
"Sample from an exponential distribution",
"lambda",
"number",
ConcreteType::Number,
|lambda, _ctx| {
if lambda <= 0.0 {
return Err("__intrinsic_dist_exponential: lambda must be positive".to_string());
}
let value = random::with_rng(|rng| sample_exponential(rng, lambda));
Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
},
);
register_typed_fn_1::<_, f64>(
&mut module,
"__intrinsic_dist_poisson",
"Sample from a Poisson distribution",
"lambda",
"number",
ConcreteType::Number,
|lambda, _ctx| {
if lambda < 0.0 {
return Err("__intrinsic_dist_poisson: lambda must be non-negative".to_string());
}
let value = random::with_rng(|rng| sample_poisson(rng, lambda));
Ok(TypedReturn::Concrete(ConcreteReturn::F64(value)))
},
);
register_typed_fn_3::<_, Arc<String>, Arc<Vec<f64>>, i64>(
&mut module,
"__intrinsic_dist_sample_n",
"Sample n values from a named distribution (uniform / lognormal / exponential / poisson)",
[
("dist_name", "string"),
("params", "Array<number>"),
("n", "int"),
],
ConcreteType::ArrayNumber,
|dist_name, params, n, _ctx| {
if n < 0 {
return Err("__intrinsic_dist_sample_n: n must be non-negative".to_string());
}
let n = n as usize;
let p = params.as_slice();
let dist = dist_name.as_str();
match dist {
"uniform" | "lognormal" => {
if p.len() != 2 {
return Err(format!(
"__intrinsic_dist_sample_n: '{}' requires 2 params, got {}",
dist,
p.len()
));
}
}
"exponential" | "poisson" => {
if p.len() != 1 {
return Err(format!(
"__intrinsic_dist_sample_n: '{}' requires 1 param, got {}",
dist,
p.len()
));
}
}
_ => return Err(format!("Unknown distribution: {}", dist)),
}
match dist {
"uniform" if p[0] >= p[1] => {
return Err(format!(
"__intrinsic_dist_sample_n: uniform lo ({}) must be < hi ({})",
p[0], p[1]
));
}
"lognormal" if p[1] < 0.0 => {
return Err(
"__intrinsic_dist_sample_n: lognormal std must be non-negative".to_string(),
);
}
"exponential" if p[0] <= 0.0 => {
return Err(
"__intrinsic_dist_sample_n: exponential lambda must be positive"
.to_string(),
);
}
"poisson" if p[0] < 0.0 => {
return Err(
"__intrinsic_dist_sample_n: poisson lambda must be non-negative"
.to_string(),
);
}
_ => {}
}
let samples: Vec<f64> = random::with_rng(|rng| {
(0..n)
.map(|_| match dist {
"uniform" => sample_uniform(rng, p[0], p[1]),
"lognormal" => sample_lognormal(rng, p[0], p[1]),
"exponential" => sample_exponential(rng, p[0]),
"poisson" => sample_poisson(rng, p[0]),
_ => unreachable!(),
})
.collect()
});
Ok(TypedReturn::Concrete(ConcreteReturn::ArrayF64(samples)))
},
);
module
}
fn sample_uniform(rng: &mut ChaCha8Rng, lo: f64, hi: f64) -> f64 {
let u: f64 = rng.r#gen();
lo + (hi - lo) * u
}
fn sample_lognormal(rng: &mut ChaCha8Rng, mean: f64, std: f64) -> f64 {
let u1: f64 = rng.r#gen();
let u2: f64 = rng.r#gen();
let z = (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos();
(mean + std * z).exp()
}
fn sample_exponential(rng: &mut ChaCha8Rng, lambda: f64) -> f64 {
let u: f64 = rng.r#gen();
-u.ln() / lambda
}
fn sample_poisson(rng: &mut ChaCha8Rng, lambda: f64) -> f64 {
if lambda < 30.0 {
let l = (-lambda).exp();
let mut k = 0;
let mut p = 1.0;
loop {
k += 1;
let u: f64 = rng.r#gen();
p *= u;
if p <= l {
break;
}
}
(k - 1) as f64
} else {
let u1: f64 = rng.r#gen();
let u2: f64 = rng.r#gen();
let z = (-2.0_f64 * u1.ln()).sqrt() * (2.0_f64 * std::f64::consts::PI * u2).cos();
let value = lambda + lambda.sqrt() * z;
value.max(0.0).round()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
use crate::intrinsics::random as random_intrinsics;
fn mean_variance(samples: &[f64]) -> (f64, f64) {
let mean = samples.iter().sum::<f64>() / samples.len() as f64;
let var = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
(mean, var)
}
#[test]
fn test_uniform_mean_variance() {
random_intrinsics::with_rng(|rng| {
*rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
});
let samples: Vec<f64> =
random_intrinsics::with_rng(|rng| (0..20000).map(|_| sample_uniform(rng, 0.0, 2.0)).collect());
let (mean, var) = mean_variance(&samples);
assert!((mean - 1.0).abs() < 0.05);
assert!((var - 1.0 / 3.0).abs() < 0.05);
}
#[test]
fn test_exponential_mean_variance() {
random_intrinsics::with_rng(|rng| {
*rng = rand_chacha::ChaCha8Rng::seed_from_u64(7);
});
let lambda = 2.0;
let samples: Vec<f64> = random_intrinsics::with_rng(|rng| {
(0..20000).map(|_| sample_exponential(rng, lambda)).collect()
});
let (mean, var) = mean_variance(&samples);
assert!((mean - 1.0 / lambda).abs() < 0.05);
assert!((var - 1.0 / (lambda * lambda)).abs() < 0.1);
}
#[test]
fn test_poisson_mean_variance() {
random_intrinsics::with_rng(|rng| {
*rng = rand_chacha::ChaCha8Rng::seed_from_u64(123);
});
let lambda = 12.0;
let samples: Vec<f64> = random_intrinsics::with_rng(|rng| {
(0..20000).map(|_| sample_poisson(rng, lambda)).collect()
});
let (mean, var) = mean_variance(&samples);
assert!((mean - lambda).abs() < 0.3);
assert!((var - lambda).abs() < 0.5);
}
}