fugue/inference/
abc.rs

1//! Approximate Bayesian Computation (ABC) - likelihood-free inference methods.
2//!
3//! ABC methods enable Bayesian inference for models where the likelihood function
4//! is intractable or unavailable, but forward simulation from the model is possible.
5//! Instead of computing likelihoods directly, ABC compares simulated data to observed
6//! data using distance functions and accepts samples that produce "similar" outcomes.
7//!
8//! ## Method Overview
9//!
10//! ABC algorithms follow this general pattern:
11//! 1. Sample parameters from the prior distribution
12//! 2. Simulate data using the model with those parameters
13//! 3. Compare simulated data to observed data using a distance function
14//! 4. Accept samples where the distance is below a threshold ε
15//!
16//! As ε → 0, the ABC posterior approaches the true posterior distribution.
17//!
18//! ## Available Methods
19//!
20//! - [`abc_rejection`]: Basic rejection ABC
21//! - [`abc_smc`]: Sequential Monte Carlo ABC for improved efficiency
22//! - [`abc_scalar_summary`]: ABC with scalar summary statistics
23//!
24//! ## Distance Functions
25//!
26//! The quality of ABC inference depends heavily on the choice of distance function:
27//! - [`EuclideanDistance`]: L2 norm for continuous data vectors
28//! - [`ManhattanDistance`]: L1 norm for robust distance computation
29//! - Custom distance functions via the [`DistanceFunction`] trait
30//!
31//! # Examples
32//!
33//! ```rust
34//! use fugue::*;
35//! use rand::rngs::StdRng;
36//! use rand::SeedableRng;
37//! use rand::Rng;
38//!
39//! // Simple ABC example for illustration
40//! let mut rng = StdRng::seed_from_u64(42);
41//! let observed_data = vec![2.0];
42//!
43//! let samples = abc_scalar_summary(
44//!     &mut rng,
45//!     || sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap()),
46//!     |trace| {
47//!         if let Some(choice) = trace.choices.get(&addr!("mu")) {
48//!             if let ChoiceValue::F64(mu) = choice.value {
49//!                 mu
50//!             } else { 0.0 }
51//!         } else { 0.0 }
52//!     },
53//!     2.0, // observed summary
54//!     0.5, // tolerance
55//!     10   // max samples
56//! );
57//!
58//! assert!(!samples.is_empty());
59//! ```
60
61use crate::core::model::Model;
62use crate::runtime::handler::run;
63use crate::runtime::interpreters::PriorHandler;
64use crate::runtime::trace::Trace;
65use rand::Rng;
66
67/// Trait for computing distances between observed and simulated data.
68///
69/// Distance functions are crucial for ABC methods as they determine how
70/// "similarity" between datasets is measured. The choice of distance function
71/// significantly affects the quality of ABC approximations.
72///
73/// # Type Parameter
74///
75/// * `T` - Type of data being compared (e.g., `Vec<f64>`, scalar values)
76///
77/// # Examples
78///
79/// ```rust
80/// use fugue::*;
81///
82/// // Use built-in Euclidean distance
83/// let euclidean = EuclideanDistance;
84/// let dist = euclidean.distance(&vec![1.0, 2.0], &vec![1.1, 2.1]);
85///
86/// // Implement custom distance function
87/// struct ScalarDistance;
88/// impl DistanceFunction<f64> for ScalarDistance {
89///     fn distance(&self, observed: &f64, simulated: &f64) -> f64 {
90///         (observed - simulated).abs()
91///     }
92/// }
93/// ```
94pub trait DistanceFunction<T> {
95    /// Compute the distance between observed and simulated data.
96    ///
97    /// # Arguments
98    ///
99    /// * `observed` - The actual observed data
100    /// * `simulated` - Data simulated from the model
101    ///
102    /// # Returns
103    ///
104    /// A non-negative distance value. Smaller values indicate greater similarity.
105    fn distance(&self, observed: &T, simulated: &T) -> f64;
106}
107
108/// Euclidean (L2) distance function for vector data.
109///
110/// Computes the standard Euclidean distance between two vectors:
111/// √(Σ(xᵢ - yᵢ)²)
112///
113/// This is appropriate for continuous data where the magnitude of differences
114/// matters and the data dimensions have similar scales.
115///
116/// # Examples
117///
118/// ```rust
119/// use fugue::*;
120///
121/// let euclidean = EuclideanDistance;
122/// let observed = vec![1.0, 2.0, 3.0];
123/// let simulated = vec![1.1, 2.1, 2.9];
124/// let distance = euclidean.distance(&observed, &simulated);
125/// assert!((distance - 0.173).abs() < 0.01); // ≈ 0.173
126/// ```
127pub struct EuclideanDistance;
128
129impl DistanceFunction<Vec<f64>> for EuclideanDistance {
130    fn distance(&self, observed: &Vec<f64>, simulated: &Vec<f64>) -> f64 {
131        if observed.len() != simulated.len() {
132            return f64::INFINITY;
133        }
134
135        observed
136            .iter()
137            .zip(simulated.iter())
138            .map(|(&o, &s)| (o - s).powi(2))
139            .sum::<f64>()
140            .sqrt()
141    }
142}
143
144/// Manhattan (L1) distance function for vector data.
145///
146/// Computes the Manhattan distance between two vectors:
147/// Σ|xᵢ - yᵢ|
148///
149/// This distance is more robust to outliers than Euclidean distance and is
150/// appropriate when you want to treat each dimension independently.
151///
152/// # Examples
153///
154/// ```rust
155/// use fugue::inference::abc::{ManhattanDistance, DistanceFunction};
156///
157/// let manhattan = ManhattanDistance;
158/// let observed = vec![1.0, 2.0, 3.0];
159/// let simulated = vec![1.5, 1.5, 3.5];
160/// let distance = manhattan.distance(&observed, &simulated);
161/// assert!((distance - 1.5).abs() < 0.001); // |1.0-1.5| + |2.0-1.5| + |3.0-3.5| = 0.5 + 0.5 + 0.5 = 1.5
162/// ```
163pub struct ManhattanDistance;
164
165impl DistanceFunction<Vec<f64>> for ManhattanDistance {
166    fn distance(&self, observed: &Vec<f64>, simulated: &Vec<f64>) -> f64 {
167        if observed.len() != simulated.len() {
168            return f64::INFINITY;
169        }
170
171        observed
172            .iter()
173            .zip(simulated.iter())
174            .map(|(&o, &s)| (o - s).abs())
175            .sum::<f64>()
176    }
177}
178
179/// Summary statistics distance.
180pub struct SummaryStatsDistance {
181    pub weights: Vec<f64>,
182}
183
184impl SummaryStatsDistance {
185    pub fn new(weights: Vec<f64>) -> Self {
186        Self { weights }
187    }
188
189    fn compute_stats(data: &[f64]) -> Vec<f64> {
190        if data.is_empty() {
191            return vec![0.0, 0.0, 0.0];
192        }
193
194        let mean = data.iter().sum::<f64>() / data.len() as f64;
195        let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / data.len() as f64;
196        let std = variance.sqrt();
197
198        let mut sorted = data.to_vec();
199        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
200        let median = if sorted.len() % 2 == 0 {
201            (sorted[sorted.len() / 2 - 1] + sorted[sorted.len() / 2]) / 2.0
202        } else {
203            sorted[sorted.len() / 2]
204        };
205
206        vec![mean, std, median]
207    }
208}
209
210impl DistanceFunction<Vec<f64>> for SummaryStatsDistance {
211    fn distance(&self, observed: &Vec<f64>, simulated: &Vec<f64>) -> f64 {
212        let obs_stats = Self::compute_stats(observed);
213        let sim_stats = Self::compute_stats(simulated);
214
215        obs_stats
216            .iter()
217            .zip(sim_stats.iter())
218            .zip(&self.weights)
219            .map(|((&o, &s), &w)| w * (o - s).powi(2))
220            .sum::<f64>()
221            .sqrt()
222    }
223}
224
225/// Basic ABC rejection sampling algorithm.
226///
227/// The simplest ABC method: repeatedly sample from the prior, simulate data,
228/// and accept samples where the distance to observed data is below a tolerance.
229/// This method is straightforward but can be inefficient for small tolerances.
230///
231/// # Algorithm
232///
233/// 1. Sample parameters from the prior using `model_fn()`
234/// 2. Simulate data using `simulator(trace)`
235/// 3. Compute distance between simulated and observed data
236/// 4. Accept if distance ≤ tolerance
237/// 5. Repeat until `max_samples` accepted or too many attempts
238///
239/// # Arguments
240///
241/// * `rng` - Random number generator
242/// * `model_fn` - Function that creates a model instance (contains priors)
243/// * `simulator` - Function that simulates data given a trace of parameter values
244/// * `observed_data` - The actual observed data to match
245/// * `distance_fn` - Function for measuring similarity between datasets
246/// * `tolerance` - Maximum allowed distance for acceptance
247/// * `max_samples` - Maximum number of samples to accept
248///
249/// # Returns
250///
251/// Vector of accepted traces (parameter samples that produced similar data).
252///
253/// # Examples
254///
255/// ```rust
256/// use fugue::*;
257/// use rand::rngs::StdRng;
258/// use rand::SeedableRng;
259///
260/// // Simple ABC rejection example
261/// let mut rng = StdRng::seed_from_u64(42);
262/// let observed_data = vec![2.0];
263///
264/// let samples = abc_scalar_summary(
265///     &mut rng,
266///     || sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap()),
267///     |trace| {
268///         if let Some(choice) = trace.choices.get(&addr!("mu")) {
269///             if let ChoiceValue::F64(mu) = choice.value {
270///                 mu
271///             } else { 0.0 }
272///         } else { 0.0 }
273///     },
274///     2.0, // observed summary
275///     0.5, // tolerance
276///     5    // max samples (small for test)
277/// );
278/// assert!(!samples.is_empty());
279/// ```
280pub fn abc_rejection<A, T, R: Rng>(
281    rng: &mut R,
282    model_fn: impl Fn() -> Model<A>,
283    simulator: impl Fn(&Trace) -> T,
284    observed_data: &T,
285    distance_fn: &dyn DistanceFunction<T>,
286    tolerance: f64,
287    max_samples: usize,
288) -> Vec<Trace> {
289    let mut accepted = Vec::new();
290    let mut attempts = 0;
291
292    while accepted.len() < max_samples && attempts < max_samples * 100 {
293        // Sample from prior
294        let (_a, trace) = run(
295            PriorHandler {
296                rng,
297                trace: Trace::default(),
298            },
299            model_fn(),
300        );
301
302        // Simulate data
303        let simulated_data = simulator(&trace);
304
305        // Check distance
306        let dist = distance_fn.distance(observed_data, &simulated_data);
307
308        if dist <= tolerance {
309            accepted.push(trace);
310        }
311
312        attempts += 1;
313    }
314
315    if accepted.is_empty() {
316        eprintln!(
317            "Warning: No samples accepted in ABC. Consider increasing tolerance or max_samples."
318        );
319    }
320
321    accepted
322}
323
324/// Sequential Monte Carlo ABC with adaptive tolerance scheduling.
325///
326/// An advanced ABC method that uses Sequential Monte Carlo to iteratively
327/// reduce the tolerance, leading to better approximations of the posterior.
328/// SMC-ABC is more efficient than rejection ABC for stringent tolerances.
329///
330/// # Algorithm
331///
332/// 1. Start with initial tolerance and generate particles using rejection ABC
333/// 2. For each subsequent tolerance level:
334///    - Resample particles from the previous population
335///    - Perturb parameters using MCMC moves
336///    - Re-simulate and check new tolerance
337/// 3. Final particles approximate the posterior at the strictest tolerance
338///
339/// # Arguments
340///
341/// * `rng` - Random number generator
342/// * `model_fn` - Function that creates a model instance
343/// * `simulator` - Function that simulates data given a trace
344/// * `observed_data` - The observed data to match
345/// * `distance_fn` - Distance function for comparing datasets
346/// * `initial_tolerance` - Starting tolerance (should be relatively large)
347/// * `tolerance_schedule` - Decreasing sequence of tolerances to use
348/// * `particles_per_round` - Number of particles to maintain in each round
349///
350/// # Returns
351///
352/// Vector of traces from the final SMC population.
353///
354/// # Examples
355///
356/// ```rust
357/// use fugue::{inference::abc::ABCSMCConfig, *};
358/// use rand::rngs::StdRng;
359/// use rand::SeedableRng;
360///
361/// // Simple SMC-ABC example with small numbers for testing
362/// let observed = vec![2.0];
363/// let mut rng = StdRng::seed_from_u64(42);
364///
365/// let samples = abc_smc(
366///     &mut rng,
367///     || sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()),
368///     |trace| {
369///         if let Some(choice) = trace.choices.get(&addr!("mu")) {
370///             if let ChoiceValue::F64(mu) = choice.value {
371///                 vec![mu]
372///             } else { vec![0.0] }
373///         } else { vec![0.0] }
374///     },
375///     &observed,
376///     &EuclideanDistance,
377///     ABCSMCConfig {
378///         initial_tolerance: 1.0,
379///         tolerance_schedule: vec![0.5],
380///         particles_per_round: 5,
381///     },
382/// );
383/// assert!(!samples.is_empty());
384/// ```
385/// Configuration for ABC-SMC algorithm.
386#[derive(Debug, Clone)]
387pub struct ABCSMCConfig {
388    /// Initial tolerance for distance threshold
389    pub initial_tolerance: f64,
390    /// Schedule of decreasing tolerances across rounds  
391    pub tolerance_schedule: Vec<f64>,
392    /// Number of particles to generate per round
393    pub particles_per_round: usize,
394}
395
396pub fn abc_smc<A, T, R: Rng>(
397    rng: &mut R,
398    model_fn: impl Fn() -> Model<A>,
399    simulator: impl Fn(&Trace) -> T,
400    observed_data: &T,
401    distance_fn: &dyn DistanceFunction<T>,
402    config: ABCSMCConfig,
403) -> Vec<Trace> {
404    let mut current_particles;
405    let mut current_tolerance = config.initial_tolerance;
406
407    // Initial round: ABC rejection
408    current_particles = abc_rejection(
409        rng,
410        &model_fn,
411        &simulator,
412        observed_data,
413        distance_fn,
414        current_tolerance,
415        config.particles_per_round,
416    );
417
418    // Sequential rounds with decreasing tolerance
419    for &new_tolerance in &config.tolerance_schedule {
420        if new_tolerance >= current_tolerance {
421            continue; // Skip if tolerance doesn't decrease
422        }
423
424        let mut new_particles = Vec::new();
425
426        while new_particles.len() < config.particles_per_round {
427            // Sample a particle to perturb
428            let base_idx = rng.gen_range(0..current_particles.len());
429            let base_trace = &current_particles[base_idx];
430
431            // Simple perturbation: resample one site
432            let mut perturbed_trace = base_trace.clone();
433            if !perturbed_trace.choices.is_empty() {
434                let sites: Vec<_> = perturbed_trace.choices.keys().cloned().collect();
435                let site_idx = rng.gen_range(0..sites.len());
436                let selected_site = &sites[site_idx];
437
438                // Resample this site from prior (simple perturbation)
439                let (_a, fresh_trace) = run(
440                    PriorHandler {
441                        rng,
442                        trace: Trace::default(),
443                    },
444                    model_fn(),
445                );
446
447                if let Some(fresh_choice) = fresh_trace.choices.get(selected_site) {
448                    perturbed_trace
449                        .choices
450                        .insert(selected_site.clone(), fresh_choice.clone());
451                }
452            }
453
454            // Check if perturbed trace meets new tolerance
455            let simulated_data = simulator(&perturbed_trace);
456            let dist = distance_fn.distance(observed_data, &simulated_data);
457
458            if dist <= new_tolerance {
459                new_particles.push(perturbed_trace);
460            }
461        }
462
463        current_particles = new_particles;
464        current_tolerance = new_tolerance;
465
466        println!(
467            "ABC SMC: tolerance = {:.4}, accepted = {}",
468            current_tolerance,
469            current_particles.len()
470        );
471    }
472
473    current_particles
474}
475
476/// ABC rejection sampling using scalar summary statistics.
477///
478/// A convenience function for ABC when both observed and simulated data can be
479/// reduced to scalar summary statistics. This is often more efficient than
480/// comparing full datasets and can focus inference on specific aspects of the data.
481///
482/// This function is equivalent to `abc_rejection` but operates on scalar summaries
483/// instead of vector data, making it easier to use for simple cases.
484///
485/// # Arguments
486///
487/// * `rng` - Random number generator
488/// * `model_fn` - Function that creates a model instance
489/// * `simulator` - Function that computes a scalar summary from a trace
490/// * `observed_summary` - Scalar summary of the observed data
491/// * `tolerance` - Maximum allowed absolute difference for acceptance
492/// * `max_samples` - Maximum number of samples to accept
493///
494/// # Returns
495///
496/// Vector of accepted traces that produced summaries within tolerance.
497///
498/// # Examples
499///
500/// ```rust
501/// use fugue::*;
502/// use rand::rngs::StdRng;
503/// use rand::SeedableRng;
504///
505/// // ABC for estimating mean when we only observe sample mean
506/// let observed_mean = 2.0;
507/// let mut rng = StdRng::seed_from_u64(42);
508///
509/// let samples = abc_scalar_summary(
510///     &mut rng,
511///     || sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap()),
512///     |trace| {
513///         // Extract mu parameter and return it as summary
514///         if let Some(choice) = trace.choices.get(&addr!("mu")) {
515///             if let ChoiceValue::F64(mu) = choice.value {
516///                 mu // The summary statistic is just the parameter
517///             } else { 0.0 }
518///         } else { 0.0 }
519///     },
520///     observed_mean,
521///     0.5, // tolerance (larger for easier acceptance)
522///     5,   // max samples (small for test)
523/// );
524/// assert!(!samples.is_empty());
525/// ```
526pub fn abc_scalar_summary<A, R: Rng>(
527    rng: &mut R,
528    model_fn: impl Fn() -> Model<A>,
529    simulator: impl Fn(&Trace) -> f64,
530    observed_summary: f64,
531    tolerance: f64,
532    max_samples: usize,
533) -> Vec<Trace> {
534    abc_rejection(
535        rng,
536        model_fn,
537        |trace| vec![simulator(trace)],
538        &vec![observed_summary],
539        &EuclideanDistance,
540        tolerance,
541        max_samples,
542    )
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548    use crate::addr;
549    use crate::core::distribution::*;
550    use crate::core::model::sample;
551
552    use rand::rngs::StdRng;
553    use rand::SeedableRng;
554
555    #[test]
556    fn distance_functions_work() {
557        let eu = EuclideanDistance;
558        let man = ManhattanDistance;
559        let a = vec![1.0, 2.0, 3.0];
560        let b = vec![1.1, 2.1, 2.9];
561        let d_eu = eu.distance(&a, &b);
562        let d_man = man.distance(&a, &b);
563        assert!(d_eu > 0.0);
564        assert!(d_man > 0.0);
565        // Euclidean should be <= Manhattan for same vectors
566        assert!(d_eu <= d_man + 1e-12);
567    }
568
569    #[test]
570    fn abc_scalar_summary_accepts_with_large_tolerance() {
571        let mut rng = StdRng::seed_from_u64(42);
572        let samples = abc_scalar_summary(
573            &mut rng,
574            || sample(addr!("mu"), Normal::new(0.0, 2.0).unwrap()),
575            |trace| trace.get_f64(&addr!("mu")).unwrap_or(0.0),
576            0.0,  // observed summary
577            10.0, // large tolerance to ensure acceptance
578            3,
579        );
580        assert!(!samples.is_empty());
581    }
582
583    #[test]
584    fn abc_rejection_can_return_empty_with_tight_tolerance() {
585        let mut rng = StdRng::seed_from_u64(43);
586        let observed = vec![1000.0]; // far from prior mean 0
587        let res = abc_rejection(
588            &mut rng,
589            || sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()),
590            |trace| vec![trace.get_f64(&addr!("mu")).unwrap_or(0.0)],
591            &observed,
592            &EuclideanDistance,
593            1e-6, // extremely tight
594            3,
595        );
596        assert!(res.is_empty());
597    }
598
599    #[test]
600    fn abc_smc_respects_tolerance_schedule() {
601        let mut rng = StdRng::seed_from_u64(44);
602        let observed = vec![0.0];
603        let config = ABCSMCConfig {
604            initial_tolerance: 2.0,
605            tolerance_schedule: vec![1.0, 0.5],
606            particles_per_round: 4,
607        };
608        let res = abc_smc(
609            &mut rng,
610            || sample(addr!("mu"), Normal::new(0.0, 1.0).unwrap()),
611            |trace| vec![trace.get_f64(&addr!("mu")).unwrap_or(0.0)],
612            &observed,
613            &EuclideanDistance,
614            config,
615        );
616        assert_eq!(res.len(), 4);
617    }
618}