Skip to main content

ronn_core/
memory_pool.rs

1//! Memory pooling for efficient tensor allocation.
2//!
3//! Reduces allocation overhead by reusing memory buffers across operations.
4//! Critical for low-latency, high-throughput inference.
5
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9/// A memory buffer that can be reused.
10#[derive(Clone)]
11pub struct PooledBuffer {
12    data: Vec<u8>,
13    capacity: usize,
14}
15
16impl PooledBuffer {
17    /// Create a new pooled buffer with the given capacity.
18    pub fn new(capacity: usize) -> Self {
19        Self {
20            data: Vec::with_capacity(capacity),
21            capacity,
22        }
23    }
24
25    /// Get the capacity of this buffer.
26    pub fn capacity(&self) -> usize {
27        self.capacity
28    }
29
30    /// Get a mutable reference to the underlying data.
31    pub fn as_mut_slice(&mut self) -> &mut [u8] {
32        &mut self.data
33    }
34
35    /// Get the underlying data.
36    pub fn data(&self) -> &[u8] {
37        &self.data
38    }
39
40    /// Resize the buffer to the given size.
41    pub fn resize(&mut self, new_size: usize, value: u8) {
42        self.data.resize(new_size, value);
43    }
44}
45
46/// Memory pool for efficient allocation and reuse of buffers.
47///
48/// Maintains separate pools for different size classes to minimize fragmentation.
49/// Thread-safe via internal locking.
50///
51/// # Performance
52///
53/// - Cache hit: O(1) - reuse existing buffer
54/// - Cache miss: O(1) - allocate new buffer
55/// - Return: O(1) - add to pool
56///
57/// # Example
58///
59/// ```
60/// use ronn_core::memory_pool::MemoryPool;
61///
62/// let pool = MemoryPool::new();
63///
64/// // Get buffer (cache miss - allocates)
65/// let mut buf = pool.get(1024);
66/// // Use buffer...
67///
68/// // Return buffer (adds to pool)
69/// pool.return_buffer(buf);
70///
71/// // Get buffer again (cache hit - reuses)
72/// let buf2 = pool.get(1024);
73/// ```
74pub struct MemoryPool {
75    /// Pools organized by size class
76    pools: Arc<Mutex<HashMap<usize, Vec<PooledBuffer>>>>,
77    /// Statistics
78    stats: Arc<Mutex<PoolStats>>,
79    /// Configuration
80    config: PoolConfig,
81}
82
83/// Configuration for memory pool.
84#[derive(Debug, Clone)]
85pub struct PoolConfig {
86    /// Maximum number of buffers per size class
87    pub max_buffers_per_size: usize,
88    /// Maximum total buffers across all sizes
89    pub max_total_buffers: usize,
90    /// Enable size class rounding (powers of 2)
91    pub round_sizes: bool,
92}
93
94impl Default for PoolConfig {
95    fn default() -> Self {
96        Self {
97            max_buffers_per_size: 16,
98            max_total_buffers: 256,
99            round_sizes: true,
100        }
101    }
102}
103
104/// Statistics about memory pool usage.
105#[derive(Debug, Clone, Default)]
106pub struct PoolStats {
107    /// Total number of get() calls
108    pub total_gets: u64,
109    /// Number of cache hits (reused buffer)
110    pub cache_hits: u64,
111    /// Number of cache misses (new allocation)
112    pub cache_misses: u64,
113    /// Total number of return_buffer() calls
114    pub total_returns: u64,
115    /// Number of buffers currently in pool
116    pub buffers_in_pool: usize,
117    /// Total bytes currently pooled
118    pub bytes_in_pool: usize,
119}
120
121impl PoolStats {
122    /// Calculate cache hit rate (0.0 to 1.0).
123    pub fn hit_rate(&self) -> f64 {
124        if self.total_gets == 0 {
125            0.0
126        } else {
127            self.cache_hits as f64 / self.total_gets as f64
128        }
129    }
130
131    /// Calculate cache miss rate (0.0 to 1.0).
132    pub fn miss_rate(&self) -> f64 {
133        1.0 - self.hit_rate()
134    }
135}
136
137impl MemoryPool {
138    /// Create a new memory pool with default configuration.
139    pub fn new() -> Self {
140        Self::with_config(PoolConfig::default())
141    }
142
143    /// Create a new memory pool with custom configuration.
144    pub fn with_config(config: PoolConfig) -> Self {
145        Self {
146            pools: Arc::new(Mutex::new(HashMap::new())),
147            stats: Arc::new(Mutex::new(PoolStats::default())),
148            config,
149        }
150    }
151
152    /// Get a buffer of at least the specified size.
153    ///
154    /// Returns a cached buffer if available, otherwise allocates a new one.
155    ///
156    /// # Arguments
157    ///
158    /// * `size` - Minimum size in bytes
159    ///
160    /// # Returns
161    ///
162    /// A pooled buffer with capacity >= size
163    pub fn get(&self, size: usize) -> PooledBuffer {
164        let size_class = self.size_class(size);
165
166        let mut pools = self.pools.lock().unwrap();
167        let mut stats = self.stats.lock().unwrap();
168
169        stats.total_gets += 1;
170
171        // Try to get from pool
172        if let Some(pool) = pools.get_mut(&size_class) {
173            if let Some(buffer) = pool.pop() {
174                stats.cache_hits += 1;
175                stats.buffers_in_pool -= 1;
176                stats.bytes_in_pool -= buffer.capacity();
177                return buffer;
178            }
179        }
180
181        // Cache miss - allocate new buffer
182        stats.cache_misses += 1;
183        PooledBuffer::new(size_class)
184    }
185
186    /// Return a buffer to the pool for reuse.
187    ///
188    /// If the pool for this size class is full, the buffer is dropped.
189    ///
190    /// # Arguments
191    ///
192    /// * `buffer` - The buffer to return
193    pub fn return_buffer(&self, buffer: PooledBuffer) {
194        let size_class = buffer.capacity();
195
196        let mut pools = self.pools.lock().unwrap();
197        let mut stats = self.stats.lock().unwrap();
198
199        stats.total_returns += 1;
200
201        // Check if pool is full
202        let pool = pools.entry(size_class).or_insert_with(Vec::new);
203
204        if pool.len() < self.config.max_buffers_per_size
205            && stats.buffers_in_pool < self.config.max_total_buffers
206        {
207            stats.buffers_in_pool += 1;
208            stats.bytes_in_pool += buffer.capacity();
209            pool.push(buffer);
210        }
211        // Otherwise drop the buffer (will be freed)
212    }
213
214    /// Get pool statistics.
215    pub fn stats(&self) -> PoolStats {
216        self.stats.lock().unwrap().clone()
217    }
218
219    /// Clear all buffers from the pool.
220    pub fn clear(&self) {
221        let mut pools = self.pools.lock().unwrap();
222        let mut stats = self.stats.lock().unwrap();
223
224        pools.clear();
225        stats.buffers_in_pool = 0;
226        stats.bytes_in_pool = 0;
227    }
228
229    /// Calculate size class for a given size.
230    ///
231    /// If round_sizes is enabled, rounds up to nearest power of 2.
232    /// Otherwise returns the size as-is.
233    fn size_class(&self, size: usize) -> usize {
234        if self.config.round_sizes {
235            size.next_power_of_two()
236        } else {
237            size
238        }
239    }
240}
241
242impl Default for MemoryPool {
243    fn default() -> Self {
244        Self::new()
245    }
246}
247
248/// Global memory pool instance.
249///
250/// Thread-safe singleton for sharing across the application.
251static GLOBAL_POOL: once_cell::sync::Lazy<MemoryPool> =
252    once_cell::sync::Lazy::new(|| MemoryPool::new());
253
254/// Get the global memory pool instance.
255pub fn global_pool() -> &'static MemoryPool {
256    &GLOBAL_POOL
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_pool_creation() {
265        let pool = MemoryPool::new();
266        let stats = pool.stats();
267        assert_eq!(stats.total_gets, 0);
268        assert_eq!(stats.cache_hits, 0);
269    }
270
271    #[test]
272    fn test_get_and_return() {
273        let pool = MemoryPool::new();
274
275        // First get - cache miss
276        let buf = pool.get(1024);
277        assert_eq!(buf.capacity(), 1024);
278
279        let stats = pool.stats();
280        assert_eq!(stats.total_gets, 1);
281        assert_eq!(stats.cache_misses, 1);
282        assert_eq!(stats.cache_hits, 0);
283
284        // Return buffer
285        pool.return_buffer(buf);
286
287        let stats = pool.stats();
288        assert_eq!(stats.total_returns, 1);
289        assert_eq!(stats.buffers_in_pool, 1);
290
291        // Second get - cache hit
292        let buf2 = pool.get(1024);
293        assert_eq!(buf2.capacity(), 1024);
294
295        let stats = pool.stats();
296        assert_eq!(stats.total_gets, 2);
297        assert_eq!(stats.cache_hits, 1);
298        assert_eq!(stats.hit_rate(), 0.5);
299    }
300
301    #[test]
302    fn test_size_rounding() {
303        let pool = MemoryPool::new();
304
305        // Request 1000 bytes, should round to 1024
306        let buf = pool.get(1000);
307        assert_eq!(buf.capacity(), 1024);
308    }
309
310    #[test]
311    fn test_pool_limit() {
312        let config = PoolConfig {
313            max_buffers_per_size: 2,
314            ..Default::default()
315        };
316        let pool = MemoryPool::with_config(config);
317
318        // Add 3 buffers, only 2 should be kept
319        pool.return_buffer(PooledBuffer::new(1024));
320        pool.return_buffer(PooledBuffer::new(1024));
321        pool.return_buffer(PooledBuffer::new(1024));
322
323        let stats = pool.stats();
324        assert_eq!(stats.buffers_in_pool, 2);
325    }
326
327    #[test]
328    fn test_multiple_sizes() {
329        let pool = MemoryPool::new();
330
331        let buf1 = pool.get(1024);
332        let buf2 = pool.get(2048);
333        let buf3 = pool.get(4096);
334
335        pool.return_buffer(buf1);
336        pool.return_buffer(buf2);
337        pool.return_buffer(buf3);
338
339        let stats = pool.stats();
340        assert_eq!(stats.buffers_in_pool, 3);
341        assert_eq!(stats.bytes_in_pool, 1024 + 2048 + 4096);
342    }
343
344    #[test]
345    fn test_clear() {
346        let pool = MemoryPool::new();
347
348        pool.return_buffer(PooledBuffer::new(1024));
349        pool.return_buffer(PooledBuffer::new(2048));
350
351        assert_eq!(pool.stats().buffers_in_pool, 2);
352
353        pool.clear();
354
355        assert_eq!(pool.stats().buffers_in_pool, 0);
356        assert_eq!(pool.stats().bytes_in_pool, 0);
357    }
358
359    #[test]
360    fn test_global_pool() {
361        let pool1 = global_pool();
362        let pool2 = global_pool();
363
364        // Should be the same instance
365        assert!(std::ptr::eq(pool1, pool2));
366    }
367}