1use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9#[derive(Clone)]
11pub struct PooledBuffer {
12 data: Vec<u8>,
13 capacity: usize,
14}
15
16impl PooledBuffer {
17 pub fn new(capacity: usize) -> Self {
19 Self {
20 data: Vec::with_capacity(capacity),
21 capacity,
22 }
23 }
24
25 pub fn capacity(&self) -> usize {
27 self.capacity
28 }
29
30 pub fn as_mut_slice(&mut self) -> &mut [u8] {
32 &mut self.data
33 }
34
35 pub fn data(&self) -> &[u8] {
37 &self.data
38 }
39
40 pub fn resize(&mut self, new_size: usize, value: u8) {
42 self.data.resize(new_size, value);
43 }
44}
45
46pub struct MemoryPool {
75 pools: Arc<Mutex<HashMap<usize, Vec<PooledBuffer>>>>,
77 stats: Arc<Mutex<PoolStats>>,
79 config: PoolConfig,
81}
82
83#[derive(Debug, Clone)]
85pub struct PoolConfig {
86 pub max_buffers_per_size: usize,
88 pub max_total_buffers: usize,
90 pub round_sizes: bool,
92}
93
94impl Default for PoolConfig {
95 fn default() -> Self {
96 Self {
97 max_buffers_per_size: 16,
98 max_total_buffers: 256,
99 round_sizes: true,
100 }
101 }
102}
103
104#[derive(Debug, Clone, Default)]
106pub struct PoolStats {
107 pub total_gets: u64,
109 pub cache_hits: u64,
111 pub cache_misses: u64,
113 pub total_returns: u64,
115 pub buffers_in_pool: usize,
117 pub bytes_in_pool: usize,
119}
120
121impl PoolStats {
122 pub fn hit_rate(&self) -> f64 {
124 if self.total_gets == 0 {
125 0.0
126 } else {
127 self.cache_hits as f64 / self.total_gets as f64
128 }
129 }
130
131 pub fn miss_rate(&self) -> f64 {
133 1.0 - self.hit_rate()
134 }
135}
136
137impl MemoryPool {
138 pub fn new() -> Self {
140 Self::with_config(PoolConfig::default())
141 }
142
143 pub fn with_config(config: PoolConfig) -> Self {
145 Self {
146 pools: Arc::new(Mutex::new(HashMap::new())),
147 stats: Arc::new(Mutex::new(PoolStats::default())),
148 config,
149 }
150 }
151
152 pub fn get(&self, size: usize) -> PooledBuffer {
164 let size_class = self.size_class(size);
165
166 let mut pools = self.pools.lock().unwrap();
167 let mut stats = self.stats.lock().unwrap();
168
169 stats.total_gets += 1;
170
171 if let Some(pool) = pools.get_mut(&size_class) {
173 if let Some(buffer) = pool.pop() {
174 stats.cache_hits += 1;
175 stats.buffers_in_pool -= 1;
176 stats.bytes_in_pool -= buffer.capacity();
177 return buffer;
178 }
179 }
180
181 stats.cache_misses += 1;
183 PooledBuffer::new(size_class)
184 }
185
186 pub fn return_buffer(&self, buffer: PooledBuffer) {
194 let size_class = buffer.capacity();
195
196 let mut pools = self.pools.lock().unwrap();
197 let mut stats = self.stats.lock().unwrap();
198
199 stats.total_returns += 1;
200
201 let pool = pools.entry(size_class).or_insert_with(Vec::new);
203
204 if pool.len() < self.config.max_buffers_per_size
205 && stats.buffers_in_pool < self.config.max_total_buffers
206 {
207 stats.buffers_in_pool += 1;
208 stats.bytes_in_pool += buffer.capacity();
209 pool.push(buffer);
210 }
211 }
213
214 pub fn stats(&self) -> PoolStats {
216 self.stats.lock().unwrap().clone()
217 }
218
219 pub fn clear(&self) {
221 let mut pools = self.pools.lock().unwrap();
222 let mut stats = self.stats.lock().unwrap();
223
224 pools.clear();
225 stats.buffers_in_pool = 0;
226 stats.bytes_in_pool = 0;
227 }
228
229 fn size_class(&self, size: usize) -> usize {
234 if self.config.round_sizes {
235 size.next_power_of_two()
236 } else {
237 size
238 }
239 }
240}
241
242impl Default for MemoryPool {
243 fn default() -> Self {
244 Self::new()
245 }
246}
247
248static GLOBAL_POOL: once_cell::sync::Lazy<MemoryPool> =
252 once_cell::sync::Lazy::new(|| MemoryPool::new());
253
254pub fn global_pool() -> &'static MemoryPool {
256 &GLOBAL_POOL
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_pool_creation() {
265 let pool = MemoryPool::new();
266 let stats = pool.stats();
267 assert_eq!(stats.total_gets, 0);
268 assert_eq!(stats.cache_hits, 0);
269 }
270
271 #[test]
272 fn test_get_and_return() {
273 let pool = MemoryPool::new();
274
275 let buf = pool.get(1024);
277 assert_eq!(buf.capacity(), 1024);
278
279 let stats = pool.stats();
280 assert_eq!(stats.total_gets, 1);
281 assert_eq!(stats.cache_misses, 1);
282 assert_eq!(stats.cache_hits, 0);
283
284 pool.return_buffer(buf);
286
287 let stats = pool.stats();
288 assert_eq!(stats.total_returns, 1);
289 assert_eq!(stats.buffers_in_pool, 1);
290
291 let buf2 = pool.get(1024);
293 assert_eq!(buf2.capacity(), 1024);
294
295 let stats = pool.stats();
296 assert_eq!(stats.total_gets, 2);
297 assert_eq!(stats.cache_hits, 1);
298 assert_eq!(stats.hit_rate(), 0.5);
299 }
300
301 #[test]
302 fn test_size_rounding() {
303 let pool = MemoryPool::new();
304
305 let buf = pool.get(1000);
307 assert_eq!(buf.capacity(), 1024);
308 }
309
310 #[test]
311 fn test_pool_limit() {
312 let config = PoolConfig {
313 max_buffers_per_size: 2,
314 ..Default::default()
315 };
316 let pool = MemoryPool::with_config(config);
317
318 pool.return_buffer(PooledBuffer::new(1024));
320 pool.return_buffer(PooledBuffer::new(1024));
321 pool.return_buffer(PooledBuffer::new(1024));
322
323 let stats = pool.stats();
324 assert_eq!(stats.buffers_in_pool, 2);
325 }
326
327 #[test]
328 fn test_multiple_sizes() {
329 let pool = MemoryPool::new();
330
331 let buf1 = pool.get(1024);
332 let buf2 = pool.get(2048);
333 let buf3 = pool.get(4096);
334
335 pool.return_buffer(buf1);
336 pool.return_buffer(buf2);
337 pool.return_buffer(buf3);
338
339 let stats = pool.stats();
340 assert_eq!(stats.buffers_in_pool, 3);
341 assert_eq!(stats.bytes_in_pool, 1024 + 2048 + 4096);
342 }
343
344 #[test]
345 fn test_clear() {
346 let pool = MemoryPool::new();
347
348 pool.return_buffer(PooledBuffer::new(1024));
349 pool.return_buffer(PooledBuffer::new(2048));
350
351 assert_eq!(pool.stats().buffers_in_pool, 2);
352
353 pool.clear();
354
355 assert_eq!(pool.stats().buffers_in_pool, 0);
356 assert_eq!(pool.stats().bytes_in_pool, 0);
357 }
358
359 #[test]
360 fn test_global_pool() {
361 let pool1 = global_pool();
362 let pool2 = global_pool();
363
364 assert!(std::ptr::eq(pool1, pool2));
366 }
367}