anomaly_grid/
memory_pool.rs

1//! Memory pooling for efficient allocation management
2//!
3//! This module provides object pooling for frequently allocated types
4//! to reduce allocation overhead and memory fragmentation during training.
5
6use crate::context_tree::ContextNode;
7use crate::context_trie::{TrieNode, NodeId};
8use crate::string_interner::{StateId, StringInterner};
9use smallvec::SmallVec;
10use std::sync::Arc;
11
12/// Memory pool for managing object allocations efficiently
13#[derive(Debug)]
14pub struct MemoryPool {
15    /// Pool of reusable ContextNode objects
16    context_nodes: Vec<ContextNode>,
17    /// Indices of free ContextNode objects
18    free_context_nodes: Vec<usize>,
19    
20    /// Pool of reusable TrieNode objects
21    trie_nodes: Vec<TrieNode>,
22    /// Indices of free TrieNode objects
23    free_trie_nodes: Vec<usize>,
24    
25    /// Pool of reusable SmallVec objects for transition counts
26    small_vecs: Vec<SmallVec<[(StateId, usize); 4]>>,
27    /// Indices of free SmallVec objects
28    free_small_vecs: Vec<usize>,
29    
30    /// Pool statistics
31    stats: PoolStats,
32}
33
34/// Statistics for memory pool usage
35#[derive(Debug, Clone, Default)]
36pub struct PoolStats {
37    /// Total ContextNode allocations requested
38    pub context_node_requests: usize,
39    /// ContextNode allocations served from pool
40    pub context_node_hits: usize,
41    
42    /// Total TrieNode allocations requested
43    pub trie_node_requests: usize,
44    /// TrieNode allocations served from pool
45    pub trie_node_hits: usize,
46    
47    /// Total SmallVec allocations requested
48    pub small_vec_requests: usize,
49    /// SmallVec allocations served from pool
50    pub small_vec_hits: usize,
51    
52    /// Peak pool sizes
53    pub peak_context_nodes: usize,
54    pub peak_trie_nodes: usize,
55    pub peak_small_vecs: usize,
56}
57
58impl MemoryPool {
59    /// Create a new memory pool with default capacity
60    pub fn new() -> Self {
61        Self::with_capacity(64, 256, 128)
62    }
63    
64    /// Create a new memory pool with specified initial capacities
65    pub fn with_capacity(
66        context_nodes: usize,
67        trie_nodes: usize,
68        small_vecs: usize,
69    ) -> Self {
70        Self {
71            context_nodes: Vec::with_capacity(context_nodes),
72            free_context_nodes: Vec::with_capacity(context_nodes),
73            trie_nodes: Vec::with_capacity(trie_nodes),
74            free_trie_nodes: Vec::with_capacity(trie_nodes),
75            small_vecs: Vec::with_capacity(small_vecs),
76            free_small_vecs: Vec::with_capacity(small_vecs),
77            stats: PoolStats::default(),
78        }
79    }
80    
81    /// Get a ContextNode from the pool or create a new one
82    pub fn get_context_node(&mut self, interner: Arc<StringInterner>) -> ContextNode {
83        self.stats.context_node_requests += 1;
84        
85        if let Some(index) = self.free_context_nodes.pop() {
86            self.stats.context_node_hits += 1;
87            
88            // Reset the node for reuse
89            let mut node = std::mem::take(&mut self.context_nodes[index]);
90            node.reset(interner);
91            node
92        } else {
93            // Create new node
94            ContextNode::new(interner)
95        }
96    }
97    
98    /// Return a ContextNode to the pool for reuse
99    pub fn return_context_node(&mut self, mut node: ContextNode) {
100        // Clear the node data but keep the allocation
101        node.clear();
102        
103        // Add to pool if we have space
104        if self.context_nodes.len() < self.context_nodes.capacity() {
105            let index = self.context_nodes.len();
106            self.context_nodes.push(node);
107            self.free_context_nodes.push(index);
108            
109            // Update peak statistics
110            if self.context_nodes.len() > self.stats.peak_context_nodes {
111                self.stats.peak_context_nodes = self.context_nodes.len();
112            }
113        }
114        // Otherwise, let it drop (pool is full)
115    }
116    
117    /// Get a TrieNode from the pool or create a new one
118    pub fn get_trie_node(&mut self, parent: Option<NodeId>, state_from_parent: Option<StateId>) -> TrieNode {
119        self.stats.trie_node_requests += 1;
120        
121        if let Some(index) = self.free_trie_nodes.pop() {
122            self.stats.trie_node_hits += 1;
123            
124            // Reset the node for reuse
125            let mut node = std::mem::take(&mut self.trie_nodes[index]);
126            node.reset(parent, state_from_parent);
127            node
128        } else {
129            // Create new node
130            TrieNode::new(parent, state_from_parent)
131        }
132    }
133    
134    /// Return a TrieNode to the pool for reuse
135    pub fn return_trie_node(&mut self, mut node: TrieNode) {
136        // Clear the node data but keep the allocation
137        node.clear();
138        
139        // Add to pool if we have space
140        if self.trie_nodes.len() < self.trie_nodes.capacity() {
141            let index = self.trie_nodes.len();
142            self.trie_nodes.push(node);
143            self.free_trie_nodes.push(index);
144            
145            // Update peak statistics
146            if self.trie_nodes.len() > self.stats.peak_trie_nodes {
147                self.stats.peak_trie_nodes = self.trie_nodes.len();
148            }
149        }
150        // Otherwise, let it drop (pool is full)
151    }
152    
153    /// Get a SmallVec from the pool or create a new one
154    pub fn get_small_vec(&mut self) -> SmallVec<[(StateId, usize); 4]> {
155        self.stats.small_vec_requests += 1;
156        
157        if let Some(index) = self.free_small_vecs.pop() {
158            self.stats.small_vec_hits += 1;
159            
160            // Take the SmallVec and clear it
161            let mut vec = std::mem::take(&mut self.small_vecs[index]);
162            vec.clear();
163            vec
164        } else {
165            // Create new SmallVec
166            SmallVec::new()
167        }
168    }
169    
170    /// Return a SmallVec to the pool for reuse
171    pub fn return_small_vec(&mut self, mut vec: SmallVec<[(StateId, usize); 4]>) {
172        // Clear the vector but keep the allocation
173        vec.clear();
174        
175        // Add to pool if we have space
176        if self.small_vecs.len() < self.small_vecs.capacity() {
177            let index = self.small_vecs.len();
178            self.small_vecs.push(vec);
179            self.free_small_vecs.push(index);
180            
181            // Update peak statistics
182            if self.small_vecs.len() > self.stats.peak_small_vecs {
183                self.stats.peak_small_vecs = self.small_vecs.len();
184            }
185        }
186        // Otherwise, let it drop (pool is full)
187    }
188    
189    /// Get pool statistics
190    pub fn stats(&self) -> &PoolStats {
191        &self.stats
192    }
193    
194    /// Reset pool statistics
195    pub fn reset_stats(&mut self) {
196        self.stats = PoolStats::default();
197    }
198    
199    /// Get current pool sizes
200    pub fn pool_sizes(&self) -> (usize, usize, usize) {
201        (
202            self.context_nodes.len(),
203            self.trie_nodes.len(),
204            self.small_vecs.len(),
205        )
206    }
207    
208    /// Calculate hit rates for each pool type
209    pub fn hit_rates(&self) -> (f64, f64, f64) {
210        let context_hit_rate = if self.stats.context_node_requests > 0 {
211            self.stats.context_node_hits as f64 / self.stats.context_node_requests as f64
212        } else {
213            0.0
214        };
215        
216        let trie_hit_rate = if self.stats.trie_node_requests > 0 {
217            self.stats.trie_node_hits as f64 / self.stats.trie_node_requests as f64
218        } else {
219            0.0
220        };
221        
222        let small_vec_hit_rate = if self.stats.small_vec_requests > 0 {
223            self.stats.small_vec_hits as f64 / self.stats.small_vec_requests as f64
224        } else {
225            0.0
226        };
227        
228        (context_hit_rate, trie_hit_rate, small_vec_hit_rate)
229    }
230    
231    /// Estimate memory usage of the pool
232    pub fn memory_usage(&self) -> usize {
233        let mut total = std::mem::size_of::<Self>();
234        
235        // ContextNode pool
236        total += self.context_nodes.capacity() * std::mem::size_of::<ContextNode>();
237        total += self.free_context_nodes.capacity() * std::mem::size_of::<usize>();
238        
239        // TrieNode pool
240        total += self.trie_nodes.capacity() * std::mem::size_of::<TrieNode>();
241        total += self.free_trie_nodes.capacity() * std::mem::size_of::<usize>();
242        
243        // SmallVec pool
244        total += self.small_vecs.capacity() * std::mem::size_of::<SmallVec<[(StateId, usize); 4]>>();
245        total += self.free_small_vecs.capacity() * std::mem::size_of::<usize>();
246        
247        total
248    }
249    
250    /// Auto-tune pool sizes based on usage patterns
251    pub fn auto_tune(&mut self) {
252        // Increase pool sizes if hit rates are low
253        let (context_hit_rate, trie_hit_rate, small_vec_hit_rate) = self.hit_rates();
254        
255        // If hit rate is below 80%, consider increasing pool size
256        const MIN_HIT_RATE: f64 = 0.8;
257        const GROWTH_FACTOR: f64 = 1.5;
258        
259        if context_hit_rate < MIN_HIT_RATE && self.stats.context_node_requests > 10 {
260            let new_capacity = (self.context_nodes.capacity() as f64 * GROWTH_FACTOR) as usize;
261            self.context_nodes.reserve(new_capacity - self.context_nodes.capacity());
262            self.free_context_nodes.reserve(new_capacity - self.free_context_nodes.capacity());
263        }
264        
265        if trie_hit_rate < MIN_HIT_RATE && self.stats.trie_node_requests > 10 {
266            let new_capacity = (self.trie_nodes.capacity() as f64 * GROWTH_FACTOR) as usize;
267            self.trie_nodes.reserve(new_capacity - self.trie_nodes.capacity());
268            self.free_trie_nodes.reserve(new_capacity - self.free_trie_nodes.capacity());
269        }
270        
271        if small_vec_hit_rate < MIN_HIT_RATE && self.stats.small_vec_requests > 10 {
272            let new_capacity = (self.small_vecs.capacity() as f64 * GROWTH_FACTOR) as usize;
273            self.small_vecs.reserve(new_capacity - self.small_vecs.capacity());
274            self.free_small_vecs.reserve(new_capacity - self.free_small_vecs.capacity());
275        }
276    }
277}
278
279impl Default for MemoryPool {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285impl PoolStats {
286    /// Calculate overall hit rate across all pool types
287    pub fn overall_hit_rate(&self) -> f64 {
288        let total_requests = self.context_node_requests + self.trie_node_requests + self.small_vec_requests;
289        let total_hits = self.context_node_hits + self.trie_node_hits + self.small_vec_hits;
290        
291        if total_requests > 0 {
292            total_hits as f64 / total_requests as f64
293        } else {
294            0.0
295        }
296    }
297    
298    /// Get a summary string of pool statistics
299    pub fn summary(&self) -> String {
300        format!(
301            "Pool Stats: Overall hit rate: {:.1}%, Context: {}/{} ({:.1}%), Trie: {}/{} ({:.1}%), SmallVec: {}/{} ({:.1}%)",
302            self.overall_hit_rate() * 100.0,
303            self.context_node_hits, self.context_node_requests,
304            if self.context_node_requests > 0 { self.context_node_hits as f64 / self.context_node_requests as f64 * 100.0 } else { 0.0 },
305            self.trie_node_hits, self.trie_node_requests,
306            if self.trie_node_requests > 0 { self.trie_node_hits as f64 / self.trie_node_requests as f64 * 100.0 } else { 0.0 },
307            self.small_vec_hits, self.small_vec_requests,
308            if self.small_vec_requests > 0 { self.small_vec_hits as f64 / self.small_vec_requests as f64 * 100.0 } else { 0.0 }
309        )
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use crate::string_interner::StringInterner;
317    
318    #[test]
319    fn test_memory_pool_creation() {
320        let pool = MemoryPool::new();
321        assert_eq!(pool.pool_sizes(), (0, 0, 0));
322        
323        let pool = MemoryPool::with_capacity(10, 20, 15);
324        assert_eq!(pool.pool_sizes(), (0, 0, 0));
325        assert!(pool.context_nodes.capacity() >= 10);
326        assert!(pool.trie_nodes.capacity() >= 20);
327        assert!(pool.small_vecs.capacity() >= 15);
328    }
329    
330    #[test]
331    fn test_context_node_pooling() {
332        let mut pool = MemoryPool::new();
333        let interner = Arc::new(StringInterner::new());
334        
335        // Get a node from empty pool (should create new)
336        let node1 = pool.get_context_node(Arc::clone(&interner));
337        assert_eq!(pool.stats().context_node_requests, 1);
338        assert_eq!(pool.stats().context_node_hits, 0);
339        
340        // Return the node
341        pool.return_context_node(node1);
342        assert_eq!(pool.pool_sizes().0, 1);
343        
344        // Get another node (should come from pool)
345        let _node2 = pool.get_context_node(Arc::clone(&interner));
346        assert_eq!(pool.stats().context_node_requests, 2);
347        assert_eq!(pool.stats().context_node_hits, 1);
348    }
349    
350    #[test]
351    fn test_trie_node_pooling() {
352        let mut pool = MemoryPool::new();
353        
354        // Get a node from empty pool
355        let node1 = pool.get_trie_node(None, None);
356        assert_eq!(pool.stats().trie_node_requests, 1);
357        assert_eq!(pool.stats().trie_node_hits, 0);
358        
359        // Return the node
360        pool.return_trie_node(node1);
361        assert_eq!(pool.pool_sizes().1, 1);
362        
363        // Get another node (should come from pool)
364        let _node2 = pool.get_trie_node(Some(NodeId::new(1)), Some(StateId::new(1)));
365        assert_eq!(pool.stats().trie_node_requests, 2);
366        assert_eq!(pool.stats().trie_node_hits, 1);
367    }
368    
369    #[test]
370    fn test_small_vec_pooling() {
371        let mut pool = MemoryPool::new();
372        
373        // Get a SmallVec from empty pool
374        let vec1 = pool.get_small_vec();
375        assert_eq!(pool.stats().small_vec_requests, 1);
376        assert_eq!(pool.stats().small_vec_hits, 0);
377        
378        // Return the SmallVec
379        pool.return_small_vec(vec1);
380        assert_eq!(pool.pool_sizes().2, 1);
381        
382        // Get another SmallVec (should come from pool)
383        let _vec2 = pool.get_small_vec();
384        assert_eq!(pool.stats().small_vec_requests, 2);
385        assert_eq!(pool.stats().small_vec_hits, 1);
386    }
387    
388    #[test]
389    fn test_hit_rates() {
390        let mut pool = MemoryPool::new();
391        let interner = Arc::new(StringInterner::new());
392        
393        // Initial hit rates should be 0
394        let (context_rate, trie_rate, vec_rate) = pool.hit_rates();
395        assert_eq!(context_rate, 0.0);
396        assert_eq!(trie_rate, 0.0);
397        assert_eq!(vec_rate, 0.0);
398        
399        // Get and return some objects
400        let node = pool.get_context_node(Arc::clone(&interner));
401        pool.return_context_node(node);
402        let _node = pool.get_context_node(Arc::clone(&interner));
403        
404        // Should have 50% hit rate for context nodes
405        let (context_rate, _, _) = pool.hit_rates();
406        assert_eq!(context_rate, 0.5);
407    }
408    
409    #[test]
410    fn test_memory_usage_calculation() {
411        let pool = MemoryPool::new();
412        let usage = pool.memory_usage();
413        assert!(usage > 0);
414        assert!(usage >= std::mem::size_of::<MemoryPool>());
415    }
416    
417    #[test]
418    fn test_auto_tuning() {
419        let mut pool = MemoryPool::with_capacity(2, 2, 2);
420        let interner = Arc::new(StringInterner::new());
421        
422        // Generate low hit rate scenario (more requests than capacity)
423        for _ in 0..15 {
424            let node = pool.get_context_node(Arc::clone(&interner));
425            // Don't return nodes to keep hit rate low
426            std::mem::drop(node);
427        }
428        
429        let initial_capacity = pool.context_nodes.capacity();
430        let (context_hit_rate, _, _) = pool.hit_rates();
431        
432        // Verify we have low hit rate
433        assert!(context_hit_rate < 0.8, "Hit rate should be low: {context_hit_rate:.2}");
434        
435        pool.auto_tune();
436        
437        // Capacity should have increased due to low hit rate
438        assert!(pool.context_nodes.capacity() >= initial_capacity, 
439               "Capacity should not decrease: {} -> {}", 
440               initial_capacity, pool.context_nodes.capacity());
441    }
442    
443    #[test]
444    fn test_pool_stats_summary() {
445        let mut pool = MemoryPool::new();
446        let interner = Arc::new(StringInterner::new());
447        
448        // Generate some activity
449        let node = pool.get_context_node(Arc::clone(&interner));
450        pool.return_context_node(node);
451        let _node = pool.get_context_node(Arc::clone(&interner));
452        
453        let summary = pool.stats().summary();
454        assert!(summary.contains("Pool Stats"));
455        assert!(summary.contains("hit rate"));
456    }
457}