1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/macros/README.md"))]
2
3#[macro_export]
16macro_rules! prob {
17 ($e:expr) => { $e };
19
20 (let $var:ident <- $expr:expr; $($rest:tt)*) => {
22 $expr.bind(move |$var| prob!($($rest)*))
23 };
24
25 (let $var:ident = $expr:expr; $($rest:tt)*) => {
27 { let $var = $expr; prob!($($rest)*) }
28 };
29
30 ($expr:expr; $($rest:tt)*) => {
32 $expr.bind(move |_| prob!($($rest)*))
33 };
34}
35
36#[macro_export]
47macro_rules! plate {
48 ($var:ident in $range:expr => $body:expr) => {
49 $crate::core::model::traverse_vec($range.collect::<Vec<_>>(), move |$var| $body)
50 };
51}
52
53#[macro_export]
63macro_rules! scoped_addr {
64 ($scope:expr, $name:expr) => {
65 $crate::core::address::Address(format!("{}::{}", $scope, $name))
66 };
67 ($scope:expr, $name:expr, $($indices:expr),+) => {
68 $crate::core::address::Address(format!("{}::{}#{}", $scope, $name, format!("{}", format_args!($($indices),+))))
69 };
70}
71
72#[cfg(test)]
73mod tests {
74
75 use crate::addr;
76 use crate::core::distribution::*;
77 use crate::core::model::{observe, pure, sample};
78 use crate::runtime::handler::run;
79 use crate::runtime::interpreters::PriorHandler;
80 use crate::runtime::trace::Trace;
81 use rand::rngs::StdRng;
82 use rand::SeedableRng;
83
84 #[test]
85 fn prob_macro_chains_computations() {
86 let model = prob!({
88 sample(addr!("x"), Normal::new(0.0, 1.0).unwrap());
89 let _x = pure(());
90 observe(addr!("y"), Normal::new(0.0, 1.0).unwrap(), 0.1);
91 pure(1)
92 });
93 let mut rng = StdRng::seed_from_u64(30);
94 let (val, trace) = run(
95 PriorHandler {
96 rng: &mut rng,
97 trace: Trace::default(),
98 },
99 model,
100 );
101 assert_eq!(val, 1);
102 assert!(trace.log_prior.is_finite());
103 assert!(trace.log_likelihood.is_finite());
104 }
105
106 #[test]
107 fn plate_macro_traverses_range() {
108 let xs = 0..5;
109 let model = plate!(i in xs => pure(i));
110 let (vals, _t) = run(
111 PriorHandler {
112 rng: &mut StdRng::seed_from_u64(31),
113 trace: Trace::default(),
114 },
115 model,
116 );
117 assert_eq!(vals, vec![0, 1, 2, 3, 4]);
118 }
119
120 #[test]
121 fn scoped_addr_formats_with_scope_and_indices() {
122 let a = scoped_addr!("scope", "name");
123 assert_eq!(a.0, "scope::name");
124 let b = scoped_addr!("scope", "name", "{}", 3);
125 assert_eq!(b.0, "scope::name#3");
126 }
127}