fugue/core/
model.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/core/model.md"))]
2use crate::core::address::Address;
3use crate::core::distribution::{Distribution, LogF64};
4
5/// `Model<A>` represents a probabilistic program that yields a value of type `A` when executed by a handler.
6/// Models are built from four variants: `Pure`, `Sample*`, `Observe*`, and `Factor`.
7///
8/// Example:
9/// ```rust
10/// # use fugue::*;
11/// // Deterministic value
12/// let m = pure(42.0);
13///
14/// // Sample from distribution
15/// let s = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
16///
17/// // Dependent sampling
18/// let chain = s.bind(|x| sample(addr!("y"), Normal::new(x, 0.5).unwrap()));
19/// ```
20pub enum Model<A> {
21    /// A deterministic computation yielding a pure value.
22    Pure(A),
23    /// Sample from an f64 distribution (continuous distributions).
24    SampleF64 {
25        /// Unique identifier for this sampling site.
26        addr: Address,
27        /// Distribution to sample from.
28        dist: Box<dyn Distribution<f64>>,
29        /// Continuation function to apply to the sampled value.
30        k: Box<dyn FnOnce(f64) -> Model<A> + Send + 'static>,
31    },
32    /// Sample from a bool distribution (Bernoulli).
33    SampleBool {
34        /// Unique identifier for this sampling site.
35        addr: Address,
36        /// Distribution to sample from.
37        dist: Box<dyn Distribution<bool>>,
38        /// Continuation function to apply to the sampled value.
39        k: Box<dyn FnOnce(bool) -> Model<A> + Send + 'static>,
40    },
41    /// Sample from a u64 distribution (Poisson, Binomial).
42    SampleU64 {
43        /// Unique identifier for this sampling site.
44        addr: Address,
45        /// Distribution to sample from.
46        dist: Box<dyn Distribution<u64>>,
47        /// Continuation function to apply to the sampled value.
48        k: Box<dyn FnOnce(u64) -> Model<A> + Send + 'static>,
49    },
50    /// Sample from a usize distribution (Categorical).
51    SampleUsize {
52        /// Unique identifier for this sampling site.
53        addr: Address,
54        /// Distribution to sample from.
55        dist: Box<dyn Distribution<usize>>,
56        /// Continuation function to apply to the sampled value.
57        k: Box<dyn FnOnce(usize) -> Model<A> + Send + 'static>,
58    },
59    /// Observe/condition on an f64 value.
60    ObserveF64 {
61        /// Unique identifier for this observation site.
62        addr: Address,
63        /// Distribution that generates the observed value.
64        dist: Box<dyn Distribution<f64>>,
65        /// The observed value to condition on.
66        value: f64,
67        /// Continuation function (always receives unit).
68        k: Box<dyn FnOnce(()) -> Model<A> + Send + 'static>,
69    },
70    /// Observe/condition on a bool value.
71    ObserveBool {
72        /// Unique identifier for this observation site.
73        addr: Address,
74        /// Distribution that generates the observed value.
75        dist: Box<dyn Distribution<bool>>,
76        /// The observed value to condition on.
77        value: bool,
78        /// Continuation function (always receives unit).
79        k: Box<dyn FnOnce(()) -> Model<A> + Send + 'static>,
80    },
81    /// Observe/condition on a u64 value.
82    ObserveU64 {
83        /// Unique identifier for this observation site.
84        addr: Address,
85        /// Distribution that generates the observed value.
86        dist: Box<dyn Distribution<u64>>,
87        /// The observed value to condition on.
88        value: u64,
89        /// Continuation function (always receives unit).
90        k: Box<dyn FnOnce(()) -> Model<A> + Send + 'static>,
91    },
92    /// Observe/condition on a usize value.
93    ObserveUsize {
94        /// Unique identifier for this observation site.
95        addr: Address,
96        /// Distribution that generates the observed value.
97        dist: Box<dyn Distribution<usize>>,
98        /// The observed value to condition on.
99        value: usize,
100        /// Continuation function (always receives unit).
101        k: Box<dyn FnOnce(()) -> Model<A> + Send + 'static>,
102    },
103    /// Add a log-weight factor to the model.
104    Factor {
105        /// Log-weight to add to the model's total weight.
106        logw: LogF64,
107        /// Continuation function (always receives unit).
108        k: Box<dyn FnOnce(()) -> Model<A> + Send + 'static>,
109    },
110}
111
112/// Lift a deterministic value, `a`, into the model monad.
113/// Creates a `Model` that always returns the given value, `a`, without any probabilistic behavior.
114/// This is the unit operation for the model monad.
115///
116/// Example:
117/// ```rust
118/// # use fugue::*;
119///
120/// let model = pure(42.0);
121/// // When executed, this model will always return 42.0
122/// ```
123pub fn pure<A>(a: A) -> Model<A> {
124    Model::Pure(a)
125}
126/// Sample from an f64 distribution (continuous distributions).
127///
128/// Example:
129/// ```rust
130/// # use fugue::*;
131///
132/// let model = sample_f64(addr!("x"), Normal::new(0.0, 1.0).unwrap());
133/// ```
134pub fn sample_f64(addr: Address, dist: impl Distribution<f64> + 'static) -> Model<f64> {
135    Model::SampleF64 {
136        addr,
137        dist: Box::new(dist),
138        k: Box::new(pure),
139    }
140}
141/// Sample from a bool distribution (Bernoulli).
142///
143/// Example:
144/// ```rust
145/// # use fugue::*;
146///
147/// let model = sample_bool(addr!("coin"), Bernoulli::new(0.5).unwrap());
148/// ```
149pub fn sample_bool(addr: Address, dist: impl Distribution<bool> + 'static) -> Model<bool> {
150    Model::SampleBool {
151        addr,
152        dist: Box::new(dist),
153        k: Box::new(pure),
154    }
155}
156/// Sample from a u64 distribution (Poisson, Binomial).
157///
158/// Example:
159/// ```rust
160/// # use fugue::*;
161///
162/// let model = sample_u64(addr!("count"), Poisson::new(3.0).unwrap());
163/// ```
164pub fn sample_u64(addr: Address, dist: impl Distribution<u64> + 'static) -> Model<u64> {
165    Model::SampleU64 {
166        addr,
167        dist: Box::new(dist),
168        k: Box::new(pure),
169    }
170}
171/// Sample from a usize distribution (Categorical).
172///
173/// Example:
174/// ```rust
175/// # use fugue::*;
176///
177/// let model = sample_usize(addr!("choice"), Categorical::new(vec![0.3, 0.5, 0.2]).unwrap());
178/// ```
179pub fn sample_usize(addr: Address, dist: impl Distribution<usize> + 'static) -> Model<usize> {
180    Model::SampleUsize {
181        addr,
182        dist: Box::new(dist),
183        k: Box::new(pure),
184    }
185}
186
187/// Sample from a distribution (generic version - chooses the right variant automatically).
188// This is the main sampling function that works with any distribution type.
189// The return type is inferred from the distribution type.
190///
191/// Type-specific variants:
192/// - `sample_f64` - Sample from f64 distributions (continuous distributions)
193/// - `sample_bool` - Sample from bool distributions (Bernoulli)
194/// - `sample_u64` - Sample from u64 distributions (Poisson, Binomial)
195/// - `sample_usize` - Sample from usize distributions (Categorical)
196///
197/// Example:
198/// ```rust
199/// # use fugue::*;
200/// // Automatically returns f64 for continuous distributions
201/// let normal_sample: Model<f64> = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
202/// // Automatically returns bool for Bernoulli
203/// let coin_flip: Model<bool> = sample(addr!("coin"), Bernoulli::new(0.5).unwrap());
204/// // Automatically returns u64 for Poisson
205/// let count: Model<u64> = sample(addr!("count"), Poisson::new(3.0).unwrap());
206/// // Automatically returns usize for Categorical
207/// let choice: Model<usize> = sample(addr!("choice"),
208///     Categorical::new(vec![0.3, 0.5, 0.2]).unwrap());
209/// ```
210pub fn sample<T>(addr: Address, dist: impl Distribution<T> + 'static) -> Model<T>
211where
212    T: SampleType,
213{
214    T::make_sample_model(addr, Box::new(dist))
215}
216
217/// Trait for types that can be sampled in Models.
218/// This enables automatic dispatch to the right Model variant.
219pub trait SampleType: 'static + Send + Sync + Sized {
220    fn make_sample_model(addr: Address, dist: Box<dyn Distribution<Self>>) -> Model<Self>;
221    fn make_observe_model(
222        addr: Address,
223        dist: Box<dyn Distribution<Self>>,
224        value: Self,
225    ) -> Model<()>;
226}
227impl SampleType for f64 {
228    fn make_sample_model(addr: Address, dist: Box<dyn Distribution<f64>>) -> Model<f64> {
229        Model::SampleF64 {
230            addr,
231            dist,
232            k: Box::new(pure),
233        }
234    }
235    fn make_observe_model(
236        addr: Address,
237        dist: Box<dyn Distribution<f64>>,
238        value: f64,
239    ) -> Model<()> {
240        Model::ObserveF64 {
241            addr,
242            dist,
243            value,
244            k: Box::new(pure),
245        }
246    }
247}
248impl SampleType for bool {
249    fn make_sample_model(addr: Address, dist: Box<dyn Distribution<bool>>) -> Model<bool> {
250        Model::SampleBool {
251            addr,
252            dist,
253            k: Box::new(pure),
254        }
255    }
256    fn make_observe_model(
257        addr: Address,
258        dist: Box<dyn Distribution<bool>>,
259        value: bool,
260    ) -> Model<()> {
261        Model::ObserveBool {
262            addr,
263            dist,
264            value,
265            k: Box::new(pure),
266        }
267    }
268}
269impl SampleType for u64 {
270    fn make_sample_model(addr: Address, dist: Box<dyn Distribution<u64>>) -> Model<u64> {
271        Model::SampleU64 {
272            addr,
273            dist,
274            k: Box::new(pure),
275        }
276    }
277    fn make_observe_model(
278        addr: Address,
279        dist: Box<dyn Distribution<u64>>,
280        value: u64,
281    ) -> Model<()> {
282        Model::ObserveU64 {
283            addr,
284            dist,
285            value,
286            k: Box::new(pure),
287        }
288    }
289}
290impl SampleType for usize {
291    fn make_sample_model(addr: Address, dist: Box<dyn Distribution<usize>>) -> Model<usize> {
292        Model::SampleUsize {
293            addr,
294            dist,
295            k: Box::new(pure),
296        }
297    }
298    fn make_observe_model(
299        addr: Address,
300        dist: Box<dyn Distribution<usize>>,
301        value: usize,
302    ) -> Model<()> {
303        Model::ObserveUsize {
304            addr,
305            dist,
306            value,
307            k: Box::new(pure),
308        }
309    }
310}
311
312/// Observe a value from a distribution (generic version).
313/// This function automatically chooses the right observation variant based on the distribution type and observed value type.
314///
315/// Example:
316/// ```rust
317/// use fugue::*;
318/// // Observe f64 value from continuous distribution
319/// let model = observe(addr!("y"), Normal::new(1.0, 0.5).unwrap(), 2.5);
320/// // Observe bool value from Bernoulli
321/// let model = observe(addr!("coin"), Bernoulli::new(0.6).unwrap(), true);
322/// // Observe u64 count from Poisson
323/// let model = observe(addr!("count"), Poisson::new(3.0).unwrap(), 5u64);
324/// // Observe usize choice from Categorical
325/// let model = observe(addr!("choice"),
326///     Categorical::new(vec![0.3, 0.5, 0.2]).unwrap(), 1usize);
327/// ```
328pub fn observe<T>(addr: Address, dist: impl Distribution<T> + 'static, value: T) -> Model<()>
329where
330    T: SampleType,
331{
332    T::make_observe_model(addr, Box::new(dist), value)
333}
334
335/// Add an unnormalized log-weight `logw` to the model, returning a `Model<()>`.
336///
337/// Factors allow encoding soft constraints or arbitrary log-probability contributions to the model.
338/// They are particularly useful for:
339///
340/// - Encoding constraints that should be "mostly satisfied"
341/// - Adding custom log-likelihood terms
342/// - Implementing rejection sampling (using negative infinity)
343///
344/// Example:
345/// ```rust
346/// # use fugue::*;
347/// // Add positive log-weight (increases probability)
348/// let model = factor(1.0); // Adds log(e) = 1.0 to weight
349/// // Add negative log-weight (decreases probability)
350/// let model = factor(-2.0); // Subtracts 2.0 from log-weight
351/// // Reject/fail (zero probability)
352/// let model = factor(f64::NEG_INFINITY);
353/// // Soft constraint: prefer values near zero
354/// let x = 5.0;
355/// let soft_constraint = factor(-0.5 * x * x); // Gaussian-like penalty
356/// ```
357pub fn factor(logw: LogF64) -> Model<()> {
358    Model::Factor {
359        logw,
360        k: Box::new(pure),
361    }
362}
363
364/// `ModelExt<A>` provides monadic operations for composing `Model<A>` values.
365/// Provides `bind`, `map`, and `and_then` for chaining and transforming probabilistic computations.
366///
367/// Example:
368/// ```rust
369/// # use fugue::*;
370/// // Transform result with map
371/// let transformed = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
372/// .map(|x| x * 2.0);
373///
374/// // Chain dependent computations with bind
375/// let dependent = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
376/// .bind(|x| sample(addr!("y"), Normal::new(x, 0.5).unwrap()));
377/// ```
378pub trait ModelExt<A>: Sized {
379    /// Monadic bind operation (>>=).
380    ///
381    /// Chains two probabilistic computations where the second depends on the result of the first.
382    /// This is the fundamental operation for building complex probabilistic models from simpler parts.
383    /// The function `k` takes the result of this model and returns a new model.
384    ///
385    /// Example:
386    /// ```rust
387    /// # use fugue::*;
388    /// // Dependent sampling: y depends on x
389    /// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
390    ///     .bind(|x| sample(addr!("y"), Normal::new(x, 0.1).unwrap()));
391    /// ```
392    fn bind<B>(self, k: impl FnOnce(A) -> Model<B> + Send + 'static) -> Model<B>;
393
394    /// Apply a function, `f`, to transform the result of this model.
395    /// This is the functor map operation - it transforms the output of a model without adding any additional probabilistic behavior.
396    ///
397    /// Example:
398    /// ```rust
399    /// # use fugue::*;
400    /// // Transform the sampled value
401    /// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
402    ///     .map(|x| x.exp()); // Apply exponential function
403    /// ```
404    fn map<B>(self, f: impl FnOnce(A) -> B + Send + 'static) -> Model<B> {
405        self.bind(|a| pure(f(a)))
406    }
407
408    /// Alias for `bind` - chains dependent probabilistic computations.
409    /// This method provides a more familiar interface for Rust developers used to `Option::and_then` and `Result::and_then`.
410    ///
411    /// Example:
412    /// ```rust
413    /// # use fugue::*;
414    /// // Dependent sampling: y depends on x
415    /// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
416    ///     .and_then(|x| sample(addr!("y"), Normal::new(x, 0.1).unwrap()));
417    /// ```
418    fn and_then<B>(self, k: impl FnOnce(A) -> Model<B> + Send + 'static) -> Model<B> {
419        self.bind(k)
420    }
421}
422impl<A: 'static> ModelExt<A> for Model<A> {
423    fn bind<B>(self, k: impl FnOnce(A) -> Model<B> + Send + 'static) -> Model<B> {
424        match self {
425            Model::Pure(a) => k(a),
426            Model::SampleF64 { addr, dist, k: k1 } => Model::SampleF64 {
427                addr,
428                dist,
429                k: Box::new(move |x| k1(x).bind(k)),
430            },
431            Model::SampleBool { addr, dist, k: k1 } => Model::SampleBool {
432                addr,
433                dist,
434                k: Box::new(move |x| k1(x).bind(k)),
435            },
436            Model::SampleU64 { addr, dist, k: k1 } => Model::SampleU64 {
437                addr,
438                dist,
439                k: Box::new(move |x| k1(x).bind(k)),
440            },
441            Model::SampleUsize { addr, dist, k: k1 } => Model::SampleUsize {
442                addr,
443                dist,
444                k: Box::new(move |x| k1(x).bind(k)),
445            },
446            Model::ObserveF64 {
447                addr,
448                dist,
449                value,
450                k: k1,
451            } => Model::ObserveF64 {
452                addr,
453                dist,
454                value,
455                k: Box::new(move |()| k1(()).bind(k)),
456            },
457            Model::ObserveBool {
458                addr,
459                dist,
460                value,
461                k: k1,
462            } => Model::ObserveBool {
463                addr,
464                dist,
465                value,
466                k: Box::new(move |()| k1(()).bind(k)),
467            },
468            Model::ObserveU64 {
469                addr,
470                dist,
471                value,
472                k: k1,
473            } => Model::ObserveU64 {
474                addr,
475                dist,
476                value,
477                k: Box::new(move |()| k1(()).bind(k)),
478            },
479            Model::ObserveUsize {
480                addr,
481                dist,
482                value,
483                k: k1,
484            } => Model::ObserveUsize {
485                addr,
486                dist,
487                value,
488                k: Box::new(move |()| k1(()).bind(k)),
489            },
490            Model::Factor { logw, k: k1 } => Model::Factor {
491                logw,
492                k: Box::new(move |()| k1(()).bind(k)),
493            },
494        }
495    }
496}
497
498/// Combine two independent models, `ma` and `mb`, into a model of their paired results.
499/// This operation runs both models and combines their results into a tuple.
500/// The models are executed independently (neither depends on the other's result).
501///
502/// Example:
503/// ```rust
504/// # use fugue::*;
505/// // Sample two independent random variables
506/// let x_model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
507/// let y_model = sample(addr!("y"), Uniform::new(0.0, 1.0).unwrap());
508/// let paired = zip(x_model, y_model); // Model<(f64, f64)>
509/// // Can be used with any model types
510/// let mixed = zip(pure(42.0), sample(addr!("z"), Exponential::new(1.0).unwrap()));
511/// ```
512pub fn zip<A: Send + 'static, B: Send + 'static>(ma: Model<A>, mb: Model<B>) -> Model<(A, B)> {
513    ma.bind(|a| mb.map(move |b| (a, b)))
514}
515
516/// Execute a vector of models, `models`, and collect their results into a single model of a vector.
517/// This function takes a collection of independent models and runs them all, collecting their results into a vector.
518/// This is useful for running multiple similar probabilistic computations.
519///
520/// Example:
521/// ```rust
522/// use fugue::*;
523/// // Create multiple independent samples
524/// let models = vec![
525///     sample(addr!("x", 0), Normal::new(0.0, 1.0).unwrap()),
526///     sample(addr!("x", 1), Normal::new(1.0, 1.0).unwrap()),
527///     sample(addr!("x", 2), Normal::new(2.0, 1.0).unwrap()),
528/// ];
529/// let all_samples = sequence_vec(models); // Model<Vec<f64>>
530/// // Mix deterministic and probabilistic models
531/// let mixed_models = vec![
532///     pure(1.0),
533///     sample(addr!("random"), Uniform::new(0.0, 1.0).unwrap()),
534///     pure(3.0),
535/// ];
536/// let results = sequence_vec(mixed_models);
537/// ```
538pub fn sequence_vec<A: Send + 'static>(models: Vec<Model<A>>) -> Model<Vec<A>> {
539    models.into_iter().fold(pure(Vec::new()), |acc, m| {
540        zip(acc, m).map(|(mut v, a)| {
541            v.push(a);
542            v
543        })
544    })
545}
546
547/// Apply a function, `f`, that produces models to each item in a vector, `items`, collecting the results.
548/// This is a higher-order function that maps each item in the input vector through a function that produces a model,
549/// then sequences all the resulting models into a single model of a vector.
550/// This is equivalent to `sequence_vec(items.map(f))` but more convenient.
551///
552/// Example:
553/// ```rust
554/// use fugue::*;
555/// // Add noise to each data point
556/// let data = vec![1.0, 2.0, 3.0];
557/// let noisy_data = traverse_vec(data, |x| {
558///     sample(addr!("noise", x as usize), Normal::new(0.0, 0.1).unwrap())
559///         .map(move |noise| x + noise)
560/// });
561/// // Create observations for each data point
562/// let observations = vec![1.2, 2.1, 2.9];
563/// let model = traverse_vec(observations, |obs| {
564///     observe(addr!("y", obs as usize), Normal::new(2.0, 0.5).unwrap(), obs)
565/// });
566/// ```
567pub fn traverse_vec<T, A: Send + 'static>(
568    items: Vec<T>,
569    f: impl Fn(T) -> Model<A> + Send + Sync + 'static,
570) -> Model<Vec<A>> {
571    sequence_vec(items.into_iter().map(f).collect())
572}
573
574/// Conditional execution: fail with zero probability when predicate is false.
575///
576/// Guards provide a way to enforce hard constraints in probabilistic models.
577/// When the predicate `pred` is true, the model continues normally.
578/// When false, the model receives negative infinite log-weight, effectively ruling out that execution path,
579/// returning a `Model<()>` that fails with zero probability.
580///
581/// Example:
582/// ```rust
583/// # use fugue::*;
584/// // Ensure a sampled value is positive
585/// let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
586///     .bind(|x| {
587///         guard(x > 0.0).bind(move |_| pure(x))
588///     });
589/// // Multiple constraints
590/// let model = sample(addr!("x"), Uniform::new(-2.0, 2.0).unwrap())
591///     .bind(|x| {
592///         guard(x > -1.0).bind(move |_|
593///             guard(x < 1.0).bind(move |_| pure(x * x))
594///         )
595///     });
596/// ```
597pub fn guard(pred: bool) -> Model<()> {
598    if pred {
599        pure(())
600    } else {
601        factor(f64::NEG_INFINITY)
602    }
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608    use crate::addr;
609    use crate::core::distribution::*;
610    use crate::runtime::handler::run;
611    use crate::runtime::interpreters::PriorHandler;
612    use crate::runtime::trace::Trace;
613    use rand::rngs::StdRng;
614    use rand::SeedableRng;
615
616    #[test]
617    fn pure_and_map_work() {
618        let m = pure(2).map(|x| x + 3);
619        let (val, t) = run(
620            PriorHandler {
621                rng: &mut StdRng::seed_from_u64(1),
622                trace: Trace::default(),
623            },
624            m,
625        );
626        assert_eq!(val, 5);
627        assert_eq!(t.choices.len(), 0);
628    }
629
630    #[test]
631    fn sample_and_observe_sites() {
632        let m = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
633            .and_then(|x| observe(addr!("y"), Normal::new(x, 1.0).unwrap(), 0.5).map(move |_| x));
634
635        let mut rng = StdRng::seed_from_u64(42);
636        let (_val, trace) = run(
637            PriorHandler {
638                rng: &mut rng,
639                trace: Trace::default(),
640            },
641            m,
642        );
643        assert!(trace.choices.contains_key(&addr!("x")));
644        // Observation contributes to likelihood but not to choices
645        assert!((trace.log_likelihood.is_finite()));
646    }
647
648    #[test]
649    fn factor_and_guard_affect_weight() {
650        // factor adds a finite weight
651        let m_ok = factor(-1.23);
652        let ((), t_ok) = run(
653            PriorHandler {
654                rng: &mut StdRng::seed_from_u64(2),
655                trace: Trace::default(),
656            },
657            m_ok,
658        );
659        assert!((t_ok.total_log_weight() + 1.23).abs() < 1e-12);
660
661        // guard(false) adds -inf weight via factor
662        let m_bad = guard(false);
663        let ((), t_bad) = run(
664            PriorHandler {
665                rng: &mut StdRng::seed_from_u64(3),
666                trace: Trace::default(),
667            },
668            m_bad,
669        );
670        assert!(
671            t_bad.total_log_weight().is_infinite() && t_bad.total_log_weight().is_sign_negative()
672        );
673    }
674
675    #[test]
676    fn sequence_and_traverse_vec() {
677        let models: Vec<Model<i32>> = (0..5).map(pure).collect();
678        let seq = sequence_vec(models);
679        let (vals, t) = run(
680            PriorHandler {
681                rng: &mut StdRng::seed_from_u64(4),
682                trace: Trace::default(),
683            },
684            seq,
685        );
686        assert_eq!(vals, vec![0, 1, 2, 3, 4]);
687        assert_eq!(t.choices.len(), 0);
688
689        let trav = traverse_vec(vec![1, 2, 3], |i| pure(i * 2));
690        let (v2, _t2) = run(
691            PriorHandler {
692                rng: &mut StdRng::seed_from_u64(5),
693                trace: Trace::default(),
694            },
695            trav,
696        );
697        assert_eq!(v2, vec![2, 4, 6]);
698    }
699
700    #[test]
701    fn zip_and_sequence_empty_and_bind_chaining() {
702        // zip
703        let m1 = pure(1);
704        let m2 = pure(2);
705        let (pair, _t) = run(
706            PriorHandler {
707                rng: &mut StdRng::seed_from_u64(6),
708                trace: Trace::default(),
709            },
710            zip(m1, m2),
711        );
712        assert_eq!(pair, (1, 2));
713
714        // sequence empty
715        let empty: Vec<Model<i32>> = vec![];
716        let (vals, _t2) = run(
717            PriorHandler {
718                rng: &mut StdRng::seed_from_u64(7),
719                trace: Trace::default(),
720            },
721            sequence_vec(empty),
722        );
723        assert!(vals.is_empty());
724
725        // bind chaining across types
726        let model = sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
727            .bind(|x| pure(x > 0.0))
728            .bind(|b| if b { pure(1u64) } else { pure(0u64) });
729        let (_val, _t3) = run(
730            PriorHandler {
731                rng: &mut StdRng::seed_from_u64(8),
732                trace: Trace::default(),
733            },
734            model,
735        );
736    }
737}