fugue/inference/smc.rs
1//! Sequential Monte Carlo (SMC) with particle filtering and resampling.
2//!
3//! This module implements Sequential Monte Carlo methods, also known as particle filters.
4//! SMC maintains a population of weighted particles (traces) and uses resampling to
5//! focus computational effort on high-probability regions of the posterior.
6//!
7//! ## Key Features
8//!
9//! - **Multiple resampling methods**: Multinomial, Systematic, Stratified
10//! - **Effective Sample Size (ESS) monitoring**: Automatic resampling triggers
11//! - **Rejuvenation**: Optional MCMC moves to maintain particle diversity
12//! - **Adaptive resampling**: Resample only when ESS drops below threshold
13//!
14//! ## Algorithm Overview
15//!
16//! SMC works by maintaining a population of particles, each representing a possible
17//! state (parameter configuration) with an associated weight:
18//!
19//! 1. **Initialize**: Start with particles from the prior
20//! 2. **Weight**: Compute importance weights based on likelihood
21//! 3. **Resample**: When weights become uneven, resample to maintain diversity
22//! 4. **Rejuvenate**: Optionally apply MCMC moves to particles
23//! 5. **Repeat**: Continue until convergence or max iterations
24//!
25//! ## When to Use SMC
26//!
27//! SMC is particularly effective for:
28//! - Models with many observations that can be processed sequentially
29//! - High-dimensional parameter spaces where MCMC mixes poorly
30//! - Real-time inference where new data arrives continuously
31//! - Situations where you need multiple diverse posterior samples
32//!
33//! # Examples
34//!
35//! ```rust
36//! use fugue::*;
37//! use rand::rngs::StdRng;
38//! use rand::SeedableRng;
39//!
40//! // Define a simple model
41//! let model_fn = || {
42//! sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap())
43//! .bind(|mu| {
44//! observe(addr!("y"), Normal::new(mu, 0.5).unwrap(), 2.0)
45//! .map(move |_| mu)
46//! })
47//! };
48//!
49//! // Run SMC (small numbers for testing)
50//! let mut rng = StdRng::seed_from_u64(42);
51//! let config = SMCConfig::default();
52//! let particles = adaptive_smc(&mut rng, 10, model_fn, config);
53//!
54//! // Analyze results
55//! let ess = effective_sample_size(&particles);
56//! assert!(ess > 0.0);
57//! ```
58use crate::core::model::Model;
59use crate::inference::mcmc_utils::DiminishingAdaptation;
60use crate::inference::mh::adaptive_single_site_mh;
61use crate::runtime::handler::run;
62use crate::runtime::interpreters::PriorHandler;
63use crate::runtime::trace::Trace;
64use rand::Rng;
65
66/// A weighted particle in the SMC population.
67///
68/// Each particle represents a possible state (parameter configuration) with
69/// associated weights that reflect its probability relative to other particles.
70/// The weight decomposition into linear and log space enables numerical stability.
71///
72/// # Fields
73///
74/// * `trace` - Execution trace containing parameter values and log-probabilities
75/// * `weight` - Normalized linear weight (used for resampling)
76/// * `log_weight` - Log-space weight (for numerical stability)
77///
78/// # Examples
79///
80/// ```rust
81/// use fugue::*;
82///
83/// // Particles are typically created by SMC algorithms
84/// let particle = Particle {
85/// trace: Trace::default(),
86/// weight: 0.25, // 25% of total weight
87/// log_weight: -1.386, // ln(0.25)
88/// };
89///
90/// println!("Particle weight: {:.3}", particle.weight);
91/// ```
92#[derive(Clone, Debug)]
93pub struct Particle {
94 /// Execution trace containing parameter values and log-probabilities.
95 pub trace: Trace,
96 /// Normalized linear weight (used for resampling).
97 pub weight: f64,
98 /// Log-space weight (for numerical stability).
99 pub log_weight: f64,
100}
101
102/// Resampling algorithms for particle filters.
103///
104/// Different resampling methods offer trade-offs between computational efficiency,
105/// variance reduction, and implementation complexity. All methods aim to replace
106/// low-weight particles with copies of high-weight particles.
107///
108/// # Variants
109///
110/// * `Multinomial` - Simple multinomial resampling (high variance)
111/// * `Systematic` - Low-variance systematic resampling (recommended)
112/// * `Stratified` - Stratified resampling (balanced variance/complexity)
113///
114/// # Examples
115///
116/// ```rust
117/// use fugue::*;
118///
119/// // Configure SMC with different resampling methods
120/// let config_systematic = SMCConfig {
121/// resampling_method: ResamplingMethod::Systematic,
122/// ..Default::default()
123/// };
124///
125/// let config_multinomial = SMCConfig {
126/// resampling_method: ResamplingMethod::Multinomial,
127/// ..Default::default()
128/// };
129/// ```
130#[derive(Clone, Copy, Debug)]
131pub enum ResamplingMethod {
132 /// Simple multinomial resampling with replacement.
133 Multinomial,
134 /// Low-variance systematic resampling (recommended).
135 Systematic,
136 /// Stratified resampling with balanced variance.
137 Stratified,
138}
139
140/// Configuration options for Sequential Monte Carlo.
141///
142/// This struct controls various aspects of the SMC algorithm, allowing fine-tuning
143/// of performance and accuracy trade-offs.
144///
145/// # Fields
146///
147/// * `resampling_method` - Algorithm used for particle resampling
148/// * `ess_threshold` - ESS threshold that triggers resampling (as fraction of N)
149/// * `rejuvenation_steps` - Number of MCMC moves after resampling to increase diversity
150///
151/// # Examples
152///
153/// ```rust
154/// use fugue::*;
155///
156/// // Conservative configuration (less resampling, more rejuvenation)
157/// let conservative_config = SMCConfig {
158/// resampling_method: ResamplingMethod::Systematic,
159/// ess_threshold: 0.2, // Resample when ESS < 20% of particles
160/// rejuvenation_steps: 5, // 5 MCMC moves after resampling
161/// };
162///
163/// // Aggressive configuration (frequent resampling, no rejuvenation)
164/// let aggressive_config = SMCConfig {
165/// resampling_method: ResamplingMethod::Systematic,
166/// ess_threshold: 0.8, // Resample when ESS < 80% of particles
167/// rejuvenation_steps: 0, // No rejuvenation
168/// };
169/// ```
170pub struct SMCConfig {
171 /// Algorithm used for particle resampling.
172 pub resampling_method: ResamplingMethod,
173 /// ESS threshold that triggers resampling (as fraction of particle count).
174 pub ess_threshold: f64,
175 /// Number of MCMC moves after resampling to increase diversity.
176 pub rejuvenation_steps: usize,
177}
178
179impl Default for SMCConfig {
180 fn default() -> Self {
181 Self {
182 resampling_method: ResamplingMethod::Systematic,
183 ess_threshold: 0.5,
184 rejuvenation_steps: 0,
185 }
186 }
187}
188
189/// Compute the effective sample size (ESS) of a particle population.
190///
191/// ESS measures how many "effective" independent samples the weighted particle
192/// population represents. It ranges from 1 (all weight on one particle) to N
193/// (uniform weights). Low ESS indicates weight degeneracy and triggers resampling.
194///
195/// **Formula:** ESS = 1 / Σᵢ(wᵢ²) where wᵢ are normalized weights.
196///
197/// # Arguments
198///
199/// * `particles` - Population of weighted particles
200///
201/// # Returns
202///
203/// Effective sample size (1.0 ≤ ESS ≤ N where N = particles.len()).
204///
205/// # Examples
206///
207/// ```rust
208/// use fugue::*;
209///
210/// // Uniform weights -> high ESS
211/// let uniform_particles = vec![
212/// Particle { trace: Trace::default(), weight: 0.25, log_weight: -1.386 },
213/// Particle { trace: Trace::default(), weight: 0.25, log_weight: -1.386 },
214/// Particle { trace: Trace::default(), weight: 0.25, log_weight: -1.386 },
215/// Particle { trace: Trace::default(), weight: 0.25, log_weight: -1.386 },
216/// ];
217/// let ess = effective_sample_size(&uniform_particles);
218/// assert!((ess - 4.0).abs() < 0.01); // ESS ≈ 4 (perfect)
219///
220/// // Degenerate weights -> low ESS
221/// let degenerate_particles = vec![
222/// Particle { trace: Trace::default(), weight: 0.99, log_weight: -0.01 },
223/// Particle { trace: Trace::default(), weight: 0.01, log_weight: -4.605 },
224/// ];
225/// let ess = effective_sample_size(°enerate_particles);
226/// assert!(ess < 1.1); // ESS ≈ 1 (very poor)
227/// ```
228pub fn effective_sample_size(particles: &[Particle]) -> f64 {
229 let sum_sq: f64 = particles.iter().map(|p| p.weight * p.weight).sum();
230 1.0 / sum_sq
231}
232
233/// Systematic resampling.
234pub fn systematic_resample<R: Rng>(rng: &mut R, particles: &[Particle]) -> Vec<usize> {
235 let n = particles.len();
236 let mut indices = Vec::with_capacity(n);
237 let u = rng.gen::<f64>() / n as f64;
238
239 let mut cum_weight = 0.0;
240 let mut i = 0;
241
242 for j in 0..n {
243 let threshold = u + j as f64 / n as f64;
244 while cum_weight < threshold && i < n {
245 cum_weight += particles[i].weight;
246 i += 1;
247 }
248 indices.push((i - 1).min(n - 1));
249 }
250 indices
251}
252
253/// Stratified resampling.
254pub fn stratified_resample<R: Rng>(rng: &mut R, particles: &[Particle]) -> Vec<usize> {
255 let n = particles.len();
256 let mut indices = Vec::with_capacity(n);
257
258 let mut cum_weight = 0.0;
259 let mut i = 0;
260
261 for j in 0..n {
262 let u = rng.gen::<f64>();
263 let threshold = (j as f64 + u) / n as f64;
264 while cum_weight < threshold && i < n {
265 cum_weight += particles[i].weight;
266 i += 1;
267 }
268 indices.push((i - 1).min(n - 1));
269 }
270 indices
271}
272
273/// Multinomial resampling.
274pub fn multinomial_resample<R: Rng>(rng: &mut R, particles: &[Particle]) -> Vec<usize> {
275 let n = particles.len();
276 let mut indices = Vec::with_capacity(n);
277
278 for _ in 0..n {
279 let u = rng.gen::<f64>();
280 let mut cum_weight = 0.0;
281 let mut selected = n - 1;
282
283 for (i, p) in particles.iter().enumerate() {
284 cum_weight += p.weight;
285 if u <= cum_weight {
286 selected = i;
287 break;
288 }
289 }
290 indices.push(selected);
291 }
292 indices
293}
294
295/// Resample particles based on weights.
296pub fn resample_particles<R: Rng>(
297 rng: &mut R,
298 particles: &[Particle],
299 method: ResamplingMethod,
300) -> Vec<Particle> {
301 let indices = match method {
302 ResamplingMethod::Multinomial => multinomial_resample(rng, particles),
303 ResamplingMethod::Systematic => systematic_resample(rng, particles),
304 ResamplingMethod::Stratified => stratified_resample(rng, particles),
305 };
306
307 let n = particles.len();
308 let uniform_weight = 1.0 / n as f64;
309
310 indices
311 .into_iter()
312 .map(|i| {
313 let mut p = particles[i].clone();
314 p.weight = uniform_weight;
315 p.log_weight = uniform_weight.ln();
316 p
317 })
318 .collect()
319}
320
321/// Run adaptive Sequential Monte Carlo with resampling and rejuvenation.
322///
323/// This is the main SMC algorithm that maintains a population of weighted particles
324/// and adaptively resamples when the effective sample size drops below a threshold.
325/// Optional rejuvenation steps help maintain particle diversity after resampling.
326///
327/// # Algorithm
328///
329/// 1. Initialize particles by sampling from the prior
330/// 2. Compute weights and effective sample size
331/// 3. If ESS < threshold × N: resample particles
332/// 4. Apply rejuvenation moves (MCMC) if configured
333/// 5. Return final particle population
334///
335/// # Arguments
336///
337/// * `rng` - Random number generator
338/// * `num_particles` - Size of particle population to maintain
339/// * `model_fn` - Function that creates the model
340/// * `config` - SMC configuration (resampling method, thresholds, etc.)
341///
342/// # Returns
343///
344/// Final population of weighted particles representing the posterior.
345///
346/// # Examples
347///
348/// ```rust
349/// use fugue::*;
350/// use rand::rngs::StdRng;
351/// use rand::SeedableRng;
352///
353/// // Simple model for testing
354/// let model_fn = || {
355/// sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap())
356/// .bind(|mu| {
357/// observe(addr!("y"), Normal::new(mu, 0.5).unwrap(), 1.8)
358/// .map(move |_| mu)
359/// })
360/// };
361///
362/// // Run SMC with small numbers for testing
363/// let mut rng = StdRng::seed_from_u64(42);
364/// let config = SMCConfig {
365/// resampling_method: ResamplingMethod::Systematic,
366/// ess_threshold: 0.5,
367/// rejuvenation_steps: 1,
368/// };
369///
370/// let particles = adaptive_smc(&mut rng, 5, model_fn, config);
371///
372/// // Analyze posterior
373/// let mu_estimates: Vec<f64> = particles.iter()
374/// .filter_map(|p| p.trace.choices.get(&addr!("mu")))
375/// .filter_map(|choice| match choice.value {
376/// ChoiceValue::F64(mu) => Some(mu),
377/// _ => None,
378/// })
379/// .collect();
380///
381/// assert!(!mu_estimates.is_empty());
382/// ```
383pub fn adaptive_smc<A, R: Rng>(
384 rng: &mut R,
385 num_particles: usize,
386 model_fn: impl Fn() -> Model<A>,
387 config: SMCConfig,
388) -> Vec<Particle> {
389 let mut particles = smc_prior_particles(rng, num_particles, &model_fn);
390
391 // Check if resampling is needed
392 let ess = effective_sample_size(&particles);
393 let ess_ratio = ess / num_particles as f64;
394
395 if ess_ratio < config.ess_threshold {
396 // Resample
397 particles = resample_particles(rng, &particles, config.resampling_method);
398
399 // Optional rejuvenation with MCMC
400 if config.rejuvenation_steps > 0 {
401 let mut adaptation = DiminishingAdaptation::new(0.44, 0.7);
402 for particle in &mut particles {
403 for _ in 0..config.rejuvenation_steps {
404 let (_, new_trace) =
405 adaptive_single_site_mh(rng, &model_fn, &particle.trace, &mut adaptation);
406 particle.trace = new_trace;
407 particle.log_weight = particle.trace.total_log_weight();
408 }
409 }
410
411 // Renormalize after rejuvenation
412 normalize_particles(&mut particles);
413 }
414 }
415
416 particles
417}
418
419/// Normalize particle weights using numerically stable log-sum-exp.
420///
421/// This function properly handles extreme log-weights without underflow or overflow,
422/// which is critical for reliable SMC performance.
423pub fn normalize_particles(particles: &mut [Particle]) {
424 use crate::core::numerical::log_sum_exp;
425
426 if particles.is_empty() {
427 return;
428 }
429
430 // Collect log weights
431 let log_weights: Vec<f64> = particles.iter().map(|p| p.log_weight).collect();
432
433 // Compute log normalizing constant stably
434 let log_norm = log_sum_exp(&log_weights);
435
436 // Handle degenerate case where all weights are -∞
437 if log_norm.is_infinite() && log_norm < 0.0 {
438 let n = particles.len();
439 for p in particles {
440 p.weight = 1.0 / n as f64; // Uniform weights as fallback
441 }
442 return;
443 }
444
445 // Normalize weights stably
446 for (p, &log_w) in particles.iter_mut().zip(&log_weights) {
447 p.weight = (log_w - log_norm).exp();
448 }
449
450 // Ensure weights sum to 1.0 (handle small numerical errors)
451 let weight_sum: f64 = particles.iter().map(|p| p.weight).sum();
452 if weight_sum > 0.0 {
453 for p in particles {
454 p.weight /= weight_sum;
455 }
456 }
457}
458
459pub fn smc_prior_particles<A, R: Rng>(
460 rng: &mut R,
461 num_particles: usize,
462 model_fn: impl Fn() -> Model<A>,
463) -> Vec<Particle> {
464 let mut particles = Vec::with_capacity(num_particles);
465 for _ in 0..num_particles {
466 let (_a, t) = run(
467 PriorHandler {
468 rng,
469 trace: Trace::default(),
470 },
471 model_fn(),
472 );
473 particles.push(Particle {
474 trace: t.clone(),
475 weight: 0.0, // Will be set by normalization
476 log_weight: t.total_log_weight(),
477 });
478 }
479 normalize_particles(&mut particles);
480 particles
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use crate::addr;
487 use crate::core::distribution::*;
488 use crate::core::model::{observe, sample, ModelExt};
489 use rand::rngs::StdRng;
490 use rand::SeedableRng;
491
492 #[test]
493 fn ess_and_resampling_behave() {
494 // Construct 4 particles with uneven weights
495 let particles = vec![
496 Particle {
497 trace: Trace::default(),
498 weight: 0.7,
499 log_weight: (0.7f64).ln(),
500 },
501 Particle {
502 trace: Trace::default(),
503 weight: 0.2,
504 log_weight: (0.2f64).ln(),
505 },
506 Particle {
507 trace: Trace::default(),
508 weight: 0.09,
509 log_weight: (0.09f64).ln(),
510 },
511 Particle {
512 trace: Trace::default(),
513 weight: 0.01,
514 log_weight: (0.01f64).ln(),
515 },
516 ];
517 let ess_val = effective_sample_size(&particles);
518 assert!(ess_val < particles.len() as f64);
519
520 // Resampling indices should be valid and length preserved
521 let mut rng = StdRng::seed_from_u64(1);
522 let idx_m = multinomial_resample(&mut rng, &particles);
523 assert_eq!(idx_m.len(), particles.len());
524
525 let idx_s = systematic_resample(&mut rng, &particles);
526 assert_eq!(idx_s.len(), particles.len());
527
528 let idx_t = stratified_resample(&mut rng, &particles);
529 assert_eq!(idx_t.len(), particles.len());
530
531 // Resample and check normalized uniform weights
532 let resampled = resample_particles(&mut rng, &particles, ResamplingMethod::Systematic);
533 let sum_w: f64 = resampled.iter().map(|p| p.weight).sum();
534 assert!((sum_w - 1.0).abs() < 1e-12);
535 for p in &resampled {
536 assert!((p.weight - 0.25).abs() < 1e-12);
537 }
538 }
539
540 #[test]
541 fn normalize_particles_handles_neg_inf() {
542 let mut particles = vec![
543 Particle {
544 trace: Trace::default(),
545 weight: 0.0,
546 log_weight: f64::NEG_INFINITY,
547 },
548 Particle {
549 trace: Trace::default(),
550 weight: 0.0,
551 log_weight: f64::NEG_INFINITY,
552 },
553 ];
554 normalize_particles(&mut particles);
555 // Fallback to uniform
556 assert!((particles[0].weight - 0.5).abs() < 1e-12);
557 assert!((particles[1].weight - 0.5).abs() < 1e-12);
558 }
559
560 #[test]
561 fn adaptive_smc_runs_with_small_config() {
562 let model_fn = || {
563 sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()).and_then(|mu| {
564 observe(addr!("y"), Normal::new(mu, 1.0).unwrap(), 0.5).map(move |_| mu)
565 })
566 };
567 let mut rng = StdRng::seed_from_u64(2);
568 let config = SMCConfig {
569 resampling_method: ResamplingMethod::Systematic,
570 ess_threshold: 0.5,
571 rejuvenation_steps: 1,
572 };
573 let particles = adaptive_smc(&mut rng, 5, model_fn, config);
574 assert_eq!(particles.len(), 5);
575 // Weights normalized
576 let sum_w: f64 = particles.iter().map(|p| p.weight).sum();
577 assert!((sum_w - 1.0).abs() < 1e-9);
578 }
579}