fugue/inference/abc.rs
1//! Approximate Bayesian Computation (ABC) - likelihood-free inference methods.
2//!
3//! ABC methods enable Bayesian inference for models where the likelihood function
4//! is intractable or unavailable, but forward simulation from the model is possible.
5//! Instead of computing likelihoods directly, ABC compares simulated data to observed
6//! data using distance functions and accepts samples that produce "similar" outcomes.
7//!
8//! ## Method Overview
9//!
10//! ABC algorithms follow this general pattern:
11//! 1. Sample parameters from the prior distribution
12//! 2. Simulate data using the model with those parameters
13//! 3. Compare simulated data to observed data using a distance function
14//! 4. Accept samples where the distance is below a threshold ε
15//!
16//! As ε → 0, the ABC posterior approaches the true posterior distribution.
17//!
18//! ## Available Methods
19//!
20//! - [`abc_rejection`]: Basic rejection ABC
21//! - [`abc_smc`]: Sequential Monte Carlo ABC for improved efficiency
22//! - [`abc_scalar_summary`]: ABC with scalar summary statistics
23//!
24//! ## Distance Functions
25//!
26//! The quality of ABC inference depends heavily on the choice of distance function:
27//! - [`EuclideanDistance`]: L2 norm for continuous data vectors
28//! - [`ManhattanDistance`]: L1 norm for robust distance computation
29//! - Custom distance functions via the [`DistanceFunction`] trait
30//!
31//! # Examples
32//!
33//! ```rust
34//! use fugue::*;
35//! use rand::rngs::StdRng;
36//! use rand::SeedableRng;
37//! use rand::Rng;
38//!
39//! // Simple ABC example for illustration
40//! let mut rng = StdRng::seed_from_u64(42);
41//! let observed_data = vec![2.0];
42//!
43//! let samples = abc_scalar_summary(
44//! &mut rng,
45//! || sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap()),
46//! |trace| {
47//! if let Some(choice) = trace.choices.get(&addr!("mu")) {
48//! if let ChoiceValue::F64(mu) = choice.value {
49//! mu
50//! } else { 0.0 }
51//! } else { 0.0 }
52//! },
53//! 2.0, // observed summary
54//! 0.5, // tolerance
55//! 10 // max samples
56//! );
57//!
58//! assert!(!samples.is_empty());
59//! ```
60
61use crate::core::model::Model;
62use crate::runtime::handler::run;
63use crate::runtime::interpreters::PriorHandler;
64use crate::runtime::trace::Trace;
65use rand::Rng;
66
67/// Trait for computing distances between observed and simulated data.
68///
69/// Distance functions are crucial for ABC methods as they determine how
70/// "similarity" between datasets is measured. The choice of distance function
71/// significantly affects the quality of ABC approximations.
72///
73/// # Type Parameter
74///
75/// * `T` - Type of data being compared (e.g., `Vec<f64>`, scalar values)
76///
77/// # Examples
78///
79/// ```rust
80/// use fugue::*;
81///
82/// // Use built-in Euclidean distance
83/// let euclidean = EuclideanDistance;
84/// let dist = euclidean.distance(&vec![1.0, 2.0], &vec![1.1, 2.1]);
85///
86/// // Implement custom distance function
87/// struct ScalarDistance;
88/// impl DistanceFunction<f64> for ScalarDistance {
89/// fn distance(&self, observed: &f64, simulated: &f64) -> f64 {
90/// (observed - simulated).abs()
91/// }
92/// }
93/// ```
94pub trait DistanceFunction<T> {
95 /// Compute the distance between observed and simulated data.
96 ///
97 /// # Arguments
98 ///
99 /// * `observed` - The actual observed data
100 /// * `simulated` - Data simulated from the model
101 ///
102 /// # Returns
103 ///
104 /// A non-negative distance value. Smaller values indicate greater similarity.
105 fn distance(&self, observed: &T, simulated: &T) -> f64;
106}
107
108/// Euclidean (L2) distance function for vector data.
109///
110/// Computes the standard Euclidean distance between two vectors:
111/// √(Σ(xᵢ - yᵢ)²)
112///
113/// This is appropriate for continuous data where the magnitude of differences
114/// matters and the data dimensions have similar scales.
115///
116/// # Examples
117///
118/// ```rust
119/// use fugue::*;
120///
121/// let euclidean = EuclideanDistance;
122/// let observed = vec![1.0, 2.0, 3.0];
123/// let simulated = vec![1.1, 2.1, 2.9];
124/// let distance = euclidean.distance(&observed, &simulated);
125/// assert!((distance - 0.173).abs() < 0.01); // ≈ 0.173
126/// ```
127pub struct EuclideanDistance;
128
129impl DistanceFunction<Vec<f64>> for EuclideanDistance {
130 fn distance(&self, observed: &Vec<f64>, simulated: &Vec<f64>) -> f64 {
131 if observed.len() != simulated.len() {
132 return f64::INFINITY;
133 }
134
135 observed
136 .iter()
137 .zip(simulated.iter())
138 .map(|(&o, &s)| (o - s).powi(2))
139 .sum::<f64>()
140 .sqrt()
141 }
142}
143
144/// Manhattan (L1) distance function for vector data.
145///
146/// Computes the Manhattan distance between two vectors:
147/// Σ|xᵢ - yᵢ|
148///
149/// This distance is more robust to outliers than Euclidean distance and is
150/// appropriate when you want to treat each dimension independently.
151///
152/// # Examples
153///
154/// ```rust
155/// use fugue::inference::abc::{ManhattanDistance, DistanceFunction};
156///
157/// let manhattan = ManhattanDistance;
158/// let observed = vec![1.0, 2.0, 3.0];
159/// let simulated = vec![1.5, 1.5, 3.5];
160/// let distance = manhattan.distance(&observed, &simulated);
161/// assert!((distance - 1.5).abs() < 0.001); // |1.0-1.5| + |2.0-1.5| + |3.0-3.5| = 0.5 + 0.5 + 0.5 = 1.5
162/// ```
163pub struct ManhattanDistance;
164
165impl DistanceFunction<Vec<f64>> for ManhattanDistance {
166 fn distance(&self, observed: &Vec<f64>, simulated: &Vec<f64>) -> f64 {
167 if observed.len() != simulated.len() {
168 return f64::INFINITY;
169 }
170
171 observed
172 .iter()
173 .zip(simulated.iter())
174 .map(|(&o, &s)| (o - s).abs())
175 .sum::<f64>()
176 }
177}
178
179/// Summary statistics distance.
180pub struct SummaryStatsDistance {
181 pub weights: Vec<f64>,
182}
183
184impl SummaryStatsDistance {
185 pub fn new(weights: Vec<f64>) -> Self {
186 Self { weights }
187 }
188
189 fn compute_stats(data: &[f64]) -> Vec<f64> {
190 if data.is_empty() {
191 return vec![0.0, 0.0, 0.0];
192 }
193
194 let mean = data.iter().sum::<f64>() / data.len() as f64;
195 let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / data.len() as f64;
196 let std = variance.sqrt();
197
198 let mut sorted = data.to_vec();
199 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
200 let median = if sorted.len() % 2 == 0 {
201 (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
202 } else {
203 sorted[sorted.len() / 2]
204 };
205
206 vec![mean, std, median]
207 }
208}
209
210impl DistanceFunction<Vec<f64>> for SummaryStatsDistance {
211 fn distance(&self, observed: &Vec<f64>, simulated: &Vec<f64>) -> f64 {
212 let obs_stats = Self::compute_stats(observed);
213 let sim_stats = Self::compute_stats(simulated);
214
215 obs_stats
216 .iter()
217 .zip(sim_stats.iter())
218 .zip(&self.weights)
219 .map(|((&o, &s), &w)| w * (o - s).powi(2))
220 .sum::<f64>()
221 .sqrt()
222 }
223}
224
225/// Basic ABC rejection sampling algorithm.
226///
227/// The simplest ABC method: repeatedly sample from the prior, simulate data,
228/// and accept samples where the distance to observed data is below a tolerance.
229/// This method is straightforward but can be inefficient for small tolerances.
230///
231/// # Algorithm
232///
233/// 1. Sample parameters from the prior using `model_fn()`
234/// 2. Simulate data using `simulator(trace)`
235/// 3. Compute distance between simulated and observed data
236/// 4. Accept if distance ≤ tolerance
237/// 5. Repeat until `max_samples` accepted or too many attempts
238///
239/// # Arguments
240///
241/// * `rng` - Random number generator
242/// * `model_fn` - Function that creates a model instance (contains priors)
243/// * `simulator` - Function that simulates data given a trace of parameter values
244/// * `observed_data` - The actual observed data to match
245/// * `distance_fn` - Function for measuring similarity between datasets
246/// * `tolerance` - Maximum allowed distance for acceptance
247/// * `max_samples` - Maximum number of samples to accept
248///
249/// # Returns
250///
251/// Vector of accepted traces (parameter samples that produced similar data).
252///
253/// # Examples
254///
255/// ```rust
256/// use fugue::*;
257/// use rand::rngs::StdRng;
258/// use rand::SeedableRng;
259///
260/// // Simple ABC rejection example
261/// let mut rng = StdRng::seed_from_u64(42);
262/// let observed_data = vec![2.0];
263///
264/// let samples = abc_scalar_summary(
265/// &mut rng,
266/// || sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap()),
267/// |trace| {
268/// if let Some(choice) = trace.choices.get(&addr!("mu")) {
269/// if let ChoiceValue::F64(mu) = choice.value {
270/// mu
271/// } else { 0.0 }
272/// } else { 0.0 }
273/// },
274/// 2.0, // observed summary
275/// 0.5, // tolerance
276/// 5 // max samples (small for test)
277/// );
278/// assert!(!samples.is_empty());
279/// ```
280pub fn abc_rejection<A, T, R: Rng>(
281 rng: &mut R,
282 model_fn: impl Fn() -> Model<A>,
283 simulator: impl Fn(&Trace) -> T,
284 observed_data: &T,
285 distance_fn: &dyn DistanceFunction<T>,
286 tolerance: f64,
287 max_samples: usize,
288) -> Vec<Trace> {
289 let mut accepted = Vec::new();
290 let mut attempts = 0;
291
292 while accepted.len() < max_samples && attempts < max_samples * 100 {
293 // Sample from prior
294 let (_a, trace) = run(
295 PriorHandler {
296 rng,
297 trace: Trace::default(),
298 },
299 model_fn(),
300 );
301
302 // Simulate data
303 let simulated_data = simulator(&trace);
304
305 // Check distance
306 let dist = distance_fn.distance(observed_data, &simulated_data);
307
308 if dist <= tolerance {
309 accepted.push(trace);
310 }
311
312 attempts += 1;
313 }
314
315 if accepted.is_empty() {
316 eprintln!(
317 "Warning: No samples accepted in ABC. Consider increasing tolerance or max_samples."
318 );
319 }
320
321 accepted
322}
323
324/// Sequential Monte Carlo ABC with adaptive tolerance scheduling.
325///
326/// An advanced ABC method that uses Sequential Monte Carlo to iteratively
327/// reduce the tolerance, leading to better approximations of the posterior.
328/// SMC-ABC is more efficient than rejection ABC for stringent tolerances.
329///
330/// # Algorithm
331///
332/// 1. Start with initial tolerance and generate particles using rejection ABC
333/// 2. For each subsequent tolerance level:
334/// - Resample particles from the previous population
335/// - Perturb parameters using MCMC moves
336/// - Re-simulate and check new tolerance
337/// 3. Final particles approximate the posterior at the strictest tolerance
338///
339/// # Arguments
340///
341/// * `rng` - Random number generator
342/// * `model_fn` - Function that creates a model instance
343/// * `simulator` - Function that simulates data given a trace
344/// * `observed_data` - The observed data to match
345/// * `distance_fn` - Distance function for comparing datasets
346/// * `initial_tolerance` - Starting tolerance (should be relatively large)
347/// * `tolerance_schedule` - Decreasing sequence of tolerances to use
348/// * `particles_per_round` - Number of particles to maintain in each round
349///
350/// # Returns
351///
352/// Vector of traces from the final SMC population.
353///
354/// # Examples
355///
356/// ```rust
357/// use fugue::{inference::abc::ABCSMCConfig, *};
358/// use rand::rngs::StdRng;
359/// use rand::SeedableRng;
360///
361/// // Simple SMC-ABC example with small numbers for testing
362/// let observed = vec![2.0];
363/// let mut rng = StdRng::seed_from_u64(42);
364///
365/// let samples = abc_smc(
366/// &mut rng,
367/// || sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()),
368/// |trace| {
369/// if let Some(choice) = trace.choices.get(&addr!("mu")) {
370/// if let ChoiceValue::F64(mu) = choice.value {
371/// vec![mu]
372/// } else { vec![0.0] }
373/// } else { vec![0.0] }
374/// },
375/// &observed,
376/// &EuclideanDistance,
377/// ABCSMCConfig {
378/// initial_tolerance: 1.0,
379/// tolerance_schedule: vec![0.5],
380/// particles_per_round: 5,
381/// },
382/// );
383/// assert!(!samples.is_empty());
384/// ```
385/// Configuration for ABC-SMC algorithm.
386#[derive(Debug, Clone)]
387pub struct ABCSMCConfig {
388 /// Initial tolerance for distance threshold
389 pub initial_tolerance: f64,
390 /// Schedule of decreasing tolerances across rounds
391 pub tolerance_schedule: Vec<f64>,
392 /// Number of particles to generate per round
393 pub particles_per_round: usize,
394}
395
396pub fn abc_smc<A, T, R: Rng>(
397 rng: &mut R,
398 model_fn: impl Fn() -> Model<A>,
399 simulator: impl Fn(&Trace) -> T,
400 observed_data: &T,
401 distance_fn: &dyn DistanceFunction<T>,
402 config: ABCSMCConfig,
403) -> Vec<Trace> {
404 let mut current_particles;
405 let mut current_tolerance = config.initial_tolerance;
406
407 // Initial round: ABC rejection
408 current_particles = abc_rejection(
409 rng,
410 &model_fn,
411 &simulator,
412 observed_data,
413 distance_fn,
414 current_tolerance,
415 config.particles_per_round,
416 );
417
418 // Sequential rounds with decreasing tolerance
419 for &new_tolerance in &config.tolerance_schedule {
420 if new_tolerance >= current_tolerance {
421 continue; // Skip if tolerance doesn't decrease
422 }
423
424 let mut new_particles = Vec::new();
425
426 while new_particles.len() < config.particles_per_round {
427 // Sample a particle to perturb
428 let base_idx = rng.gen_range(0..current_particles.len());
429 let base_trace = ¤t_particles[base_idx];
430
431 // Simple perturbation: resample one site
432 let mut perturbed_trace = base_trace.clone();
433 if !perturbed_trace.choices.is_empty() {
434 let sites: Vec<_> = perturbed_trace.choices.keys().cloned().collect();
435 let site_idx = rng.gen_range(0..sites.len());
436 let selected_site = &sites[site_idx];
437
438 // Resample this site from prior (simple perturbation)
439 let (_a, fresh_trace) = run(
440 PriorHandler {
441 rng,
442 trace: Trace::default(),
443 },
444 model_fn(),
445 );
446
447 if let Some(fresh_choice) = fresh_trace.choices.get(selected_site) {
448 perturbed_trace
449 .choices
450 .insert(selected_site.clone(), fresh_choice.clone());
451 }
452 }
453
454 // Check if perturbed trace meets new tolerance
455 let simulated_data = simulator(&perturbed_trace);
456 let dist = distance_fn.distance(observed_data, &simulated_data);
457
458 if dist <= new_tolerance {
459 new_particles.push(perturbed_trace);
460 }
461 }
462
463 current_particles = new_particles;
464 current_tolerance = new_tolerance;
465
466 println!(
467 "ABC SMC: tolerance = {:.4}, accepted = {}",
468 current_tolerance,
469 current_particles.len()
470 );
471 }
472
473 current_particles
474}
475
476/// ABC rejection sampling using scalar summary statistics.
477///
478/// A convenience function for ABC when both observed and simulated data can be
479/// reduced to scalar summary statistics. This is often more efficient than
480/// comparing full datasets and can focus inference on specific aspects of the data.
481///
482/// This function is equivalent to `abc_rejection` but operates on scalar summaries
483/// instead of vector data, making it easier to use for simple cases.
484///
485/// # Arguments
486///
487/// * `rng` - Random number generator
488/// * `model_fn` - Function that creates a model instance
489/// * `simulator` - Function that computes a scalar summary from a trace
490/// * `observed_summary` - Scalar summary of the observed data
491/// * `tolerance` - Maximum allowed absolute difference for acceptance
492/// * `max_samples` - Maximum number of samples to accept
493///
494/// # Returns
495///
496/// Vector of accepted traces that produced summaries within tolerance.
497///
498/// # Examples
499///
500/// ```rust
501/// use fugue::*;
502/// use rand::rngs::StdRng;
503/// use rand::SeedableRng;
504///
505/// // ABC for estimating mean when we only observe sample mean
506/// let observed_mean = 2.0;
507/// let mut rng = StdRng::seed_from_u64(42);
508///
509/// let samples = abc_scalar_summary(
510/// &mut rng,
511/// || sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap()),
512/// |trace| {
513/// // Extract mu parameter and return it as summary
514/// if let Some(choice) = trace.choices.get(&addr!("mu")) {
515/// if let ChoiceValue::F64(mu) = choice.value {
516/// mu // The summary statistic is just the parameter
517/// } else { 0.0 }
518/// } else { 0.0 }
519/// },
520/// observed_mean,
521/// 0.5, // tolerance (larger for easier acceptance)
522/// 5, // max samples (small for test)
523/// );
524/// assert!(!samples.is_empty());
525/// ```
526pub fn abc_scalar_summary<A, R: Rng>(
527 rng: &mut R,
528 model_fn: impl Fn() -> Model<A>,
529 simulator: impl Fn(&Trace) -> f64,
530 observed_summary: f64,
531 tolerance: f64,
532 max_samples: usize,
533) -> Vec<Trace> {
534 abc_rejection(
535 rng,
536 model_fn,
537 |trace| vec![simulator(trace)],
538 &vec![observed_summary],
539 &EuclideanDistance,
540 tolerance,
541 max_samples,
542 )
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548 use crate::addr;
549 use crate::core::distribution::*;
550 use crate::core::model::sample;
551
552 use rand::rngs::StdRng;
553 use rand::SeedableRng;
554
555 #[test]
556 fn distance_functions_work() {
557 let eu = EuclideanDistance;
558 let man = ManhattanDistance;
559 let a = vec![1.0, 2.0, 3.0];
560 let b = vec![1.1, 2.1, 2.9];
561 let d_eu = eu.distance(&a, &b);
562 let d_man = man.distance(&a, &b);
563 assert!(d_eu > 0.0);
564 assert!(d_man > 0.0);
565 // Euclidean should be <= Manhattan for same vectors
566 assert!(d_eu <= d_man + 1e-12);
567 }
568
569 #[test]
570 fn abc_scalar_summary_accepts_with_large_tolerance() {
571 let mut rng = StdRng::seed_from_u64(42);
572 let samples = abc_scalar_summary(
573 &mut rng,
574 || sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap()),
575 |trace| trace.get_f64(&addr!("mu")).unwrap_or(0.0),
576 0.0, // observed summary
577 10.0, // large tolerance to ensure acceptance
578 3,
579 );
580 assert!(!samples.is_empty());
581 }
582
583 #[test]
584 fn abc_rejection_can_return_empty_with_tight_tolerance() {
585 let mut rng = StdRng::seed_from_u64(43);
586 let observed = vec![1000.0]; // far from prior mean 0
587 let res = abc_rejection(
588 &mut rng,
589 || sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()),
590 |trace| vec![trace.get_f64(&addr!("mu")).unwrap_or(0.0)],
591 &observed,
592 &EuclideanDistance,
593 1e-6, // extremely tight
594 3,
595 );
596 assert!(res.is_empty());
597 }
598
599 #[test]
600 fn abc_smc_respects_tolerance_schedule() {
601 let mut rng = StdRng::seed_from_u64(44);
602 let observed = vec![0.0];
603 let config = ABCSMCConfig {
604 initial_tolerance: 2.0,
605 tolerance_schedule: vec![1.0, 0.5],
606 particles_per_round: 4,
607 };
608 let res = abc_smc(
609 &mut rng,
610 || sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()),
611 |trace| vec![trace.get_f64(&addr!("mu")).unwrap_or(0.0)],
612 &observed,
613 &EuclideanDistance,
614 config,
615 );
616 assert_eq!(res.len(), 4);
617 }
618}