gen_rs/modeling/
triefn.rs

1use std::rc::Rc;
2use std::any::Any;
3use crate::modeling::dists::Distribution;
4use crate::{GLOBAL_RNG, Trie, GenFn, GfDiff, Trace};
5
6
7/// Incremental computational state of a `trace` during the execution of the different `GenFn` methods with a `TrieFn`.
8pub enum TrieFnState<A,T> {
9    /// State for executing `GenFn::simulate` in a `TrieFn`.
10    Simulate {
11        ///
12        trace: Trace<A,Trie<(Rc<dyn Any>,f64)>,T>,
13    },
14
15    /// State for executing `GenFn::generate` in a `TrieFn`.
16    Generate {
17        ///
18        trace: Trace<A,Trie<(Rc<dyn Any>,f64)>,T>,
19        ///
20        weight: f64,
21        ///
22        constraints: Trie<Rc<dyn Any>>,
23    },
24
25    /// State for executing `GenFn::update` in a `TrieFn`.
26    Update {
27        ///
28        trace: Trace<A,Trie<(Rc<dyn Any>,f64)>,T>,
29        ///
30        constraints: Trie<Rc<dyn Any>>,
31        ///
32        weight: f64,
33        ///
34        discard: Trie<Rc<dyn Any>>,
35        ///
36        visitor: AddrTrie
37    }
38}
39
40/// Trie for hierarchical address schemas (without values).
41pub type AddrTrie = Trie<()>;
42
43impl AddrTrie {
44
45    /// Return the unique `AddrTrie` that contains an `addr` if and only if `data` contains that `addr`.
46    pub fn schema<V>(data: &Trie<V>) -> Self {
47        let mut visitor = Trie::new();
48        for (addr, _) in data.leaf_iter() {
49            visitor.insert_leaf_node(addr, ());
50        }
51        for (addr, inode) in data.internal_iter() {
52            visitor.insert_internal_node(addr, Self::schema(inode));
53        }
54        visitor
55    }
56
57    /// Add an address to the `AddrTrie`.
58    pub fn visit(&mut self, addr: &str) {
59        self.insert_leaf_node(addr, ());
60    }
61
62    /// Return `true` if every `addr` in `data` is also present in `self`.
63    pub fn all_visited<T>(&self, data: &Trie<T>) -> bool {
64        let mut allvisited = true;
65        for (addr, _) in data.leaf_iter() {
66            allvisited = allvisited && self.has_leaf_node(&addr);
67        }
68        for (addr, inode) in data.internal_iter() {
69            if !self.has_leaf_node(&addr) {
70                let subvisited = self.get_internal_node(&addr).unwrap();
71                allvisited = allvisited && subvisited.all_visited(inode)
72            }
73        }
74        allvisited
75    }
76
77    /// Return the `AddrTrie` that contains all addresses present in `data`, but not present in `self`.
78    pub fn get_unvisited<V>(&self, data: &Trie<V>) -> Self {
79        let mut unvisited = Trie::new();
80        for (addr, _) in data.leaf_iter() {
81            if !self.has_leaf_node(&addr) {
82                unvisited.insert_leaf_node(&addr, ());
83            }
84        }
85        for (addr, inode) in data.internal_iter() {
86            if !self.has_leaf_node(&addr) {
87                let subvisited = self.get_internal_node(&addr).unwrap();
88                let sub_unvisited = subvisited.get_unvisited(inode);
89                unvisited.insert_internal_node(&addr, sub_unvisited);
90            }
91        }
92        unvisited
93    }
94
95}
96
97impl<A: 'static,T: 'static> TrieFnState<A,T> {
98    /// Sample a random value from a distribution and insert it into the `self.trace.data` trie as a weighted leaf node.
99    /// 
100    /// Return a clone of the sampled value.
101    pub fn sample_at<
102        V: Clone + 'static,
103        W: Clone + 'static
104    >(&mut self, dist: &impl Distribution<V,W>, args: W, addr: &str) -> V {
105        match self {
106            TrieFnState::Simulate {
107                trace,
108            } => {
109                let x = GLOBAL_RNG.with_borrow_mut(|rng| {
110                    dist.random(rng, args.clone())
111                });
112                let logp = dist.logpdf(&x, args);
113                let data = &mut trace.data;
114                data.insert_leaf_node(addr, (Rc::new(x.clone()), logp));
115                trace.logp += logp;
116                x
117            }
118
119            TrieFnState::Generate {
120                trace,
121                weight,
122                constraints,
123            } => {
124                // check if there are constraints
125                let (x, logp) = match constraints.remove_leaf_node(addr) {
126                    // if None, sample a value and calculate change to trace.logp
127                    None => {
128                        let x = GLOBAL_RNG.with_borrow_mut(|rng| {
129                            dist.random(rng, args.clone())
130                        });
131                        let logp = dist.logpdf(&x, args);
132                        (Rc::new(x), logp)
133                    }
134                    // if Some, cast to type V, calculate change to trace.logp (and add to weight)
135                    Some(call) => {
136                        let x = call.downcast::<V>().ok().unwrap();
137                        let logp = dist.logpdf(x.as_ref(), args);
138                        *weight += logp;
139                        (x, logp)
140                    }
141                };
142                
143                // mutate trace with sampled leaf, increment total trace.logp, and insert in logp_trie.
144                let data = &mut trace.data;
145                data.insert_leaf_node(addr, (x.clone(), logp));
146                trace.logp += logp;
147
148                x.as_ref().clone()
149            }
150
151            TrieFnState::Update {
152                trace,
153                constraints,
154                weight,
155                discard,
156                visitor
157            } => {
158                visitor.visit(addr);
159
160                let data = &mut trace.data;
161                let prev_x: Rc<V>;
162                let x: Rc<V>;
163
164                let has_previous = data.has_leaf_node(addr);
165                let constrained = constraints.has_leaf_node(addr);
166                let logp;
167                let mut prev_logp = 0.;
168                if has_previous {
169                    let val = data.remove_leaf_node(addr).unwrap();
170                    prev_x = val.0.downcast::<V>().ok().unwrap();
171                    prev_logp = val.1;
172                    if constrained {
173                        discard.insert_leaf_node(addr, prev_x);
174                        x = constraints.remove_leaf_node(addr).unwrap().downcast::<V>().ok().unwrap();
175                    } else {
176                        x = prev_x;
177                    }
178                    logp = dist.logpdf(x.as_ref(), args);
179                    *weight += logp;
180                    *weight -= prev_logp;
181                } else {
182                    if constrained {
183                        x = constraints.remove_leaf_node(addr).unwrap().downcast::<V>().ok().unwrap();
184                        logp = dist.logpdf(x.as_ref(), args);
185                        *weight += logp;
186                    } else {
187                        x = Rc::new(GLOBAL_RNG.with_borrow_mut(|rng| {
188                            dist.random(rng, args.clone())
189                        }));
190                        logp = dist.logpdf(x.as_ref(), args);
191                    }
192                }
193
194                data.insert_leaf_node(addr, (x.clone(), logp));
195                trace.logp += logp;
196                trace.logp -= prev_logp;
197
198                x.as_ref().clone()
199            }
200        }
201    }
202
203    /// Recursively sample a trace from another `gen_fn`.
204    /// 
205    /// Insert its `subtrace.data` trie as a weighted internal node of the current `trace.data` trie.
206    /// Insert its `retv` as a (zero-weighted) internal node of the current `trace.data` trie.
207    /// 
208    /// Return a clone of the `retv`.
209    pub fn trace_at<
210        X: Clone + 'static,
211        Y: Clone + 'static
212    >(&mut self, gen_fn: &impl GenFn<X,Trie<(Rc<dyn Any>,f64)>,Y>, args: X, addr: &str) -> Y {
213        match self {
214            TrieFnState::Simulate {
215                trace,
216            } => {
217                let subtrace = gen_fn.simulate(args);
218
219                let data = &mut trace.data;
220                data.insert_internal_node(addr, subtrace.data);
221
222                let retv = subtrace.retv.unwrap();
223                data.insert_leaf_node(addr, (Rc::new(retv.clone()), 0.));
224                trace.logp += subtrace.logp;
225
226                retv
227            }
228
229            TrieFnState::Generate {
230                trace,
231                weight,
232                constraints,
233            } => {
234                let subtrace = match constraints.remove_internal_node(addr) {
235                    None => {
236                        gen_fn.simulate(args)
237                    }
238                    Some(subconstraints) => {
239                        let (subtrace, new_weight) = gen_fn.generate(args, Trie::from_unweighted(subconstraints));
240                        *weight += new_weight;
241                        subtrace
242                    }
243                };
244
245                let data = &mut trace.data;
246                data.insert_internal_node(addr, subtrace.data);
247
248                let retv = subtrace.retv.unwrap().clone();
249                data.insert_leaf_node(addr, (Rc::new(retv.clone()), 0.));
250                trace.logp += subtrace.logp;
251
252                retv
253            },
254
255            TrieFnState::Update {
256                trace,
257                constraints,
258                weight,
259                discard,
260                visitor
261            } => {
262                visitor.visit(addr);
263
264                let data = &mut trace.data;
265                let prev_subtrie: Trie<(Rc<dyn Any>,f64)>;
266                let subtrie: Trie<(Rc<dyn Any>,f64)>;
267                let retv: Rc<Y>;
268
269                let has_previous = data.has_internal_node(addr);
270                let constrained = constraints.has_internal_node(addr);
271                let mut logp = 0.;
272                if has_previous {
273                    prev_subtrie = data.remove_internal_node(addr).unwrap();
274                    if constrained {
275                        let subconstraints = Trie::from_unweighted(constraints.remove_internal_node(addr).unwrap());
276                        // in case constraints came from a proposal
277                        constraints.remove_leaf_node(addr);
278                        let prev_logp = prev_subtrie.sum();
279                        // note: the args in the subtrace are technically incorrect (they should be from
280                        // the previous call) and update only works because we completely disregard them.
281                        let subtrace = Trace { args: args.clone(), data: prev_subtrie, retv: None, logp: prev_logp };
282                        let (subtrace, subdiscard, new_weight) = gen_fn.update(subtrace, args, GfDiff::Unknown, subconstraints);
283                        discard.insert_internal_node(addr, subdiscard.into_unweighted());
284                        subtrie = subtrace.data;
285                        retv = Rc::new(subtrace.retv.unwrap());
286                        logp = new_weight;
287                        *weight += new_weight;
288                    } else {
289                        dbg!(prev_subtrie.sum());
290                        subtrie = prev_subtrie;
291                        retv = data.remove_leaf_node(addr).unwrap().0.downcast::<Y>().ok().unwrap();
292                    }
293                    *weight += logp;
294                } else {
295                    if constrained {
296                        let subconstraints = Trie::from_unweighted(constraints.remove_internal_node(addr).unwrap());
297                        let (subtrace, new_weight) = gen_fn.generate(args, subconstraints);
298                        subtrie = subtrace.data;
299                        retv = Rc::new(subtrace.retv.unwrap());
300                        logp = new_weight;
301                        *weight += logp;
302                    } else {
303                        let subtrace = gen_fn.simulate(args);
304                        subtrie = subtrace.data;
305                        retv = Rc::new(subtrace.retv.unwrap());
306                        logp = subtrace.logp;
307                    }
308                }
309
310                data.insert_internal_node(addr, subtrie);
311                data.insert_leaf_node(addr, (retv.clone(), 0.));
312                trace.logp += logp;
313
314                retv.as_ref().clone()
315            }
316        }
317    }
318
319    fn _gc(
320        mut trie: Trie<(Rc<dyn Any>,f64)>,
321        unvisited: &AddrTrie,
322    ) -> (Trie<(Rc<dyn Any>,f64)>,Trie<Rc<dyn Any>>,f64) {
323        let mut garbage = Trie::new();
324        let mut garbage_weight = 0.;
325        // todo: profile this and make more efficient (eg. with Merkle trees)
326        if &AddrTrie::schema(&trie) == unvisited {
327            garbage_weight = trie.sum();
328            garbage = trie.into_unweighted();
329            trie = Trie::new();
330        } else if !unvisited.is_empty() {
331            for (addr, _) in unvisited.leaf_iter() {
332                let Some((value, logp)) = trie.remove_leaf_node(addr) else { unreachable!() };
333                garbage.insert_leaf_node(addr, value);
334                garbage_weight += logp;
335            }
336            for (addr, sub_unvisited) in unvisited.internal_iter() {
337                let Some(subtrie) = trie.remove_internal_node(addr) else { unreachable!() };
338                let (subtrie, subgarbage, logp) = Self::_gc(subtrie, sub_unvisited);
339                if !subtrie.is_empty() {
340                    trie.insert_internal_node(addr, subtrie);
341                }
342                if !subgarbage.is_empty() {
343                    garbage.insert_internal_node(addr, subgarbage);
344                }
345                garbage_weight += logp;
346            }
347        }
348        (trie, garbage, garbage_weight)
349    }
350
351    /// For all `addr` present in `self.trace.data`, but not present in `self.visitor`, remove `addr` from `self.trace.data` and merge into `self.discard`.
352    /// 
353    /// Panics if `self` is not the `Self::Update` variant.
354    pub fn gc(self) -> Self {
355        if let Self::Update { trace, constraints, weight, discard, visitor } = self {
356            let unvisited = visitor.get_unvisited(&trace.data);
357            let (data, garbage, garbage_weight) = Self::_gc(trace.data, &unvisited);
358            assert!(visitor.all_visited(&data));  // all unvisited nodes garbage-collected
359            Self::Update {
360                trace: Trace { args: trace.args, data, retv: trace.retv, logp: trace.logp - garbage_weight },
361                constraints,
362                weight: weight - garbage_weight,
363                discard: discard.merge(garbage),
364                visitor
365            }
366        } else { panic!("garbage-collect (gc) called outside of update context") }
367    }
368}
369
370
371/// Wrapper struct for functions that use the `TrieFnState` DSL (`sample_at` and `trace_at`) and automatically implement the GFI.
372pub struct TrieFn<A,T> {
373    /// A random function that takes in a mutable reference to a `TrieFnState<A,T>` and some args `A`, effectfully mutates the state, and produces a value `T`.
374    pub func: fn(&mut TrieFnState<A,T>, A) -> T,
375}
376
377impl<Args,Ret> TrieFn<Args,Ret>{
378    /// Dynamically construct a `TrieFn` from a function at run-time.
379    pub fn new(func: fn(&mut TrieFnState<Args,Ret>, Args) -> Ret) -> Self {
380        TrieFn { func }
381    }
382}
383
384
385impl<Args: Clone + 'static,Ret: 'static> GenFn<Args,Trie<(Rc<dyn Any>,f64)>,Ret> for TrieFn<Args,Ret> {
386    fn simulate(&self, args: Args) -> Trace<Args,Trie<(Rc<dyn Any>,f64)>,Ret> {
387        let mut g = TrieFnState::Simulate {
388            trace: Trace { args: args.clone(), data: Trie::new(), retv: None, logp: 0. },
389        };
390        let retv = (self.func)(&mut g, args);
391        let TrieFnState::Simulate {mut trace} = g else { unreachable!() };
392        trace.set_retv(retv);
393        trace
394    }
395
396    fn generate(&self, args: Args, constraints: Trie<(Rc<dyn Any>,f64)>) -> (Trace<Args,Trie<(Rc<dyn Any>,f64)>,Ret>, f64) {
397        let mut g = TrieFnState::Generate {
398            trace: Trace { args: args.clone(), data: Trie::new(), retv: None, logp: 0. },
399            weight: 0.,
400            constraints: constraints.into_unweighted(),
401        };
402        let retv = (self.func)(&mut g, args);
403        let TrieFnState::Generate {mut trace, weight, constraints} = g else { unreachable!() };
404        assert!(constraints.is_empty());  // all constraints bound to trace
405        trace.set_retv(retv);
406        (trace, weight)
407    }
408
409    fn update(&self,
410        trace: Trace<Args,Trie<(Rc<dyn Any>,f64)>,Ret>,
411        args: Args,
412        _: GfDiff,
413        constraints: Trie<(Rc<dyn Any>,f64)>
414    ) -> (Trace<Args,Trie<(Rc<dyn Any>,f64)>,Ret>, Trie<(Rc<dyn Any>,f64)>, f64) {
415        let mut g = TrieFnState::Update {
416            trace,
417            weight: 0.,
418            constraints: constraints.into_unweighted(),
419            discard: Trie::new(),
420            visitor: AddrTrie::new()
421        };
422        let retv = (self.func)(&mut g, args);
423        let g = g.gc();  // add unvisited to discard
424        let TrieFnState::Update {mut trace, weight, constraints, discard, visitor: _visitor} = g else { unreachable!() };
425        assert!(constraints.is_empty());  // all constraints bound to trace
426        trace.set_retv(retv);
427        (trace, Trie::from_unweighted(discard), weight)
428    }
429}