fugue/inference/vi.rs
1//! Variational Inference (VI) with mean-field approximations and ELBO optimization.
2//!
3//! This module implements variational inference, a deterministic approximate inference
4//! method that turns posterior inference into an optimization problem. Instead of sampling
5//! from the true posterior, VI finds the best approximation within a chosen family of
6//! distributions by maximizing the Evidence Lower BOund (ELBO).
7//!
8//! ## Method Overview
9//!
10//! Variational inference works by:
11//! 1. Choosing a family of tractable distributions Q(θ; φ) parameterized by φ
12//! 2. Finding φ* that minimizes KL(Q(θ; φ) || P(θ|data))
13//! 3. Using Q(θ; φ*) as an approximation to the true posterior P(θ|data)
14//!
15//! ## Mean-Field Approximation
16//!
17//! This implementation uses mean-field variational inference, where the posterior
18//! is approximated as a product of independent distributions:
19//! Q(θ₁, θ₂, ..., θₖ) = Q₁(θ₁) × Q₂(θ₂) × ... × Qₖ(θₖ)
20//!
21//! ## Advantages of VI
22//!
23//! - **Deterministic**: No random sampling, reproducible results
24//! - **Fast**: Typically faster than MCMC for large models
25//! - **Scalable**: Handles high-dimensional parameters well
26//! - **Convergence detection**: Clear optimization objective to monitor
27//!
28//! ## Limitations
29//!
30//! - **Approximation quality**: May underestimate posterior uncertainty
31//! - **Local optima**: Gradient-based optimization can get stuck
32//! - **Family restrictions**: Posterior must be well-approximated by chosen family
33//!
34//! # Examples
35//!
36//! ```rust
37//! use fugue::*;
38//! use rand::rngs::StdRng;
39//! use rand::SeedableRng;
40//! use std::collections::HashMap;
41//!
42//! // Simple VI example
43//! let model_fn = || {
44//! sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap())
45//! .bind(|mu| observe(addr!("y"), Normal::new(mu, 0.5).unwrap(), 2.0).map(move |_| mu))
46//! };
47//!
48//! // Create mean-field guide manually
49//! let mut guide = MeanFieldGuide {
50//! params: HashMap::new()
51//! };
52//! guide.params.insert(
53//! addr!("mu"),
54//! VariationalParam::Normal { mu: 0.0, log_sigma: 0.0 }
55//! );
56//!
57//! // Simple ELBO computation
58//! let mut rng = StdRng::seed_from_u64(42);
59//! let elbo = elbo_with_guide(&mut rng, &model_fn, &guide, 10);
60//! assert!(elbo.is_finite());
61//! ```
62use crate::core::address::Address;
63use crate::core::distribution::*;
64use crate::core::model::Model;
65use crate::runtime::handler::run;
66use crate::runtime::interpreters::{PriorHandler, ScoreGivenTrace};
67use crate::runtime::trace::{Choice, ChoiceValue, Trace};
68use rand::Rng;
69use std::collections::HashMap;
70
71/// Variational distribution parameters for a single random variable.
72///
73/// Each random variable in the model gets its own variational distribution that
74/// approximates its marginal posterior. The parameters are stored in log-space
75/// for numerical stability and to ensure positive constraints.
76///
77/// # Variants
78///
79/// * `Normal` - Gaussian approximation with mean and log-standard-deviation
80/// * `LogNormal` - Log-normal approximation for positive variables
81/// * `Beta` - Beta approximation for variables constrained to \[0,1\]
82///
83/// # Examples
84///
85/// ```rust
86/// use fugue::*;
87/// use rand::rngs::StdRng;
88/// use rand::SeedableRng;
89///
90/// // Create variational parameters
91/// let normal_param = VariationalParam::Normal {
92/// mu: 1.5,
93/// log_sigma: -0.693 // sigma = 0.5
94/// };
95///
96/// let beta_param = VariationalParam::Beta {
97/// log_alpha: 1.099, // alpha = 3.0
98/// log_beta: 0.693, // beta = 2.0
99/// };
100///
101/// // Sample from variational distribution
102/// let mut rng = StdRng::seed_from_u64(42);
103/// let sample = normal_param.sample(&mut rng);
104/// let log_prob = normal_param.log_prob(sample);
105/// ```
106#[derive(Clone, Debug)]
107pub enum VariationalParam {
108 /// Normal/Gaussian variational distribution.
109 Normal {
110 /// Mean parameter.
111 mu: f64,
112 /// Log of standard deviation (for positivity).
113 log_sigma: f64,
114 },
115 /// Log-normal variational distribution for positive variables.
116 LogNormal {
117 /// Mean of underlying normal.
118 mu: f64,
119 /// Log of standard deviation of underlying normal.
120 log_sigma: f64,
121 },
122 /// Beta variational distribution for variables in \[0,1\].
123 Beta {
124 /// Log of first shape parameter (for positivity).
125 log_alpha: f64,
126 /// Log of second shape parameter (for positivity).
127 log_beta: f64,
128 },
129}
130
131impl VariationalParam {
132 /// Sample a value from this variational distribution with numerical stability.
133 ///
134 /// Generates a random sample using the current variational parameters.
135 /// This version includes parameter validation and numerical stability checks.
136 ///
137 /// # Arguments
138 ///
139 /// * `rng` - Random number generator
140 ///
141 /// # Returns
142 ///
143 /// A sample from the variational distribution, or NaN if parameters are invalid.
144 pub fn sample<R: Rng>(&self, rng: &mut R) -> f64 {
145 match self {
146 VariationalParam::Normal { mu, log_sigma } => {
147 let sigma = log_sigma.exp();
148 if !mu.is_finite() || !sigma.is_finite() || sigma <= 0.0 {
149 return f64::NAN;
150 }
151 Normal::new(*mu, sigma).unwrap().sample(rng)
152 }
153 VariationalParam::LogNormal { mu, log_sigma } => {
154 let sigma = log_sigma.exp();
155 if !mu.is_finite() || !sigma.is_finite() || sigma <= 0.0 {
156 return f64::NAN;
157 }
158 LogNormal::new(*mu, sigma).unwrap().sample(rng)
159 }
160 VariationalParam::Beta {
161 log_alpha,
162 log_beta,
163 } => {
164 let alpha = log_alpha.exp();
165 let beta = log_beta.exp();
166 if !alpha.is_finite() || !beta.is_finite() || alpha <= 0.0 || beta <= 0.0 {
167 return f64::NAN;
168 }
169 Beta::new(alpha, beta).unwrap().sample(rng)
170 }
171 }
172 }
173
174 /// Sample with reparameterization for gradient computation (experimental).
175 ///
176 /// Returns both the sample and auxiliary information needed for
177 /// computing gradients via the reparameterization trick.
178 pub fn sample_with_aux<R: Rng>(&self, rng: &mut R) -> (f64, f64) {
179 match self {
180 VariationalParam::Normal { mu, log_sigma } => {
181 let sigma = log_sigma.exp();
182 // Simple standard normal sampling
183 let u1: f64 = rng.gen::<f64>().max(1e-10);
184 let u2: f64 = rng.gen();
185 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
186 let value = mu + sigma * z;
187 const LN_2PI: f64 = 1.837_877_066_409_345_6;
188 let _log_prob = -0.5 * z * z - log_sigma - 0.5 * LN_2PI;
189 (value, z)
190 }
191 VariationalParam::LogNormal { mu, log_sigma } => {
192 let sigma = log_sigma.exp();
193 // Simple standard normal sampling
194 let u1: f64 = rng.gen::<f64>().max(1e-10);
195 let u2: f64 = rng.gen();
196 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
197 let log_value = mu + sigma * z;
198 let value = log_value.exp();
199 const LN_2PI: f64 = 1.837_877_066_409_345_6;
200 let _log_prob = -0.5 * z * z - log_sigma - 0.5 * LN_2PI - log_value;
201 (value, z)
202 }
203 VariationalParam::Beta {
204 log_alpha,
205 log_beta,
206 } => {
207 // Use normal approximation for Beta (stable fallback)
208 let alpha = log_alpha.exp();
209 let beta = log_beta.exp();
210 let approx_mu = alpha / (alpha + beta);
211 let approx_var = (alpha * beta) / ((alpha + beta).powi(2) * (alpha + beta + 1.0));
212 let approx_sigma = approx_var.sqrt();
213
214 // Simple standard normal sampling
215 let u1: f64 = rng.gen::<f64>().max(1e-10);
216 let u2: f64 = rng.gen();
217 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
218 let raw_value = approx_mu + approx_sigma * z;
219 let value = raw_value.clamp(0.001, 0.999);
220
221 let _log_prob = Beta::new(alpha, beta).unwrap().log_prob(&value);
222 (value, z)
223 }
224 }
225 }
226
227 /// Compute log-probability of a value under this variational distribution.
228 ///
229 /// This is used for computing entropy terms in the ELBO and for evaluating
230 /// the quality of the variational approximation. Now includes numerical stability checks.
231 ///
232 /// # Arguments
233 ///
234 /// * `x` - Value to evaluate
235 ///
236 /// # Returns
237 ///
238 /// Log-probability density at the given value.
239 pub fn log_prob(&self, x: f64) -> f64 {
240 match self {
241 VariationalParam::Normal { mu, log_sigma } => {
242 let sigma = log_sigma.exp();
243 Normal::new(*mu, sigma).unwrap().log_prob(&x)
244 }
245 VariationalParam::LogNormal { mu, log_sigma } => {
246 let sigma = log_sigma.exp();
247 LogNormal::new(*mu, sigma).unwrap().log_prob(&x)
248 }
249 VariationalParam::Beta {
250 log_alpha,
251 log_beta,
252 } => {
253 let alpha = log_alpha.exp();
254 let beta = log_beta.exp();
255 Beta::new(alpha, beta).unwrap().log_prob(&x)
256 }
257 }
258 }
259}
260
261/// Mean-field variational guide for approximate posterior inference.
262///
263/// A mean-field guide specifies independent variational distributions for each
264/// random variable in the model. This factorization assumption simplifies
265/// optimization but may underestimate correlations between variables.
266///
267/// The guide maps each address (random variable) to its variational parameters,
268/// which are optimized to minimize the KL divergence to the true posterior.
269///
270/// # Fields
271///
272/// * `params` - Map from addresses to their variational parameters
273///
274/// # Examples
275///
276/// ```rust
277/// use fugue::*;
278/// use std::collections::HashMap;
279///
280/// // Create a guide for a two-parameter model
281/// let mut guide = MeanFieldGuide::new();
282/// guide.params.insert(
283/// addr!("mu"),
284/// VariationalParam::Normal { mu: 0.0, log_sigma: 0.0 }
285/// );
286/// guide.params.insert(
287/// addr!("sigma"),
288/// VariationalParam::Normal { mu: 0.0, log_sigma: -1.0 }
289/// );
290///
291/// // Check if parameters are specified
292/// assert!(guide.params.contains_key(&addr!("mu")));
293/// assert!(guide.params.contains_key(&addr!("sigma")));
294/// ```
295#[derive(Clone, Debug)]
296pub struct MeanFieldGuide {
297 /// Map from addresses to their variational parameters.
298 pub params: HashMap<Address, VariationalParam>,
299}
300
301impl Default for MeanFieldGuide {
302 fn default() -> Self {
303 Self::new()
304 }
305}
306
307impl MeanFieldGuide {
308 /// Create a new empty mean-field guide.
309 ///
310 /// The guide starts with no variational parameters. You must add parameters
311 /// for each random variable in your model using the `add_*_param` methods.
312 pub fn new() -> Self {
313 Self {
314 params: HashMap::new(),
315 }
316 }
317
318 /// Initialize guide from a prior trace.
319 pub fn from_trace(trace: &Trace) -> Self {
320 let mut guide = Self::new();
321
322 for (addr, choice) in &trace.choices {
323 let param = match choice.value {
324 ChoiceValue::F64(val) => {
325 if val > 0.0 {
326 // Use LogNormal for positive values
327 VariationalParam::LogNormal {
328 mu: val.ln(),
329 log_sigma: 0.0_f64.ln(),
330 }
331 } else {
332 // Use Normal for real values
333 VariationalParam::Normal {
334 mu: val,
335 log_sigma: 1.0_f64.ln(),
336 }
337 }
338 }
339 ChoiceValue::Bool(_) => {
340 // Use Beta(1,1) = Uniform for boolean (as continuous relaxation)
341 VariationalParam::Beta {
342 log_alpha: 1.0_f64.ln(),
343 log_beta: 1.0_f64.ln(),
344 }
345 }
346 ChoiceValue::I64(val) => {
347 // Use Normal for integers (continuous relaxation)
348 VariationalParam::Normal {
349 mu: val as f64,
350 log_sigma: 1.0_f64.ln(),
351 }
352 }
353 ChoiceValue::U64(val) => {
354 // Use LogNormal for unsigned integers (always positive)
355 VariationalParam::LogNormal {
356 mu: (val as f64).ln(),
357 log_sigma: 1.0_f64.ln(),
358 }
359 }
360 ChoiceValue::Usize(val) => {
361 // Use LogNormal for categorical indices (always positive)
362 VariationalParam::LogNormal {
363 mu: (val as f64 + 1.0).ln(), // +1 to avoid log(0)
364 log_sigma: 1.0_f64.ln(),
365 }
366 }
367 };
368 guide.params.insert(addr.clone(), param);
369 }
370 guide
371 }
372
373 /// Sample a trace from the guide.
374 pub fn sample_trace<R: Rng>(&self, rng: &mut R) -> Trace {
375 let mut trace = Trace::default();
376
377 for (addr, param) in &self.params {
378 let value = param.sample(rng);
379 let log_prob = param.log_prob(value);
380
381 trace.choices.insert(
382 addr.clone(),
383 Choice {
384 addr: addr.clone(),
385 value: ChoiceValue::F64(value),
386 logp: log_prob,
387 },
388 );
389 trace.log_prior += log_prob;
390 }
391 trace
392 }
393}
394
395/// ELBO estimation using a variational guide.
396pub fn elbo_with_guide<A, R: Rng>(
397 rng: &mut R,
398 model_fn: impl Fn() -> Model<A>,
399 guide: &MeanFieldGuide,
400 num_samples: usize,
401) -> f64 {
402 let mut total_elbo = 0.0;
403
404 for _ in 0..num_samples {
405 let guide_trace = guide.sample_trace(rng);
406 let (_a, model_trace) = run(
407 ScoreGivenTrace {
408 base: guide_trace.clone(),
409 trace: Trace::default(),
410 },
411 model_fn(),
412 );
413
414 // ELBO = E_q[log p(x,z) - log q(z)]
415 let log_joint = model_trace.total_log_weight();
416 let log_guide = guide_trace.log_prior;
417 total_elbo += log_joint - log_guide;
418 }
419
420 total_elbo / num_samples as f64
421}
422
423/// Simple VI optimization using coordinate ascent.
424pub fn optimize_meanfield_vi<A, R: Rng>(
425 rng: &mut R,
426 model_fn: impl Fn() -> Model<A>,
427 initial_guide: MeanFieldGuide,
428 n_iterations: usize,
429 n_samples_per_iter: usize,
430 learning_rate: f64,
431) -> MeanFieldGuide {
432 let mut guide = initial_guide;
433
434 for iter in 0..n_iterations {
435 let current_elbo = elbo_with_guide(rng, &model_fn, &guide, n_samples_per_iter);
436
437 // Simple gradient ascent (placeholder - would use automatic differentiation in practice)
438 let guide_clone = guide.clone();
439 for (_addr, param) in &mut guide.params {
440 match param {
441 VariationalParam::Normal { mu, log_sigma: _ } => {
442 // Finite difference gradients (very basic)
443 let eps = 0.01;
444 let mut guide_plus = guide_clone.clone();
445 if let Some(VariationalParam::Normal { mu: mu_plus, .. }) =
446 guide_plus.params.get_mut(_addr)
447 {
448 *mu_plus += eps;
449 }
450 let elbo_plus = elbo_with_guide(rng, &model_fn, &guide_plus, 10);
451 let grad_mu = (elbo_plus - current_elbo) / eps;
452
453 // Add numerical stability checks
454 if grad_mu.is_finite() {
455 let update = learning_rate * grad_mu;
456 if update.is_finite() {
457 *mu += update;
458 // Clamp to reasonable range to prevent overflow
459 *mu = mu.clamp(-100.0, 100.0);
460 }
461 }
462 }
463 VariationalParam::LogNormal { mu, log_sigma: _ } => {
464 // Similar finite difference for LogNormal parameters
465 let eps = 0.01;
466 let mut guide_plus = guide_clone.clone();
467 if let Some(VariationalParam::LogNormal { mu: mu_plus, .. }) =
468 guide_plus.params.get_mut(_addr)
469 {
470 *mu_plus += eps;
471 }
472 let elbo_plus = elbo_with_guide(rng, &model_fn, &guide_plus, 10);
473 let grad_mu = (elbo_plus - current_elbo) / eps;
474
475 // Add numerical stability checks
476 if grad_mu.is_finite() {
477 let update = learning_rate * grad_mu;
478 if update.is_finite() {
479 *mu += update;
480 // Clamp to reasonable range for LogNormal
481 *mu = mu.clamp(-10.0, 10.0);
482 }
483 }
484 }
485 VariationalParam::Beta {
486 log_alpha,
487 log_beta: _,
488 } => {
489 // Basic update for Beta parameters
490 let eps = 0.01;
491 let mut guide_plus = guide_clone.clone();
492 if let Some(VariationalParam::Beta {
493 log_alpha: alpha_plus,
494 ..
495 }) = guide_plus.params.get_mut(_addr)
496 {
497 *alpha_plus += eps;
498 }
499 let elbo_plus = elbo_with_guide(rng, &model_fn, &guide_plus, 10);
500 let grad_alpha = (elbo_plus - current_elbo) / eps;
501
502 // Add numerical stability checks
503 if grad_alpha.is_finite() {
504 let update = learning_rate * grad_alpha;
505 if update.is_finite() {
506 *log_alpha += update;
507 // Clamp to reasonable range for Beta
508 *log_alpha = log_alpha.clamp(-5.0, 5.0);
509 }
510 }
511 }
512 }
513 }
514
515 if iter % 100 == 0 {
516 println!("VI Iteration {}: ELBO = {:.4}", iter, current_elbo);
517 }
518 }
519
520 guide
521}
522
523// Keep the original simple function for backward compatibility
524pub fn estimate_elbo<A, R: Rng>(
525 rng: &mut R,
526 model_fn: impl Fn() -> Model<A>,
527 num_samples: usize,
528) -> f64 {
529 let mut total = 0.0;
530 for _ in 0..num_samples {
531 let (_a, prior_t) = run(
532 PriorHandler {
533 rng,
534 trace: Trace::default(),
535 },
536 model_fn(),
537 );
538 let (_a2, scored) = run(
539 ScoreGivenTrace {
540 base: prior_t.clone(),
541 trace: Trace::default(),
542 },
543 model_fn(),
544 );
545 total += scored.total_log_weight();
546 }
547 total / (num_samples as f64)
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553 use crate::addr;
554
555 use crate::core::model::{observe, sample, ModelExt};
556 use crate::runtime::trace::{Choice, ChoiceValue, Trace};
557 use rand::rngs::StdRng;
558 use rand::SeedableRng;
559
560 #[test]
561 fn variational_param_sampling_and_log_prob() {
562 let mut rng = StdRng::seed_from_u64(20);
563 let vp_n = VariationalParam::Normal {
564 mu: 0.0,
565 log_sigma: 0.0,
566 };
567 let x = vp_n.sample(&mut rng);
568 assert!(x.is_finite());
569 assert!(vp_n.log_prob(x).is_finite());
570
571 let vp_b = VariationalParam::Beta {
572 log_alpha: (2.0f64).ln(),
573 log_beta: (3.0f64).ln(),
574 };
575 let y = vp_b.sample(&mut rng);
576 assert!(y > 0.0 && y < 1.0);
577 assert!(vp_b.log_prob(y).is_finite());
578 }
579
580 #[test]
581 fn elbo_computation_is_finite() {
582 let model_fn = || {
583 sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()).and_then(|mu| {
584 observe(addr!("y"), Normal::new(mu, 1.0).unwrap(), 0.2).map(move |_| mu)
585 })
586 };
587
588 // Build a simple guide
589 let mut guide = MeanFieldGuide::new();
590 guide.params.insert(
591 addr!("mu"),
592 VariationalParam::Normal {
593 mu: 0.0,
594 log_sigma: 0.0,
595 },
596 );
597
598 let mut rng = StdRng::seed_from_u64(21);
599 let elbo = elbo_with_guide(&mut rng, model_fn, &guide, 5);
600 assert!(elbo.is_finite());
601 }
602
603 #[test]
604 fn meanfield_from_trace_and_sampling() {
605 // Create a base trace with mixed types
606 let mut base = Trace::default();
607 base.choices.insert(
608 addr!("pos"),
609 Choice {
610 addr: addr!("pos"),
611 value: ChoiceValue::F64(-1.0),
612 logp: -0.1,
613 },
614 );
615 base.choices.insert(
616 addr!("bool"),
617 Choice {
618 addr: addr!("bool"),
619 value: ChoiceValue::Bool(true),
620 logp: -0.7,
621 },
622 );
623 base.choices.insert(
624 addr!("u64"),
625 Choice {
626 addr: addr!("u64"),
627 value: ChoiceValue::U64(3),
628 logp: -0.5,
629 },
630 );
631
632 let guide = MeanFieldGuide::from_trace(&base);
633 assert!(!guide.params.is_empty());
634
635 // Sample a trace from the guide
636 let t = guide.sample_trace(&mut StdRng::seed_from_u64(22));
637 assert!(!t.choices.is_empty());
638 assert!(t.log_prior.is_finite());
639 }
640
641 #[test]
642 fn optimize_vi_updates_parameters_and_is_stable() {
643 let model_fn = || {
644 sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()).and_then(|mu| {
645 observe(addr!("y"), Normal::new(mu, 1.0).unwrap(), 0.3).map(move |_| mu)
646 })
647 };
648
649 let mut guide = MeanFieldGuide::new();
650 guide.params.insert(
651 addr!("mu"),
652 VariationalParam::Normal {
653 mu: 0.0,
654 log_sigma: 0.0,
655 },
656 );
657
658 let optimized = optimize_meanfield_vi(
659 &mut StdRng::seed_from_u64(23),
660 model_fn,
661 guide.clone(),
662 2, // small iterations for speed
663 3,
664 0.1,
665 );
666
667 // Parameter exists and remains within clamped bounds
668 if let VariationalParam::Normal { mu, .. } = optimized.params.get(&addr!("mu")).unwrap() {
669 assert!(*mu <= 100.0 && *mu >= -100.0);
670 } else {
671 panic!("expected Normal param");
672 }
673 }
674}