1use crate::context_tree::ContextNode;
7use crate::context_trie::{TrieNode, NodeId};
8use crate::string_interner::{StateId, StringInterner};
9use smallvec::SmallVec;
10use std::sync::Arc;
11
12#[derive(Debug)]
14pub struct MemoryPool {
15 context_nodes: Vec<ContextNode>,
17 free_context_nodes: Vec<usize>,
19
20 trie_nodes: Vec<TrieNode>,
22 free_trie_nodes: Vec<usize>,
24
25 small_vecs: Vec<SmallVec<[(StateId, usize); 4]>>,
27 free_small_vecs: Vec<usize>,
29
30 stats: PoolStats,
32}
33
34#[derive(Debug, Clone, Default)]
36pub struct PoolStats {
37 pub context_node_requests: usize,
39 pub context_node_hits: usize,
41
42 pub trie_node_requests: usize,
44 pub trie_node_hits: usize,
46
47 pub small_vec_requests: usize,
49 pub small_vec_hits: usize,
51
52 pub peak_context_nodes: usize,
54 pub peak_trie_nodes: usize,
55 pub peak_small_vecs: usize,
56}
57
58impl MemoryPool {
59 pub fn new() -> Self {
61 Self::with_capacity(64, 256, 128)
62 }
63
64 pub fn with_capacity(
66 context_nodes: usize,
67 trie_nodes: usize,
68 small_vecs: usize,
69 ) -> Self {
70 Self {
71 context_nodes: Vec::with_capacity(context_nodes),
72 free_context_nodes: Vec::with_capacity(context_nodes),
73 trie_nodes: Vec::with_capacity(trie_nodes),
74 free_trie_nodes: Vec::with_capacity(trie_nodes),
75 small_vecs: Vec::with_capacity(small_vecs),
76 free_small_vecs: Vec::with_capacity(small_vecs),
77 stats: PoolStats::default(),
78 }
79 }
80
81 pub fn get_context_node(&mut self, interner: Arc<StringInterner>) -> ContextNode {
83 self.stats.context_node_requests += 1;
84
85 if let Some(index) = self.free_context_nodes.pop() {
86 self.stats.context_node_hits += 1;
87
88 let mut node = std::mem::take(&mut self.context_nodes[index]);
90 node.reset(interner);
91 node
92 } else {
93 ContextNode::new(interner)
95 }
96 }
97
98 pub fn return_context_node(&mut self, mut node: ContextNode) {
100 node.clear();
102
103 if self.context_nodes.len() < self.context_nodes.capacity() {
105 let index = self.context_nodes.len();
106 self.context_nodes.push(node);
107 self.free_context_nodes.push(index);
108
109 if self.context_nodes.len() > self.stats.peak_context_nodes {
111 self.stats.peak_context_nodes = self.context_nodes.len();
112 }
113 }
114 }
116
117 pub fn get_trie_node(&mut self, parent: Option<NodeId>, state_from_parent: Option<StateId>) -> TrieNode {
119 self.stats.trie_node_requests += 1;
120
121 if let Some(index) = self.free_trie_nodes.pop() {
122 self.stats.trie_node_hits += 1;
123
124 let mut node = std::mem::take(&mut self.trie_nodes[index]);
126 node.reset(parent, state_from_parent);
127 node
128 } else {
129 TrieNode::new(parent, state_from_parent)
131 }
132 }
133
134 pub fn return_trie_node(&mut self, mut node: TrieNode) {
136 node.clear();
138
139 if self.trie_nodes.len() < self.trie_nodes.capacity() {
141 let index = self.trie_nodes.len();
142 self.trie_nodes.push(node);
143 self.free_trie_nodes.push(index);
144
145 if self.trie_nodes.len() > self.stats.peak_trie_nodes {
147 self.stats.peak_trie_nodes = self.trie_nodes.len();
148 }
149 }
150 }
152
153 pub fn get_small_vec(&mut self) -> SmallVec<[(StateId, usize); 4]> {
155 self.stats.small_vec_requests += 1;
156
157 if let Some(index) = self.free_small_vecs.pop() {
158 self.stats.small_vec_hits += 1;
159
160 let mut vec = std::mem::take(&mut self.small_vecs[index]);
162 vec.clear();
163 vec
164 } else {
165 SmallVec::new()
167 }
168 }
169
170 pub fn return_small_vec(&mut self, mut vec: SmallVec<[(StateId, usize); 4]>) {
172 vec.clear();
174
175 if self.small_vecs.len() < self.small_vecs.capacity() {
177 let index = self.small_vecs.len();
178 self.small_vecs.push(vec);
179 self.free_small_vecs.push(index);
180
181 if self.small_vecs.len() > self.stats.peak_small_vecs {
183 self.stats.peak_small_vecs = self.small_vecs.len();
184 }
185 }
186 }
188
189 pub fn stats(&self) -> &PoolStats {
191 &self.stats
192 }
193
194 pub fn reset_stats(&mut self) {
196 self.stats = PoolStats::default();
197 }
198
199 pub fn pool_sizes(&self) -> (usize, usize, usize) {
201 (
202 self.context_nodes.len(),
203 self.trie_nodes.len(),
204 self.small_vecs.len(),
205 )
206 }
207
208 pub fn hit_rates(&self) -> (f64, f64, f64) {
210 let context_hit_rate = if self.stats.context_node_requests > 0 {
211 self.stats.context_node_hits as f64 / self.stats.context_node_requests as f64
212 } else {
213 0.0
214 };
215
216 let trie_hit_rate = if self.stats.trie_node_requests > 0 {
217 self.stats.trie_node_hits as f64 / self.stats.trie_node_requests as f64
218 } else {
219 0.0
220 };
221
222 let small_vec_hit_rate = if self.stats.small_vec_requests > 0 {
223 self.stats.small_vec_hits as f64 / self.stats.small_vec_requests as f64
224 } else {
225 0.0
226 };
227
228 (context_hit_rate, trie_hit_rate, small_vec_hit_rate)
229 }
230
231 pub fn memory_usage(&self) -> usize {
233 let mut total = std::mem::size_of::<Self>();
234
235 total += self.context_nodes.capacity() * std::mem::size_of::<ContextNode>();
237 total += self.free_context_nodes.capacity() * std::mem::size_of::<usize>();
238
239 total += self.trie_nodes.capacity() * std::mem::size_of::<TrieNode>();
241 total += self.free_trie_nodes.capacity() * std::mem::size_of::<usize>();
242
243 total += self.small_vecs.capacity() * std::mem::size_of::<SmallVec<[(StateId, usize); 4]>>();
245 total += self.free_small_vecs.capacity() * std::mem::size_of::<usize>();
246
247 total
248 }
249
250 pub fn auto_tune(&mut self) {
252 let (context_hit_rate, trie_hit_rate, small_vec_hit_rate) = self.hit_rates();
254
255 const MIN_HIT_RATE: f64 = 0.8;
257 const GROWTH_FACTOR: f64 = 1.5;
258
259 if context_hit_rate < MIN_HIT_RATE && self.stats.context_node_requests > 10 {
260 let new_capacity = (self.context_nodes.capacity() as f64 * GROWTH_FACTOR) as usize;
261 self.context_nodes.reserve(new_capacity - self.context_nodes.capacity());
262 self.free_context_nodes.reserve(new_capacity - self.free_context_nodes.capacity());
263 }
264
265 if trie_hit_rate < MIN_HIT_RATE && self.stats.trie_node_requests > 10 {
266 let new_capacity = (self.trie_nodes.capacity() as f64 * GROWTH_FACTOR) as usize;
267 self.trie_nodes.reserve(new_capacity - self.trie_nodes.capacity());
268 self.free_trie_nodes.reserve(new_capacity - self.free_trie_nodes.capacity());
269 }
270
271 if small_vec_hit_rate < MIN_HIT_RATE && self.stats.small_vec_requests > 10 {
272 let new_capacity = (self.small_vecs.capacity() as f64 * GROWTH_FACTOR) as usize;
273 self.small_vecs.reserve(new_capacity - self.small_vecs.capacity());
274 self.free_small_vecs.reserve(new_capacity - self.free_small_vecs.capacity());
275 }
276 }
277}
278
279impl Default for MemoryPool {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285impl PoolStats {
286 pub fn overall_hit_rate(&self) -> f64 {
288 let total_requests = self.context_node_requests + self.trie_node_requests + self.small_vec_requests;
289 let total_hits = self.context_node_hits + self.trie_node_hits + self.small_vec_hits;
290
291 if total_requests > 0 {
292 total_hits as f64 / total_requests as f64
293 } else {
294 0.0
295 }
296 }
297
298 pub fn summary(&self) -> String {
300 format!(
301 "Pool Stats: Overall hit rate: {:.1}%, Context: {}/{} ({:.1}%), Trie: {}/{} ({:.1}%), SmallVec: {}/{} ({:.1}%)",
302 self.overall_hit_rate() * 100.0,
303 self.context_node_hits, self.context_node_requests,
304 if self.context_node_requests > 0 { self.context_node_hits as f64 / self.context_node_requests as f64 * 100.0 } else { 0.0 },
305 self.trie_node_hits, self.trie_node_requests,
306 if self.trie_node_requests > 0 { self.trie_node_hits as f64 / self.trie_node_requests as f64 * 100.0 } else { 0.0 },
307 self.small_vec_hits, self.small_vec_requests,
308 if self.small_vec_requests > 0 { self.small_vec_hits as f64 / self.small_vec_requests as f64 * 100.0 } else { 0.0 }
309 )
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use crate::string_interner::StringInterner;
317
318 #[test]
319 fn test_memory_pool_creation() {
320 let pool = MemoryPool::new();
321 assert_eq!(pool.pool_sizes(), (0, 0, 0));
322
323 let pool = MemoryPool::with_capacity(10, 20, 15);
324 assert_eq!(pool.pool_sizes(), (0, 0, 0));
325 assert!(pool.context_nodes.capacity() >= 10);
326 assert!(pool.trie_nodes.capacity() >= 20);
327 assert!(pool.small_vecs.capacity() >= 15);
328 }
329
330 #[test]
331 fn test_context_node_pooling() {
332 let mut pool = MemoryPool::new();
333 let interner = Arc::new(StringInterner::new());
334
335 let node1 = pool.get_context_node(Arc::clone(&interner));
337 assert_eq!(pool.stats().context_node_requests, 1);
338 assert_eq!(pool.stats().context_node_hits, 0);
339
340 pool.return_context_node(node1);
342 assert_eq!(pool.pool_sizes().0, 1);
343
344 let _node2 = pool.get_context_node(Arc::clone(&interner));
346 assert_eq!(pool.stats().context_node_requests, 2);
347 assert_eq!(pool.stats().context_node_hits, 1);
348 }
349
350 #[test]
351 fn test_trie_node_pooling() {
352 let mut pool = MemoryPool::new();
353
354 let node1 = pool.get_trie_node(None, None);
356 assert_eq!(pool.stats().trie_node_requests, 1);
357 assert_eq!(pool.stats().trie_node_hits, 0);
358
359 pool.return_trie_node(node1);
361 assert_eq!(pool.pool_sizes().1, 1);
362
363 let _node2 = pool.get_trie_node(Some(NodeId::new(1)), Some(StateId::new(1)));
365 assert_eq!(pool.stats().trie_node_requests, 2);
366 assert_eq!(pool.stats().trie_node_hits, 1);
367 }
368
369 #[test]
370 fn test_small_vec_pooling() {
371 let mut pool = MemoryPool::new();
372
373 let vec1 = pool.get_small_vec();
375 assert_eq!(pool.stats().small_vec_requests, 1);
376 assert_eq!(pool.stats().small_vec_hits, 0);
377
378 pool.return_small_vec(vec1);
380 assert_eq!(pool.pool_sizes().2, 1);
381
382 let _vec2 = pool.get_small_vec();
384 assert_eq!(pool.stats().small_vec_requests, 2);
385 assert_eq!(pool.stats().small_vec_hits, 1);
386 }
387
388 #[test]
389 fn test_hit_rates() {
390 let mut pool = MemoryPool::new();
391 let interner = Arc::new(StringInterner::new());
392
393 let (context_rate, trie_rate, vec_rate) = pool.hit_rates();
395 assert_eq!(context_rate, 0.0);
396 assert_eq!(trie_rate, 0.0);
397 assert_eq!(vec_rate, 0.0);
398
399 let node = pool.get_context_node(Arc::clone(&interner));
401 pool.return_context_node(node);
402 let _node = pool.get_context_node(Arc::clone(&interner));
403
404 let (context_rate, _, _) = pool.hit_rates();
406 assert_eq!(context_rate, 0.5);
407 }
408
409 #[test]
410 fn test_memory_usage_calculation() {
411 let pool = MemoryPool::new();
412 let usage = pool.memory_usage();
413 assert!(usage > 0);
414 assert!(usage >= std::mem::size_of::<MemoryPool>());
415 }
416
417 #[test]
418 fn test_auto_tuning() {
419 let mut pool = MemoryPool::with_capacity(2, 2, 2);
420 let interner = Arc::new(StringInterner::new());
421
422 for _ in 0..15 {
424 let node = pool.get_context_node(Arc::clone(&interner));
425 std::mem::drop(node);
427 }
428
429 let initial_capacity = pool.context_nodes.capacity();
430 let (context_hit_rate, _, _) = pool.hit_rates();
431
432 assert!(context_hit_rate < 0.8, "Hit rate should be low: {context_hit_rate:.2}");
434
435 pool.auto_tune();
436
437 assert!(pool.context_nodes.capacity() >= initial_capacity,
439 "Capacity should not decrease: {} -> {}",
440 initial_capacity, pool.context_nodes.capacity());
441 }
442
443 #[test]
444 fn test_pool_stats_summary() {
445 let mut pool = MemoryPool::new();
446 let interner = Arc::new(StringInterner::new());
447
448 let node = pool.get_context_node(Arc::clone(&interner));
450 pool.return_context_node(node);
451 let _node = pool.get_context_node(Arc::clone(&interner));
452
453 let summary = pool.stats().summary();
454 assert!(summary.contains("Pool Stats"));
455 assert!(summary.contains("hit rate"));
456 }
457}