fugue/runtime/
handler.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/runtime/handler.md"))]
2
3use crate::core::address::Address;
4use crate::core::distribution::Distribution;
5use crate::core::model::Model;
6use crate::runtime::trace::Trace;
7
8/// Core trait for interpreting probabilistic model effects.
9///
10/// Handlers define how to interpret the three fundamental effects in probabilistic programming:
11/// sampling, observation, and factoring. Different implementations enable different execution modes.
12///
13/// Example:
14/// ```rust
15/// # use fugue::*;
16/// # use fugue::runtime::interpreters::PriorHandler;
17/// # use rand::rngs::StdRng;
18/// # use rand::SeedableRng;
19///
20/// // Use a built-in handler
21/// let mut rng = StdRng::seed_from_u64(42);
22/// let handler = PriorHandler {
23///     rng: &mut rng,
24///     trace: Trace::default()
25/// };
26/// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
27/// let (result, trace) = runtime::handler::run(handler, model);
28/// ```
29pub trait Handler {
30    /// Handle an f64 sampling operation (continuous distributions).
31    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64;
32
33    /// Handle a bool sampling operation (Bernoulli).
34    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool;
35
36    /// Handle a u64 sampling operation (Poisson, Binomial).
37    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64;
38
39    /// Handle a usize sampling operation (Categorical).
40    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize;
41
42    /// Handle an f64 observation operation.
43    fn on_observe_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>, value: f64);
44
45    /// Handle a bool observation operation.
46    fn on_observe_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>, value: bool);
47
48    /// Handle a u64 observation operation.
49    fn on_observe_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>, value: u64);
50
51    /// Handle a usize observation operation.
52    fn on_observe_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>, value: usize);
53
54    /// Handle a factor operation.
55    ///
56    /// This method is called when the model encounters a `factor` operation.
57    /// The handler typically adds the log-weight to the trace.
58    ///
59    /// # Arguments
60    ///
61    /// * `logw` - Log-weight to add to the model's total weight
62    fn on_factor(&mut self, logw: f64);
63
64    /// Finalize the handler and return the accumulated trace.
65    ///
66    /// This method is called after model execution completes to retrieve
67    /// the final trace containing all choices and log-weights.
68    fn finish(self) -> Trace
69    where
70        Self: Sized;
71}
72
73/// Execute a probabilistic model using the provided handler.
74///
75/// This is the core execution engine for probabilistic models. It walks through
76/// the model structure and dispatches effects to the handler, returning both
77/// the model's final result and the accumulated execution trace.
78///
79/// Example:
80/// ```rust
81/// # use fugue::*;
82/// # use fugue::runtime::interpreters::PriorHandler;
83/// # use rand::rngs::StdRng;
84/// # use rand::SeedableRng;
85///
86/// // Create a simple model
87/// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
88///     .bind(|x| observe(addr!("y"), Normal::new(x, 0.1).unwrap(), 1.2))
89///     .map(|_| "completed");
90///
91/// let mut rng = StdRng::seed_from_u64(123);
92/// let (result, trace) = runtime::handler::run(
93///     PriorHandler { rng: &mut rng, trace: Trace::default() },
94///     model
95/// );
96/// assert_eq!(result, "completed");
97/// assert!(trace.total_log_weight().is_finite());
98/// ```
99pub fn run<A>(mut h: impl Handler, m: Model<A>) -> (A, Trace) {
100    fn go<A>(h: &mut impl Handler, m: Model<A>) -> A {
101        match m {
102            Model::Pure(a) => a,
103            Model::SampleF64 { addr, dist, k } => {
104                let x = h.on_sample_f64(&addr, &*dist);
105                go(h, k(x))
106            }
107            Model::SampleBool { addr, dist, k } => {
108                let x = h.on_sample_bool(&addr, &*dist);
109                go(h, k(x))
110            }
111            Model::SampleU64 { addr, dist, k } => {
112                let x = h.on_sample_u64(&addr, &*dist);
113                go(h, k(x))
114            }
115            Model::SampleUsize { addr, dist, k } => {
116                let x = h.on_sample_usize(&addr, &*dist);
117                go(h, k(x))
118            }
119            Model::ObserveF64 {
120                addr,
121                dist,
122                value,
123                k,
124            } => {
125                h.on_observe_f64(&addr, &*dist, value);
126                go(h, k(()))
127            }
128            Model::ObserveBool {
129                addr,
130                dist,
131                value,
132                k,
133            } => {
134                h.on_observe_bool(&addr, &*dist, value);
135                go(h, k(()))
136            }
137            Model::ObserveU64 {
138                addr,
139                dist,
140                value,
141                k,
142            } => {
143                h.on_observe_u64(&addr, &*dist, value);
144                go(h, k(()))
145            }
146            Model::ObserveUsize {
147                addr,
148                dist,
149                value,
150                k,
151            } => {
152                h.on_observe_usize(&addr, &*dist, value);
153                go(h, k(()))
154            }
155            Model::Factor { logw, k } => {
156                h.on_factor(logw);
157                go(h, k(()))
158            }
159        }
160    }
161    let a = go(&mut h, m);
162    let t = h.finish();
163    (a, t)
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use crate::addr;
170    use crate::core::distribution::*;
171    use crate::core::model::ModelExt;
172    use crate::runtime::interpreters::PriorHandler;
173    use rand::rngs::StdRng;
174    use rand::SeedableRng;
175
176    #[test]
177    fn run_accumulates_logs_for_sample_observe_factor() {
178        // Model: sample x ~ Normal(0,1); observe y ~ Normal(x,1) with value 0.5; factor(-1.0)
179        let model = crate::core::model::sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
180            .and_then(|x| {
181                crate::core::model::observe(addr!("y"), Normal::new(x, 1.0).unwrap(), 0.5)
182            })
183            .and_then(|_| crate::core::model::factor(-1.0));
184
185        let mut rng = StdRng::seed_from_u64(123);
186        let (_a, trace) = crate::runtime::handler::run(
187            PriorHandler {
188                rng: &mut rng,
189                trace: Trace::default(),
190            },
191            model,
192        );
193
194        // Should have a sample recorded and finite prior
195        assert!(trace.choices.contains_key(&addr!("x")));
196        assert!(trace.log_prior.is_finite());
197        // Observation contributes to likelihood
198        assert!(trace.log_likelihood.is_finite());
199        // Factor contributes exact -1.0
200        assert!((trace.log_factors + 1.0).abs() < 1e-12);
201    }
202}