gen_rs/modeling/
unfold.rs

1use crate::{Trie,Trace,GenFn,GfDiff,TrieFnState};
2use std::any::Any;
3use std::rc::Rc;
4
5
6/// Combinator struct for kernels that use the `TrieFnState` DSL (`sample_at` and `trace_at`) and automatically implement the GFI.
7/// Supports memory-efficient extension via the `GfDiff::Extend` flag (eg. as passed during a `ParticleSystem::step`).
8pub struct Unfold<State> {
9    /// A random kernel that takes in a mutable reference to a `TrieFnState<A,T>` and some `State`, effectfully mutates it, and produces a new `State`.
10    pub kernel: fn(&mut TrieFnState<(i64,State),State>, (i64,State)) -> State
11}
12
13impl<State> Unfold<State> {
14    /// Dynamically construct an `Unfold` from a kernel at run-time.
15    fn new(kernel: fn(&mut TrieFnState<(i64,State),State>, (i64,State)) -> State) -> Self {
16        Unfold { kernel }
17    }
18}
19
20
21impl<State: Clone> GenFn<(i64,State),Vec<Trie<(Rc<dyn Any>,f64)>>,Vec<State>> for Unfold<State> {
22    fn simulate(&self, T_and_args: (i64, State)) -> Trace<(i64,State),Vec<Trie<(Rc<dyn Any>,f64)>>,Vec<State>> {
23        let (T, mut state) = T_and_args;
24        assert!(T >= 1);
25        let mut vec_trace = Trace { args: (T, state.clone()), data: vec![], retv: Some(vec![]), logp: 0. };
26        for t in 0..T {
27            let mut g = TrieFnState::Simulate {
28                trace: Trace { args: (t as i64, state.clone()), data: Trie::new(), retv: None, logp: 0. },
29            };
30            state = (self.kernel)(&mut g, (t as i64, state.clone()));
31            let TrieFnState::Simulate {mut trace} = g else { unreachable!() };
32            vec_trace.retv.as_mut().unwrap().push(state.clone());
33            vec_trace.data.push(trace.data);
34            vec_trace.logp += trace.logp;
35        }
36        vec_trace
37    }
38
39    fn generate(&self, T_and_args: (i64, State), vec_constraints: Vec<Trie<(Rc<dyn Any>,f64)>>) 
40        -> (Trace<(i64,State),Vec<Trie<(Rc<dyn Any>,f64)>>,Vec<State>>, f64)
41    {
42        let (T, mut state) = T_and_args;
43        assert!(T >= 1);
44        let mut vec_trace = Trace { args: (T, state.clone()), data: vec![], retv: Some(vec![]), logp: 0. };
45        let mut gen_weight = 0.;
46        for (t,constraints) in vec_constraints.into_iter().enumerate() {
47            let mut g = TrieFnState::Generate {
48                trace: Trace { args: (t as i64, state.clone()), data: Trie::new(), retv: None, logp: 0. },
49                weight: 0.,
50                constraints: constraints.into_unweighted()
51            };
52            state = (self.kernel)(&mut g, (t as i64, state.clone()));
53            let TrieFnState::Generate {mut trace, weight, constraints} = g else { unreachable!() };
54            assert!(constraints.is_empty());
55            vec_trace.retv.as_mut().unwrap().push(state.clone());
56            vec_trace.data.push(trace.data);
57            vec_trace.logp += trace.logp;
58            gen_weight += weight;
59        }
60        (vec_trace, gen_weight)
61    }
62
63    fn update(&self,
64        mut vec_trace: Trace<(i64,State),Vec<Trie<(Rc<dyn Any>,f64)>>,Vec<State>>,
65        T_and_args: (i64, State),
66        diff: GfDiff,
67        vec_constraints: Vec<Trie<(Rc<dyn Any>,f64)>>
68    ) -> (Trace<(i64,State),Vec<Trie<(Rc<dyn Any>,f64)>>,Vec<State>>, Vec<Trie<(Rc<dyn Any>,f64)>>, f64) {
69        let (T, _) = T_and_args;
70        assert!(T >= 1);
71        let prev_T = vec_trace.args.0;
72        assert!(T - prev_T == vec_constraints.len() as i64);
73        let mut state = vec_trace.retv.as_ref().unwrap().last().unwrap().clone();
74        let mut update_weight = 0.;
75        match diff {
76            GfDiff::Extend => {
77                for (t,constraints) in vec_constraints.into_iter().enumerate() {
78                    let mut g = TrieFnState::Generate {
79                        trace: Trace { args: (prev_T + (t as i64), state.clone()), data: Trie::new(), retv: None, logp: 0. },
80                        weight: 0.,
81                        constraints: constraints.into_unweighted()
82                    };
83                    state = (self.kernel)(&mut g, (prev_T + (t as i64), state.clone()));
84                    let TrieFnState::Generate {mut trace, weight, constraints} = g else { unreachable!() };
85                    assert!(constraints.is_empty());
86                    vec_trace.args.0 += 1;
87                    vec_trace.retv.as_mut().unwrap().push(state.clone());
88                    vec_trace.data.push(trace.data);
89                    vec_trace.logp += trace.logp;
90                    update_weight += weight;
91                }
92            },
93            _ => { panic!("Can't handle GF change type: {:?}", diff) },
94        }
95        (vec_trace, (prev_T..T).map(|_| Trie::new()).collect::<_>(), update_weight)
96    }
97}