fugue/inference/
smc.rs

1//! Sequential Monte Carlo (SMC) with particle filtering and resampling.
2//!
3//! This module implements Sequential Monte Carlo methods, also known as particle filters.
4//! SMC maintains a population of weighted particles (traces) and uses resampling to
5//! focus computational effort on high-probability regions of the posterior.
6//!
7//! ## Key Features
8//!
9//! - **Multiple resampling methods**: Multinomial, Systematic, Stratified
10//! - **Effective Sample Size (ESS) monitoring**: Automatic resampling triggers
11//! - **Rejuvenation**: Optional MCMC moves to maintain particle diversity
12//! - **Adaptive resampling**: Resample only when ESS drops below threshold
13//!
14//! ## Algorithm Overview
15//!
16//! SMC works by maintaining a population of particles, each representing a possible
17//! state (parameter configuration) with an associated weight:
18//!
19//! 1. **Initialize**: Start with particles from the prior
20//! 2. **Weight**: Compute importance weights based on likelihood
21//! 3. **Resample**: When weights become uneven, resample to maintain diversity
22//! 4. **Rejuvenate**: Optionally apply MCMC moves to particles
23//! 5. **Repeat**: Continue until convergence or max iterations
24//!
25//! ## When to Use SMC
26//!
27//! SMC is particularly effective for:
28//! - Models with many observations that can be processed sequentially
29//! - High-dimensional parameter spaces where MCMC mixes poorly
30//! - Real-time inference where new data arrives continuously
31//! - Situations where you need multiple diverse posterior samples
32//!
33//! # Examples
34//!
35//! ```rust
36//! use fugue::*;
37//! use rand::rngs::StdRng;
38//! use rand::SeedableRng;
39//!
40//! // Define a simple model
41//! let model_fn = || {
42//!     sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap())
43//!         .bind(|mu| {
44//!             observe(addr!("y"), Normal::new(mu, 0.5).unwrap(), 2.0)
45//!                 .map(move |_| mu)
46//!         })
47//! };
48//!
49//! // Run SMC (small numbers for testing)
50//! let mut rng = StdRng::seed_from_u64(42);
51//! let config = SMCConfig::default();
52//! let particles = adaptive_smc(&mut rng, 10, model_fn, config);
53//!
54//! // Analyze results
55//! let ess = effective_sample_size(&particles);
56//! assert!(ess > 0.0);
57//! ```
58use crate::core::model::Model;
59use crate::inference::mcmc_utils::DiminishingAdaptation;
60use crate::inference::mh::adaptive_single_site_mh;
61use crate::runtime::handler::run;
62use crate::runtime::interpreters::PriorHandler;
63use crate::runtime::trace::Trace;
64use rand::Rng;
65
66/// A weighted particle in the SMC population.
67///
68/// Each particle represents a possible state (parameter configuration) with
69/// associated weights that reflect its probability relative to other particles.
70/// The weight decomposition into linear and log space enables numerical stability.
71///
72/// # Fields
73///
74/// * `trace` - Execution trace containing parameter values and log-probabilities
75/// * `weight` - Normalized linear weight (used for resampling)
76/// * `log_weight` - Log-space weight (for numerical stability)
77///
78/// # Examples
79///
80/// ```rust
81/// use fugue::*;
82///
83/// // Particles are typically created by SMC algorithms
84/// let particle = Particle {
85///     trace: Trace::default(),
86///     weight: 0.25,           // 25% of total weight
87///     log_weight: -1.386,     // ln(0.25)
88/// };
89///
90/// println!("Particle weight: {:.3}", particle.weight);
91/// ```
92#[derive(Clone, Debug)]
93pub struct Particle {
94    /// Execution trace containing parameter values and log-probabilities.
95    pub trace: Trace,
96    /// Normalized linear weight (used for resampling).
97    pub weight: f64,
98    /// Log-space weight (for numerical stability).
99    pub log_weight: f64,
100}
101
102/// Resampling algorithms for particle filters.
103///
104/// Different resampling methods offer trade-offs between computational efficiency,
105/// variance reduction, and implementation complexity. All methods aim to replace
106/// low-weight particles with copies of high-weight particles.
107///
108/// # Variants
109///
110/// * `Multinomial` - Simple multinomial resampling (high variance)
111/// * `Systematic` - Low-variance systematic resampling (recommended)
112/// * `Stratified` - Stratified resampling (balanced variance/complexity)
113///
114/// # Examples
115///
116/// ```rust
117/// use fugue::*;
118///
119/// // Configure SMC with different resampling methods
120/// let config_systematic = SMCConfig {
121///     resampling_method: ResamplingMethod::Systematic,
122///     ..Default::default()
123/// };
124///
125/// let config_multinomial = SMCConfig {
126///     resampling_method: ResamplingMethod::Multinomial,
127///     ..Default::default()
128/// };
129/// ```
130#[derive(Clone, Copy, Debug)]
131pub enum ResamplingMethod {
132    /// Simple multinomial resampling with replacement.
133    Multinomial,
134    /// Low-variance systematic resampling (recommended).
135    Systematic,
136    /// Stratified resampling with balanced variance.
137    Stratified,
138}
139
140/// Configuration options for Sequential Monte Carlo.
141///
142/// This struct controls various aspects of the SMC algorithm, allowing fine-tuning
143/// of performance and accuracy trade-offs.
144///
145/// # Fields
146///
147/// * `resampling_method` - Algorithm used for particle resampling
148/// * `ess_threshold` - ESS threshold that triggers resampling (as fraction of N)
149/// * `rejuvenation_steps` - Number of MCMC moves after resampling to increase diversity
150///
151/// # Examples
152///
153/// ```rust
154/// use fugue::*;
155///
156/// // Conservative configuration (less resampling, more rejuvenation)
157/// let conservative_config = SMCConfig {
158///     resampling_method: ResamplingMethod::Systematic,
159///     ess_threshold: 0.2,  // Resample when ESS < 20% of particles
160///     rejuvenation_steps: 5, // 5 MCMC moves after resampling
161/// };
162///
163/// // Aggressive configuration (frequent resampling, no rejuvenation)
164/// let aggressive_config = SMCConfig {
165///     resampling_method: ResamplingMethod::Systematic,
166///     ess_threshold: 0.8,  // Resample when ESS < 80% of particles
167///     rejuvenation_steps: 0, // No rejuvenation
168/// };
169/// ```
170pub struct SMCConfig {
171    /// Algorithm used for particle resampling.
172    pub resampling_method: ResamplingMethod,
173    /// ESS threshold that triggers resampling (as fraction of particle count).
174    pub ess_threshold: f64,
175    /// Number of MCMC moves after resampling to increase diversity.
176    pub rejuvenation_steps: usize,
177}
178
179impl Default for SMCConfig {
180    fn default() -> Self {
181        Self {
182            resampling_method: ResamplingMethod::Systematic,
183            ess_threshold: 0.5,
184            rejuvenation_steps: 0,
185        }
186    }
187}
188
189/// Compute the effective sample size (ESS) of a particle population.
190///
191/// ESS measures how many "effective" independent samples the weighted particle
192/// population represents. It ranges from 1 (all weight on one particle) to N
193/// (uniform weights). Low ESS indicates weight degeneracy and triggers resampling.
194///
195/// **Formula:** ESS = 1 / Σᵢ(wᵢ²) where wᵢ are normalized weights.
196///
197/// # Arguments
198///
199/// * `particles` - Population of weighted particles
200///
201/// # Returns
202///
203/// Effective sample size (1.0 ≤ ESS ≤ N where N = particles.len()).
204///
205/// # Examples
206///
207/// ```rust
208/// use fugue::*;
209///
210/// // Uniform weights -> high ESS
211/// let uniform_particles = vec![
212///     Particle { trace: Trace::default(), weight: 0.25, log_weight: -1.386 },
213///     Particle { trace: Trace::default(), weight: 0.25, log_weight: -1.386 },
214///     Particle { trace: Trace::default(), weight: 0.25, log_weight: -1.386 },
215///     Particle { trace: Trace::default(), weight: 0.25, log_weight: -1.386 },
216/// ];
217/// let ess = effective_sample_size(&uniform_particles);
218/// assert!((ess - 4.0).abs() < 0.01); // ESS ≈ 4 (perfect)
219///
220/// // Degenerate weights -> low ESS
221/// let degenerate_particles = vec![
222///     Particle { trace: Trace::default(), weight: 0.99, log_weight: -0.01 },
223///     Particle { trace: Trace::default(), weight: 0.01, log_weight: -4.605 },
224/// ];
225/// let ess = effective_sample_size(&degenerate_particles);
226/// assert!(ess < 1.1); // ESS ≈ 1 (very poor)
227/// ```
228pub fn effective_sample_size(particles: &[Particle]) -> f64 {
229    let sum_sq: f64 = particles.iter().map(|p| p.weight * p.weight).sum();
230    1.0 / sum_sq
231}
232
233/// Systematic resampling.
234pub fn systematic_resample<R: Rng>(rng: &mut R, particles: &[Particle]) -> Vec<usize> {
235    let n = particles.len();
236    let mut indices = Vec::with_capacity(n);
237    let u = rng.gen::<f64>() / n as f64;
238
239    let mut cum_weight = 0.0;
240    let mut i = 0;
241
242    for j in 0..n {
243        let threshold = u + j as f64 / n as f64;
244        while cum_weight < threshold && i < n {
245            cum_weight += particles[i].weight;
246            i += 1;
247        }
248        indices.push((i - 1).min(n - 1));
249    }
250    indices
251}
252
253/// Stratified resampling.
254pub fn stratified_resample<R: Rng>(rng: &mut R, particles: &[Particle]) -> Vec<usize> {
255    let n = particles.len();
256    let mut indices = Vec::with_capacity(n);
257
258    let mut cum_weight = 0.0;
259    let mut i = 0;
260
261    for j in 0..n {
262        let u = rng.gen::<f64>();
263        let threshold = (j as f64 + u) / n as f64;
264        while cum_weight < threshold && i < n {
265            cum_weight += particles[i].weight;
266            i += 1;
267        }
268        indices.push((i - 1).min(n - 1));
269    }
270    indices
271}
272
273/// Multinomial resampling.
274pub fn multinomial_resample<R: Rng>(rng: &mut R, particles: &[Particle]) -> Vec<usize> {
275    let n = particles.len();
276    let mut indices = Vec::with_capacity(n);
277
278    for _ in 0..n {
279        let u = rng.gen::<f64>();
280        let mut cum_weight = 0.0;
281        let mut selected = n - 1;
282
283        for (i, p) in particles.iter().enumerate() {
284            cum_weight += p.weight;
285            if u <= cum_weight {
286                selected = i;
287                break;
288            }
289        }
290        indices.push(selected);
291    }
292    indices
293}
294
295/// Resample particles based on weights.
296pub fn resample_particles<R: Rng>(
297    rng: &mut R,
298    particles: &[Particle],
299    method: ResamplingMethod,
300) -> Vec<Particle> {
301    let indices = match method {
302        ResamplingMethod::Multinomial => multinomial_resample(rng, particles),
303        ResamplingMethod::Systematic => systematic_resample(rng, particles),
304        ResamplingMethod::Stratified => stratified_resample(rng, particles),
305    };
306
307    let n = particles.len();
308    let uniform_weight = 1.0 / n as f64;
309
310    indices
311        .into_iter()
312        .map(|i| {
313            let mut p = particles[i].clone();
314            p.weight = uniform_weight;
315            p.log_weight = uniform_weight.ln();
316            p
317        })
318        .collect()
319}
320
321/// Run adaptive Sequential Monte Carlo with resampling and rejuvenation.
322///
323/// This is the main SMC algorithm that maintains a population of weighted particles
324/// and adaptively resamples when the effective sample size drops below a threshold.
325/// Optional rejuvenation steps help maintain particle diversity after resampling.
326///
327/// # Algorithm
328///
329/// 1. Initialize particles by sampling from the prior
330/// 2. Compute weights and effective sample size
331/// 3. If ESS < threshold × N: resample particles
332/// 4. Apply rejuvenation moves (MCMC) if configured
333/// 5. Return final particle population
334///
335/// # Arguments
336///
337/// * `rng` - Random number generator
338/// * `num_particles` - Size of particle population to maintain
339/// * `model_fn` - Function that creates the model
340/// * `config` - SMC configuration (resampling method, thresholds, etc.)
341///
342/// # Returns
343///
344/// Final population of weighted particles representing the posterior.
345///
346/// # Examples
347///
348/// ```rust
349/// use fugue::*;
350/// use rand::rngs::StdRng;
351/// use rand::SeedableRng;
352///
353/// // Simple model for testing
354/// let model_fn = || {
355///     sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap())
356///         .bind(|mu| {
357///             observe(addr!("y"), Normal::new(mu, 0.5).unwrap(), 1.8)
358///                 .map(move |_| mu)
359///         })
360/// };
361///
362/// // Run SMC with small numbers for testing
363/// let mut rng = StdRng::seed_from_u64(42);
364/// let config = SMCConfig {
365///     resampling_method: ResamplingMethod::Systematic,
366///     ess_threshold: 0.5,
367///     rejuvenation_steps: 1,
368/// };
369///
370/// let particles = adaptive_smc(&mut rng, 5, model_fn, config);
371///
372/// // Analyze posterior
373/// let mu_estimates: Vec<f64> = particles.iter()
374///     .filter_map(|p| p.trace.choices.get(&addr!("mu")))
375///     .filter_map(|choice| match choice.value {
376///         ChoiceValue::F64(mu) => Some(mu),
377///         _ => None,
378///     })
379///     .collect();
380///
381/// assert!(!mu_estimates.is_empty());
382/// ```
383pub fn adaptive_smc<A, R: Rng>(
384    rng: &mut R,
385    num_particles: usize,
386    model_fn: impl Fn() -> Model<A>,
387    config: SMCConfig,
388) -> Vec<Particle> {
389    let mut particles = smc_prior_particles(rng, num_particles, &model_fn);
390
391    // Check if resampling is needed
392    let ess = effective_sample_size(&particles);
393    let ess_ratio = ess / num_particles as f64;
394
395    if ess_ratio < config.ess_threshold {
396        // Resample
397        particles = resample_particles(rng, &particles, config.resampling_method);
398
399        // Optional rejuvenation with MCMC
400        if config.rejuvenation_steps > 0 {
401            let mut adaptation = DiminishingAdaptation::new(0.44, 0.7);
402            for particle in &mut particles {
403                for _ in 0..config.rejuvenation_steps {
404                    let (_, new_trace) =
405                        adaptive_single_site_mh(rng, &model_fn, &particle.trace, &mut adaptation);
406                    particle.trace = new_trace;
407                    particle.log_weight = particle.trace.total_log_weight();
408                }
409            }
410
411            // Renormalize after rejuvenation
412            normalize_particles(&mut particles);
413        }
414    }
415
416    particles
417}
418
419/// Normalize particle weights using numerically stable log-sum-exp.
420///
421/// This function properly handles extreme log-weights without underflow or overflow,
422/// which is critical for reliable SMC performance.
423pub fn normalize_particles(particles: &mut [Particle]) {
424    use crate::core::numerical::log_sum_exp;
425
426    if particles.is_empty() {
427        return;
428    }
429
430    // Collect log weights
431    let log_weights: Vec<f64> = particles.iter().map(|p| p.log_weight).collect();
432
433    // Compute log normalizing constant stably
434    let log_norm = log_sum_exp(&log_weights);
435
436    // Handle degenerate case where all weights are -∞
437    if log_norm.is_infinite() && log_norm < 0.0 {
438        let n = particles.len();
439        for p in particles {
440            p.weight = 1.0 / n as f64; // Uniform weights as fallback
441        }
442        return;
443    }
444
445    // Normalize weights stably
446    for (p, &log_w) in particles.iter_mut().zip(&log_weights) {
447        p.weight = (log_w - log_norm).exp();
448    }
449
450    // Ensure weights sum to 1.0 (handle small numerical errors)
451    let weight_sum: f64 = particles.iter().map(|p| p.weight).sum();
452    if weight_sum > 0.0 {
453        for p in particles {
454            p.weight /= weight_sum;
455        }
456    }
457}
458
459pub fn smc_prior_particles<A, R: Rng>(
460    rng: &mut R,
461    num_particles: usize,
462    model_fn: impl Fn() -> Model<A>,
463) -> Vec<Particle> {
464    let mut particles = Vec::with_capacity(num_particles);
465    for _ in 0..num_particles {
466        let (_a, t) = run(
467            PriorHandler {
468                rng,
469                trace: Trace::default(),
470            },
471            model_fn(),
472        );
473        particles.push(Particle {
474            trace: t.clone(),
475            weight: 0.0, // Will be set by normalization
476            log_weight: t.total_log_weight(),
477        });
478    }
479    normalize_particles(&mut particles);
480    particles
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use crate::addr;
487    use crate::core::distribution::*;
488    use crate::core::model::{observe, sample, ModelExt};
489    use rand::rngs::StdRng;
490    use rand::SeedableRng;
491
492    #[test]
493    fn ess_and_resampling_behave() {
494        // Construct 4 particles with uneven weights
495        let particles = vec![
496            Particle {
497                trace: Trace::default(),
498                weight: 0.7,
499                log_weight: (0.7f64).ln(),
500            },
501            Particle {
502                trace: Trace::default(),
503                weight: 0.2,
504                log_weight: (0.2f64).ln(),
505            },
506            Particle {
507                trace: Trace::default(),
508                weight: 0.09,
509                log_weight: (0.09f64).ln(),
510            },
511            Particle {
512                trace: Trace::default(),
513                weight: 0.01,
514                log_weight: (0.01f64).ln(),
515            },
516        ];
517        let ess_val = effective_sample_size(&particles);
518        assert!(ess_val < particles.len() as f64);
519
520        // Resampling indices should be valid and length preserved
521        let mut rng = StdRng::seed_from_u64(1);
522        let idx_m = multinomial_resample(&mut rng, &particles);
523        assert_eq!(idx_m.len(), particles.len());
524
525        let idx_s = systematic_resample(&mut rng, &particles);
526        assert_eq!(idx_s.len(), particles.len());
527
528        let idx_t = stratified_resample(&mut rng, &particles);
529        assert_eq!(idx_t.len(), particles.len());
530
531        // Resample and check normalized uniform weights
532        let resampled = resample_particles(&mut rng, &particles, ResamplingMethod::Systematic);
533        let sum_w: f64 = resampled.iter().map(|p| p.weight).sum();
534        assert!((sum_w - 1.0).abs() < 1e-12);
535        for p in &resampled {
536            assert!((p.weight - 0.25).abs() < 1e-12);
537        }
538    }
539
540    #[test]
541    fn normalize_particles_handles_neg_inf() {
542        let mut particles = vec![
543            Particle {
544                trace: Trace::default(),
545                weight: 0.0,
546                log_weight: f64::NEG_INFINITY,
547            },
548            Particle {
549                trace: Trace::default(),
550                weight: 0.0,
551                log_weight: f64::NEG_INFINITY,
552            },
553        ];
554        normalize_particles(&mut particles);
555        // Fallback to uniform
556        assert!((particles[0].weight - 0.5).abs() < 1e-12);
557        assert!((particles[1].weight - 0.5).abs() < 1e-12);
558    }
559
560    #[test]
561    fn adaptive_smc_runs_with_small_config() {
562        let model_fn = || {
563            sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()).and_then(|mu| {
564                observe(addr!("y"), Normal::new(mu, 1.0).unwrap(), 0.5).map(move |_| mu)
565            })
566        };
567        let mut rng = StdRng::seed_from_u64(2);
568        let config = SMCConfig {
569            resampling_method: ResamplingMethod::Systematic,
570            ess_threshold: 0.5,
571            rejuvenation_steps: 1,
572        };
573        let particles = adaptive_smc(&mut rng, 5, model_fn, config);
574        assert_eq!(particles.len(), 5);
575        // Weights normalized
576        let sum_w: f64 = particles.iter().map(|p| p.weight).sum();
577        assert!((sum_w - 1.0).abs() < 1e-9);
578    }
579}