1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/runtime/interpreters.md"))]
2
3use crate::core::address::Address;
4use crate::core::distribution::Distribution;
5use crate::runtime::handler::Handler;
6use crate::runtime::trace::{Choice, ChoiceValue, Trace};
7
8use rand::RngCore;
9
10pub struct PriorHandler<'r, R: RngCore> {
37 pub rng: &'r mut R,
39 pub trace: Trace,
41}
42impl<'r, R: RngCore> Handler for PriorHandler<'r, R> {
43 fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
44 let x = dist.sample(self.rng);
45 let lp = dist.log_prob(&x);
46 self.trace.log_prior += lp;
47 self.trace.choices.insert(
48 addr.clone(),
49 Choice {
50 addr: addr.clone(),
51 value: ChoiceValue::F64(x),
52 logp: lp,
53 },
54 );
55 x
56 }
57
58 fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
59 let x = dist.sample(self.rng);
60 let lp = dist.log_prob(&x);
61 self.trace.log_prior += lp;
62 self.trace.choices.insert(
63 addr.clone(),
64 Choice {
65 addr: addr.clone(),
66 value: ChoiceValue::Bool(x),
67 logp: lp,
68 },
69 );
70 x
71 }
72
73 fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
74 let x = dist.sample(self.rng);
75 let lp = dist.log_prob(&x);
76 self.trace.log_prior += lp;
77 self.trace.choices.insert(
78 addr.clone(),
79 Choice {
80 addr: addr.clone(),
81 value: ChoiceValue::U64(x),
82 logp: lp,
83 },
84 );
85 x
86 }
87
88 fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
89 let x = dist.sample(self.rng);
90 let lp = dist.log_prob(&x);
91 self.trace.log_prior += lp;
92 self.trace.choices.insert(
93 addr.clone(),
94 Choice {
95 addr: addr.clone(),
96 value: ChoiceValue::Usize(x),
97 logp: lp,
98 },
99 );
100 x
101 }
102
103 fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
104 self.trace.log_likelihood += dist.log_prob(&value);
105 }
106
107 fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
108 self.trace.log_likelihood += dist.log_prob(&value);
109 }
110
111 fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
112 self.trace.log_likelihood += dist.log_prob(&value);
113 }
114
115 fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
116 self.trace.log_likelihood += dist.log_prob(&value);
117 }
118
119 fn on_factor(&mut self, logw: f64) {
120 self.trace.log_factors += logw;
121 }
122
123 fn finish(self) -> Trace {
124 self.trace
125 }
126}
127
128pub struct ReplayHandler<'r, R: RngCore> {
163 pub rng: &'r mut R,
165 pub base: Trace,
167 pub trace: Trace,
169}
170impl<'r, R: RngCore> Handler for ReplayHandler<'r, R> {
171 fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
172 let x = if let Some(c) = self.base.choices.get(addr) {
173 match c.value {
174 ChoiceValue::F64(v) => v,
175 _ => panic!("expected f64 at {}", addr),
176 }
177 } else {
178 dist.sample(self.rng)
179 };
180 let lp = dist.log_prob(&x);
181 self.trace.log_prior += lp;
182 self.trace.choices.insert(
183 addr.clone(),
184 Choice {
185 addr: addr.clone(),
186 value: ChoiceValue::F64(x),
187 logp: lp,
188 },
189 );
190 x
191 }
192
193 fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
194 let x = if let Some(c) = self.base.choices.get(addr) {
195 match c.value {
196 ChoiceValue::Bool(v) => v,
197 _ => panic!("expected bool at {}", addr),
198 }
199 } else {
200 dist.sample(self.rng)
201 };
202 let lp = dist.log_prob(&x);
203 self.trace.log_prior += lp;
204 self.trace.choices.insert(
205 addr.clone(),
206 Choice {
207 addr: addr.clone(),
208 value: ChoiceValue::Bool(x),
209 logp: lp,
210 },
211 );
212 x
213 }
214
215 fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
216 let x = if let Some(c) = self.base.choices.get(addr) {
217 match c.value {
218 ChoiceValue::U64(v) => v,
219 _ => panic!("expected u64 at {}", addr),
220 }
221 } else {
222 dist.sample(self.rng)
223 };
224 let lp = dist.log_prob(&x);
225 self.trace.log_prior += lp;
226 self.trace.choices.insert(
227 addr.clone(),
228 Choice {
229 addr: addr.clone(),
230 value: ChoiceValue::U64(x),
231 logp: lp,
232 },
233 );
234 x
235 }
236
237 fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
238 let x = if let Some(c) = self.base.choices.get(addr) {
239 match c.value {
240 ChoiceValue::Usize(v) => v,
241 _ => panic!("expected usize at {}", addr),
242 }
243 } else {
244 dist.sample(self.rng)
245 };
246 let lp = dist.log_prob(&x);
247 self.trace.log_prior += lp;
248 self.trace.choices.insert(
249 addr.clone(),
250 Choice {
251 addr: addr.clone(),
252 value: ChoiceValue::Usize(x),
253 logp: lp,
254 },
255 );
256 x
257 }
258
259 fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
260 self.trace.log_likelihood += dist.log_prob(&value);
261 }
262
263 fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
264 self.trace.log_likelihood += dist.log_prob(&value);
265 }
266
267 fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
268 self.trace.log_likelihood += dist.log_prob(&value);
269 }
270
271 fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
272 self.trace.log_likelihood += dist.log_prob(&value);
273 }
274
275 fn on_factor(&mut self, logw: f64) {
276 self.trace.log_factors += logw;
277 }
278
279 fn finish(self) -> Trace {
280 self.trace
281 }
282}
283
284pub struct ScoreGivenTrace {
317 pub base: Trace,
319 pub trace: Trace,
321}
322impl Handler for ScoreGivenTrace {
323 fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
324 let c = self
325 .base
326 .choices
327 .get(addr)
328 .unwrap_or_else(|| panic!("missing value for site {} in base trace", addr));
329 let x = match c.value {
330 ChoiceValue::F64(v) => v,
331 _ => panic!("expected f64 at {}", addr),
332 };
333 let lp = dist.log_prob(&x);
334 self.trace.log_prior += lp;
335 self.trace.choices.insert(addr.clone(), c.clone());
336 x
337 }
338
339 fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
340 let c = self
341 .base
342 .choices
343 .get(addr)
344 .unwrap_or_else(|| panic!("missing value for site {} in base trace", addr));
345 let x = match c.value {
346 ChoiceValue::Bool(v) => v,
347 _ => panic!("expected bool at {}", addr),
348 };
349 let lp = dist.log_prob(&x);
350 self.trace.log_prior += lp;
351 self.trace.choices.insert(addr.clone(), c.clone());
352 x
353 }
354
355 fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
356 let c = self
357 .base
358 .choices
359 .get(addr)
360 .unwrap_or_else(|| panic!("missing value for site {} in base trace", addr));
361 let x = match c.value {
362 ChoiceValue::U64(v) => v,
363 _ => panic!("expected u64 at {}", addr),
364 };
365 let lp = dist.log_prob(&x);
366 self.trace.log_prior += lp;
367 self.trace.choices.insert(addr.clone(), c.clone());
368 x
369 }
370
371 fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
372 let c = self
373 .base
374 .choices
375 .get(addr)
376 .unwrap_or_else(|| panic!("missing value for site {} in base trace", addr));
377 let x = match c.value {
378 ChoiceValue::Usize(v) => v,
379 _ => panic!("expected usize at {}", addr),
380 };
381 let lp = dist.log_prob(&x);
382 self.trace.log_prior += lp;
383 self.trace.choices.insert(addr.clone(), c.clone());
384 x
385 }
386
387 fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
388 self.trace.log_likelihood += dist.log_prob(&value);
389 }
390
391 fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
392 self.trace.log_likelihood += dist.log_prob(&value);
393 }
394
395 fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
396 self.trace.log_likelihood += dist.log_prob(&value);
397 }
398
399 fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
400 self.trace.log_likelihood += dist.log_prob(&value);
401 }
402
403 fn on_factor(&mut self, logw: f64) {
404 self.trace.log_factors += logw;
405 }
406
407 fn finish(self) -> Trace {
408 self.trace
409 }
410}
411
412pub struct SafeReplayHandler<'r, R: RngCore> {
447 pub rng: &'r mut R,
449 pub base: Trace,
451 pub trace: Trace,
453 pub warn_on_mismatch: bool,
455}
456impl<'r, R: RngCore> Handler for SafeReplayHandler<'r, R> {
457 fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
458 let x = match self.base.get_f64(addr) {
459 Some(v) => v,
460 None => {
461 if self.warn_on_mismatch && self.base.choices.contains_key(addr) {
462 if let Some(choice) = self.base.choices.get(addr) {
463 eprintln!(
464 "Warning: Type mismatch at {}: expected f64, found {}",
465 addr,
466 choice.value.type_name()
467 );
468 }
469 }
470 dist.sample(self.rng)
471 }
472 };
473 let lp = dist.log_prob(&x);
474 self.trace.log_prior += lp;
475 self.trace.choices.insert(
476 addr.clone(),
477 Choice {
478 addr: addr.clone(),
479 value: ChoiceValue::F64(x),
480 logp: lp,
481 },
482 );
483 x
484 }
485
486 fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
487 let x = match self.base.get_bool(addr) {
488 Some(v) => v,
489 None => {
490 if self.warn_on_mismatch && self.base.choices.contains_key(addr) {
491 if let Some(choice) = self.base.choices.get(addr) {
492 eprintln!(
493 "Warning: Type mismatch at {}: expected bool, found {}",
494 addr,
495 choice.value.type_name()
496 );
497 }
498 }
499 dist.sample(self.rng)
500 }
501 };
502 let lp = dist.log_prob(&x);
503 self.trace.log_prior += lp;
504 self.trace.choices.insert(
505 addr.clone(),
506 Choice {
507 addr: addr.clone(),
508 value: ChoiceValue::Bool(x),
509 logp: lp,
510 },
511 );
512 x
513 }
514
515 fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
516 let x = match self.base.get_u64(addr) {
517 Some(v) => v,
518 None => {
519 if self.warn_on_mismatch && self.base.choices.contains_key(addr) {
520 if let Some(choice) = self.base.choices.get(addr) {
521 eprintln!(
522 "Warning: Type mismatch at {}: expected u64, found {}",
523 addr,
524 choice.value.type_name()
525 );
526 }
527 }
528 dist.sample(self.rng)
529 }
530 };
531 let lp = dist.log_prob(&x);
532 self.trace.log_prior += lp;
533 self.trace.choices.insert(
534 addr.clone(),
535 Choice {
536 addr: addr.clone(),
537 value: ChoiceValue::U64(x),
538 logp: lp,
539 },
540 );
541 x
542 }
543
544 fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
545 let x = match self.base.get_usize(addr) {
546 Some(v) => v,
547 None => {
548 if self.warn_on_mismatch && self.base.choices.contains_key(addr) {
549 if let Some(choice) = self.base.choices.get(addr) {
550 eprintln!(
551 "Warning: Type mismatch at {}: expected usize, found {}",
552 addr,
553 choice.value.type_name()
554 );
555 }
556 }
557 dist.sample(self.rng)
558 }
559 };
560 let lp = dist.log_prob(&x);
561 self.trace.log_prior += lp;
562 self.trace.choices.insert(
563 addr.clone(),
564 Choice {
565 addr: addr.clone(),
566 value: ChoiceValue::Usize(x),
567 logp: lp,
568 },
569 );
570 x
571 }
572
573 fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
574 self.trace.log_likelihood += dist.log_prob(&value);
575 }
576
577 fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
578 self.trace.log_likelihood += dist.log_prob(&value);
579 }
580
581 fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
582 self.trace.log_likelihood += dist.log_prob(&value);
583 }
584
585 fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
586 self.trace.log_likelihood += dist.log_prob(&value);
587 }
588
589 fn on_factor(&mut self, logw: f64) {
590 self.trace.log_factors += logw;
591 }
592
593 fn finish(self) -> Trace {
594 self.trace
595 }
596}
597
598pub struct SafeScoreGivenTrace {
632 pub base: Trace,
634 pub trace: Trace,
636 pub warn_on_error: bool,
638}
639impl Handler for SafeScoreGivenTrace {
640 fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
641 match self.base.get_f64_result(addr) {
642 Ok(x) => {
643 let lp = dist.log_prob(&x);
644 self.trace.log_prior += lp;
645 if let Some(choice) = self.base.choices.get(addr) {
646 self.trace.choices.insert(addr.clone(), choice.clone());
647 }
648 x
649 }
650 Err(e) => {
651 if self.warn_on_error {
652 eprintln!("Warning: Failed to get f64 at {}: {}", addr, e);
653 }
654 self.trace.log_prior += f64::NEG_INFINITY;
656 0.0 }
658 }
659 }
660
661 fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
662 match self.base.get_bool_result(addr) {
663 Ok(x) => {
664 let lp = dist.log_prob(&x);
665 self.trace.log_prior += lp;
666 if let Some(choice) = self.base.choices.get(addr) {
667 self.trace.choices.insert(addr.clone(), choice.clone());
668 }
669 x
670 }
671 Err(e) => {
672 if self.warn_on_error {
673 eprintln!("Warning: Failed to get bool at {}: {}", addr, e);
674 }
675 self.trace.log_prior += f64::NEG_INFINITY;
676 false
677 }
678 }
679 }
680
681 fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
682 match self.base.get_u64_result(addr) {
683 Ok(x) => {
684 let lp = dist.log_prob(&x);
685 self.trace.log_prior += lp;
686 if let Some(choice) = self.base.choices.get(addr) {
687 self.trace.choices.insert(addr.clone(), choice.clone());
688 }
689 x
690 }
691 Err(e) => {
692 if self.warn_on_error {
693 eprintln!("Warning: Failed to get u64 at {}: {}", addr, e);
694 }
695 self.trace.log_prior += f64::NEG_INFINITY;
696 0
697 }
698 }
699 }
700
701 fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
702 match self.base.get_usize_result(addr) {
703 Ok(x) => {
704 let lp = dist.log_prob(&x);
705 self.trace.log_prior += lp;
706 if let Some(choice) = self.base.choices.get(addr) {
707 self.trace.choices.insert(addr.clone(), choice.clone());
708 }
709 x
710 }
711 Err(e) => {
712 if self.warn_on_error {
713 eprintln!("Warning: Failed to get usize at {}: {}", addr, e);
714 }
715 self.trace.log_prior += f64::NEG_INFINITY;
716 0
717 }
718 }
719 }
720
721 fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
722 self.trace.log_likelihood += dist.log_prob(&value);
723 }
724
725 fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
726 self.trace.log_likelihood += dist.log_prob(&value);
727 }
728
729 fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
730 self.trace.log_likelihood += dist.log_prob(&value);
731 }
732
733 fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
734 self.trace.log_likelihood += dist.log_prob(&value);
735 }
736
737 fn on_factor(&mut self, logw: f64) {
738 self.trace.log_factors += logw;
739 }
740
741 fn finish(self) -> Trace {
742 self.trace
743 }
744}
745
746#[cfg(test)]
747mod tests {
748 use super::*;
749 use crate::addr;
750 use crate::core::distribution::*;
751 use crate::core::model::{observe, sample, ModelExt};
752 use rand::rngs::StdRng;
753 use rand::SeedableRng;
754
755 #[test]
756 fn prior_handler_samples_and_accumulates() {
757 let mut rng = StdRng::seed_from_u64(7);
758 let (_val, trace) = crate::runtime::handler::run(
759 PriorHandler {
760 rng: &mut rng,
761 trace: Trace::default(),
762 },
763 sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
764 .and_then(|x| observe(addr!("y"), Normal::new(x, 1.0).unwrap(), 0.5)),
765 );
766 assert!(trace.choices.contains_key(&addr!("x")));
767 assert!(trace.log_prior.is_finite());
768 assert!(trace.log_likelihood.is_finite());
769 }
770
771 #[test]
772 fn replay_handler_reuses_values() {
773 let mut rng = StdRng::seed_from_u64(8);
774 let ((), base) = crate::runtime::handler::run(
775 PriorHandler {
776 rng: &mut rng,
777 trace: Trace::default(),
778 },
779 sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()).map(|_| ()),
780 );
781
782 let ((), replayed) = crate::runtime::handler::run(
783 ReplayHandler {
784 rng: &mut rng,
785 base: base.clone(),
786 trace: Trace::default(),
787 },
788 sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()).map(|_| ()),
789 );
790
791 let x_base = base.get_f64(&addr!("x")).unwrap();
792 let x_replay = replayed.get_f64(&addr!("x")).unwrap();
793 assert_eq!(x_base, x_replay);
794 }
795
796 #[test]
797 fn score_given_trace_scores_fixed_values() {
798 let mut rng = StdRng::seed_from_u64(9);
799 let (_a, base) = crate::runtime::handler::run(
800 PriorHandler {
801 rng: &mut rng,
802 trace: Trace::default(),
803 },
804 sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()),
805 );
806
807 let (_a2, scored) = crate::runtime::handler::run(
808 ScoreGivenTrace {
809 base: base.clone(),
810 trace: Trace::default(),
811 },
812 sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()),
813 );
814
815 assert_eq!(scored.get_f64(&addr!("x")), base.get_f64(&addr!("x")));
817 assert!(scored.log_prior.is_finite());
818 }
819
820 #[test]
821 fn safe_variants_handle_mismatches() {
822 let mut rng = StdRng::seed_from_u64(10);
824 let (_a, base) = crate::runtime::handler::run(
825 PriorHandler {
826 rng: &mut rng,
827 trace: Trace::default(),
828 },
829 sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()),
830 );
831
832 let (_b, t1) = crate::runtime::handler::run(
834 SafeReplayHandler {
835 rng: &mut rng,
836 base: base.clone(),
837 trace: Trace::default(),
838 warn_on_mismatch: true,
839 },
840 sample(addr!("x"), Bernoulli::new(0.5).unwrap()),
841 );
842 assert!(t1.log_prior.is_finite());
843
844 let (_c, t2) = crate::runtime::handler::run(
846 SafeScoreGivenTrace {
847 base: base.clone(),
848 trace: Trace::default(),
849 warn_on_error: true,
850 },
851 sample(addr!("x"), Bernoulli::new(0.5).unwrap()),
852 );
853 assert!(t2.log_prior.is_infinite());
854 }
855
856 #[test]
857 fn handlers_cover_all_types_sample_and_observe() {
858 let model = sample(addr!("f"), Normal::new(0.0, 1.0).unwrap())
860 .and_then(|_| sample(addr!("b"), Bernoulli::new(0.6).unwrap()))
861 .and_then(|_| sample(addr!("u64"), Poisson::new(3.0).unwrap()))
862 .and_then(|_| sample(addr!("usz"), Categorical::new(vec![0.3, 0.7]).unwrap()))
863 .and_then(|_| observe(addr!("f_obs"), Normal::new(0.0, 1.0).unwrap(), 0.1))
864 .and_then(|_| observe(addr!("b_obs"), Bernoulli::new(0.4).unwrap(), true))
865 .and_then(|_| observe(addr!("u64_obs"), Poisson::new(2.0).unwrap(), 1))
866 .and_then(|_| {
867 observe(
868 addr!("usz_obs"),
869 Categorical::new(vec![0.5, 0.5]).unwrap(),
870 1,
871 )
872 });
873
874 let (_a, t) = crate::runtime::handler::run(
875 PriorHandler {
876 rng: &mut StdRng::seed_from_u64(100),
877 trace: Trace::default(),
878 },
879 model,
880 );
881 assert!(t.get_f64(&addr!("f")).is_some());
882 assert!(t.get_bool(&addr!("b")).is_some());
883 assert!(t.get_u64(&addr!("u64")).is_some());
884 assert!(t.get_usize(&addr!("usz")).is_some());
885 assert!(t.log_likelihood.is_finite());
886
887 let base = t.clone();
889 let (_sv, scored) = crate::runtime::handler::run(
890 ScoreGivenTrace {
891 base: base.clone(),
892 trace: Trace::default(),
893 },
894 sample(addr!("f"), Normal::new(0.0, 1.0).unwrap())
895 .and_then(|_| sample(addr!("b"), Bernoulli::new(0.6).unwrap()))
896 .and_then(|_| sample(addr!("u64"), Poisson::new(3.0).unwrap()))
897 .and_then(|_| sample(addr!("usz"), Categorical::new(vec![0.3, 0.7]).unwrap())),
898 );
899 assert!(scored.log_prior.is_finite());
900
901 let (_sv2, safe) = crate::runtime::handler::run(
903 SafeReplayHandler {
904 rng: &mut StdRng::seed_from_u64(101),
905 base: base.clone(),
906 trace: Trace::default(),
907 warn_on_mismatch: true,
908 },
909 sample(addr!("u64"), Bernoulli::new(0.5).unwrap()),
910 );
911 assert!(safe.log_prior.is_finite());
912 }
913
914 #[test]
915 fn safe_score_given_trace_warn_flag_branches() {
916 let mut rng = StdRng::seed_from_u64(102);
917 let (_a, base) = crate::runtime::handler::run(
918 PriorHandler {
919 rng: &mut rng,
920 trace: Trace::default(),
921 },
922 sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()),
923 );
924 let (_b, t_false) = crate::runtime::handler::run(
926 SafeScoreGivenTrace {
927 base: base.clone(),
928 trace: Trace::default(),
929 warn_on_error: false,
930 },
931 sample(addr!("x"), Bernoulli::new(0.5).unwrap()),
932 );
933 assert!(t_false.log_prior.is_infinite());
934
935 let (_c, t_true) = crate::runtime::handler::run(
937 SafeScoreGivenTrace {
938 base: base.clone(),
939 trace: Trace::default(),
940 warn_on_error: true,
941 },
942 sample(addr!("x"), Bernoulli::new(0.5).unwrap()),
943 );
944 assert!(t_true.log_prior.is_infinite());
945 }
946
947 #[test]
948 #[should_panic]
949 fn replay_handler_panics_on_type_mismatch() {
950 let mut rng = StdRng::seed_from_u64(103);
952 let (_a, base) = crate::runtime::handler::run(
953 PriorHandler {
954 rng: &mut rng,
955 trace: Trace::default(),
956 },
957 sample(addr!("x"), Normal::new(0.0, 1.0).unwrap()),
958 );
959 let (_b, _t) = crate::runtime::handler::run(
960 ReplayHandler {
961 rng: &mut rng,
962 base: base.clone(),
963 trace: Trace::default(),
964 },
965 sample(addr!("x"), Bernoulli::new(0.5).unwrap()),
966 );
967 }
968
969 #[test]
970 fn safe_replay_handler_samples_fresh_for_missing_address() {
971 let mut rng = StdRng::seed_from_u64(104);
972 let base = Trace::default();
974 let (_a, t) = crate::runtime::handler::run(
975 SafeReplayHandler {
976 rng: &mut rng,
977 base,
978 trace: Trace::default(),
979 warn_on_mismatch: true,
980 },
981 sample(addr!("z"), Normal::new(0.0, 1.0).unwrap()),
982 );
983 assert!(t.get_f64(&addr!("z")).is_some());
984 }
985}