1use std::rc::Rc;
2use std::any::Any;
3use crate::modeling::dists::Distribution;
4use crate::{GLOBAL_RNG, Trie, GenFn, GfDiff, Trace};
5
6
7pub enum TrieFnState<A,T> {
9 Simulate {
11 trace: Trace<A,Trie<(Rc<dyn Any>,f64)>,T>,
13 },
14
15 Generate {
17 trace: Trace<A,Trie<(Rc<dyn Any>,f64)>,T>,
19 weight: f64,
21 constraints: Trie<Rc<dyn Any>>,
23 },
24
25 Update {
27 trace: Trace<A,Trie<(Rc<dyn Any>,f64)>,T>,
29 constraints: Trie<Rc<dyn Any>>,
31 weight: f64,
33 discard: Trie<Rc<dyn Any>>,
35 visitor: AddrTrie
37 }
38}
39
40pub type AddrTrie = Trie<()>;
42
43impl AddrTrie {
44
45 pub fn schema<V>(data: &Trie<V>) -> Self {
47 let mut visitor = Trie::new();
48 for (addr, _) in data.leaf_iter() {
49 visitor.insert_leaf_node(addr, ());
50 }
51 for (addr, inode) in data.internal_iter() {
52 visitor.insert_internal_node(addr, Self::schema(inode));
53 }
54 visitor
55 }
56
57 pub fn visit(&mut self, addr: &str) {
59 self.insert_leaf_node(addr, ());
60 }
61
62 pub fn all_visited<T>(&self, data: &Trie<T>) -> bool {
64 let mut allvisited = true;
65 for (addr, _) in data.leaf_iter() {
66 allvisited = allvisited && self.has_leaf_node(&addr);
67 }
68 for (addr, inode) in data.internal_iter() {
69 if !self.has_leaf_node(&addr) {
70 let subvisited = self.get_internal_node(&addr).unwrap();
71 allvisited = allvisited && subvisited.all_visited(inode)
72 }
73 }
74 allvisited
75 }
76
77 pub fn get_unvisited<V>(&self, data: &Trie<V>) -> Self {
79 let mut unvisited = Trie::new();
80 for (addr, _) in data.leaf_iter() {
81 if !self.has_leaf_node(&addr) {
82 unvisited.insert_leaf_node(&addr, ());
83 }
84 }
85 for (addr, inode) in data.internal_iter() {
86 if !self.has_leaf_node(&addr) {
87 let subvisited = self.get_internal_node(&addr).unwrap();
88 let sub_unvisited = subvisited.get_unvisited(inode);
89 unvisited.insert_internal_node(&addr, sub_unvisited);
90 }
91 }
92 unvisited
93 }
94
95}
96
97impl<A: 'static,T: 'static> TrieFnState<A,T> {
98 pub fn sample_at<
102 V: Clone + 'static,
103 W: Clone + 'static
104 >(&mut self, dist: &impl Distribution<V,W>, args: W, addr: &str) -> V {
105 match self {
106 TrieFnState::Simulate {
107 trace,
108 } => {
109 let x = GLOBAL_RNG.with_borrow_mut(|rng| {
110 dist.random(rng, args.clone())
111 });
112 let logp = dist.logpdf(&x, args);
113 let data = &mut trace.data;
114 data.insert_leaf_node(addr, (Rc::new(x.clone()), logp));
115 trace.logp += logp;
116 x
117 }
118
119 TrieFnState::Generate {
120 trace,
121 weight,
122 constraints,
123 } => {
124 let (x, logp) = match constraints.remove_leaf_node(addr) {
126 None => {
128 let x = GLOBAL_RNG.with_borrow_mut(|rng| {
129 dist.random(rng, args.clone())
130 });
131 let logp = dist.logpdf(&x, args);
132 (Rc::new(x), logp)
133 }
134 Some(call) => {
136 let x = call.downcast::<V>().ok().unwrap();
137 let logp = dist.logpdf(x.as_ref(), args);
138 *weight += logp;
139 (x, logp)
140 }
141 };
142
143 let data = &mut trace.data;
145 data.insert_leaf_node(addr, (x.clone(), logp));
146 trace.logp += logp;
147
148 x.as_ref().clone()
149 }
150
151 TrieFnState::Update {
152 trace,
153 constraints,
154 weight,
155 discard,
156 visitor
157 } => {
158 visitor.visit(addr);
159
160 let data = &mut trace.data;
161 let prev_x: Rc<V>;
162 let x: Rc<V>;
163
164 let has_previous = data.has_leaf_node(addr);
165 let constrained = constraints.has_leaf_node(addr);
166 let logp;
167 let mut prev_logp = 0.;
168 if has_previous {
169 let val = data.remove_leaf_node(addr).unwrap();
170 prev_x = val.0.downcast::<V>().ok().unwrap();
171 prev_logp = val.1;
172 if constrained {
173 discard.insert_leaf_node(addr, prev_x);
174 x = constraints.remove_leaf_node(addr).unwrap().downcast::<V>().ok().unwrap();
175 } else {
176 x = prev_x;
177 }
178 logp = dist.logpdf(x.as_ref(), args);
179 *weight += logp;
180 *weight -= prev_logp;
181 } else {
182 if constrained {
183 x = constraints.remove_leaf_node(addr).unwrap().downcast::<V>().ok().unwrap();
184 logp = dist.logpdf(x.as_ref(), args);
185 *weight += logp;
186 } else {
187 x = Rc::new(GLOBAL_RNG.with_borrow_mut(|rng| {
188 dist.random(rng, args.clone())
189 }));
190 logp = dist.logpdf(x.as_ref(), args);
191 }
192 }
193
194 data.insert_leaf_node(addr, (x.clone(), logp));
195 trace.logp += logp;
196 trace.logp -= prev_logp;
197
198 x.as_ref().clone()
199 }
200 }
201 }
202
203 pub fn trace_at<
210 X: Clone + 'static,
211 Y: Clone + 'static
212 >(&mut self, gen_fn: &impl GenFn<X,Trie<(Rc<dyn Any>,f64)>,Y>, args: X, addr: &str) -> Y {
213 match self {
214 TrieFnState::Simulate {
215 trace,
216 } => {
217 let subtrace = gen_fn.simulate(args);
218
219 let data = &mut trace.data;
220 data.insert_internal_node(addr, subtrace.data);
221
222 let retv = subtrace.retv.unwrap();
223 data.insert_leaf_node(addr, (Rc::new(retv.clone()), 0.));
224 trace.logp += subtrace.logp;
225
226 retv
227 }
228
229 TrieFnState::Generate {
230 trace,
231 weight,
232 constraints,
233 } => {
234 let subtrace = match constraints.remove_internal_node(addr) {
235 None => {
236 gen_fn.simulate(args)
237 }
238 Some(subconstraints) => {
239 let (subtrace, new_weight) = gen_fn.generate(args, Trie::from_unweighted(subconstraints));
240 *weight += new_weight;
241 subtrace
242 }
243 };
244
245 let data = &mut trace.data;
246 data.insert_internal_node(addr, subtrace.data);
247
248 let retv = subtrace.retv.unwrap().clone();
249 data.insert_leaf_node(addr, (Rc::new(retv.clone()), 0.));
250 trace.logp += subtrace.logp;
251
252 retv
253 },
254
255 TrieFnState::Update {
256 trace,
257 constraints,
258 weight,
259 discard,
260 visitor
261 } => {
262 visitor.visit(addr);
263
264 let data = &mut trace.data;
265 let prev_subtrie: Trie<(Rc<dyn Any>,f64)>;
266 let subtrie: Trie<(Rc<dyn Any>,f64)>;
267 let retv: Rc<Y>;
268
269 let has_previous = data.has_internal_node(addr);
270 let constrained = constraints.has_internal_node(addr);
271 let mut logp = 0.;
272 if has_previous {
273 prev_subtrie = data.remove_internal_node(addr).unwrap();
274 if constrained {
275 let subconstraints = Trie::from_unweighted(constraints.remove_internal_node(addr).unwrap());
276 constraints.remove_leaf_node(addr);
278 let prev_logp = prev_subtrie.sum();
279 let subtrace = Trace { args: args.clone(), data: prev_subtrie, retv: None, logp: prev_logp };
282 let (subtrace, subdiscard, new_weight) = gen_fn.update(subtrace, args, GfDiff::Unknown, subconstraints);
283 discard.insert_internal_node(addr, subdiscard.into_unweighted());
284 subtrie = subtrace.data;
285 retv = Rc::new(subtrace.retv.unwrap());
286 logp = new_weight;
287 *weight += new_weight;
288 } else {
289 dbg!(prev_subtrie.sum());
290 subtrie = prev_subtrie;
291 retv = data.remove_leaf_node(addr).unwrap().0.downcast::<Y>().ok().unwrap();
292 }
293 *weight += logp;
294 } else {
295 if constrained {
296 let subconstraints = Trie::from_unweighted(constraints.remove_internal_node(addr).unwrap());
297 let (subtrace, new_weight) = gen_fn.generate(args, subconstraints);
298 subtrie = subtrace.data;
299 retv = Rc::new(subtrace.retv.unwrap());
300 logp = new_weight;
301 *weight += logp;
302 } else {
303 let subtrace = gen_fn.simulate(args);
304 subtrie = subtrace.data;
305 retv = Rc::new(subtrace.retv.unwrap());
306 logp = subtrace.logp;
307 }
308 }
309
310 data.insert_internal_node(addr, subtrie);
311 data.insert_leaf_node(addr, (retv.clone(), 0.));
312 trace.logp += logp;
313
314 retv.as_ref().clone()
315 }
316 }
317 }
318
319 fn _gc(
320 mut trie: Trie<(Rc<dyn Any>,f64)>,
321 unvisited: &AddrTrie,
322 ) -> (Trie<(Rc<dyn Any>,f64)>,Trie<Rc<dyn Any>>,f64) {
323 let mut garbage = Trie::new();
324 let mut garbage_weight = 0.;
325 if &AddrTrie::schema(&trie) == unvisited {
327 garbage_weight = trie.sum();
328 garbage = trie.into_unweighted();
329 trie = Trie::new();
330 } else if !unvisited.is_empty() {
331 for (addr, _) in unvisited.leaf_iter() {
332 let Some((value, logp)) = trie.remove_leaf_node(addr) else { unreachable!() };
333 garbage.insert_leaf_node(addr, value);
334 garbage_weight += logp;
335 }
336 for (addr, sub_unvisited) in unvisited.internal_iter() {
337 let Some(subtrie) = trie.remove_internal_node(addr) else { unreachable!() };
338 let (subtrie, subgarbage, logp) = Self::_gc(subtrie, sub_unvisited);
339 if !subtrie.is_empty() {
340 trie.insert_internal_node(addr, subtrie);
341 }
342 if !subgarbage.is_empty() {
343 garbage.insert_internal_node(addr, subgarbage);
344 }
345 garbage_weight += logp;
346 }
347 }
348 (trie, garbage, garbage_weight)
349 }
350
351 pub fn gc(self) -> Self {
355 if let Self::Update { trace, constraints, weight, discard, visitor } = self {
356 let unvisited = visitor.get_unvisited(&trace.data);
357 let (data, garbage, garbage_weight) = Self::_gc(trace.data, &unvisited);
358 assert!(visitor.all_visited(&data)); Self::Update {
360 trace: Trace { args: trace.args, data, retv: trace.retv, logp: trace.logp - garbage_weight },
361 constraints,
362 weight: weight - garbage_weight,
363 discard: discard.merge(garbage),
364 visitor
365 }
366 } else { panic!("garbage-collect (gc) called outside of update context") }
367 }
368}
369
370
371pub struct TrieFn<A,T> {
373 pub func: fn(&mut TrieFnState<A,T>, A) -> T,
375}
376
377impl<Args,Ret> TrieFn<Args,Ret>{
378 pub fn new(func: fn(&mut TrieFnState<Args,Ret>, Args) -> Ret) -> Self {
380 TrieFn { func }
381 }
382}
383
384
385impl<Args: Clone + 'static,Ret: 'static> GenFn<Args,Trie<(Rc<dyn Any>,f64)>,Ret> for TrieFn<Args,Ret> {
386 fn simulate(&self, args: Args) -> Trace<Args,Trie<(Rc<dyn Any>,f64)>,Ret> {
387 let mut g = TrieFnState::Simulate {
388 trace: Trace { args: args.clone(), data: Trie::new(), retv: None, logp: 0. },
389 };
390 let retv = (self.func)(&mut g, args);
391 let TrieFnState::Simulate {mut trace} = g else { unreachable!() };
392 trace.set_retv(retv);
393 trace
394 }
395
396 fn generate(&self, args: Args, constraints: Trie<(Rc<dyn Any>,f64)>) -> (Trace<Args,Trie<(Rc<dyn Any>,f64)>,Ret>, f64) {
397 let mut g = TrieFnState::Generate {
398 trace: Trace { args: args.clone(), data: Trie::new(), retv: None, logp: 0. },
399 weight: 0.,
400 constraints: constraints.into_unweighted(),
401 };
402 let retv = (self.func)(&mut g, args);
403 let TrieFnState::Generate {mut trace, weight, constraints} = g else { unreachable!() };
404 assert!(constraints.is_empty()); trace.set_retv(retv);
406 (trace, weight)
407 }
408
409 fn update(&self,
410 trace: Trace<Args,Trie<(Rc<dyn Any>,f64)>,Ret>,
411 args: Args,
412 _: GfDiff,
413 constraints: Trie<(Rc<dyn Any>,f64)>
414 ) -> (Trace<Args,Trie<(Rc<dyn Any>,f64)>,Ret>, Trie<(Rc<dyn Any>,f64)>, f64) {
415 let mut g = TrieFnState::Update {
416 trace,
417 weight: 0.,
418 constraints: constraints.into_unweighted(),
419 discard: Trie::new(),
420 visitor: AddrTrie::new()
421 };
422 let retv = (self.func)(&mut g, args);
423 let g = g.gc(); let TrieFnState::Update {mut trace, weight, constraints, discard, visitor: _visitor} = g else { unreachable!() };
425 assert!(constraints.is_empty()); trace.set_retv(retv);
427 (trace, Trie::from_unweighted(discard), weight)
428 }
429}