fugue/runtime/
interpreters.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/runtime/interpreters.md"))]
2
3use crate::core::address::Address;
4use crate::core::distribution::Distribution;
5use crate::runtime::handler::Handler;
6use crate::runtime::trace::{Choice, ChoiceValue, Trace};
7
8use rand::RngCore;
9
10/// Handler for prior sampling - generates fresh random values from distributions.
11///
12/// This is the foundational interpreter that implements standard "forward sampling"
13/// from probabilistic models. It draws fresh values from distributions and accumulates
14/// log-probabilities in the trace.
15///
16/// Example:
17/// ```rust
18/// # use fugue::*;
19/// # use fugue::runtime::interpreters::PriorHandler;
20/// # use rand::rngs::StdRng;
21/// # use rand::SeedableRng;
22///
23/// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
24///     .bind(|x| observe(addr!("y"), Normal::new(x, 0.5).unwrap(), 1.2)
25///         .map(move |_| x));
26///
27/// let mut rng = StdRng::seed_from_u64(42);
28/// let (result, trace) = runtime::handler::run(
29///     PriorHandler { rng: &mut rng, trace: Trace::default() },
30///     model
31/// );
32///
33/// assert!(result.is_finite());
34/// assert!(trace.log_likelihood.is_finite());
35/// ```
36pub struct PriorHandler<'r, R: RngCore> {
37    /// Random number generator for sampling.
38    pub rng: &'r mut R,
39    /// Trace to accumulate execution history.
40    pub trace: Trace,
41}
42impl<'r, R: RngCore> Handler for PriorHandler<'r, R> {
43    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
44        let x = dist.sample(self.rng);
45        let lp = dist.log_prob(&x);
46        self.trace.log_prior += lp;
47        self.trace.choices.insert(
48            addr.clone(),
49            Choice {
50                addr: addr.clone(),
51                value: ChoiceValue::F64(x),
52                logp: lp,
53            },
54        );
55        x
56    }
57
58    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
59        let x = dist.sample(self.rng);
60        let lp = dist.log_prob(&x);
61        self.trace.log_prior += lp;
62        self.trace.choices.insert(
63            addr.clone(),
64            Choice {
65                addr: addr.clone(),
66                value: ChoiceValue::Bool(x),
67                logp: lp,
68            },
69        );
70        x
71    }
72
73    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
74        let x = dist.sample(self.rng);
75        let lp = dist.log_prob(&x);
76        self.trace.log_prior += lp;
77        self.trace.choices.insert(
78            addr.clone(),
79            Choice {
80                addr: addr.clone(),
81                value: ChoiceValue::U64(x),
82                logp: lp,
83            },
84        );
85        x
86    }
87
88    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
89        let x = dist.sample(self.rng);
90        let lp = dist.log_prob(&x);
91        self.trace.log_prior += lp;
92        self.trace.choices.insert(
93            addr.clone(),
94            Choice {
95                addr: addr.clone(),
96                value: ChoiceValue::Usize(x),
97                logp: lp,
98            },
99        );
100        x
101    }
102
103    fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
104        self.trace.log_likelihood += dist.log_prob(&value);
105    }
106
107    fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
108        self.trace.log_likelihood += dist.log_prob(&value);
109    }
110
111    fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
112        self.trace.log_likelihood += dist.log_prob(&value);
113    }
114
115    fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
116        self.trace.log_likelihood += dist.log_prob(&value);
117    }
118
119    fn on_factor(&mut self, logw: f64) {
120        self.trace.log_factors += logw;
121    }
122
123    fn finish(self) -> Trace {
124        self.trace
125    }
126}
127
128/// Handler for replaying models with values from an existing trace.
129///
130/// ReplayHandler replays a model execution using stored trace values. When a sampling
131/// site is encountered: if the address exists in the base trace, use that value;
132/// if missing, sample fresh. Essential for MCMC where you replay most choices
133/// but sample new values at specific sites.
134///
135/// Example:
136/// ```rust
137/// # use fugue::*;
138/// # use fugue::runtime::interpreters::*;
139/// # use rand::rngs::StdRng;
140/// # use rand::SeedableRng;
141///
142/// // Create base trace
143/// let mut rng = StdRng::seed_from_u64(42);
144/// let model_fn = || sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
145/// let (original, base_trace) = runtime::handler::run(
146///     PriorHandler { rng: &mut rng, trace: Trace::default() },
147///     model_fn()
148/// );
149///
150/// // Replay using base trace values
151/// let (replayed, _) = runtime::handler::run(
152///     ReplayHandler {
153///         rng: &mut rng,
154///         base: base_trace,
155///         trace: Trace::default()
156///     },
157///     model_fn()
158/// );
159///
160/// assert_eq!(original, replayed); // Same value replayed
161/// ```
162pub struct ReplayHandler<'r, R: RngCore> {
163    /// Random number generator for sampling at addresses not in base trace.
164    pub rng: &'r mut R,
165    /// Base trace containing values to replay.
166    pub base: Trace,
167    /// New trace to accumulate the replay execution.
168    pub trace: Trace,
169}
170impl<'r, R: RngCore> Handler for ReplayHandler<'r, R> {
171    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
172        let x = if let Some(c) = self.base.choices.get(addr) {
173            match c.value {
174                ChoiceValue::F64(v) => v,
175                _ => panic!("expected f64 at {}", addr),
176            }
177        } else {
178            dist.sample(self.rng)
179        };
180        let lp = dist.log_prob(&x);
181        self.trace.log_prior += lp;
182        self.trace.choices.insert(
183            addr.clone(),
184            Choice {
185                addr: addr.clone(),
186                value: ChoiceValue::F64(x),
187                logp: lp,
188            },
189        );
190        x
191    }
192
193    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
194        let x = if let Some(c) = self.base.choices.get(addr) {
195            match c.value {
196                ChoiceValue::Bool(v) => v,
197                _ => panic!("expected bool at {}", addr),
198            }
199        } else {
200            dist.sample(self.rng)
201        };
202        let lp = dist.log_prob(&x);
203        self.trace.log_prior += lp;
204        self.trace.choices.insert(
205            addr.clone(),
206            Choice {
207                addr: addr.clone(),
208                value: ChoiceValue::Bool(x),
209                logp: lp,
210            },
211        );
212        x
213    }
214
215    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
216        let x = if let Some(c) = self.base.choices.get(addr) {
217            match c.value {
218                ChoiceValue::U64(v) => v,
219                _ => panic!("expected u64 at {}", addr),
220            }
221        } else {
222            dist.sample(self.rng)
223        };
224        let lp = dist.log_prob(&x);
225        self.trace.log_prior += lp;
226        self.trace.choices.insert(
227            addr.clone(),
228            Choice {
229                addr: addr.clone(),
230                value: ChoiceValue::U64(x),
231                logp: lp,
232            },
233        );
234        x
235    }
236
237    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
238        let x = if let Some(c) = self.base.choices.get(addr) {
239            match c.value {
240                ChoiceValue::Usize(v) => v,
241                _ => panic!("expected usize at {}", addr),
242            }
243        } else {
244            dist.sample(self.rng)
245        };
246        let lp = dist.log_prob(&x);
247        self.trace.log_prior += lp;
248        self.trace.choices.insert(
249            addr.clone(),
250            Choice {
251                addr: addr.clone(),
252                value: ChoiceValue::Usize(x),
253                logp: lp,
254            },
255        );
256        x
257    }
258
259    fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
260        self.trace.log_likelihood += dist.log_prob(&value);
261    }
262
263    fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
264        self.trace.log_likelihood += dist.log_prob(&value);
265    }
266
267    fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
268        self.trace.log_likelihood += dist.log_prob(&value);
269    }
270
271    fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
272        self.trace.log_likelihood += dist.log_prob(&value);
273    }
274
275    fn on_factor(&mut self, logw: f64) {
276        self.trace.log_factors += logw;
277    }
278
279    fn finish(self) -> Trace {
280        self.trace
281    }
282}
283
284/// Handler for scoring a model given a complete trace of fixed choices.
285///
286/// ScoreGivenTrace computes log-probability of a model execution where all random
287/// choices are predetermined. No sampling occurs - values are looked up from the
288/// base trace and their log-probabilities computed. Essential for MCMC acceptance
289/// ratios, importance sampling, and model comparison.
290///
291/// Example:
292/// ```rust
293/// # use fugue::*;
294/// # use fugue::runtime::interpreters::*;
295/// # use rand::rngs::StdRng;
296/// # use rand::SeedableRng;
297///
298/// // Create a complete trace
299/// let mut rng = StdRng::seed_from_u64(42);
300/// let (_, complete_trace) = runtime::handler::run(
301///     PriorHandler { rng: &mut rng, trace: Trace::default() },
302///     sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
303/// );
304///
305/// // Score under different model parameters  
306/// let (value, score_trace) = runtime::handler::run(
307///     ScoreGivenTrace {
308///         base: complete_trace,
309///         trace: Trace::default()
310///     },
311///     sample(addr!("x"), Normal::new(1.0, 2.0).unwrap()) // Different parameters
312/// );
313///
314/// assert!(score_trace.total_log_weight().is_finite());
315/// ```
316pub struct ScoreGivenTrace {
317    /// Base trace containing the fixed choices to score.
318    pub base: Trace,
319    /// New trace to accumulate log-probabilities.
320    pub trace: Trace,
321}
322impl Handler for ScoreGivenTrace {
323    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
324        let c = self
325            .base
326            .choices
327            .get(addr)
328            .unwrap_or_else(|| panic!("missing value for site {} in base trace", addr));
329        let x = match c.value {
330            ChoiceValue::F64(v) => v,
331            _ => panic!("expected f64 at {}", addr),
332        };
333        let lp = dist.log_prob(&x);
334        self.trace.log_prior += lp;
335        self.trace.choices.insert(addr.clone(), c.clone());
336        x
337    }
338
339    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
340        let c = self
341            .base
342            .choices
343            .get(addr)
344            .unwrap_or_else(|| panic!("missing value for site {} in base trace", addr));
345        let x = match c.value {
346            ChoiceValue::Bool(v) => v,
347            _ => panic!("expected bool at {}", addr),
348        };
349        let lp = dist.log_prob(&x);
350        self.trace.log_prior += lp;
351        self.trace.choices.insert(addr.clone(), c.clone());
352        x
353    }
354
355    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
356        let c = self
357            .base
358            .choices
359            .get(addr)
360            .unwrap_or_else(|| panic!("missing value for site {} in base trace", addr));
361        let x = match c.value {
362            ChoiceValue::U64(v) => v,
363            _ => panic!("expected u64 at {}", addr),
364        };
365        let lp = dist.log_prob(&x);
366        self.trace.log_prior += lp;
367        self.trace.choices.insert(addr.clone(), c.clone());
368        x
369    }
370
371    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
372        let c = self
373            .base
374            .choices
375            .get(addr)
376            .unwrap_or_else(|| panic!("missing value for site {} in base trace", addr));
377        let x = match c.value {
378            ChoiceValue::Usize(v) => v,
379            _ => panic!("expected usize at {}", addr),
380        };
381        let lp = dist.log_prob(&x);
382        self.trace.log_prior += lp;
383        self.trace.choices.insert(addr.clone(), c.clone());
384        x
385    }
386
387    fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
388        self.trace.log_likelihood += dist.log_prob(&value);
389    }
390
391    fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
392        self.trace.log_likelihood += dist.log_prob(&value);
393    }
394
395    fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
396        self.trace.log_likelihood += dist.log_prob(&value);
397    }
398
399    fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
400        self.trace.log_likelihood += dist.log_prob(&value);
401    }
402
403    fn on_factor(&mut self, logw: f64) {
404        self.trace.log_factors += logw;
405    }
406
407    fn finish(self) -> Trace {
408        self.trace
409    }
410}
411
412/// Safe version of ReplayHandler that gracefully handles trace inconsistencies.
413///
414/// SafeReplayHandler replays model execution like ReplayHandler, but handles type
415/// mismatches and missing addresses gracefully by logging warnings and sampling
416/// fresh values instead of panicking. Essential for production systems where
417/// trace consistency cannot be guaranteed.
418///
419/// Example:
420/// ```rust
421/// # use fugue::*;
422/// # use fugue::runtime::interpreters::*;
423/// # use rand::rngs::StdRng;
424/// # use rand::SeedableRng;
425///
426/// // Create trace with potential inconsistencies
427/// let mut rng = StdRng::seed_from_u64(42);
428/// let (_, base_trace) = runtime::handler::run(
429///     PriorHandler { rng: &mut rng, trace: Trace::default() },
430///     sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()) // f64 value
431/// );
432///
433/// // Safe replay handles type mismatch gracefully
434/// let (result, trace) = runtime::handler::run(
435///     SafeReplayHandler {
436///         rng: &mut rng,
437///         base: base_trace,
438///         trace: Trace::default(),
439///         warn_on_mismatch: true, // Enable warnings
440///     },
441///     sample(addr!("x"), Bernoulli::new(0.5).unwrap()) // Expects bool
442/// );
443///
444/// assert!(trace.total_log_weight().is_finite()); // Continues execution
445/// ```
446pub struct SafeReplayHandler<'r, R: RngCore> {
447    /// Random number generator for sampling at addresses not in base trace.
448    pub rng: &'r mut R,
449    /// Base trace containing values to replay.
450    pub base: Trace,
451    /// New trace to accumulate the replay execution.
452    pub trace: Trace,
453    /// Whether to log warnings on type mismatches (useful for debugging).
454    pub warn_on_mismatch: bool,
455}
456impl<'r, R: RngCore> Handler for SafeReplayHandler<'r, R> {
457    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
458        let x = match self.base.get_f64(addr) {
459            Some(v) => v,
460            None => {
461                if self.warn_on_mismatch && self.base.choices.contains_key(addr) {
462                    if let Some(choice) = self.base.choices.get(addr) {
463                        eprintln!(
464                            "Warning: Type mismatch at {}: expected f64, found {}",
465                            addr,
466                            choice.value.type_name()
467                        );
468                    }
469                }
470                dist.sample(self.rng)
471            }
472        };
473        let lp = dist.log_prob(&x);
474        self.trace.log_prior += lp;
475        self.trace.choices.insert(
476            addr.clone(),
477            Choice {
478                addr: addr.clone(),
479                value: ChoiceValue::F64(x),
480                logp: lp,
481            },
482        );
483        x
484    }
485
486    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
487        let x = match self.base.get_bool(addr) {
488            Some(v) => v,
489            None => {
490                if self.warn_on_mismatch && self.base.choices.contains_key(addr) {
491                    if let Some(choice) = self.base.choices.get(addr) {
492                        eprintln!(
493                            "Warning: Type mismatch at {}: expected bool, found {}",
494                            addr,
495                            choice.value.type_name()
496                        );
497                    }
498                }
499                dist.sample(self.rng)
500            }
501        };
502        let lp = dist.log_prob(&x);
503        self.trace.log_prior += lp;
504        self.trace.choices.insert(
505            addr.clone(),
506            Choice {
507                addr: addr.clone(),
508                value: ChoiceValue::Bool(x),
509                logp: lp,
510            },
511        );
512        x
513    }
514
515    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
516        let x = match self.base.get_u64(addr) {
517            Some(v) => v,
518            None => {
519                if self.warn_on_mismatch && self.base.choices.contains_key(addr) {
520                    if let Some(choice) = self.base.choices.get(addr) {
521                        eprintln!(
522                            "Warning: Type mismatch at {}: expected u64, found {}",
523                            addr,
524                            choice.value.type_name()
525                        );
526                    }
527                }
528                dist.sample(self.rng)
529            }
530        };
531        let lp = dist.log_prob(&x);
532        self.trace.log_prior += lp;
533        self.trace.choices.insert(
534            addr.clone(),
535            Choice {
536                addr: addr.clone(),
537                value: ChoiceValue::U64(x),
538                logp: lp,
539            },
540        );
541        x
542    }
543
544    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
545        let x = match self.base.get_usize(addr) {
546            Some(v) => v,
547            None => {
548                if self.warn_on_mismatch && self.base.choices.contains_key(addr) {
549                    if let Some(choice) = self.base.choices.get(addr) {
550                        eprintln!(
551                            "Warning: Type mismatch at {}: expected usize, found {}",
552                            addr,
553                            choice.value.type_name()
554                        );
555                    }
556                }
557                dist.sample(self.rng)
558            }
559        };
560        let lp = dist.log_prob(&x);
561        self.trace.log_prior += lp;
562        self.trace.choices.insert(
563            addr.clone(),
564            Choice {
565                addr: addr.clone(),
566                value: ChoiceValue::Usize(x),
567                logp: lp,
568            },
569        );
570        x
571    }
572
573    fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
574        self.trace.log_likelihood += dist.log_prob(&value);
575    }
576
577    fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
578        self.trace.log_likelihood += dist.log_prob(&value);
579    }
580
581    fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
582        self.trace.log_likelihood += dist.log_prob(&value);
583    }
584
585    fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
586        self.trace.log_likelihood += dist.log_prob(&value);
587    }
588
589    fn on_factor(&mut self, logw: f64) {
590        self.trace.log_factors += logw;
591    }
592
593    fn finish(self) -> Trace {
594        self.trace
595    }
596}
597
598/// Safe version of ScoreGivenTrace that gracefully handles incomplete traces.
599///
600/// SafeScoreGivenTrace computes log-probability like ScoreGivenTrace, but handles
601/// missing addresses or type mismatches by returning negative infinity log-weight
602/// instead of panicking. Essential for production inference where trace validity
603/// cannot be guaranteed.
604///
605/// Example:
606/// ```rust
607/// # use fugue::*;
608/// # use fugue::runtime::interpreters::*;
609/// # use rand::rngs::StdRng;
610/// # use rand::SeedableRng;
611///
612/// // Create incomplete trace
613/// let mut rng = StdRng::seed_from_u64(42);
614/// let (_, incomplete_trace) = runtime::handler::run(
615///     PriorHandler { rng: &mut rng, trace: Trace::default() },
616///     sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()) // Only has "x"
617/// );
618///
619/// // Safe scoring handles missing address gracefully
620/// let (_, score_trace) = runtime::handler::run(
621///     SafeScoreGivenTrace {
622///         base: incomplete_trace,
623///         trace: Trace::default(),
624///         warn_on_error: true, // Enable warnings
625///     },
626///     sample(addr!("missing"), Normal::new(0.0, 1.0).unwrap()) // Address not in base
627/// );
628///
629/// assert_eq!(score_trace.total_log_weight(), f64::NEG_INFINITY); // Graceful failure
630/// ```
631pub struct SafeScoreGivenTrace {
632    /// Base trace containing the fixed choices to score.
633    pub base: Trace,
634    /// New trace to accumulate log-probabilities.
635    pub trace: Trace,
636    /// Whether to log warnings on missing addresses or type mismatches.
637    pub warn_on_error: bool,
638}
639impl Handler for SafeScoreGivenTrace {
640    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
641        match self.base.get_f64_result(addr) {
642            Ok(x) => {
643                let lp = dist.log_prob(&x);
644                self.trace.log_prior += lp;
645                if let Some(choice) = self.base.choices.get(addr) {
646                    self.trace.choices.insert(addr.clone(), choice.clone());
647                }
648                x
649            }
650            Err(e) => {
651                if self.warn_on_error {
652                    eprintln!("Warning: Failed to get f64 at {}: {}", addr, e);
653                }
654                // Add negative infinity to make this trace invalid
655                self.trace.log_prior += f64::NEG_INFINITY;
656                0.0 // Return a dummy value
657            }
658        }
659    }
660
661    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
662        match self.base.get_bool_result(addr) {
663            Ok(x) => {
664                let lp = dist.log_prob(&x);
665                self.trace.log_prior += lp;
666                if let Some(choice) = self.base.choices.get(addr) {
667                    self.trace.choices.insert(addr.clone(), choice.clone());
668                }
669                x
670            }
671            Err(e) => {
672                if self.warn_on_error {
673                    eprintln!("Warning: Failed to get bool at {}: {}", addr, e);
674                }
675                self.trace.log_prior += f64::NEG_INFINITY;
676                false
677            }
678        }
679    }
680
681    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
682        match self.base.get_u64_result(addr) {
683            Ok(x) => {
684                let lp = dist.log_prob(&x);
685                self.trace.log_prior += lp;
686                if let Some(choice) = self.base.choices.get(addr) {
687                    self.trace.choices.insert(addr.clone(), choice.clone());
688                }
689                x
690            }
691            Err(e) => {
692                if self.warn_on_error {
693                    eprintln!("Warning: Failed to get u64 at {}: {}", addr, e);
694                }
695                self.trace.log_prior += f64::NEG_INFINITY;
696                0
697            }
698        }
699    }
700
701    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
702        match self.base.get_usize_result(addr) {
703            Ok(x) => {
704                let lp = dist.log_prob(&x);
705                self.trace.log_prior += lp;
706                if let Some(choice) = self.base.choices.get(addr) {
707                    self.trace.choices.insert(addr.clone(), choice.clone());
708                }
709                x
710            }
711            Err(e) => {
712                if self.warn_on_error {
713                    eprintln!("Warning: Failed to get usize at {}: {}", addr, e);
714                }
715                self.trace.log_prior += f64::NEG_INFINITY;
716                0
717            }
718        }
719    }
720
721    fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
722        self.trace.log_likelihood += dist.log_prob(&value);
723    }
724
725    fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
726        self.trace.log_likelihood += dist.log_prob(&value);
727    }
728
729    fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
730        self.trace.log_likelihood += dist.log_prob(&value);
731    }
732
733    fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
734        self.trace.log_likelihood += dist.log_prob(&value);
735    }
736
737    fn on_factor(&mut self, logw: f64) {
738        self.trace.log_factors += logw;
739    }
740
741    fn finish(self) -> Trace {
742        self.trace
743    }
744}
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749    use crate::addr;
750    use crate::core::distribution::*;
751    use crate::core::model::{observe, sample, ModelExt};
752    use rand::rngs::StdRng;
753    use rand::SeedableRng;
754
755    #[test]
756    fn prior_handler_samples_and_accumulates() {
757        let mut rng = StdRng::seed_from_u64(7);
758        let (_val, trace) = crate::runtime::handler::run(
759            PriorHandler {
760                rng: &mut rng,
761                trace: Trace::default(),
762            },
763            sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
764                .and_then(|x| observe(addr!("y"), Normal::new(x, 1.0).unwrap(), 0.5)),
765        );
766        assert!(trace.choices.contains_key(&addr!("x")));
767        assert!(trace.log_prior.is_finite());
768        assert!(trace.log_likelihood.is_finite());
769    }
770
771    #[test]
772    fn replay_handler_reuses_values() {
773        let mut rng = StdRng::seed_from_u64(8);
774        let ((), base) = crate::runtime::handler::run(
775            PriorHandler {
776                rng: &mut rng,
777                trace: Trace::default(),
778            },
779            sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()).map(|_| ()),
780        );
781
782        let ((), replayed) = crate::runtime::handler::run(
783            ReplayHandler {
784                rng: &mut rng,
785                base: base.clone(),
786                trace: Trace::default(),
787            },
788            sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()).map(|_| ()),
789        );
790
791        let x_base = base.get_f64(&addr!("x")).unwrap();
792        let x_replay = replayed.get_f64(&addr!("x")).unwrap();
793        assert_eq!(x_base, x_replay);
794    }
795
796    #[test]
797    fn score_given_trace_scores_fixed_values() {
798        let mut rng = StdRng::seed_from_u64(9);
799        let (_a, base) = crate::runtime::handler::run(
800            PriorHandler {
801                rng: &mut rng,
802                trace: Trace::default(),
803            },
804            sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()),
805        );
806
807        let (_a2, scored) = crate::runtime::handler::run(
808            ScoreGivenTrace {
809                base: base.clone(),
810                trace: Trace::default(),
811            },
812            sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()),
813        );
814
815        // Should have same value and finite log_prior
816        assert_eq!(scored.get_f64(&addr!("x")), base.get_f64(&addr!("x")));
817        assert!(scored.log_prior.is_finite());
818    }
819
820    #[test]
821    fn safe_variants_handle_mismatches() {
822        // Build base trace with x as f64, then attempt to replay as bool
823        let mut rng = StdRng::seed_from_u64(10);
824        let (_a, base) = crate::runtime::handler::run(
825            PriorHandler {
826                rng: &mut rng,
827                trace: Trace::default(),
828            },
829            sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()),
830        );
831
832        // SafeReplayHandler should sample fresh value for bool and continue
833        let (_b, t1) = crate::runtime::handler::run(
834            SafeReplayHandler {
835                rng: &mut rng,
836                base: base.clone(),
837                trace: Trace::default(),
838                warn_on_mismatch: true,
839            },
840            sample(addr!("x"), Bernoulli::new(0.5).unwrap()),
841        );
842        assert!(t1.log_prior.is_finite());
843
844        // SafeScoreGivenTrace should mark as invalid by adding -inf
845        let (_c, t2) = crate::runtime::handler::run(
846            SafeScoreGivenTrace {
847                base: base.clone(),
848                trace: Trace::default(),
849                warn_on_error: true,
850            },
851            sample(addr!("x"), Bernoulli::new(0.5).unwrap()),
852        );
853        assert!(t2.log_prior.is_infinite());
854    }
855
856    #[test]
857    fn handlers_cover_all_types_sample_and_observe() {
858        // Model with multiple types
859        let model = sample(addr!("f"), Normal::new(0.0, 1.0).unwrap())
860            .and_then(|_| sample(addr!("b"), Bernoulli::new(0.6).unwrap()))
861            .and_then(|_| sample(addr!("u64"), Poisson::new(3.0).unwrap()))
862            .and_then(|_| sample(addr!("usz"), Categorical::new(vec![0.3, 0.7]).unwrap()))
863            .and_then(|_| observe(addr!("f_obs"), Normal::new(0.0, 1.0).unwrap(), 0.1))
864            .and_then(|_| observe(addr!("b_obs"), Bernoulli::new(0.4).unwrap(), true))
865            .and_then(|_| observe(addr!("u64_obs"), Poisson::new(2.0).unwrap(), 1))
866            .and_then(|_| {
867                observe(
868                    addr!("usz_obs"),
869                    Categorical::new(vec![0.5, 0.5]).unwrap(),
870                    1,
871                )
872            });
873
874        let (_a, t) = crate::runtime::handler::run(
875            PriorHandler {
876                rng: &mut StdRng::seed_from_u64(100),
877                trace: Trace::default(),
878            },
879            model,
880        );
881        assert!(t.get_f64(&addr!("f")).is_some());
882        assert!(t.get_bool(&addr!("b")).is_some());
883        assert!(t.get_u64(&addr!("u64")).is_some());
884        assert!(t.get_usize(&addr!("usz")).is_some());
885        assert!(t.log_likelihood.is_finite());
886
887        // Build base and score given trace for all types
888        let base = t.clone();
889        let (_sv, scored) = crate::runtime::handler::run(
890            ScoreGivenTrace {
891                base: base.clone(),
892                trace: Trace::default(),
893            },
894            sample(addr!("f"), Normal::new(0.0, 1.0).unwrap())
895                .and_then(|_| sample(addr!("b"), Bernoulli::new(0.6).unwrap()))
896                .and_then(|_| sample(addr!("u64"), Poisson::new(3.0).unwrap()))
897                .and_then(|_| sample(addr!("usz"), Categorical::new(vec![0.3, 0.7]).unwrap())),
898        );
899        assert!(scored.log_prior.is_finite());
900
901        // Safe replay mismatches for integer/categorical types
902        let (_sv2, safe) = crate::runtime::handler::run(
903            SafeReplayHandler {
904                rng: &mut StdRng::seed_from_u64(101),
905                base: base.clone(),
906                trace: Trace::default(),
907                warn_on_mismatch: true,
908            },
909            sample(addr!("u64"), Bernoulli::new(0.5).unwrap()),
910        );
911        assert!(safe.log_prior.is_finite());
912    }
913
914    #[test]
915    fn safe_score_given_trace_warn_flag_branches() {
916        let mut rng = StdRng::seed_from_u64(102);
917        let (_a, base) = crate::runtime::handler::run(
918            PriorHandler {
919                rng: &mut rng,
920                trace: Trace::default(),
921            },
922            sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()),
923        );
924        // warn_on_error = false
925        let (_b, t_false) = crate::runtime::handler::run(
926            SafeScoreGivenTrace {
927                base: base.clone(),
928                trace: Trace::default(),
929                warn_on_error: false,
930            },
931            sample(addr!("x"), Bernoulli::new(0.5).unwrap()),
932        );
933        assert!(t_false.log_prior.is_infinite());
934
935        // warn_on_error = true
936        let (_c, t_true) = crate::runtime::handler::run(
937            SafeScoreGivenTrace {
938                base: base.clone(),
939                trace: Trace::default(),
940                warn_on_error: true,
941            },
942            sample(addr!("x"), Bernoulli::new(0.5).unwrap()),
943        );
944        assert!(t_true.log_prior.is_infinite());
945    }
946
947    #[test]
948    #[should_panic]
949    fn replay_handler_panics_on_type_mismatch() {
950        // Base has f64, replay expects bool -> panic as designed
951        let mut rng = StdRng::seed_from_u64(103);
952        let (_a, base) = crate::runtime::handler::run(
953            PriorHandler {
954                rng: &mut rng,
955                trace: Trace::default(),
956            },
957            sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()),
958        );
959        let (_b, _t) = crate::runtime::handler::run(
960            ReplayHandler {
961                rng: &mut rng,
962                base: base.clone(),
963                trace: Trace::default(),
964            },
965            sample(addr!("x"), Bernoulli::new(0.5).unwrap()),
966        );
967    }
968
969    #[test]
970    fn safe_replay_handler_samples_fresh_for_missing_address() {
971        let mut rng = StdRng::seed_from_u64(104);
972        // Base trace without address "z"
973        let base = Trace::default();
974        let (_a, t) = crate::runtime::handler::run(
975            SafeReplayHandler {
976                rng: &mut rng,
977                base,
978                trace: Trace::default(),
979                warn_on_mismatch: true,
980            },
981            sample(addr!("z"), Normal::new(0.0, 1.0).unwrap()),
982        );
983        assert!(t.get_f64(&addr!("z")).is_some());
984    }
985}