fugue/inference/
vi.rs

1//! Variational Inference (VI) with mean-field approximations and ELBO optimization.
2//!
3//! This module implements variational inference, a deterministic approximate inference
4//! method that turns posterior inference into an optimization problem. Instead of sampling
5//! from the true posterior, VI finds the best approximation within a chosen family of
6//! distributions by maximizing the Evidence Lower BOund (ELBO).
7//!
8//! ## Method Overview
9//!
10//! Variational inference works by:
11//! 1. Choosing a family of tractable distributions Q(θ; φ) parameterized by φ
12//! 2. Finding φ* that minimizes KL(Q(θ; φ) || P(θ|data))
13//! 3. Using Q(θ; φ*) as an approximation to the true posterior P(θ|data)
14//!
15//! ## Mean-Field Approximation
16//!
17//! This implementation uses mean-field variational inference, where the posterior
18//! is approximated as a product of independent distributions:
19//! Q(θ₁, θ₂, ..., θₖ) = Q₁(θ₁) × Q₂(θ₂) × ... × Qₖ(θₖ)
20//!
21//! ## Advantages of VI
22//!
23//! - **Deterministic**: No random sampling, reproducible results
24//! - **Fast**: Typically faster than MCMC for large models
25//! - **Scalable**: Handles high-dimensional parameters well
26//! - **Convergence detection**: Clear optimization objective to monitor
27//!
28//! ## Limitations
29//!
30//! - **Approximation quality**: May underestimate posterior uncertainty
31//! - **Local optima**: Gradient-based optimization can get stuck
32//! - **Family restrictions**: Posterior must be well-approximated by chosen family
33//!
34//! # Examples
35//!
36//! ```rust
37//! use fugue::*;
38//! use rand::rngs::StdRng;
39//! use rand::SeedableRng;
40//! use std::collections::HashMap;
41//!
42//! // Simple VI example
43//! let model_fn = || {
44//!     sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap())
45//!         .bind(|mu| observe(addr!("y"), Normal::new(mu, 0.5).unwrap(), 2.0).map(move |_| mu))
46//! };
47//!
48//! // Create mean-field guide manually
49//! let mut guide = MeanFieldGuide {
50//!     params: HashMap::new()
51//! };
52//! guide.params.insert(
53//!     addr!("mu"),
54//!     VariationalParam::Normal { mu: 0.0, log_sigma: 0.0 }
55//! );
56//!
57//! // Simple ELBO computation
58//! let mut rng = StdRng::seed_from_u64(42);
59//! let elbo = elbo_with_guide(&mut rng, &model_fn, &guide, 10);
60//! assert!(elbo.is_finite());
61//! ```
62use crate::core::address::Address;
63use crate::core::distribution::*;
64use crate::core::model::Model;
65use crate::runtime::handler::run;
66use crate::runtime::interpreters::{PriorHandler, ScoreGivenTrace};
67use crate::runtime::trace::{Choice, ChoiceValue, Trace};
68use rand::Rng;
69use std::collections::HashMap;
70
71/// Variational distribution parameters for a single random variable.
72///
73/// Each random variable in the model gets its own variational distribution that
74/// approximates its marginal posterior. The parameters are stored in log-space
75/// for numerical stability and to ensure positive constraints.
76///
77/// # Variants
78///
79/// * `Normal` - Gaussian approximation with mean and log-standard-deviation
80/// * `LogNormal` - Log-normal approximation for positive variables
81/// * `Beta` - Beta approximation for variables constrained to \[0,1\]
82///
83/// # Examples
84///
85/// ```rust
86/// use fugue::*;
87/// use rand::rngs::StdRng;
88/// use rand::SeedableRng;
89///
90/// // Create variational parameters
91/// let normal_param = VariationalParam::Normal {
92///     mu: 1.5,
93///     log_sigma: -0.693  // sigma = 0.5
94/// };
95///
96/// let beta_param = VariationalParam::Beta {
97///     log_alpha: 1.099,  // alpha = 3.0
98///     log_beta: 0.693,   // beta = 2.0
99/// };
100///
101/// // Sample from variational distribution
102/// let mut rng = StdRng::seed_from_u64(42);
103/// let sample = normal_param.sample(&mut rng);
104/// let log_prob = normal_param.log_prob(sample);
105/// ```
106#[derive(Clone, Debug)]
107pub enum VariationalParam {
108    /// Normal/Gaussian variational distribution.
109    Normal {
110        /// Mean parameter.
111        mu: f64,
112        /// Log of standard deviation (for positivity).
113        log_sigma: f64,
114    },
115    /// Log-normal variational distribution for positive variables.
116    LogNormal {
117        /// Mean of underlying normal.
118        mu: f64,
119        /// Log of standard deviation of underlying normal.
120        log_sigma: f64,
121    },
122    /// Beta variational distribution for variables in \[0,1\].
123    Beta {
124        /// Log of first shape parameter (for positivity).
125        log_alpha: f64,
126        /// Log of second shape parameter (for positivity).
127        log_beta: f64,
128    },
129}
130
131impl VariationalParam {
132    /// Sample a value from this variational distribution with numerical stability.
133    ///
134    /// Generates a random sample using the current variational parameters.
135    /// This version includes parameter validation and numerical stability checks.
136    ///
137    /// # Arguments
138    ///
139    /// * `rng` - Random number generator
140    ///
141    /// # Returns
142    ///
143    /// A sample from the variational distribution, or NaN if parameters are invalid.
144    pub fn sample<R: Rng>(&self, rng: &mut R) -> f64 {
145        match self {
146            VariationalParam::Normal { mu, log_sigma } => {
147                let sigma = log_sigma.exp();
148                if !mu.is_finite() || !sigma.is_finite() || sigma <= 0.0 {
149                    return f64::NAN;
150                }
151                Normal::new(*mu, sigma).unwrap().sample(rng)
152            }
153            VariationalParam::LogNormal { mu, log_sigma } => {
154                let sigma = log_sigma.exp();
155                if !mu.is_finite() || !sigma.is_finite() || sigma <= 0.0 {
156                    return f64::NAN;
157                }
158                LogNormal::new(*mu, sigma).unwrap().sample(rng)
159            }
160            VariationalParam::Beta {
161                log_alpha,
162                log_beta,
163            } => {
164                let alpha = log_alpha.exp();
165                let beta = log_beta.exp();
166                if !alpha.is_finite() || !beta.is_finite() || alpha <= 0.0 || beta <= 0.0 {
167                    return f64::NAN;
168                }
169                Beta::new(alpha, beta).unwrap().sample(rng)
170            }
171        }
172    }
173
174    /// Sample with reparameterization for gradient computation (experimental).
175    ///
176    /// Returns both the sample and auxiliary information needed for
177    /// computing gradients via the reparameterization trick.
178    pub fn sample_with_aux<R: Rng>(&self, rng: &mut R) -> (f64, f64) {
179        match self {
180            VariationalParam::Normal { mu, log_sigma } => {
181                let sigma = log_sigma.exp();
182                // Simple standard normal sampling
183                let u1: f64 = rng.gen::<f64>().max(1e-10);
184                let u2: f64 = rng.gen();
185                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
186                let value = mu + sigma * z;
187                const LN_2PI: f64 = 1.837_877_066_409_345_6;
188                let _log_prob = -0.5 * z * z - log_sigma - 0.5 * LN_2PI;
189                (value, z)
190            }
191            VariationalParam::LogNormal { mu, log_sigma } => {
192                let sigma = log_sigma.exp();
193                // Simple standard normal sampling
194                let u1: f64 = rng.gen::<f64>().max(1e-10);
195                let u2: f64 = rng.gen();
196                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
197                let log_value = mu + sigma * z;
198                let value = log_value.exp();
199                const LN_2PI: f64 = 1.837_877_066_409_345_6;
200                let _log_prob = -0.5 * z * z - log_sigma - 0.5 * LN_2PI - log_value;
201                (value, z)
202            }
203            VariationalParam::Beta {
204                log_alpha,
205                log_beta,
206            } => {
207                // Use normal approximation for Beta (stable fallback)
208                let alpha = log_alpha.exp();
209                let beta = log_beta.exp();
210                let approx_mu = alpha / (alpha + beta);
211                let approx_var = (alpha * beta) / ((alpha + beta).powi(2) * (alpha + beta + 1.0));
212                let approx_sigma = approx_var.sqrt();
213
214                // Simple standard normal sampling
215                let u1: f64 = rng.gen::<f64>().max(1e-10);
216                let u2: f64 = rng.gen();
217                let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
218                let raw_value = approx_mu + approx_sigma * z;
219                let value = raw_value.clamp(0.001, 0.999);
220
221                let _log_prob = Beta::new(alpha, beta).unwrap().log_prob(&value);
222                (value, z)
223            }
224        }
225    }
226
227    /// Compute log-probability of a value under this variational distribution.
228    ///
229    /// This is used for computing entropy terms in the ELBO and for evaluating
230    /// the quality of the variational approximation. Now includes numerical stability checks.
231    ///
232    /// # Arguments
233    ///
234    /// * `x` - Value to evaluate
235    ///
236    /// # Returns
237    ///
238    /// Log-probability density at the given value.
239    pub fn log_prob(&self, x: f64) -> f64 {
240        match self {
241            VariationalParam::Normal { mu, log_sigma } => {
242                let sigma = log_sigma.exp();
243                Normal::new(*mu, sigma).unwrap().log_prob(&x)
244            }
245            VariationalParam::LogNormal { mu, log_sigma } => {
246                let sigma = log_sigma.exp();
247                LogNormal::new(*mu, sigma).unwrap().log_prob(&x)
248            }
249            VariationalParam::Beta {
250                log_alpha,
251                log_beta,
252            } => {
253                let alpha = log_alpha.exp();
254                let beta = log_beta.exp();
255                Beta::new(alpha, beta).unwrap().log_prob(&x)
256            }
257        }
258    }
259}
260
261/// Mean-field variational guide for approximate posterior inference.
262///
263/// A mean-field guide specifies independent variational distributions for each
264/// random variable in the model. This factorization assumption simplifies
265/// optimization but may underestimate correlations between variables.
266///
267/// The guide maps each address (random variable) to its variational parameters,
268/// which are optimized to minimize the KL divergence to the true posterior.
269///
270/// # Fields
271///
272/// * `params` - Map from addresses to their variational parameters
273///
274/// # Examples
275///
276/// ```rust
277/// use fugue::*;
278/// use std::collections::HashMap;
279///
280/// // Create a guide for a two-parameter model
281/// let mut guide = MeanFieldGuide::new();
282/// guide.params.insert(
283///     addr!("mu"),
284///     VariationalParam::Normal { mu: 0.0, log_sigma: 0.0 }
285/// );
286/// guide.params.insert(
287///     addr!("sigma"),
288///     VariationalParam::Normal { mu: 0.0, log_sigma: -1.0 }
289/// );
290///
291/// // Check if parameters are specified
292/// assert!(guide.params.contains_key(&addr!("mu")));
293/// assert!(guide.params.contains_key(&addr!("sigma")));
294/// ```
295#[derive(Clone, Debug)]
296pub struct MeanFieldGuide {
297    /// Map from addresses to their variational parameters.
298    pub params: HashMap<Address, VariationalParam>,
299}
300
301impl Default for MeanFieldGuide {
302    fn default() -> Self {
303        Self::new()
304    }
305}
306
307impl MeanFieldGuide {
308    /// Create a new empty mean-field guide.
309    ///
310    /// The guide starts with no variational parameters. You must add parameters
311    /// for each random variable in your model using the `add_*_param` methods.
312    pub fn new() -> Self {
313        Self {
314            params: HashMap::new(),
315        }
316    }
317
318    /// Initialize guide from a prior trace.
319    pub fn from_trace(trace: &Trace) -> Self {
320        let mut guide = Self::new();
321
322        for (addr, choice) in &trace.choices {
323            let param = match choice.value {
324                ChoiceValue::F64(val) => {
325                    if val > 0.0 {
326                        // Use LogNormal for positive values
327                        VariationalParam::LogNormal {
328                            mu: val.ln(),
329                            log_sigma: 0.0_f64.ln(),
330                        }
331                    } else {
332                        // Use Normal for real values
333                        VariationalParam::Normal {
334                            mu: val,
335                            log_sigma: 1.0_f64.ln(),
336                        }
337                    }
338                }
339                ChoiceValue::Bool(_) => {
340                    // Use Beta(1,1) = Uniform for boolean (as continuous relaxation)
341                    VariationalParam::Beta {
342                        log_alpha: 1.0_f64.ln(),
343                        log_beta: 1.0_f64.ln(),
344                    }
345                }
346                ChoiceValue::I64(val) => {
347                    // Use Normal for integers (continuous relaxation)
348                    VariationalParam::Normal {
349                        mu: val as f64,
350                        log_sigma: 1.0_f64.ln(),
351                    }
352                }
353                ChoiceValue::U64(val) => {
354                    // Use LogNormal for unsigned integers (always positive)
355                    VariationalParam::LogNormal {
356                        mu: (val as f64).ln(),
357                        log_sigma: 1.0_f64.ln(),
358                    }
359                }
360                ChoiceValue::Usize(val) => {
361                    // Use LogNormal for categorical indices (always positive)
362                    VariationalParam::LogNormal {
363                        mu: (val as f64 + 1.0).ln(), // +1 to avoid log(0)
364                        log_sigma: 1.0_f64.ln(),
365                    }
366                }
367            };
368            guide.params.insert(addr.clone(), param);
369        }
370        guide
371    }
372
373    /// Sample a trace from the guide.
374    pub fn sample_trace<R: Rng>(&self, rng: &mut R) -> Trace {
375        let mut trace = Trace::default();
376
377        for (addr, param) in &self.params {
378            let value = param.sample(rng);
379            let log_prob = param.log_prob(value);
380
381            trace.choices.insert(
382                addr.clone(),
383                Choice {
384                    addr: addr.clone(),
385                    value: ChoiceValue::F64(value),
386                    logp: log_prob,
387                },
388            );
389            trace.log_prior += log_prob;
390        }
391        trace
392    }
393}
394
395/// ELBO estimation using a variational guide.
396pub fn elbo_with_guide<A, R: Rng>(
397    rng: &mut R,
398    model_fn: impl Fn() -> Model<A>,
399    guide: &MeanFieldGuide,
400    num_samples: usize,
401) -> f64 {
402    let mut total_elbo = 0.0;
403
404    for _ in 0..num_samples {
405        let guide_trace = guide.sample_trace(rng);
406        let (_a, model_trace) = run(
407            ScoreGivenTrace {
408                base: guide_trace.clone(),
409                trace: Trace::default(),
410            },
411            model_fn(),
412        );
413
414        // ELBO = E_q[log p(x,z) - log q(z)]
415        let log_joint = model_trace.total_log_weight();
416        let log_guide = guide_trace.log_prior;
417        total_elbo += log_joint - log_guide;
418    }
419
420    total_elbo / num_samples as f64
421}
422
423/// Simple VI optimization using coordinate ascent.
424pub fn optimize_meanfield_vi<A, R: Rng>(
425    rng: &mut R,
426    model_fn: impl Fn() -> Model<A>,
427    initial_guide: MeanFieldGuide,
428    n_iterations: usize,
429    n_samples_per_iter: usize,
430    learning_rate: f64,
431) -> MeanFieldGuide {
432    let mut guide = initial_guide;
433
434    for iter in 0..n_iterations {
435        let current_elbo = elbo_with_guide(rng, &model_fn, &guide, n_samples_per_iter);
436
437        // Simple gradient ascent (placeholder - would use automatic differentiation in practice)
438        let guide_clone = guide.clone();
439        for (_addr, param) in &mut guide.params {
440            match param {
441                VariationalParam::Normal { mu, log_sigma: _ } => {
442                    // Finite difference gradients (very basic)
443                    let eps = 0.01;
444                    let mut guide_plus = guide_clone.clone();
445                    if let Some(VariationalParam::Normal { mu: mu_plus, .. }) =
446                        guide_plus.params.get_mut(_addr)
447                    {
448                        *mu_plus += eps;
449                    }
450                    let elbo_plus = elbo_with_guide(rng, &model_fn, &guide_plus, 10);
451                    let grad_mu = (elbo_plus - current_elbo) / eps;
452
453                    // Add numerical stability checks
454                    if grad_mu.is_finite() {
455                        let update = learning_rate * grad_mu;
456                        if update.is_finite() {
457                            *mu += update;
458                            // Clamp to reasonable range to prevent overflow
459                            *mu = mu.clamp(-100.0, 100.0);
460                        }
461                    }
462                }
463                VariationalParam::LogNormal { mu, log_sigma: _ } => {
464                    // Similar finite difference for LogNormal parameters
465                    let eps = 0.01;
466                    let mut guide_plus = guide_clone.clone();
467                    if let Some(VariationalParam::LogNormal { mu: mu_plus, .. }) =
468                        guide_plus.params.get_mut(_addr)
469                    {
470                        *mu_plus += eps;
471                    }
472                    let elbo_plus = elbo_with_guide(rng, &model_fn, &guide_plus, 10);
473                    let grad_mu = (elbo_plus - current_elbo) / eps;
474
475                    // Add numerical stability checks
476                    if grad_mu.is_finite() {
477                        let update = learning_rate * grad_mu;
478                        if update.is_finite() {
479                            *mu += update;
480                            // Clamp to reasonable range for LogNormal
481                            *mu = mu.clamp(-10.0, 10.0);
482                        }
483                    }
484                }
485                VariationalParam::Beta {
486                    log_alpha,
487                    log_beta: _,
488                } => {
489                    // Basic update for Beta parameters
490                    let eps = 0.01;
491                    let mut guide_plus = guide_clone.clone();
492                    if let Some(VariationalParam::Beta {
493                        log_alpha: alpha_plus,
494                        ..
495                    }) = guide_plus.params.get_mut(_addr)
496                    {
497                        *alpha_plus += eps;
498                    }
499                    let elbo_plus = elbo_with_guide(rng, &model_fn, &guide_plus, 10);
500                    let grad_alpha = (elbo_plus - current_elbo) / eps;
501
502                    // Add numerical stability checks
503                    if grad_alpha.is_finite() {
504                        let update = learning_rate * grad_alpha;
505                        if update.is_finite() {
506                            *log_alpha += update;
507                            // Clamp to reasonable range for Beta
508                            *log_alpha = log_alpha.clamp(-5.0, 5.0);
509                        }
510                    }
511                }
512            }
513        }
514
515        if iter % 100 == 0 {
516            println!("VI Iteration {}: ELBO = {:.4}", iter, current_elbo);
517        }
518    }
519
520    guide
521}
522
523// Keep the original simple function for backward compatibility
524pub fn estimate_elbo<A, R: Rng>(
525    rng: &mut R,
526    model_fn: impl Fn() -> Model<A>,
527    num_samples: usize,
528) -> f64 {
529    let mut total = 0.0;
530    for _ in 0..num_samples {
531        let (_a, prior_t) = run(
532            PriorHandler {
533                rng,
534                trace: Trace::default(),
535            },
536            model_fn(),
537        );
538        let (_a2, scored) = run(
539            ScoreGivenTrace {
540                base: prior_t.clone(),
541                trace: Trace::default(),
542            },
543            model_fn(),
544        );
545        total += scored.total_log_weight();
546    }
547    total / (num_samples as f64)
548}
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553    use crate::addr;
554
555    use crate::core::model::{observe, sample, ModelExt};
556    use crate::runtime::trace::{Choice, ChoiceValue, Trace};
557    use rand::rngs::StdRng;
558    use rand::SeedableRng;
559
560    #[test]
561    fn variational_param_sampling_and_log_prob() {
562        let mut rng = StdRng::seed_from_u64(20);
563        let vp_n = VariationalParam::Normal {
564            mu: 0.0,
565            log_sigma: 0.0,
566        };
567        let x = vp_n.sample(&mut rng);
568        assert!(x.is_finite());
569        assert!(vp_n.log_prob(x).is_finite());
570
571        let vp_b = VariationalParam::Beta {
572            log_alpha: (2.0f64).ln(),
573            log_beta: (3.0f64).ln(),
574        };
575        let y = vp_b.sample(&mut rng);
576        assert!(y > 0.0 && y < 1.0);
577        assert!(vp_b.log_prob(y).is_finite());
578    }
579
580    #[test]
581    fn elbo_computation_is_finite() {
582        let model_fn = || {
583            sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()).and_then(|mu| {
584                observe(addr!("y"), Normal::new(mu, 1.0).unwrap(), 0.2).map(move |_| mu)
585            })
586        };
587
588        // Build a simple guide
589        let mut guide = MeanFieldGuide::new();
590        guide.params.insert(
591            addr!("mu"),
592            VariationalParam::Normal {
593                mu: 0.0,
594                log_sigma: 0.0,
595            },
596        );
597
598        let mut rng = StdRng::seed_from_u64(21);
599        let elbo = elbo_with_guide(&mut rng, model_fn, &guide, 5);
600        assert!(elbo.is_finite());
601    }
602
603    #[test]
604    fn meanfield_from_trace_and_sampling() {
605        // Create a base trace with mixed types
606        let mut base = Trace::default();
607        base.choices.insert(
608            addr!("pos"),
609            Choice {
610                addr: addr!("pos"),
611                value: ChoiceValue::F64(-1.0),
612                logp: -0.1,
613            },
614        );
615        base.choices.insert(
616            addr!("bool"),
617            Choice {
618                addr: addr!("bool"),
619                value: ChoiceValue::Bool(true),
620                logp: -0.7,
621            },
622        );
623        base.choices.insert(
624            addr!("u64"),
625            Choice {
626                addr: addr!("u64"),
627                value: ChoiceValue::U64(3),
628                logp: -0.5,
629            },
630        );
631
632        let guide = MeanFieldGuide::from_trace(&base);
633        assert!(!guide.params.is_empty());
634
635        // Sample a trace from the guide
636        let t = guide.sample_trace(&mut StdRng::seed_from_u64(22));
637        assert!(!t.choices.is_empty());
638        assert!(t.log_prior.is_finite());
639    }
640
641    #[test]
642    fn optimize_vi_updates_parameters_and_is_stable() {
643        let model_fn = || {
644            sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()).and_then(|mu| {
645                observe(addr!("y"), Normal::new(mu, 1.0).unwrap(), 0.3).map(move |_| mu)
646            })
647        };
648
649        let mut guide = MeanFieldGuide::new();
650        guide.params.insert(
651            addr!("mu"),
652            VariationalParam::Normal {
653                mu: 0.0,
654                log_sigma: 0.0,
655            },
656        );
657
658        let optimized = optimize_meanfield_vi(
659            &mut StdRng::seed_from_u64(23),
660            model_fn,
661            guide.clone(),
662            2, // small iterations for speed
663            3,
664            0.1,
665        );
666
667        // Parameter exists and remains within clamped bounds
668        if let VariationalParam::Normal { mu, .. } = optimized.params.get(&addr!("mu")).unwrap() {
669            assert!(*mu <= 100.0 && *mu >= -100.0);
670        } else {
671            panic!("expected Normal param");
672        }
673    }
674}