1use crate::{Trie,Trace,GenFn,GfDiff,TrieFnState};
2use std::any::Any;
3use std::rc::Rc;
4
5
6pub struct Unfold<State> {
9 pub kernel: fn(&mut TrieFnState<(i64,State),State>, (i64,State)) -> State
11}
12
13impl<State> Unfold<State> {
14 fn new(kernel: fn(&mut TrieFnState<(i64,State),State>, (i64,State)) -> State) -> Self {
16 Unfold { kernel }
17 }
18}
19
20
21impl<State: Clone> GenFn<(i64,State),Vec<Trie<(Rc<dyn Any>,f64)>>,Vec<State>> for Unfold<State> {
22 fn simulate(&self, T_and_args: (i64, State)) -> Trace<(i64,State),Vec<Trie<(Rc<dyn Any>,f64)>>,Vec<State>> {
23 let (T, mut state) = T_and_args;
24 assert!(T >= 1);
25 let mut vec_trace = Trace { args: (T, state.clone()), data: vec![], retv: Some(vec![]), logp: 0. };
26 for t in 0..T {
27 let mut g = TrieFnState::Simulate {
28 trace: Trace { args: (t as i64, state.clone()), data: Trie::new(), retv: None, logp: 0. },
29 };
30 state = (self.kernel)(&mut g, (t as i64, state.clone()));
31 let TrieFnState::Simulate {mut trace} = g else { unreachable!() };
32 vec_trace.retv.as_mut().unwrap().push(state.clone());
33 vec_trace.data.push(trace.data);
34 vec_trace.logp += trace.logp;
35 }
36 vec_trace
37 }
38
39 fn generate(&self, T_and_args: (i64, State), vec_constraints: Vec<Trie<(Rc<dyn Any>,f64)>>)
40 -> (Trace<(i64,State),Vec<Trie<(Rc<dyn Any>,f64)>>,Vec<State>>, f64)
41 {
42 let (T, mut state) = T_and_args;
43 assert!(T >= 1);
44 let mut vec_trace = Trace { args: (T, state.clone()), data: vec![], retv: Some(vec![]), logp: 0. };
45 let mut gen_weight = 0.;
46 for (t,constraints) in vec_constraints.into_iter().enumerate() {
47 let mut g = TrieFnState::Generate {
48 trace: Trace { args: (t as i64, state.clone()), data: Trie::new(), retv: None, logp: 0. },
49 weight: 0.,
50 constraints: constraints.into_unweighted()
51 };
52 state = (self.kernel)(&mut g, (t as i64, state.clone()));
53 let TrieFnState::Generate {mut trace, weight, constraints} = g else { unreachable!() };
54 assert!(constraints.is_empty());
55 vec_trace.retv.as_mut().unwrap().push(state.clone());
56 vec_trace.data.push(trace.data);
57 vec_trace.logp += trace.logp;
58 gen_weight += weight;
59 }
60 (vec_trace, gen_weight)
61 }
62
63 fn update(&self,
64 mut vec_trace: Trace<(i64,State),Vec<Trie<(Rc<dyn Any>,f64)>>,Vec<State>>,
65 T_and_args: (i64, State),
66 diff: GfDiff,
67 vec_constraints: Vec<Trie<(Rc<dyn Any>,f64)>>
68 ) -> (Trace<(i64,State),Vec<Trie<(Rc<dyn Any>,f64)>>,Vec<State>>, Vec<Trie<(Rc<dyn Any>,f64)>>, f64) {
69 let (T, _) = T_and_args;
70 assert!(T >= 1);
71 let prev_T = vec_trace.args.0;
72 assert!(T - prev_T == vec_constraints.len() as i64);
73 let mut state = vec_trace.retv.as_ref().unwrap().last().unwrap().clone();
74 let mut update_weight = 0.;
75 match diff {
76 GfDiff::Extend => {
77 for (t,constraints) in vec_constraints.into_iter().enumerate() {
78 let mut g = TrieFnState::Generate {
79 trace: Trace { args: (prev_T + (t as i64), state.clone()), data: Trie::new(), retv: None, logp: 0. },
80 weight: 0.,
81 constraints: constraints.into_unweighted()
82 };
83 state = (self.kernel)(&mut g, (prev_T + (t as i64), state.clone()));
84 let TrieFnState::Generate {mut trace, weight, constraints} = g else { unreachable!() };
85 assert!(constraints.is_empty());
86 vec_trace.args.0 += 1;
87 vec_trace.retv.as_mut().unwrap().push(state.clone());
88 vec_trace.data.push(trace.data);
89 vec_trace.logp += trace.logp;
90 update_weight += weight;
91 }
92 },
93 _ => { panic!("Can't handle GF change type: {:?}", diff) },
94 }
95 (vec_trace, (prev_T..T).map(|_| Trie::new()).collect::<_>(), update_weight)
96 }
97}