fugue/runtime/
memory.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/runtime/memory.md"))]
2
3use crate::core::address::Address;
4use crate::core::distribution::Distribution;
5use crate::runtime::trace::{Choice, ChoiceValue, Trace};
6use std::collections::BTreeMap;
7use std::sync::Arc;
8
9/// Copy-on-write trace for efficient memory sharing in MCMC operations.
10///
11/// Most MCMC operations modify only a small number of choices, so CowTrace
12/// shares the majority of trace data between states using `Arc<BTreeMap>`.
13///
14/// Example:
15/// ```rust
16/// # use fugue::*;
17/// # use fugue::runtime::memory::CowTrace;
18///
19/// // Convert from regular trace
20/// # let mut rng = rand::thread_rng();
21/// # let (_, trace) = runtime::handler::run(
22/// #     PriorHandler { rng: &mut rng, trace: Trace::default() },
23/// #     sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
24/// # );
25/// let cow_trace = CowTrace::from_trace(trace);
26///
27/// // Clone is very efficient (shares memory)
28/// let clone1 = cow_trace.clone();
29/// let clone2 = cow_trace.clone();
30///
31/// // Modification triggers copy-on-write only when needed
32/// let mut modified = clone1.clone();
33/// modified.insert_choice(addr!("new"), Choice {
34///     addr: addr!("new"),
35///     value: ChoiceValue::F64(42.0),
36///     logp: -1.0,
37/// });
38/// // Now `modified` has its own copy, others still share
39/// ```
40#[derive(Clone, Debug)]
41pub struct CowTrace {
42    choices: Arc<BTreeMap<Address, Choice>>,
43    log_prior: f64,
44    log_likelihood: f64,
45    log_factors: f64,
46}
47
48impl Default for CowTrace {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl CowTrace {
55    /// Create a new copy-on-write trace.
56    pub fn new() -> Self {
57        Self {
58            choices: Arc::new(BTreeMap::new()),
59            log_prior: 0.0,
60            log_likelihood: 0.0,
61            log_factors: 0.0,
62        }
63    }
64
65    /// Convert from regular trace.
66    pub fn from_trace(trace: Trace) -> Self {
67        Self {
68            choices: Arc::new(trace.choices),
69            log_prior: trace.log_prior,
70            log_likelihood: trace.log_likelihood,
71            log_factors: trace.log_factors,
72        }
73    }
74
75    /// Convert to regular trace (may involve copying).
76    pub fn to_trace(&self) -> Trace {
77        Trace {
78            choices: (*self.choices).clone(),
79            log_prior: self.log_prior,
80            log_likelihood: self.log_likelihood,
81            log_factors: self.log_factors,
82        }
83    }
84
85    /// Get mutable access to choices, copying if necessary.
86    pub fn choices_mut(&mut self) -> &mut BTreeMap<Address, Choice> {
87        if Arc::strong_count(&self.choices) > 1 {
88            // Need to copy - other references exist
89            self.choices = Arc::new((*self.choices).clone());
90        }
91        Arc::get_mut(&mut self.choices).unwrap()
92    }
93
94    /// Insert a choice, copying the map if needed.
95    pub fn insert_choice(&mut self, addr: Address, choice: Choice) {
96        self.choices_mut().insert(addr, choice);
97    }
98
99    /// Get read-only access to choices.
100    pub fn choices(&self) -> &BTreeMap<Address, Choice> {
101        &self.choices
102    }
103
104    /// Total log weight.
105    pub fn total_log_weight(&self) -> f64 {
106        self.log_prior + self.log_likelihood + self.log_factors
107    }
108}
109
110/// Efficient trace builder that minimizes allocations during construction.
111///
112/// TraceBuilder uses pre-allocated collections and provides type-specific
113/// methods to build traces efficiently with minimal memory overhead.
114///
115/// Example:
116/// ```rust
117/// # use fugue::*;
118/// # use fugue::runtime::memory::TraceBuilder;
119///
120/// let mut builder = TraceBuilder::new();
121///
122/// // Add different types of samples efficiently
123/// builder.add_sample(addr!("x"), 1.5, -0.5);
124/// builder.add_sample_bool(addr!("flag"), true, -0.693);
125/// builder.add_sample_u64(addr!("count"), 42, -1.0);
126///
127/// // Add observations and factors
128/// builder.add_observation(-2.3); // Likelihood contribution
129/// builder.add_factor(-0.1);      // Soft constraint
130///
131/// // Build final trace
132/// let trace = builder.build();
133/// assert_eq!(trace.choices.len(), 3);
134/// ```
135pub struct TraceBuilder {
136    choices: BTreeMap<Address, Choice>,
137    log_prior: f64,
138    log_likelihood: f64,
139    log_factors: f64,
140}
141
142impl Default for TraceBuilder {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148impl TraceBuilder {
149    pub fn new() -> Self {
150        Self {
151            choices: BTreeMap::new(),
152            log_prior: 0.0,
153            log_likelihood: 0.0,
154            log_factors: 0.0,
155        }
156    }
157
158    pub fn with_capacity(_capacity: usize) -> Self {
159        // BTreeMap doesn't have with_capacity, but we can pre-allocate differently
160        Self::new()
161    }
162
163    pub fn add_sample(&mut self, addr: Address, value: f64, log_prob: f64) {
164        let choice = Choice {
165            addr: addr.clone(),
166            value: ChoiceValue::F64(value),
167            logp: log_prob,
168        };
169        self.choices.insert(addr, choice);
170        self.log_prior += log_prob;
171    }
172
173    pub fn add_sample_bool(&mut self, addr: Address, value: bool, log_prob: f64) {
174        let choice = Choice {
175            addr: addr.clone(),
176            value: ChoiceValue::Bool(value),
177            logp: log_prob,
178        };
179        self.choices.insert(addr, choice);
180        self.log_prior += log_prob;
181    }
182
183    pub fn add_sample_u64(&mut self, addr: Address, value: u64, log_prob: f64) {
184        let choice = Choice {
185            addr: addr.clone(),
186            value: ChoiceValue::U64(value),
187            logp: log_prob,
188        };
189        self.choices.insert(addr, choice);
190        self.log_prior += log_prob;
191    }
192
193    pub fn add_sample_usize(&mut self, addr: Address, value: usize, log_prob: f64) {
194        let choice = Choice {
195            addr: addr.clone(),
196            value: ChoiceValue::Usize(value),
197            logp: log_prob,
198        };
199        self.choices.insert(addr, choice);
200        self.log_prior += log_prob;
201    }
202
203    pub fn add_observation(&mut self, log_likelihood: f64) {
204        self.log_likelihood += log_likelihood;
205    }
206
207    pub fn add_factor(&mut self, log_weight: f64) {
208        self.log_factors += log_weight;
209    }
210
211    pub fn build(self) -> Trace {
212        Trace {
213            choices: self.choices,
214            log_prior: self.log_prior,
215            log_likelihood: self.log_likelihood,
216            log_factors: self.log_factors,
217        }
218    }
219}
220
221/// Memory pool for reusing trace allocations to reduce overhead.
222///
223/// TracePool maintains a collection of cleared Trace objects that can be
224/// reused to reduce allocation overhead in MCMC and other inference algorithms.
225///
226/// Example:
227/// ```rust
228/// # use fugue::*;
229/// # use fugue::runtime::memory::TracePool;
230///
231/// let mut pool = TracePool::new(10); // Pool up to 10 traces
232///
233/// // Get traces from pool (creates new ones initially)
234/// let trace1 = pool.get();
235/// let trace2 = pool.get();
236/// assert_eq!(pool.stats().misses, 2); // Both were cache misses
237///
238/// // Return traces to pool for reuse
239/// pool.return_trace(trace1);
240/// pool.return_trace(trace2);
241/// assert_eq!(pool.stats().returns, 2);
242///
243/// // Next gets will reuse pooled traces (cache hits)
244/// let trace3 = pool.get();
245/// assert_eq!(pool.stats().hits, 1);
246/// assert_eq!(trace3.choices.len(), 0); // Trace was cleared
247/// ```
248pub struct TracePool {
249    available: Vec<Trace>,
250    max_size: usize,
251    min_size: usize,
252    stats: PoolStats,
253}
254
255/// Statistics for monitoring TracePool usage and efficiency.
256///
257/// PoolStats tracks cache hits/misses and provides metrics to optimize
258/// memory pool performance in inference algorithms.
259///
260/// Example:
261/// ```rust
262/// # use fugue::runtime::memory::*;
263///
264/// let mut pool = TracePool::new(5);
265///
266/// // Generate some cache activity
267/// let trace1 = pool.get(); // miss
268/// let trace2 = pool.get(); // miss
269/// pool.return_trace(trace1);
270/// let trace3 = pool.get(); // hit (reuses trace1)
271///
272/// // Check performance metrics
273/// let stats = pool.stats();
274/// println!("Hit ratio: {:.1}%", stats.hit_ratio());
275/// println!("Total operations: {}", stats.total_gets());
276/// assert_eq!(stats.hits, 1);
277/// assert_eq!(stats.misses, 2);
278/// ```
279#[derive(Debug, Clone, Default)]
280pub struct PoolStats {
281    /// Number of successful gets from the pool (cache hits).
282    pub hits: u64,
283    /// Number of gets that required new allocation (cache misses).
284    pub misses: u64,
285    /// Number of traces returned to the pool.
286    pub returns: u64,
287    /// Number of traces dropped due to pool being full.
288    pub drops: u64,
289}
290
291impl PoolStats {
292    /// Calculate hit ratio as a percentage.
293    pub fn hit_ratio(&self) -> f64 {
294        let total = self.hits + self.misses;
295        if total == 0 {
296            0.0
297        } else {
298            (self.hits as f64 / total as f64) * 100.0
299        }
300    }
301
302    /// Total number of get operations.
303    pub fn total_gets(&self) -> u64 {
304        self.hits + self.misses
305    }
306}
307
308impl TracePool {
309    /// Create a new trace pool with the specified capacity bounds.
310    ///
311    /// - `max_size`: Maximum number of traces to keep in the pool
312    /// - `min_size`: Minimum number of traces to maintain (for shrinking)
313    pub fn new(max_size: usize) -> Self {
314        Self {
315            available: Vec::with_capacity(max_size),
316            max_size,
317            min_size: max_size / 4, // Keep at least 25% of max capacity
318            stats: PoolStats::default(),
319        }
320    }
321
322    /// Create a new trace pool with custom capacity bounds.
323    pub fn with_bounds(max_size: usize, min_size: usize) -> Self {
324        assert!(min_size <= max_size, "min_size must be <= max_size");
325        Self {
326            available: Vec::with_capacity(max_size),
327            max_size,
328            min_size,
329            stats: PoolStats::default(),
330        }
331    }
332
333    /// Get a trace from the pool or create new one.
334    ///
335    /// Returns a cleared trace ready for use. Updates hit/miss statistics.
336    pub fn get(&mut self) -> Trace {
337        if let Some(trace) = self.available.pop() {
338            self.stats.hits += 1;
339            trace
340        } else {
341            self.stats.misses += 1;
342            Trace::default()
343        }
344    }
345
346    /// Return a trace to the pool for reuse.
347    ///
348    /// The trace will be cleared and made available for future gets.
349    /// If the pool is full, the trace will be dropped.
350    pub fn return_trace(&mut self, mut trace: Trace) {
351        if self.available.len() < self.max_size {
352            // Clear the trace for reuse
353            trace.choices.clear();
354            trace.log_prior = 0.0;
355            trace.log_likelihood = 0.0;
356            trace.log_factors = 0.0;
357            self.available.push(trace);
358            self.stats.returns += 1;
359        } else {
360            self.stats.drops += 1;
361        }
362    }
363
364    /// Shrink the pool to the minimum size if it's grown too large.
365    ///
366    /// This can be called periodically to reclaim memory when the pool
367    /// has accumulated more traces than needed.
368    pub fn shrink(&mut self) {
369        if self.available.len() > self.min_size {
370            self.available.truncate(self.min_size);
371            self.available.shrink_to_fit();
372        }
373    }
374
375    /// Force shrink to a specific size.
376    pub fn shrink_to(&mut self, target_size: usize) {
377        let target = target_size.min(self.max_size);
378        if self.available.len() > target {
379            self.available.truncate(target);
380            self.available.shrink_to_fit();
381        }
382    }
383
384    /// Clear all traces from the pool.
385    pub fn clear(&mut self) {
386        self.available.clear();
387    }
388
389    /// Get current pool statistics.
390    pub fn stats(&self) -> &PoolStats {
391        &self.stats
392    }
393
394    /// Reset statistics counters.
395    pub fn reset_stats(&mut self) {
396        self.stats = PoolStats::default();
397    }
398
399    /// Current number of available traces in the pool.
400    pub fn len(&self) -> usize {
401        self.available.len()
402    }
403
404    /// Check if the pool is empty.
405    pub fn is_empty(&self) -> bool {
406        self.available.is_empty()
407    }
408
409    /// Maximum capacity of the pool.
410    pub fn capacity(&self) -> usize {
411        self.max_size
412    }
413
414    /// Minimum size maintained during shrinking.
415    pub fn min_capacity(&self) -> usize {
416        self.min_size
417    }
418}
419
420/// Optimized handler that uses memory pooling for zero-allocation inference.
421///
422/// PooledPriorHandler combines TraceBuilder efficiency with TracePool reuse
423/// to achieve zero-allocation execution after pool warm-up.
424///
425/// Example:
426/// ```rust
427/// # use fugue::*;
428/// # use fugue::runtime::memory::*;
429/// # use rand::rngs::StdRng;
430/// # use rand::SeedableRng;
431///
432/// let mut pool = TracePool::new(10);
433/// let mut rng = StdRng::seed_from_u64(42);
434///
435/// // Run model with pooled handler
436/// let (result, trace) = runtime::handler::run(
437///     PooledPriorHandler {
438///         rng: &mut rng,
439///         trace_builder: TraceBuilder::new(),
440///         pool: &mut pool,
441///     },
442///     sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
443/// );
444///
445/// // Return trace to pool for reuse
446/// pool.return_trace(trace);
447///
448/// // Subsequent runs will reuse pooled traces (zero allocations)
449/// assert!(result.is_finite());
450/// ```
451pub struct PooledPriorHandler<'a, R: rand::RngCore> {
452    pub rng: &'a mut R,
453    pub trace_builder: TraceBuilder,
454    pub pool: &'a mut TracePool,
455}
456
457impl<'a, R: rand::RngCore> crate::runtime::handler::Handler for PooledPriorHandler<'a, R> {
458    fn on_sample_f64(&mut self, addr: &Address, dist: &dyn Distribution<f64>) -> f64 {
459        let x = dist.sample(self.rng);
460        let lp = dist.log_prob(&x);
461        self.trace_builder.add_sample(addr.clone(), x, lp);
462        x
463    }
464
465    fn on_sample_bool(&mut self, addr: &Address, dist: &dyn Distribution<bool>) -> bool {
466        let x = dist.sample(self.rng);
467        let lp = dist.log_prob(&x);
468        self.trace_builder.add_sample_bool(addr.clone(), x, lp);
469        x
470    }
471
472    fn on_sample_u64(&mut self, addr: &Address, dist: &dyn Distribution<u64>) -> u64 {
473        let x = dist.sample(self.rng);
474        let lp = dist.log_prob(&x);
475        self.trace_builder.add_sample_u64(addr.clone(), x, lp);
476        x
477    }
478
479    fn on_sample_usize(&mut self, addr: &Address, dist: &dyn Distribution<usize>) -> usize {
480        let x = dist.sample(self.rng);
481        let lp = dist.log_prob(&x);
482        self.trace_builder.add_sample_usize(addr.clone(), x, lp);
483        x
484    }
485
486    fn on_observe_f64(&mut self, _: &Address, dist: &dyn Distribution<f64>, value: f64) {
487        let log_likelihood = dist.log_prob(&value);
488        self.trace_builder.add_observation(log_likelihood);
489    }
490
491    fn on_observe_bool(&mut self, _: &Address, dist: &dyn Distribution<bool>, value: bool) {
492        let log_likelihood = dist.log_prob(&value);
493        self.trace_builder.add_observation(log_likelihood);
494    }
495
496    fn on_observe_u64(&mut self, _: &Address, dist: &dyn Distribution<u64>, value: u64) {
497        let log_likelihood = dist.log_prob(&value);
498        self.trace_builder.add_observation(log_likelihood);
499    }
500
501    fn on_observe_usize(&mut self, _: &Address, dist: &dyn Distribution<usize>, value: usize) {
502        let log_likelihood = dist.log_prob(&value);
503        self.trace_builder.add_observation(log_likelihood);
504    }
505
506    fn on_factor(&mut self, logw: f64) {
507        self.trace_builder.add_factor(logw);
508    }
509
510    fn finish(self) -> Trace {
511        self.trace_builder.build()
512    }
513}
514
515#[cfg(test)]
516mod memory_tests {
517    use super::*;
518    use crate::addr;
519    use std::time::Instant;
520
521    #[test]
522    fn test_cow_trace_efficiency() {
523        let mut trace1 = CowTrace::new();
524        trace1.insert_choice(
525            addr!("x"),
526            Choice {
527                addr: addr!("x"),
528                value: ChoiceValue::F64(1.0),
529                logp: -0.5,
530            },
531        );
532
533        // Clone should be efficient (no copying yet)
534        let trace2 = trace1.clone();
535        assert!(Arc::ptr_eq(&trace1.choices, &trace2.choices));
536
537        // Modifying one should trigger copy
538        let mut trace3 = trace2.clone();
539        trace3.insert_choice(
540            addr!("y"),
541            Choice {
542                addr: addr!("y"),
543                value: ChoiceValue::F64(2.0),
544                logp: -1.0,
545            },
546        );
547
548        // Now they should have different underlying data
549        assert!(!Arc::ptr_eq(&trace1.choices, &trace3.choices));
550    }
551
552    #[test]
553    fn test_trace_pool_basic() {
554        let mut pool = TracePool::new(3);
555
556        // Get traces from pool
557        let trace1 = pool.get();
558        let trace2 = pool.get();
559
560        // Should be cache misses initially
561        assert_eq!(pool.stats().misses, 2);
562        assert_eq!(pool.stats().hits, 0);
563
564        // Return to pool
565        pool.return_trace(trace1);
566        pool.return_trace(trace2);
567        assert_eq!(pool.stats().returns, 2);
568
569        // Should reuse returned traces (cache hits)
570        let trace3 = pool.get();
571        assert_eq!(trace3.choices.len(), 0); // Should be cleared
572        assert_eq!(pool.stats().hits, 1);
573    }
574
575    #[test]
576    fn test_trace_pool_stats() {
577        let mut pool = TracePool::new(2);
578
579        // Test hit/miss tracking
580        let t1 = pool.get(); // miss
581        let t2 = pool.get(); // miss
582        assert_eq!(pool.stats().misses, 2);
583        assert_eq!(pool.stats().hit_ratio(), 0.0);
584
585        pool.return_trace(t1); // return
586        let _t3 = pool.get(); // hit
587        assert_eq!(pool.stats().hits, 1);
588        assert_eq!(pool.stats().returns, 1);
589        assert!(pool.stats().hit_ratio() > 0.0);
590
591        // Test overflow (drop) - need to fill pool first
592        pool.return_trace(t2); // return (pool now has 1 item)
593        let another_trace = pool.get(); // get the returned trace (hit)
594        pool.return_trace(another_trace); // return it (pool now has 1 item)
595
596        // Add one more to make pool full (capacity 2)
597        let extra_trace = Trace::default();
598        pool.return_trace(extra_trace); // pool now has 2 items (full)
599
600        // Now this should be dropped
601        let dummy_trace = Trace {
602            log_prior: 1.0, // Make it non-empty
603            ..Trace::default()
604        };
605        pool.return_trace(dummy_trace); // should be dropped because pool is full
606        assert_eq!(pool.stats().drops, 1);
607    }
608
609    #[test]
610    fn test_trace_pool_shrinking() {
611        let mut pool = TracePool::with_bounds(10, 3);
612
613        // Fill pool beyond minimum
614        for _ in 0..8 {
615            pool.return_trace(Trace::default());
616        }
617        assert_eq!(pool.len(), 8);
618
619        // Shrink should reduce to minimum
620        pool.shrink();
621        assert_eq!(pool.len(), 3);
622
623        // Shrink to specific size
624        for _ in 0..5 {
625            pool.return_trace(Trace::default());
626        }
627        assert_eq!(pool.len(), 8); // 3 + 5
628        pool.shrink_to(2);
629        assert_eq!(pool.len(), 2);
630    }
631
632    #[test]
633    fn test_trace_builder_efficiency() {
634        let mut builder = TraceBuilder::new();
635
636        // Add many choices efficiently
637        for i in 0..1000 {
638            builder.add_sample(addr!("x", i), i as f64, -0.5);
639        }
640
641        let trace = builder.build();
642        assert_eq!(trace.choices.len(), 1000);
643        assert!((trace.log_prior - (-500.0)).abs() < 1e-10);
644    }
645
646    #[test]
647    fn test_address_optimization() {
648        // Test that the new TraceBuilder implementation doesn't create
649        // unnecessary address clones
650        let start = Instant::now();
651        let mut builder = TraceBuilder::new();
652
653        for i in 0..10000 {
654            let addr = addr!("test", i);
655            builder.add_sample(addr, i as f64, -0.5);
656        }
657
658        let trace = builder.build();
659        let duration = start.elapsed();
660
661        assert_eq!(trace.choices.len(), 10000);
662        // This is a smoke test - in practice you'd compare with a baseline
663        println!("Built trace with 10k choices in {:?}", duration);
664    }
665
666    #[test]
667    fn test_mixed_value_types() {
668        let mut builder = TraceBuilder::new();
669
670        // Test all supported value types
671        builder.add_sample(addr!("f64"), 1.5, -0.5);
672        builder.add_sample_bool(addr!("bool"), true, -0.693);
673        builder.add_sample_u64(addr!("u64"), 42, -1.0);
674        builder.add_sample_usize(addr!("usize"), 3, -1.2);
675
676        let trace = builder.build();
677        assert_eq!(trace.choices.len(), 4);
678
679        // Verify values are stored correctly
680        assert_eq!(trace.choices[&addr!("f64")].value, ChoiceValue::F64(1.5));
681        assert_eq!(trace.choices[&addr!("bool")].value, ChoiceValue::Bool(true));
682        assert_eq!(trace.choices[&addr!("u64")].value, ChoiceValue::U64(42));
683        assert_eq!(trace.choices[&addr!("usize")].value, ChoiceValue::Usize(3));
684    }
685
686    #[test]
687    fn test_cow_trace_memory_sharing() {
688        // Create a large base trace
689        let mut base = Trace::default();
690        for i in 0..1000 {
691            base.insert_choice(addr!("x", i), ChoiceValue::F64(i as f64), -0.5);
692        }
693        let cow_base = CowTrace::from_trace(base);
694
695        // Create many clones (should share memory)
696        let mut clones = Vec::new();
697        for _ in 0..100 {
698            clones.push(cow_base.clone());
699        }
700
701        // All clones should share the same Arc
702        for clone in &clones {
703            assert!(Arc::ptr_eq(&cow_base.choices, &clone.choices));
704        }
705
706        // Modifying one clone should not affect others
707        let mut modified = clones[0].clone();
708        modified.insert_choice(
709            addr!("new"),
710            Choice {
711                addr: addr!("new"),
712                value: ChoiceValue::F64(999.0),
713                logp: -2.0,
714            },
715        );
716
717        // The modified clone should have different data
718        assert!(!Arc::ptr_eq(&cow_base.choices, &modified.choices));
719        // But other clones should still share with base
720        assert!(Arc::ptr_eq(&cow_base.choices, &clones[1].choices));
721    }
722
723    #[test]
724    fn test_pool_stats_accuracy() {
725        let mut pool = TracePool::new(5);
726
727        // Pattern: get 10, return 5, get 10 more
728        // First 10 gets: all misses
729        for _ in 0..10 {
730            pool.get(); // 10 misses
731        }
732
733        // Return 5 traces (pool capacity is 5, so all should be accepted)
734        for _ in 0..5 {
735            pool.return_trace(Trace::default()); // 5 returns
736        }
737
738        // Next 10 gets: first 5 should be hits, next 5 should be misses
739        for _ in 0..10 {
740            pool.get(); // 5 hits + 5 misses
741        }
742
743        let stats = pool.stats();
744        assert_eq!(stats.misses, 15); // 10 + 5
745        assert_eq!(stats.hits, 5);
746        assert_eq!(stats.returns, 5);
747        assert_eq!(stats.drops, 0);
748        assert_eq!(stats.total_gets(), 20);
749        assert!((stats.hit_ratio() - 25.0).abs() < 1e-10);
750    }
751}
752
753#[cfg(test)]
754mod pooled_tests {
755    use super::*;
756    use crate::addr;
757    use crate::core::distribution::*;
758    use crate::core::model::{observe, sample, ModelExt};
759    use crate::runtime::handler::run;
760    use rand::rngs::StdRng;
761    use rand::SeedableRng;
762
763    #[test]
764    fn pooled_prior_handler_builds_trace_and_updates_pool() {
765        let mut pool = TracePool::new(4);
766        let mut rng = StdRng::seed_from_u64(40);
767        let (_val, trace) = run(
768            PooledPriorHandler {
769                rng: &mut rng,
770                trace_builder: TraceBuilder::new(),
771                pool: &mut pool,
772            },
773            sample(addr!("x"), Normal::new(0.0, 1.0).unwrap())
774                .and_then(|x| observe(addr!("y"), Normal::new(x, 1.0).unwrap(), 0.3)),
775        );
776        assert!(trace.choices.contains_key(&addr!("x")));
777        assert!(trace.log_likelihood.is_finite());
778
779        // Return a trace and check stats update when pool accepts
780        let before_returns = pool.stats().returns;
781        pool.return_trace(trace);
782        assert_eq!(pool.stats().returns, before_returns + 1);
783    }
784}