Skip to main content

oxilean_std/probabilistic_programming/
functions_2.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4#![allow(clippy::items_after_test_module)]
5
6use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
7
8use super::functions::*;
9use super::types::{Distribution, Hmc, ImportanceSampler, MeanFieldVI, ParticleFilter, Rng};
10
11#[cfg(test)]
12mod tests {
13    use super::*;
14    #[test]
15    fn test_distribution_normal_log_density() {
16        let d = Distribution::Normal {
17            mean: 0.0,
18            std: 1.0,
19        };
20        let lp = d.log_density(0.0);
21        assert!((lp - (-0.5 * (2.0 * std::f64::consts::PI).ln())).abs() < 1e-8);
22    }
23    #[test]
24    fn test_distribution_sample_bernoulli() {
25        let mut rng = Rng::new(42);
26        let d = Distribution::Bernoulli { p: 0.7 };
27        let n = 2000;
28        let ones = (0..n).filter(|_| d.sample(&mut rng) == 1.0).count();
29        let frac = ones as f64 / n as f64;
30        assert!((frac - 0.7).abs() < 0.05, "Bernoulli(0.7) fraction: {frac}");
31    }
32    #[test]
33    fn test_importance_sampling_mean() {
34        let mut is = ImportanceSampler::new(5000, 99);
35        let est = is.estimate(
36            |x| x * x,
37            |rng| rng.normal_mv(0.0, 2.0),
38            |x| {
39                let lp = -0.5 * x * x - 0.5 * (2.0 * std::f64::consts::PI).ln();
40                let lq = -0.5 * (x / 2.0).powi(2)
41                    - (2.0_f64).ln()
42                    - 0.5 * (2.0 * std::f64::consts::PI).ln();
43                lp - lq
44            },
45        );
46        assert!(
47            (est - 1.0).abs() < 0.1,
48            "IS estimate of E[x^2] should be near 1.0, got {est}"
49        );
50    }
51    #[test]
52    fn test_particle_filter_constant_state() {
53        let mut pf = ParticleFilter::new(500, 7);
54        let obs: Vec<f64> = vec![3.0; 10];
55        let means = pf.filter_mean(
56            &obs,
57            |rng| rng.normal_mv(3.0, 1.0),
58            |x, rng| rng.normal_mv(x, 0.1),
59            |x, y| {
60                let z = (x - y) / 0.5;
61                -0.5 * z * z
62            },
63        );
64        let last = *means.last().expect("last should succeed");
65        assert!(
66            (last - 3.0).abs() < 1.0,
67            "PF mean should be near 3.0, got {last}"
68        );
69    }
70    #[test]
71    fn test_hmc_samples_normal() {
72        let mut hmc = Hmc::new(0.2, 5, 1337);
73        let samples = hmc.sample(
74            vec![0.0],
75            1000,
76            |q| -0.5 * (q[0] - 2.0).powi(2),
77            |q| vec![-(q[0] - 2.0)],
78        );
79        let mean = samples.iter().map(|s| s[0]).sum::<f64>() / samples.len() as f64;
80        assert!(
81            (mean - 2.0).abs() < 0.3,
82            "HMC mean should be near 2.0, got {mean}"
83        );
84    }
85    #[test]
86    fn test_mean_field_vi_converges() {
87        let mut vi = MeanFieldVI::new(1, 0.1, 10, 42);
88        let elbo_hist = vi.fit(|z| -0.5 * (z[0] - 5.0).powi(2), 200);
89        assert!(
90            (vi.mu[0] - 5.0).abs() < 1.5,
91            "VI mean should converge near 5.0, got {}",
92            vi.mu[0]
93        );
94        assert!(elbo_hist.last().expect("last should succeed").is_finite());
95    }
96    #[test]
97    fn test_build_probabilistic_programming_env() {
98        let mut env = Environment::new();
99        build_probabilistic_programming_env(&mut env).expect("env build failed");
100        assert!(env.get(&Name::str("Measure")).is_some());
101        assert!(env.get(&Name::str("ProbabilityMonad")).is_some());
102        assert!(env.get(&Name::str("elbo_lower_bound")).is_some());
103        assert!(env.get(&Name::str("hmc_invariant")).is_some());
104    }
105    #[test]
106    fn test_effective_sample_size() {
107        let n = 100;
108        let log_weights = vec![0.0f64; n];
109        let ess = ImportanceSampler::effective_sample_size(&log_weights);
110        assert!(
111            (ess - n as f64).abs() < 1.0,
112            "ESS for uniform weights should be {n}, got {ess}"
113        );
114        let mut lw = vec![f64::NEG_INFINITY; n];
115        lw[0] = 0.0;
116        let ess2 = ImportanceSampler::effective_sample_size(&lw);
117        assert!(ess2 < 2.0, "degenerate ESS should be ≈ 1, got {ess2}");
118    }
119}
120/// Natural log of the Beta function: ln B(a,b) = lgamma(a) + lgamma(b) - lgamma(a+b).
121pub(super) fn ln_beta(a: f64, b: f64) -> f64 {
122    lgamma(a) + lgamma(b) - lgamma(a + b)
123}
124/// Stirling approximation of ln Γ(x).
125pub(super) fn lgamma(x: f64) -> f64 {
126    if x < 0.5 {
127        std::f64::consts::PI.ln() - (std::f64::consts::PI * x).sin().ln() - lgamma(1.0 - x)
128    } else {
129        let g = 7.0_f64;
130        let c = [
131            0.999_999_999_999_809_9_f64,
132            676.5203681218851,
133            -1259.1392167224028,
134            771.323_428_777_653_1,
135            -176.615_029_162_140_6,
136            12.507343278686905,
137            -0.13857109526572012,
138            9.984_369_578_019_572e-6,
139            1.5056327351493116e-7,
140        ];
141        let x = x - 1.0;
142        let t = x + g + 0.5;
143        let mut s = c[0];
144        for i in 1..9 {
145            s += c[i] / (x + i as f64);
146        }
147        0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * t.ln() - t + s.ln()
148    }
149}