Skip to main content

phantom_protocol/transport/
buffer_pool.rs

1//! Pre-allocated Buffer Pool
2//!
3//! Eliminates per-packet memory allocations for maximum throughput.
4
5use crossbeam_queue::ArrayQueue;
6use std::cell::RefCell;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9const BATCH_SIZE: usize = 32;
10const MAX_LOCAL_BUFFERS: usize = 64;
11
12/// A pool of pre-allocated buffers for zero-allocation I/O
13pub struct BufferPool {
14    /// Global pool of available buffers
15    buffers: ArrayQueue<Vec<u8>>,
16    /// Buffer size
17    buffer_size: usize,
18    /// Max pool size
19    #[allow(dead_code)]
20    max_buffers: usize,
21    /// Stats: total allocations
22    allocations: AtomicUsize,
23    /// Stats: pool hits
24    hits: AtomicUsize,
25}
26
27thread_local! {
28    static LOCAL_POOL: RefCell<Vec<Vec<u8>>> = RefCell::new(Vec::with_capacity(MAX_LOCAL_BUFFERS));
29}
30
31impl BufferPool {
32    /// Create a new buffer pool
33    pub fn new(buffer_size: usize, initial_count: usize, max_buffers: usize) -> Self {
34        let buffers = ArrayQueue::new(max_buffers);
35        let count = std::cmp::min(initial_count, max_buffers);
36        for _ in 0..count {
37            let _ = buffers.push(vec![0u8; buffer_size]);
38        }
39
40        Self {
41            buffers,
42            buffer_size,
43            max_buffers,
44            allocations: AtomicUsize::new(count),
45            hits: AtomicUsize::new(0),
46        }
47    }
48
49    /// Acquire a buffer from the pool
50    #[inline]
51    pub fn acquire(&self) -> PooledBuffer<'_> {
52        let mut buffer: Option<Vec<u8>> = LOCAL_POOL.with(|local| {
53            let mut local_pool = local.borrow_mut();
54            if let Some(mut buf) = local_pool.pop() {
55                buf.clear();
56                return Some(buf);
57            }
58            None
59        });
60
61        if buffer.is_none() {
62            // Refill local pool from global pool
63            LOCAL_POOL.with(|local| {
64                let mut local_pool = local.borrow_mut();
65                for _ in 0..BATCH_SIZE {
66                    if let Some(mut buf) = self.buffers.pop() {
67                        buf.clear();
68                        local_pool.push(buf);
69                    } else {
70                        break;
71                    }
72                }
73            });
74
75            buffer = LOCAL_POOL.with(|local| {
76                let mut local_pool = local.borrow_mut();
77                local_pool.pop()
78            });
79        }
80
81        let buffer = if let Some(buf) = buffer {
82            self.hits.fetch_add(1, Ordering::Relaxed);
83            buf
84        } else {
85            self.allocations.fetch_add(1, Ordering::Relaxed);
86            Vec::with_capacity(self.buffer_size)
87        };
88
89        PooledBuffer { buffer, pool: self }
90    }
91
92    /// Return a buffer to the pool
93    #[inline]
94    fn return_buffer(&self, mut buffer: Vec<u8>) {
95        buffer.clear();
96        LOCAL_POOL.with(|local| {
97            let mut local_pool = local.borrow_mut();
98            if local_pool.len() < MAX_LOCAL_BUFFERS {
99                local_pool.push(buffer);
100            } else {
101                // If local pool is full, flush half back to global pool
102                let half = MAX_LOCAL_BUFFERS / 2;
103                for _ in 0..half {
104                    if let Some(buf) = local_pool.pop() {
105                        let _ = self.buffers.push(buf);
106                    }
107                }
108
109                local_pool.push(buffer);
110            }
111        });
112    }
113
114    /// Get pool statistics
115    pub fn stats(&self) -> PoolStats {
116        PoolStats {
117            allocations: self.allocations.load(Ordering::Relaxed),
118            hits: self.hits.load(Ordering::Relaxed),
119            pool_size: self.buffers.len(),
120        }
121    }
122}
123
124/// A buffer borrowed from the pool
125pub struct PooledBuffer<'a> {
126    buffer: Vec<u8>,
127    pool: &'a BufferPool,
128}
129
130impl<'a> PooledBuffer<'a> {
131    /// Get mutable reference to inner buffer
132    ///
133    /// Note: returns `&mut Vec<u8>` (not `&mut T`) so it cannot implement
134    /// `std::convert::AsMut` without a concrete target type — the inherent
135    /// method intentionally provides the richer `Vec` interface.
136    #[allow(clippy::should_implement_trait)]
137    #[inline]
138    pub fn as_mut(&mut self) -> &mut Vec<u8> {
139        &mut self.buffer
140    }
141
142    /// Get reference to inner buffer
143    ///
144    /// Note: returns `&[u8]` (not `&Vec<u8>`) so the signature differs from
145    /// what `std::convert::AsRef<Vec<u8>>` would produce — an inherent method
146    /// avoids the ambiguity.
147    #[allow(clippy::should_implement_trait)]
148    #[inline]
149    pub fn as_ref(&self) -> &[u8] {
150        &self.buffer
151    }
152
153    /// Get the buffer length
154    #[inline]
155    pub fn len(&self) -> usize {
156        self.buffer.len()
157    }
158
159    /// Check if buffer is empty
160    #[inline]
161    pub fn is_empty(&self) -> bool {
162        self.buffer.is_empty()
163    }
164}
165
166impl<'a> std::ops::Deref for PooledBuffer<'a> {
167    type Target = Vec<u8>;
168
169    #[inline]
170    fn deref(&self) -> &Self::Target {
171        &self.buffer
172    }
173}
174
175impl<'a> std::ops::DerefMut for PooledBuffer<'a> {
176    #[inline]
177    fn deref_mut(&mut self) -> &mut Self::Target {
178        &mut self.buffer
179    }
180}
181
182impl<'a> Drop for PooledBuffer<'a> {
183    fn drop(&mut self) {
184        let buffer = std::mem::take(&mut self.buffer);
185        // Only return if it has some capacity (not completely empty shell)
186        if buffer.capacity() > 0 {
187            self.pool.return_buffer(buffer);
188        }
189    }
190}
191
192/// Pool statistics
193#[derive(Debug, Clone, Copy)]
194pub struct PoolStats {
195    pub allocations: usize,
196    pub hits: usize,
197    pub pool_size: usize,
198}
199
200impl PoolStats {
201    /// Hit rate (0.0 - 1.0)
202    pub fn hit_rate(&self) -> f64 {
203        if self.allocations + self.hits == 0 {
204            0.0
205        } else {
206            self.hits as f64 / (self.allocations + self.hits) as f64
207        }
208    }
209}
210
211/// Global buffer pool for common use
212static GLOBAL_POOL: std::sync::OnceLock<BufferPool> = std::sync::OnceLock::new();
213
214/// Get the global buffer pool
215pub fn global_pool() -> &'static BufferPool {
216    GLOBAL_POOL.get_or_init(|| {
217        // 64 KB buffers, 1024 initial, 65536 max
218        BufferPool::new(64 * 1024, 1024, 65536)
219    })
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use std::thread;
226
227    #[test]
228    fn test_buffer_pool() {
229        let pool = BufferPool::new(1024, 4, 16);
230
231        let mut buf1 = pool.acquire();
232        buf1.extend_from_slice(b"hello");
233        assert_eq!(buf1.len(), 5);
234
235        let buf2 = pool.acquire();
236        assert_eq!(buf2.len(), 0);
237
238        drop(buf1);
239        drop(buf2);
240
241        // After returning, buffers are pushed to local pool.
242        // It preloaded 4 buffers initially, we used 2 and returned 2. So it has 4.
243        LOCAL_POOL.with(|local| {
244            assert_eq!(local.borrow().len(), 4);
245        });
246    }
247
248    #[test]
249    fn test_thread_local_flushing() {
250        let pool = std::sync::Arc::new(BufferPool::new(1024, 0, 100));
251
252        let p_clone = pool.clone();
253        thread::spawn(move || {
254            let mut bufs = Vec::new();
255            // Allocate 70 buffers to exceed MAX_LOCAL_BUFFERS (64)
256            for _ in 0..70 {
257                bufs.push(p_clone.acquire());
258            }
259
260            // Drop all
261            drop(bufs);
262
263            // Local pool should have 64 buffers (or less if flushed), global should have some
264            let mut count = 0;
265            LOCAL_POOL.with(|local| {
266                count = local.borrow().len();
267            });
268            assert!(count <= MAX_LOCAL_BUFFERS);
269        })
270        .join()
271        .unwrap();
272
273        // global pool should have received the flushed buffers
274        assert!(pool.buffers.len() > 0);
275    }
276}