fugue/runtime/handler.rs
1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/runtime/handler.md"))]
2
3use crate::core::address::Address;
4use crate::core::distribution::Distribution;
5use crate::core::model::Model;
6use crate::runtime::trace::Trace;
7
8/// Core trait for interpreting probabilistic model effects.
9///
10/// Handlers define how to interpret the three fundamental effects in probabilistic programming:
11/// sampling, observation, and factoring. Different implementations enable different execution modes.
12///
13/// Example:
14/// ```rust
15/// # use fugue::*;
16/// # use fugue::runtime::interpreters::PriorHandler;
17/// # use rand::rngs::StdRng;
18/// # use rand::SeedableRng;
19///
20/// // Use a built-in handler
21/// let mut rng = StdRng::seed_from_u64(42);
22/// let handler = PriorHandler {
23/// rng: &mut rng,
24/// trace: Trace::default()
25/// };
26/// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
27/// let (result, trace) = runtime::handler::run(handler, model);
28/// ```
29pub trait Handler {
30 /// Handle an f64 sampling operation (continuous distributions).
31 fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64;
32
33 /// Handle a bool sampling operation (Bernoulli).
34 fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool;
35
36 /// Handle a u64 sampling operation (Poisson, Binomial).
37 fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64;
38
39 /// Handle a usize sampling operation (Categorical).
40 fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize;
41
42 /// Handle an f64 observation operation.
43 fn on_observe_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>, value: f64);
44
45 /// Handle a bool observation operation.
46 fn on_observe_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>, value: bool);
47
48 /// Handle a u64 observation operation.
49 fn on_observe_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>, value: u64);
50
51 /// Handle a usize observation operation.
52 fn on_observe_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>, value: usize);
53
54 /// Handle a factor operation.
55 ///
56 /// This method is called when the model encounters a `factor` operation.
57 /// The handler typically adds the log-weight to the trace.
58 ///
59 /// # Arguments
60 ///
61 /// * `logw` - Log-weight to add to the model's total weight
62 fn on_factor(&mut self, logw: f64);
63
64 /// Finalize the handler and return the accumulated trace.
65 ///
66 /// This method is called after model execution completes to retrieve
67 /// the final trace containing all choices and log-weights.
68 fn finish(self) -> Trace
69 where
70 Self: Sized;
71}
72
73/// Execute a probabilistic model using the provided handler.
74///
75/// This is the core execution engine for probabilistic models. It walks through
76/// the model structure and dispatches effects to the handler, returning both
77/// the model's final result and the accumulated execution trace.
78///
79/// Example:
80/// ```rust
81/// # use fugue::*;
82/// # use fugue::runtime::interpreters::PriorHandler;
83/// # use rand::rngs::StdRng;
84/// # use rand::SeedableRng;
85///
86/// // Create a simple model
87/// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
88/// .bind(|x| observe(addr!("y"), Normal::new(x, 0.1).unwrap(), 1.2))
89/// .map(|_| "completed");
90///
91/// let mut rng = StdRng::seed_from_u64(123);
92/// let (result, trace) = runtime::handler::run(
93/// PriorHandler { rng: &mut rng, trace: Trace::default() },
94/// model
95/// );
96/// assert_eq!(result, "completed");
97/// assert!(trace.total_log_weight().is_finite());
98/// ```
99pub fn run<A>(mut h: impl Handler, m: Model<A>) -> (A, Trace) {
100 fn go<A>(h: &mut impl Handler, m: Model<A>) -> A {
101 match m {
102 Model::Pure(a) => a,
103 Model::SampleF64 { addr, dist, k } => {
104 let x = h.on_sample_f64(&addr, &*dist);
105 go(h, k(x))
106 }
107 Model::SampleBool { addr, dist, k } => {
108 let x = h.on_sample_bool(&addr, &*dist);
109 go(h, k(x))
110 }
111 Model::SampleU64 { addr, dist, k } => {
112 let x = h.on_sample_u64(&addr, &*dist);
113 go(h, k(x))
114 }
115 Model::SampleUsize { addr, dist, k } => {
116 let x = h.on_sample_usize(&addr, &*dist);
117 go(h, k(x))
118 }
119 Model::ObserveF64 {
120 addr,
121 dist,
122 value,
123 k,
124 } => {
125 h.on_observe_f64(&addr, &*dist, value);
126 go(h, k(()))
127 }
128 Model::ObserveBool {
129 addr,
130 dist,
131 value,
132 k,
133 } => {
134 h.on_observe_bool(&addr, &*dist, value);
135 go(h, k(()))
136 }
137 Model::ObserveU64 {
138 addr,
139 dist,
140 value,
141 k,
142 } => {
143 h.on_observe_u64(&addr, &*dist, value);
144 go(h, k(()))
145 }
146 Model::ObserveUsize {
147 addr,
148 dist,
149 value,
150 k,
151 } => {
152 h.on_observe_usize(&addr, &*dist, value);
153 go(h, k(()))
154 }
155 Model::Factor { logw, k } => {
156 h.on_factor(logw);
157 go(h, k(()))
158 }
159 }
160 }
161 let a = go(&mut h, m);
162 let t = h.finish();
163 (a, t)
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::addr;
170 use crate::core::distribution::*;
171 use crate::core::model::ModelExt;
172 use crate::runtime::interpreters::PriorHandler;
173 use rand::rngs::StdRng;
174 use rand::SeedableRng;
175
176 #[test]
177 fn run_accumulates_logs_for_sample_observe_factor() {
178 // Model: sample x ~ Normal(0,1); observe y ~ Normal(x,1) with value 0.5; factor(-1.0)
179 let model = crate::core::model::sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
180 .and_then(|x| {
181 crate::core::model::observe(addr!("y"), Normal::new(x, 1.0).unwrap(), 0.5)
182 })
183 .and_then(|_| crate::core::model::factor(-1.0));
184
185 let mut rng = StdRng::seed_from_u64(123);
186 let (_a, trace) = crate::runtime::handler::run(
187 PriorHandler {
188 rng: &mut rng,
189 trace: Trace::default(),
190 },
191 model,
192 );
193
194 // Should have a sample recorded and finite prior
195 assert!(trace.choices.contains_key(&addr!("x")));
196 assert!(trace.log_prior.is_finite());
197 // Observation contributes to likelihood
198 assert!(trace.log_likelihood.is_finite());
199 // Factor contributes exact -1.0
200 assert!((trace.log_factors + 1.0).abs() < 1e-12);
201 }
202}