fugue/core/model.rs
1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/core/model.md"))]
2use crate::core::address::Address;
3use crate::core::distribution::{Distribution, LogF64};
4
5/// `Model<A>` represents a probabilistic program that yields a value of type `A` when executed by a handler.
6/// Models are built from four variants: `Pure`, `Sample*`, `Observe*`, and `Factor`.
7///
8/// Example:
9/// ```rust
10/// # use fugue::*;
11/// // Deterministic value
12/// let m = pure(42.0);
13///
14/// // Sample from distribution
15/// let s = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
16///
17/// // Dependent sampling
18/// let chain = s.bind(|x| sample(addr!("y"), Normal::new(x, 0.5).unwrap()));
19/// ```
20pub enum Model<A> {
21 /// A deterministic computation yielding a pure value.
22 Pure(A),
23 /// Sample from an f64 distribution (continuous distributions).
24 SampleF64 {
25 /// Unique identifier for this sampling site.
26 addr: Address,
27 /// Distribution to sample from.
28 dist: Box<dyn Distribution<f64>>,
29 /// Continuation function to apply to the sampled value.
30 k: Box<dyn FnOnce(f64) -> Model<A> + Send + 'static>,
31 },
32 /// Sample from a bool distribution (Bernoulli).
33 SampleBool {
34 /// Unique identifier for this sampling site.
35 addr: Address,
36 /// Distribution to sample from.
37 dist: Box<dyn Distribution<bool>>,
38 /// Continuation function to apply to the sampled value.
39 k: Box<dyn FnOnce(bool) -> Model<A> + Send + 'static>,
40 },
41 /// Sample from a u64 distribution (Poisson, Binomial).
42 SampleU64 {
43 /// Unique identifier for this sampling site.
44 addr: Address,
45 /// Distribution to sample from.
46 dist: Box<dyn Distribution<u64>>,
47 /// Continuation function to apply to the sampled value.
48 k: Box<dyn FnOnce(u64) -> Model<A> + Send + 'static>,
49 },
50 /// Sample from a usize distribution (Categorical).
51 SampleUsize {
52 /// Unique identifier for this sampling site.
53 addr: Address,
54 /// Distribution to sample from.
55 dist: Box<dyn Distribution<usize>>,
56 /// Continuation function to apply to the sampled value.
57 k: Box<dyn FnOnce(usize) -> Model<A> + Send + 'static>,
58 },
59 /// Observe/condition on an f64 value.
60 ObserveF64 {
61 /// Unique identifier for this observation site.
62 addr: Address,
63 /// Distribution that generates the observed value.
64 dist: Box<dyn Distribution<f64>>,
65 /// The observed value to condition on.
66 value: f64,
67 /// Continuation function (always receives unit).
68 k: Box<dyn FnOnce(()) -> Model<A> + Send + 'static>,
69 },
70 /// Observe/condition on a bool value.
71 ObserveBool {
72 /// Unique identifier for this observation site.
73 addr: Address,
74 /// Distribution that generates the observed value.
75 dist: Box<dyn Distribution<bool>>,
76 /// The observed value to condition on.
77 value: bool,
78 /// Continuation function (always receives unit).
79 k: Box<dyn FnOnce(()) -> Model<A> + Send + 'static>,
80 },
81 /// Observe/condition on a u64 value.
82 ObserveU64 {
83 /// Unique identifier for this observation site.
84 addr: Address,
85 /// Distribution that generates the observed value.
86 dist: Box<dyn Distribution<u64>>,
87 /// The observed value to condition on.
88 value: u64,
89 /// Continuation function (always receives unit).
90 k: Box<dyn FnOnce(()) -> Model<A> + Send + 'static>,
91 },
92 /// Observe/condition on a usize value.
93 ObserveUsize {
94 /// Unique identifier for this observation site.
95 addr: Address,
96 /// Distribution that generates the observed value.
97 dist: Box<dyn Distribution<usize>>,
98 /// The observed value to condition on.
99 value: usize,
100 /// Continuation function (always receives unit).
101 k: Box<dyn FnOnce(()) -> Model<A> + Send + 'static>,
102 },
103 /// Add a log-weight factor to the model.
104 Factor {
105 /// Log-weight to add to the model's total weight.
106 logw: LogF64,
107 /// Continuation function (always receives unit).
108 k: Box<dyn FnOnce(()) -> Model<A> + Send + 'static>,
109 },
110}
111
112/// Lift a deterministic value, `a`, into the model monad.
113/// Creates a `Model` that always returns the given value, `a`, without any probabilistic behavior.
114/// This is the unit operation for the model monad.
115///
116/// Example:
117/// ```rust
118/// # use fugue::*;
119///
120/// let model = pure(42.0);
121/// // When executed, this model will always return 42.0
122/// ```
123pub fn pure<A>(a: A) -> Model<A> {
124 Model::Pure(a)
125}
126/// Sample from an f64 distribution (continuous distributions).
127///
128/// Example:
129/// ```rust
130/// # use fugue::*;
131///
132/// let model = sample_f64(addr!("x"), Normal::new(0.0, 1.0).unwrap());
133/// ```
134pub fn sample_f64(addr: Address, dist: impl Distribution<f64> + 'static) -> Model<f64> {
135 Model::SampleF64 {
136 addr,
137 dist: Box::new(dist),
138 k: Box::new(pure),
139 }
140}
141/// Sample from a bool distribution (Bernoulli).
142///
143/// Example:
144/// ```rust
145/// # use fugue::*;
146///
147/// let model = sample_bool(addr!("coin"), Bernoulli::new(0.5).unwrap());
148/// ```
149pub fn sample_bool(addr: Address, dist: impl Distribution<bool> + 'static) -> Model<bool> {
150 Model::SampleBool {
151 addr,
152 dist: Box::new(dist),
153 k: Box::new(pure),
154 }
155}
156/// Sample from a u64 distribution (Poisson, Binomial).
157///
158/// Example:
159/// ```rust
160/// # use fugue::*;
161///
162/// let model = sample_u64(addr!("count"), Poisson::new(3.0).unwrap());
163/// ```
164pub fn sample_u64(addr: Address, dist: impl Distribution<u64> + 'static) -> Model<u64> {
165 Model::SampleU64 {
166 addr,
167 dist: Box::new(dist),
168 k: Box::new(pure),
169 }
170}
171/// Sample from a usize distribution (Categorical).
172///
173/// Example:
174/// ```rust
175/// # use fugue::*;
176///
177/// let model = sample_usize(addr!("choice"), Categorical::new(vec![0.3, 0.5, 0.2]).unwrap());
178/// ```
179pub fn sample_usize(addr: Address, dist: impl Distribution<usize> + 'static) -> Model<usize> {
180 Model::SampleUsize {
181 addr,
182 dist: Box::new(dist),
183 k: Box::new(pure),
184 }
185}
186
187/// Sample from a distribution (generic version - chooses the right variant automatically).
188// This is the main sampling function that works with any distribution type.
189// The return type is inferred from the distribution type.
190///
191/// Type-specific variants:
192/// - `sample_f64` - Sample from f64 distributions (continuous distributions)
193/// - `sample_bool` - Sample from bool distributions (Bernoulli)
194/// - `sample_u64` - Sample from u64 distributions (Poisson, Binomial)
195/// - `sample_usize` - Sample from usize distributions (Categorical)
196///
197/// Example:
198/// ```rust
199/// # use fugue::*;
200/// // Automatically returns f64 for continuous distributions
201/// let normal_sample: Model<f64> = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
202/// // Automatically returns bool for Bernoulli
203/// let coin_flip: Model<bool> = sample(addr!("coin"), Bernoulli::new(0.5).unwrap());
204/// // Automatically returns u64 for Poisson
205/// let count: Model<u64> = sample(addr!("count"), Poisson::new(3.0).unwrap());
206/// // Automatically returns usize for Categorical
207/// let choice: Model<usize> = sample(addr!("choice"),
208/// Categorical::new(vec![0.3, 0.5, 0.2]).unwrap());
209/// ```
210pub fn sample<T>(addr: Address, dist: impl Distribution<T> + 'static) -> Model<T>
211where
212 T: SampleType,
213{
214 T::make_sample_model(addr, Box::new(dist))
215}
216
217/// Trait for types that can be sampled in Models.
218/// This enables automatic dispatch to the right Model variant.
219pub trait SampleType: 'static + Send + Sync + Sized {
220 fn make_sample_model(addr: Address, dist: Box<dyn Distribution<Self>>) -> Model<Self>;
221 fn make_observe_model(
222 addr: Address,
223 dist: Box<dyn Distribution<Self>>,
224 value: Self,
225 ) -> Model<()>;
226}
227impl SampleType for f64 {
228 fn make_sample_model(addr: Address, dist: Box<dyn Distribution<f64>>) -> Model<f64> {
229 Model::SampleF64 {
230 addr,
231 dist,
232 k: Box::new(pure),
233 }
234 }
235 fn make_observe_model(
236 addr: Address,
237 dist: Box<dyn Distribution<f64>>,
238 value: f64,
239 ) -> Model<()> {
240 Model::ObserveF64 {
241 addr,
242 dist,
243 value,
244 k: Box::new(pure),
245 }
246 }
247}
248impl SampleType for bool {
249 fn make_sample_model(addr: Address, dist: Box<dyn Distribution<bool>>) -> Model<bool> {
250 Model::SampleBool {
251 addr,
252 dist,
253 k: Box::new(pure),
254 }
255 }
256 fn make_observe_model(
257 addr: Address,
258 dist: Box<dyn Distribution<bool>>,
259 value: bool,
260 ) -> Model<()> {
261 Model::ObserveBool {
262 addr,
263 dist,
264 value,
265 k: Box::new(pure),
266 }
267 }
268}
269impl SampleType for u64 {
270 fn make_sample_model(addr: Address, dist: Box<dyn Distribution<u64>>) -> Model<u64> {
271 Model::SampleU64 {
272 addr,
273 dist,
274 k: Box::new(pure),
275 }
276 }
277 fn make_observe_model(
278 addr: Address,
279 dist: Box<dyn Distribution<u64>>,
280 value: u64,
281 ) -> Model<()> {
282 Model::ObserveU64 {
283 addr,
284 dist,
285 value,
286 k: Box::new(pure),
287 }
288 }
289}
290impl SampleType for usize {
291 fn make_sample_model(addr: Address, dist: Box<dyn Distribution<usize>>) -> Model<usize> {
292 Model::SampleUsize {
293 addr,
294 dist,
295 k: Box::new(pure),
296 }
297 }
298 fn make_observe_model(
299 addr: Address,
300 dist: Box<dyn Distribution<usize>>,
301 value: usize,
302 ) -> Model<()> {
303 Model::ObserveUsize {
304 addr,
305 dist,
306 value,
307 k: Box::new(pure),
308 }
309 }
310}
311
312/// Observe a value from a distribution (generic version).
313/// This function automatically chooses the right observation variant based on the distribution type and observed value type.
314///
315/// Example:
316/// ```rust
317/// use fugue::*;
318/// // Observe f64 value from continuous distribution
319/// let model = observe(addr!("y"), Normal::new(1.0, 0.5).unwrap(), 2.5);
320/// // Observe bool value from Bernoulli
321/// let model = observe(addr!("coin"), Bernoulli::new(0.6).unwrap(), true);
322/// // Observe u64 count from Poisson
323/// let model = observe(addr!("count"), Poisson::new(3.0).unwrap(), 5u64);
324/// // Observe usize choice from Categorical
325/// let model = observe(addr!("choice"),
326/// Categorical::new(vec![0.3, 0.5, 0.2]).unwrap(), 1usize);
327/// ```
328pub fn observe<T>(addr: Address, dist: impl Distribution<T> + 'static, value: T) -> Model<()>
329where
330 T: SampleType,
331{
332 T::make_observe_model(addr, Box::new(dist), value)
333}
334
335/// Add an unnormalized log-weight `logw` to the model, returning a `Model<()>`.
336///
337/// Factors allow encoding soft constraints or arbitrary log-probability contributions to the model.
338/// They are particularly useful for:
339///
340/// - Encoding constraints that should be "mostly satisfied"
341/// - Adding custom log-likelihood terms
342/// - Implementing rejection sampling (using negative infinity)
343///
344/// Example:
345/// ```rust
346/// # use fugue::*;
347/// // Add positive log-weight (increases probability)
348/// let model = factor(1.0); // Adds log(e) = 1.0 to weight
349/// // Add negative log-weight (decreases probability)
350/// let model = factor(-2.0); // Subtracts 2.0 from log-weight
351/// // Reject/fail (zero probability)
352/// let model = factor(f64::NEG_INFINITY);
353/// // Soft constraint: prefer values near zero
354/// let x = 5.0;
355/// let soft_constraint = factor(-0.5 * x * x); // Gaussian-like penalty
356/// ```
357pub fn factor(logw: LogF64) -> Model<()> {
358 Model::Factor {
359 logw,
360 k: Box::new(pure),
361 }
362}
363
364/// `ModelExt<A>` provides monadic operations for composing `Model<A>` values.
365/// Provides `bind`, `map`, and `and_then` for chaining and transforming probabilistic computations.
366///
367/// Example:
368/// ```rust
369/// # use fugue::*;
370/// // Transform result with map
371/// let transformed = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
372/// .map(|x| x * 2.0);
373///
374/// // Chain dependent computations with bind
375/// let dependent = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
376/// .bind(|x| sample(addr!("y"), Normal::new(x, 0.5).unwrap()));
377/// ```
378pub trait ModelExt<A>: Sized {
379 /// Monadic bind operation (>>=).
380 ///
381 /// Chains two probabilistic computations where the second depends on the result of the first.
382 /// This is the fundamental operation for building complex probabilistic models from simpler parts.
383 /// The function `k` takes the result of this model and returns a new model.
384 ///
385 /// Example:
386 /// ```rust
387 /// # use fugue::*;
388 /// // Dependent sampling: y depends on x
389 /// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
390 /// .bind(|x| sample(addr!("y"), Normal::new(x, 0.1).unwrap()));
391 /// ```
392 fn bind<B>(self, k: impl FnOnce(A) -> Model<B> + Send + 'static) -> Model<B>;
393
394 /// Apply a function, `f`, to transform the result of this model.
395 /// This is the functor map operation - it transforms the output of a model without adding any additional probabilistic behavior.
396 ///
397 /// Example:
398 /// ```rust
399 /// # use fugue::*;
400 /// // Transform the sampled value
401 /// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
402 /// .map(|x| x.exp()); // Apply exponential function
403 /// ```
404 fn map<B>(self, f: impl FnOnce(A) -> B + Send + 'static) -> Model<B> {
405 self.bind(|a| pure(f(a)))
406 }
407
408 /// Alias for `bind` - chains dependent probabilistic computations.
409 /// This method provides a more familiar interface for Rust developers used to `Option::and_then` and `Result::and_then`.
410 ///
411 /// Example:
412 /// ```rust
413 /// # use fugue::*;
414 /// // Dependent sampling: y depends on x
415 /// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
416 /// .and_then(|x| sample(addr!("y"), Normal::new(x, 0.1).unwrap()));
417 /// ```
418 fn and_then<B>(self, k: impl FnOnce(A) -> Model<B> + Send + 'static) -> Model<B> {
419 self.bind(k)
420 }
421}
422impl<A: 'static> ModelExt<A> for Model<A> {
423 fn bind<B>(self, k: impl FnOnce(A) -> Model<B> + Send + 'static) -> Model<B> {
424 match self {
425 Model::Pure(a) => k(a),
426 Model::SampleF64 { addr, dist, k: k1 } => Model::SampleF64 {
427 addr,
428 dist,
429 k: Box::new(move |x| k1(x).bind(k)),
430 },
431 Model::SampleBool { addr, dist, k: k1 } => Model::SampleBool {
432 addr,
433 dist,
434 k: Box::new(move |x| k1(x).bind(k)),
435 },
436 Model::SampleU64 { addr, dist, k: k1 } => Model::SampleU64 {
437 addr,
438 dist,
439 k: Box::new(move |x| k1(x).bind(k)),
440 },
441 Model::SampleUsize { addr, dist, k: k1 } => Model::SampleUsize {
442 addr,
443 dist,
444 k: Box::new(move |x| k1(x).bind(k)),
445 },
446 Model::ObserveF64 {
447 addr,
448 dist,
449 value,
450 k: k1,
451 } => Model::ObserveF64 {
452 addr,
453 dist,
454 value,
455 k: Box::new(move |()| k1(()).bind(k)),
456 },
457 Model::ObserveBool {
458 addr,
459 dist,
460 value,
461 k: k1,
462 } => Model::ObserveBool {
463 addr,
464 dist,
465 value,
466 k: Box::new(move |()| k1(()).bind(k)),
467 },
468 Model::ObserveU64 {
469 addr,
470 dist,
471 value,
472 k: k1,
473 } => Model::ObserveU64 {
474 addr,
475 dist,
476 value,
477 k: Box::new(move |()| k1(()).bind(k)),
478 },
479 Model::ObserveUsize {
480 addr,
481 dist,
482 value,
483 k: k1,
484 } => Model::ObserveUsize {
485 addr,
486 dist,
487 value,
488 k: Box::new(move |()| k1(()).bind(k)),
489 },
490 Model::Factor { logw, k: k1 } => Model::Factor {
491 logw,
492 k: Box::new(move |()| k1(()).bind(k)),
493 },
494 }
495 }
496}
497
498/// Combine two independent models, `ma` and `mb`, into a model of their paired results.
499/// This operation runs both models and combines their results into a tuple.
500/// The models are executed independently (neither depends on the other's result).
501///
502/// Example:
503/// ```rust
504/// # use fugue::*;
505/// // Sample two independent random variables
506/// let x_model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
507/// let y_model = sample(addr!("y"), Uniform::new(0.0, 1.0).unwrap());
508/// let paired = zip(x_model, y_model); // Model<(f64, f64)>
509/// // Can be used with any model types
510/// let mixed = zip(pure(42.0), sample(addr!("z"), Exponential::new(1.0).unwrap()));
511/// ```
512pub fn zip<A: Send + 'static, B: Send + 'static>(ma: Model<A>, mb: Model<B>) -> Model<(A, B)> {
513 ma.bind(|a| mb.map(move |b| (a, b)))
514}
515
516/// Execute a vector of models, `models`, and collect their results into a single model of a vector.
517/// This function takes a collection of independent models and runs them all, collecting their results into a vector.
518/// This is useful for running multiple similar probabilistic computations.
519///
520/// Example:
521/// ```rust
522/// use fugue::*;
523/// // Create multiple independent samples
524/// let models = vec![
525/// sample(addr!("x", 0), Normal::new(0.0, 1.0).unwrap()),
526/// sample(addr!("x", 1), Normal::new(1.0, 1.0).unwrap()),
527/// sample(addr!("x", 2), Normal::new(2.0, 1.0).unwrap()),
528/// ];
529/// let all_samples = sequence_vec(models); // Model<Vec<f64>>
530/// // Mix deterministic and probabilistic models
531/// let mixed_models = vec![
532/// pure(1.0),
533/// sample(addr!("random"), Uniform::new(0.0, 1.0).unwrap()),
534/// pure(3.0),
535/// ];
536/// let results = sequence_vec(mixed_models);
537/// ```
538pub fn sequence_vec<A: Send + 'static>(models: Vec<Model<A>>) -> Model<Vec<A>> {
539 models.into_iter().fold(pure(Vec::new()), |acc, m| {
540 zip(acc, m).map(|(mut v, a)| {
541 v.push(a);
542 v
543 })
544 })
545}
546
547/// Apply a function, `f`, that produces models to each item in a vector, `items`, collecting the results.
548/// This is a higher-order function that maps each item in the input vector through a function that produces a model,
549/// then sequences all the resulting models into a single model of a vector.
550/// This is equivalent to `sequence_vec(items.map(f))` but more convenient.
551///
552/// Example:
553/// ```rust
554/// use fugue::*;
555/// // Add noise to each data point
556/// let data = vec![1.0, 2.0, 3.0];
557/// let noisy_data = traverse_vec(data, |x| {
558/// sample(addr!("noise", x as usize), Normal::new(0.0, 0.1).unwrap())
559/// .map(move |noise| x + noise)
560/// });
561/// // Create observations for each data point
562/// let observations = vec![1.2, 2.1, 2.9];
563/// let model = traverse_vec(observations, |obs| {
564/// observe(addr!("y", obs as usize), Normal::new(2.0, 0.5).unwrap(), obs)
565/// });
566/// ```
567pub fn traverse_vec<T, A: Send + 'static>(
568 items: Vec<T>,
569 f: impl Fn(T) -> Model<A> + Send + Sync + 'static,
570) -> Model<Vec<A>> {
571 sequence_vec(items.into_iter().map(f).collect())
572}
573
574/// Conditional execution: fail with zero probability when predicate is false.
575///
576/// Guards provide a way to enforce hard constraints in probabilistic models.
577/// When the predicate `pred` is true, the model continues normally.
578/// When false, the model receives negative infinite log-weight, effectively ruling out that execution path,
579/// returning a `Model<()>` that fails with zero probability.
580///
581/// Example:
582/// ```rust
583/// # use fugue::*;
584/// // Ensure a sampled value is positive
585/// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
586/// .bind(|x| {
587/// guard(x > 0.0).bind(move |_| pure(x))
588/// });
589/// // Multiple constraints
590/// let model = sample(addr!("x"), Uniform::new(-2.0, 2.0).unwrap())
591/// .bind(|x| {
592/// guard(x > -1.0).bind(move |_|
593/// guard(x < 1.0).bind(move |_| pure(x * x))
594/// )
595/// });
596/// ```
597pub fn guard(pred: bool) -> Model<()> {
598 if pred {
599 pure(())
600 } else {
601 factor(f64::NEG_INFINITY)
602 }
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608 use crate::addr;
609 use crate::core::distribution::*;
610 use crate::runtime::handler::run;
611 use crate::runtime::interpreters::PriorHandler;
612 use crate::runtime::trace::Trace;
613 use rand::rngs::StdRng;
614 use rand::SeedableRng;
615
616 #[test]
617 fn pure_and_map_work() {
618 let m = pure(2).map(|x| x + 3);
619 let (val, t) = run(
620 PriorHandler {
621 rng: &mut StdRng::seed_from_u64(1),
622 trace: Trace::default(),
623 },
624 m,
625 );
626 assert_eq!(val, 5);
627 assert_eq!(t.choices.len(), 0);
628 }
629
630 #[test]
631 fn sample_and_observe_sites() {
632 let m = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
633 .and_then(|x| observe(addr!("y"), Normal::new(x, 1.0).unwrap(), 0.5).map(move |_| x));
634
635 let mut rng = StdRng::seed_from_u64(42);
636 let (_val, trace) = run(
637 PriorHandler {
638 rng: &mut rng,
639 trace: Trace::default(),
640 },
641 m,
642 );
643 assert!(trace.choices.contains_key(&addr!("x")));
644 // Observation contributes to likelihood but not to choices
645 assert!((trace.log_likelihood.is_finite()));
646 }
647
648 #[test]
649 fn factor_and_guard_affect_weight() {
650 // factor adds a finite weight
651 let m_ok = factor(-1.23);
652 let ((), t_ok) = run(
653 PriorHandler {
654 rng: &mut StdRng::seed_from_u64(2),
655 trace: Trace::default(),
656 },
657 m_ok,
658 );
659 assert!((t_ok.total_log_weight() + 1.23).abs() < 1e-12);
660
661 // guard(false) adds -inf weight via factor
662 let m_bad = guard(false);
663 let ((), t_bad) = run(
664 PriorHandler {
665 rng: &mut StdRng::seed_from_u64(3),
666 trace: Trace::default(),
667 },
668 m_bad,
669 );
670 assert!(
671 t_bad.total_log_weight().is_infinite() && t_bad.total_log_weight().is_sign_negative()
672 );
673 }
674
675 #[test]
676 fn sequence_and_traverse_vec() {
677 let models: Vec<Model<i32>> = (0..5).map(pure).collect();
678 let seq = sequence_vec(models);
679 let (vals, t) = run(
680 PriorHandler {
681 rng: &mut StdRng::seed_from_u64(4),
682 trace: Trace::default(),
683 },
684 seq,
685 );
686 assert_eq!(vals, vec![0, 1, 2, 3, 4]);
687 assert_eq!(t.choices.len(), 0);
688
689 let trav = traverse_vec(vec![1, 2, 3], |i| pure(i * 2));
690 let (v2, _t2) = run(
691 PriorHandler {
692 rng: &mut StdRng::seed_from_u64(5),
693 trace: Trace::default(),
694 },
695 trav,
696 );
697 assert_eq!(v2, vec![2, 4, 6]);
698 }
699
700 #[test]
701 fn zip_and_sequence_empty_and_bind_chaining() {
702 // zip
703 let m1 = pure(1);
704 let m2 = pure(2);
705 let (pair, _t) = run(
706 PriorHandler {
707 rng: &mut StdRng::seed_from_u64(6),
708 trace: Trace::default(),
709 },
710 zip(m1, m2),
711 );
712 assert_eq!(pair, (1, 2));
713
714 // sequence empty
715 let empty: Vec<Model<i32>> = vec![];
716 let (vals, _t2) = run(
717 PriorHandler {
718 rng: &mut StdRng::seed_from_u64(7),
719 trace: Trace::default(),
720 },
721 sequence_vec(empty),
722 );
723 assert!(vals.is_empty());
724
725 // bind chaining across types
726 let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
727 .bind(|x| pure(x > 0.0))
728 .bind(|b| if b { pure(1u64) } else { pure(0u64) });
729 let (_val, _t3) = run(
730 PriorHandler {
731 rng: &mut StdRng::seed_from_u64(8),
732 trace: Trace::default(),
733 },
734 model,
735 );
736 }
737}