oxilean_std/probabilistic_programming/
functions_2.rs1#![allow(clippy::items_after_test_module)]
5
6use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
7
8use super::functions::*;
9use super::types::{Distribution, Hmc, ImportanceSampler, MeanFieldVI, ParticleFilter, Rng};
10
11#[cfg(test)]
12mod tests {
13 use super::*;
14 #[test]
15 fn test_distribution_normal_log_density() {
16 let d = Distribution::Normal {
17 mean: 0.0,
18 std: 1.0,
19 };
20 let lp = d.log_density(0.0);
21 assert!((lp - (-0.5 * (2.0 * std::f64::consts::PI).ln())).abs() < 1e-8);
22 }
23 #[test]
24 fn test_distribution_sample_bernoulli() {
25 let mut rng = Rng::new(42);
26 let d = Distribution::Bernoulli { p: 0.7 };
27 let n = 2000;
28 let ones = (0..n).filter(|_| d.sample(&mut rng) == 1.0).count();
29 let frac = ones as f64 / n as f64;
30 assert!((frac - 0.7).abs() < 0.05, "Bernoulli(0.7) fraction: {frac}");
31 }
32 #[test]
33 fn test_importance_sampling_mean() {
34 let mut is = ImportanceSampler::new(5000, 99);
35 let est = is.estimate(
36 |x| x * x,
37 |rng| rng.normal_mv(0.0, 2.0),
38 |x| {
39 let lp = -0.5 * x * x - 0.5 * (2.0 * std::f64::consts::PI).ln();
40 let lq = -0.5 * (x / 2.0).powi(2)
41 - (2.0_f64).ln()
42 - 0.5 * (2.0 * std::f64::consts::PI).ln();
43 lp - lq
44 },
45 );
46 assert!(
47 (est - 1.0).abs() < 0.1,
48 "IS estimate of E[x^2] should be near 1.0, got {est}"
49 );
50 }
51 #[test]
52 fn test_particle_filter_constant_state() {
53 let mut pf = ParticleFilter::new(500, 7);
54 let obs: Vec<f64> = vec![3.0; 10];
55 let means = pf.filter_mean(
56 &obs,
57 |rng| rng.normal_mv(3.0, 1.0),
58 |x, rng| rng.normal_mv(x, 0.1),
59 |x, y| {
60 let z = (x - y) / 0.5;
61 -0.5 * z * z
62 },
63 );
64 let last = *means.last().expect("last should succeed");
65 assert!(
66 (last - 3.0).abs() < 1.0,
67 "PF mean should be near 3.0, got {last}"
68 );
69 }
70 #[test]
71 fn test_hmc_samples_normal() {
72 let mut hmc = Hmc::new(0.2, 5, 1337);
73 let samples = hmc.sample(
74 vec![0.0],
75 1000,
76 |q| -0.5 * (q[0] - 2.0).powi(2),
77 |q| vec![-(q[0] - 2.0)],
78 );
79 let mean = samples.iter().map(|s| s[0]).sum::<f64>() / samples.len() as f64;
80 assert!(
81 (mean - 2.0).abs() < 0.3,
82 "HMC mean should be near 2.0, got {mean}"
83 );
84 }
85 #[test]
86 fn test_mean_field_vi_converges() {
87 let mut vi = MeanFieldVI::new(1, 0.1, 10, 42);
88 let elbo_hist = vi.fit(|z| -0.5 * (z[0] - 5.0).powi(2), 200);
89 assert!(
90 (vi.mu[0] - 5.0).abs() < 1.5,
91 "VI mean should converge near 5.0, got {}",
92 vi.mu[0]
93 );
94 assert!(elbo_hist.last().expect("last should succeed").is_finite());
95 }
96 #[test]
97 fn test_build_probabilistic_programming_env() {
98 let mut env = Environment::new();
99 build_probabilistic_programming_env(&mut env).expect("env build failed");
100 assert!(env.get(&Name::str("Measure")).is_some());
101 assert!(env.get(&Name::str("ProbabilityMonad")).is_some());
102 assert!(env.get(&Name::str("elbo_lower_bound")).is_some());
103 assert!(env.get(&Name::str("hmc_invariant")).is_some());
104 }
105 #[test]
106 fn test_effective_sample_size() {
107 let n = 100;
108 let log_weights = vec![0.0f64; n];
109 let ess = ImportanceSampler::effective_sample_size(&log_weights);
110 assert!(
111 (ess - n as f64).abs() < 1.0,
112 "ESS for uniform weights should be {n}, got {ess}"
113 );
114 let mut lw = vec![f64::NEG_INFINITY; n];
115 lw[0] = 0.0;
116 let ess2 = ImportanceSampler::effective_sample_size(&lw);
117 assert!(ess2 < 2.0, "degenerate ESS should be ≈ 1, got {ess2}");
118 }
119}
120pub(super) fn ln_beta(a: f64, b: f64) -> f64 {
122 lgamma(a) + lgamma(b) - lgamma(a + b)
123}
124pub(super) fn lgamma(x: f64) -> f64 {
126 if x < 0.5 {
127 std::f64::consts::PI.ln() - (std::f64::consts::PI * x).sin().ln() - lgamma(1.0 - x)
128 } else {
129 let g = 7.0_f64;
130 let c = [
131 0.999_999_999_999_809_9_f64,
132 676.5203681218851,
133 -1259.1392167224028,
134 771.323_428_777_653_1,
135 -176.615_029_162_140_6,
136 12.507343278686905,
137 -0.13857109526572012,
138 9.984_369_578_019_572e-6,
139 1.5056327351493116e-7,
140 ];
141 let x = x - 1.0;
142 let t = x + g + 0.5;
143 let mut s = c[0];
144 for i in 1..9 {
145 s += c[i] / (x + i as f64);
146 }
147 0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * t.ln() - t + s.ln()
148 }
149}