fugue/macros/
mod.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/macros/README.md"))]
2
3/// Probabilistic programming macro, used to define probabilistic programs with do-notation.
4///
5/// Example:
6/// ```rust
7/// # use fugue::*;
8///
9/// let model = prob! {
10///     let x <- sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
11///     let y <- sample(addr!("y"), Normal::new(x, 1.0).unwrap());
12///     pure(y)
13/// };
14/// ```
15#[macro_export]
16macro_rules! prob {
17    // Simple cases first
18    ($e:expr) => { $e };
19
20    // let var <- expr; rest
21    (let $var:ident <- $expr:expr; $($rest:tt)*) => {
22        $expr.bind(move |$var| prob!($($rest)*))
23    };
24
25    // let var = expr; rest
26    (let $var:ident = $expr:expr; $($rest:tt)*) => {
27        { let $var = $expr; prob!($($rest)*) }
28    };
29
30    // expr; rest
31    ($expr:expr; $($rest:tt)*) => {
32        $expr.bind(move |_| prob!($($rest)*))
33    };
34}
35
36/// Plate notation for replicating models over ranges.
37///
38/// Example:
39/// ```rust
40/// # use fugue::*;
41///
42/// let model = plate!(i in 0..10 => {
43///     sample(addr!("x", i), Normal::new(0.0, 1.0).unwrap())
44/// });
45/// ```
46#[macro_export]
47macro_rules! plate {
48    ($var:ident in $range:expr => $body:expr) => {
49        $crate::core::model::traverse_vec($range.collect::<Vec<_>>(), move |$var| $body)
50    };
51}
52
53/// Enhanced address macro with scoping support.
54///
55/// Example:
56/// ```rust
57/// # use fugue::*;
58///
59/// let a = scoped_addr!("scope", "name");
60/// let b = scoped_addr!("scope", "name", "{}", 3);
61/// ```
62#[macro_export]
63macro_rules! scoped_addr {
64    ($scope:expr, $name:expr) => {
65        $crate::core::address::Address(format!("{}::{}", $scope, $name))
66    };
67    ($scope:expr, $name:expr, $($indices:expr),+) => {
68        $crate::core::address::Address(format!("{}::{}#{}", $scope, $name, format!("{}", format_args!($($indices),+))))
69    };
70}
71
72#[cfg(test)]
73mod tests {
74
75    use crate::addr;
76    use crate::core::distribution::*;
77    use crate::core::model::{observe, pure, sample};
78    use crate::runtime::handler::run;
79    use crate::runtime::interpreters::PriorHandler;
80    use crate::runtime::trace::Trace;
81    use rand::rngs::StdRng;
82    use rand::SeedableRng;
83
84    #[test]
85    fn prob_macro_chains_computations() {
86        // Equivalent to: let x <- sample(...); observe(...); pure(x)
87        let model = prob!({
88            sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
89            let _x = pure(());
90            observe(addr!("y"), Normal::new(0.0, 1.0).unwrap(), 0.1);
91            pure(1)
92        });
93        let mut rng = StdRng::seed_from_u64(30);
94        let (val, trace) = run(
95            PriorHandler {
96                rng: &mut rng,
97                trace: Trace::default(),
98            },
99            model,
100        );
101        assert_eq!(val, 1);
102        assert!(trace.log_prior.is_finite());
103        assert!(trace.log_likelihood.is_finite());
104    }
105
106    #[test]
107    fn plate_macro_traverses_range() {
108        let xs = 0..5;
109        let model = plate!(i in xs => pure(i));
110        let (vals, _t) = run(
111            PriorHandler {
112                rng: &mut StdRng::seed_from_u64(31),
113                trace: Trace::default(),
114            },
115            model,
116        );
117        assert_eq!(vals, vec![0, 1, 2, 3, 4]);
118    }
119
120    #[test]
121    fn scoped_addr_formats_with_scope_and_indices() {
122        let a = scoped_addr!("scope", "name");
123        assert_eq!(a.0, "scope::name");
124        let b = scoped_addr!("scope", "name", "{}", 3);
125        assert_eq!(b.0, "scope::name#3");
126    }
127}