1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/runtime/memory.md"))]
2
3use crate::core::address::Address;
4use crate::core::distribution::Distribution;
5use crate::runtime::trace::{Choice, ChoiceValue, Trace};
6use std::collections::BTreeMap;
7use std::sync::Arc;
8
9#[derive(Clone, Debug)]
41pub struct CowTrace {
42 choices: Arc<BTreeMap<Address, Choice>>,
43 log_prior: f64,
44 log_likelihood: f64,
45 log_factors: f64,
46}
47
48impl Default for CowTrace {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl CowTrace {
55 pub fn new() -> Self {
57 Self {
58 choices: Arc::new(BTreeMap::new()),
59 log_prior: 0.0,
60 log_likelihood: 0.0,
61 log_factors: 0.0,
62 }
63 }
64
65 pub fn from_trace(trace: Trace) -> Self {
67 Self {
68 choices: Arc::new(trace.choices),
69 log_prior: trace.log_prior,
70 log_likelihood: trace.log_likelihood,
71 log_factors: trace.log_factors,
72 }
73 }
74
75 pub fn to_trace(&self) -> Trace {
77 Trace {
78 choices: (*self.choices).clone(),
79 log_prior: self.log_prior,
80 log_likelihood: self.log_likelihood,
81 log_factors: self.log_factors,
82 }
83 }
84
85 pub fn choices_mut(&mut self) -> &mut BTreeMap<Address, Choice> {
87 if Arc::strong_count(&self.choices) > 1 {
88 self.choices = Arc::new((*self.choices).clone());
90 }
91 Arc::get_mut(&mut self.choices).unwrap()
92 }
93
94 pub fn insert_choice(&mut self, addr: Address, choice: Choice) {
96 self.choices_mut().insert(addr, choice);
97 }
98
99 pub fn choices(&self) -> &BTreeMap<Address, Choice> {
101 &self.choices
102 }
103
104 pub fn total_log_weight(&self) -> f64 {
106 self.log_prior + self.log_likelihood + self.log_factors
107 }
108}
109
110pub struct TraceBuilder {
136 choices: BTreeMap<Address, Choice>,
137 log_prior: f64,
138 log_likelihood: f64,
139 log_factors: f64,
140}
141
142impl Default for TraceBuilder {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148impl TraceBuilder {
149 pub fn new() -> Self {
150 Self {
151 choices: BTreeMap::new(),
152 log_prior: 0.0,
153 log_likelihood: 0.0,
154 log_factors: 0.0,
155 }
156 }
157
158 pub fn with_capacity(_capacity: usize) -> Self {
159 Self::new()
161 }
162
163 pub fn add_sample(&mut self, addr: Address, value: f64, log_prob: f64) {
164 let choice = Choice {
165 addr: addr.clone(),
166 value: ChoiceValue::F64(value),
167 logp: log_prob,
168 };
169 self.choices.insert(addr, choice);
170 self.log_prior += log_prob;
171 }
172
173 pub fn add_sample_bool(&mut self, addr: Address, value: bool, log_prob: f64) {
174 let choice = Choice {
175 addr: addr.clone(),
176 value: ChoiceValue::Bool(value),
177 logp: log_prob,
178 };
179 self.choices.insert(addr, choice);
180 self.log_prior += log_prob;
181 }
182
183 pub fn add_sample_u64(&mut self, addr: Address, value: u64, log_prob: f64) {
184 let choice = Choice {
185 addr: addr.clone(),
186 value: ChoiceValue::U64(value),
187 logp: log_prob,
188 };
189 self.choices.insert(addr, choice);
190 self.log_prior += log_prob;
191 }
192
193 pub fn add_sample_usize(&mut self, addr: Address, value: usize, log_prob: f64) {
194 let choice = Choice {
195 addr: addr.clone(),
196 value: ChoiceValue::Usize(value),
197 logp: log_prob,
198 };
199 self.choices.insert(addr, choice);
200 self.log_prior += log_prob;
201 }
202
203 pub fn add_observation(&mut self, log_likelihood: f64) {
204 self.log_likelihood += log_likelihood;
205 }
206
207 pub fn add_factor(&mut self, log_weight: f64) {
208 self.log_factors += log_weight;
209 }
210
211 pub fn build(self) -> Trace {
212 Trace {
213 choices: self.choices,
214 log_prior: self.log_prior,
215 log_likelihood: self.log_likelihood,
216 log_factors: self.log_factors,
217 }
218 }
219}
220
221pub struct TracePool {
249 available: Vec<Trace>,
250 max_size: usize,
251 min_size: usize,
252 stats: PoolStats,
253}
254
255#[derive(Debug, Clone, Default)]
280pub struct PoolStats {
281 pub hits: u64,
283 pub misses: u64,
285 pub returns: u64,
287 pub drops: u64,
289}
290
291impl PoolStats {
292 pub fn hit_ratio(&self) -> f64 {
294 let total = self.hits + self.misses;
295 if total == 0 {
296 0.0
297 } else {
298 (self.hits as f64 / total as f64) * 100.0
299 }
300 }
301
302 pub fn total_gets(&self) -> u64 {
304 self.hits + self.misses
305 }
306}
307
308impl TracePool {
309 pub fn new(max_size: usize) -> Self {
314 Self {
315 available: Vec::with_capacity(max_size),
316 max_size,
317 min_size: max_size / 4, stats: PoolStats::default(),
319 }
320 }
321
322 pub fn with_bounds(max_size: usize, min_size: usize) -> Self {
324 assert!(min_size <= max_size, "min_size must be <= max_size");
325 Self {
326 available: Vec::with_capacity(max_size),
327 max_size,
328 min_size,
329 stats: PoolStats::default(),
330 }
331 }
332
333 pub fn get(&mut self) -> Trace {
337 if let Some(trace) = self.available.pop() {
338 self.stats.hits += 1;
339 trace
340 } else {
341 self.stats.misses += 1;
342 Trace::default()
343 }
344 }
345
346 pub fn return_trace(&mut self, mut trace: Trace) {
351 if self.available.len() < self.max_size {
352 trace.choices.clear();
354 trace.log_prior = 0.0;
355 trace.log_likelihood = 0.0;
356 trace.log_factors = 0.0;
357 self.available.push(trace);
358 self.stats.returns += 1;
359 } else {
360 self.stats.drops += 1;
361 }
362 }
363
364 pub fn shrink(&mut self) {
369 if self.available.len() > self.min_size {
370 self.available.truncate(self.min_size);
371 self.available.shrink_to_fit();
372 }
373 }
374
375 pub fn shrink_to(&mut self, target_size: usize) {
377 let target = target_size.min(self.max_size);
378 if self.available.len() > target {
379 self.available.truncate(target);
380 self.available.shrink_to_fit();
381 }
382 }
383
384 pub fn clear(&mut self) {
386 self.available.clear();
387 }
388
389 pub fn stats(&self) -> &PoolStats {
391 &self.stats
392 }
393
394 pub fn reset_stats(&mut self) {
396 self.stats = PoolStats::default();
397 }
398
399 pub fn len(&self) -> usize {
401 self.available.len()
402 }
403
404 pub fn is_empty(&self) -> bool {
406 self.available.is_empty()
407 }
408
409 pub fn capacity(&self) -> usize {
411 self.max_size
412 }
413
414 pub fn min_capacity(&self) -> usize {
416 self.min_size
417 }
418}
419
420pub struct PooledPriorHandler<'a, R: rand::RngCore> {
452 pub rng: &'a mut R,
453 pub trace_builder: TraceBuilder,
454 pub pool: &'a mut TracePool,
455}
456
457impl<'a, R: rand::RngCore> crate::runtime::handler::Handler for PooledPriorHandler<'a, R> {
458 fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
459 let x = dist.sample(self.rng);
460 let lp = dist.log_prob(&x);
461 self.trace_builder.add_sample(addr.clone(), x, lp);
462 x
463 }
464
465 fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
466 let x = dist.sample(self.rng);
467 let lp = dist.log_prob(&x);
468 self.trace_builder.add_sample_bool(addr.clone(), x, lp);
469 x
470 }
471
472 fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
473 let x = dist.sample(self.rng);
474 let lp = dist.log_prob(&x);
475 self.trace_builder.add_sample_u64(addr.clone(), x, lp);
476 x
477 }
478
479 fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
480 let x = dist.sample(self.rng);
481 let lp = dist.log_prob(&x);
482 self.trace_builder.add_sample_usize(addr.clone(), x, lp);
483 x
484 }
485
486 fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
487 let log_likelihood = dist.log_prob(&value);
488 self.trace_builder.add_observation(log_likelihood);
489 }
490
491 fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
492 let log_likelihood = dist.log_prob(&value);
493 self.trace_builder.add_observation(log_likelihood);
494 }
495
496 fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
497 let log_likelihood = dist.log_prob(&value);
498 self.trace_builder.add_observation(log_likelihood);
499 }
500
501 fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
502 let log_likelihood = dist.log_prob(&value);
503 self.trace_builder.add_observation(log_likelihood);
504 }
505
506 fn on_factor(&mut self, logw: f64) {
507 self.trace_builder.add_factor(logw);
508 }
509
510 fn finish(self) -> Trace {
511 self.trace_builder.build()
512 }
513}
514
515#[cfg(test)]
516mod memory_tests {
517 use super::*;
518 use crate::addr;
519 use std::time::Instant;
520
521 #[test]
522 fn test_cow_trace_efficiency() {
523 let mut trace1 = CowTrace::new();
524 trace1.insert_choice(
525 addr!("x"),
526 Choice {
527 addr: addr!("x"),
528 value: ChoiceValue::F64(1.0),
529 logp: -0.5,
530 },
531 );
532
533 let trace2 = trace1.clone();
535 assert!(Arc::ptr_eq(&trace1.choices, &trace2.choices));
536
537 let mut trace3 = trace2.clone();
539 trace3.insert_choice(
540 addr!("y"),
541 Choice {
542 addr: addr!("y"),
543 value: ChoiceValue::F64(2.0),
544 logp: -1.0,
545 },
546 );
547
548 assert!(!Arc::ptr_eq(&trace1.choices, &trace3.choices));
550 }
551
552 #[test]
553 fn test_trace_pool_basic() {
554 let mut pool = TracePool::new(3);
555
556 let trace1 = pool.get();
558 let trace2 = pool.get();
559
560 assert_eq!(pool.stats().misses, 2);
562 assert_eq!(pool.stats().hits, 0);
563
564 pool.return_trace(trace1);
566 pool.return_trace(trace2);
567 assert_eq!(pool.stats().returns, 2);
568
569 let trace3 = pool.get();
571 assert_eq!(trace3.choices.len(), 0); assert_eq!(pool.stats().hits, 1);
573 }
574
575 #[test]
576 fn test_trace_pool_stats() {
577 let mut pool = TracePool::new(2);
578
579 let t1 = pool.get(); let t2 = pool.get(); assert_eq!(pool.stats().misses, 2);
583 assert_eq!(pool.stats().hit_ratio(), 0.0);
584
585 pool.return_trace(t1); let _t3 = pool.get(); assert_eq!(pool.stats().hits, 1);
588 assert_eq!(pool.stats().returns, 1);
589 assert!(pool.stats().hit_ratio() > 0.0);
590
591 pool.return_trace(t2); let another_trace = pool.get(); pool.return_trace(another_trace); let extra_trace = Trace::default();
598 pool.return_trace(extra_trace); let dummy_trace = Trace {
602 log_prior: 1.0, ..Trace::default()
604 };
605 pool.return_trace(dummy_trace); assert_eq!(pool.stats().drops, 1);
607 }
608
609 #[test]
610 fn test_trace_pool_shrinking() {
611 let mut pool = TracePool::with_bounds(10, 3);
612
613 for _ in 0..8 {
615 pool.return_trace(Trace::default());
616 }
617 assert_eq!(pool.len(), 8);
618
619 pool.shrink();
621 assert_eq!(pool.len(), 3);
622
623 for _ in 0..5 {
625 pool.return_trace(Trace::default());
626 }
627 assert_eq!(pool.len(), 8); pool.shrink_to(2);
629 assert_eq!(pool.len(), 2);
630 }
631
632 #[test]
633 fn test_trace_builder_efficiency() {
634 let mut builder = TraceBuilder::new();
635
636 for i in 0..1000 {
638 builder.add_sample(addr!("x", i), i as f64, -0.5);
639 }
640
641 let trace = builder.build();
642 assert_eq!(trace.choices.len(), 1000);
643 assert!((trace.log_prior - (-500.0)).abs() < 1e-10);
644 }
645
646 #[test]
647 fn test_address_optimization() {
648 let start = Instant::now();
651 let mut builder = TraceBuilder::new();
652
653 for i in 0..10000 {
654 let addr = addr!("test", i);
655 builder.add_sample(addr, i as f64, -0.5);
656 }
657
658 let trace = builder.build();
659 let duration = start.elapsed();
660
661 assert_eq!(trace.choices.len(), 10000);
662 println!("Built trace with 10k choices in {:?}", duration);
664 }
665
666 #[test]
667 fn test_mixed_value_types() {
668 let mut builder = TraceBuilder::new();
669
670 builder.add_sample(addr!("f64"), 1.5, -0.5);
672 builder.add_sample_bool(addr!("bool"), true, -0.693);
673 builder.add_sample_u64(addr!("u64"), 42, -1.0);
674 builder.add_sample_usize(addr!("usize"), 3, -1.2);
675
676 let trace = builder.build();
677 assert_eq!(trace.choices.len(), 4);
678
679 assert_eq!(trace.choices[&addr!("f64")].value, ChoiceValue::F64(1.5));
681 assert_eq!(trace.choices[&addr!("bool")].value, ChoiceValue::Bool(true));
682 assert_eq!(trace.choices[&addr!("u64")].value, ChoiceValue::U64(42));
683 assert_eq!(trace.choices[&addr!("usize")].value, ChoiceValue::Usize(3));
684 }
685
686 #[test]
687 fn test_cow_trace_memory_sharing() {
688 let mut base = Trace::default();
690 for i in 0..1000 {
691 base.insert_choice(addr!("x", i), ChoiceValue::F64(i as f64), -0.5);
692 }
693 let cow_base = CowTrace::from_trace(base);
694
695 let mut clones = Vec::new();
697 for _ in 0..100 {
698 clones.push(cow_base.clone());
699 }
700
701 for clone in &clones {
703 assert!(Arc::ptr_eq(&cow_base.choices, &clone.choices));
704 }
705
706 let mut modified = clones[0].clone();
708 modified.insert_choice(
709 addr!("new"),
710 Choice {
711 addr: addr!("new"),
712 value: ChoiceValue::F64(999.0),
713 logp: -2.0,
714 },
715 );
716
717 assert!(!Arc::ptr_eq(&cow_base.choices, &modified.choices));
719 assert!(Arc::ptr_eq(&cow_base.choices, &clones[1].choices));
721 }
722
723 #[test]
724 fn test_pool_stats_accuracy() {
725 let mut pool = TracePool::new(5);
726
727 for _ in 0..10 {
730 pool.get(); }
732
733 for _ in 0..5 {
735 pool.return_trace(Trace::default()); }
737
738 for _ in 0..10 {
740 pool.get(); }
742
743 let stats = pool.stats();
744 assert_eq!(stats.misses, 15); assert_eq!(stats.hits, 5);
746 assert_eq!(stats.returns, 5);
747 assert_eq!(stats.drops, 0);
748 assert_eq!(stats.total_gets(), 20);
749 assert!((stats.hit_ratio() - 25.0).abs() < 1e-10);
750 }
751}
752
753#[cfg(test)]
754mod pooled_tests {
755 use super::*;
756 use crate::addr;
757 use crate::core::distribution::*;
758 use crate::core::model::{observe, sample, ModelExt};
759 use crate::runtime::handler::run;
760 use rand::rngs::StdRng;
761 use rand::SeedableRng;
762
763 #[test]
764 fn pooled_prior_handler_builds_trace_and_updates_pool() {
765 let mut pool = TracePool::new(4);
766 let mut rng = StdRng::seed_from_u64(40);
767 let (_val, trace) = run(
768 PooledPriorHandler {
769 rng: &mut rng,
770 trace_builder: TraceBuilder::new(),
771 pool: &mut pool,
772 },
773 sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
774 .and_then(|x| observe(addr!("y"), Normal::new(x, 1.0).unwrap(), 0.3)),
775 );
776 assert!(trace.choices.contains_key(&addr!("x")));
777 assert!(trace.log_likelihood.is_finite());
778
779 let before_returns = pool.stats().returns;
781 pool.return_trace(trace);
782 assert_eq!(pool.stats().returns, before_returns + 1);
783 }
784}