Skip to main content

px_core/
buffer_pool.rs

1//! Reusable byte-buffer pool.
2//!
3//! Used by the WS hot path and HTTP body reads to avoid realloc churn. Each
4//! buffer defaults to 512 KB; pool size 10. Buffers that have grown past 2x
5//! the configured size on return are shrunk back so the pool does not
6//! retain pathologically large allocations.
7
8use std::sync::Arc;
9use tokio::sync::Mutex;
10
11pub struct BufferPool {
12    buffers: Arc<Mutex<Vec<Vec<u8>>>>,
13    buffer_size: usize,
14    max_pool_size: usize,
15}
16
17impl BufferPool {
18    pub const DEFAULT_BUFFER_SIZE: usize = 512 * 1024;
19    pub const DEFAULT_POOL_SIZE: usize = 10;
20
21    pub fn new(buffer_size: usize, max_pool_size: usize) -> Self {
22        Self {
23            buffers: Arc::new(Mutex::new(Vec::with_capacity(max_pool_size))),
24            buffer_size,
25            max_pool_size,
26        }
27    }
28
29    /// A `Vec<u8>` with capacity `buffer_size` (or more). If the pool is empty,
30    /// allocates a fresh one.
31    pub async fn get(&self) -> Vec<u8> {
32        let mut pool = self.buffers.lock().await;
33        pool.pop()
34            .map(|mut b| {
35                b.clear();
36                b
37            })
38            .unwrap_or_else(|| Vec::with_capacity(self.buffer_size))
39    }
40
41    /// Returns `buf` to the pool, shrinking it if it has grown past 2x the
42    /// configured size. If the pool is full, the buffer is dropped.
43    pub async fn return_buffer(&self, mut buf: Vec<u8>) {
44        buf.clear();
45        if buf.capacity() > self.buffer_size * 2 {
46            buf.shrink_to(self.buffer_size);
47        }
48        let mut pool = self.buffers.lock().await;
49        if pool.len() < self.max_pool_size {
50            pool.push(buf);
51        }
52    }
53
54    /// Pre-allocates `count` buffers at startup so the first request does not
55    /// pay the allocation cost.
56    pub async fn prewarm(&self, count: usize) {
57        let mut pool = self.buffers.lock().await;
58        let target = count.min(self.max_pool_size);
59        while pool.len() < target {
60            pool.push(Vec::with_capacity(self.buffer_size));
61        }
62    }
63
64    pub fn buffer_size(&self) -> usize {
65        self.buffer_size
66    }
67}
68
69impl Default for BufferPool {
70    fn default() -> Self {
71        Self::new(Self::DEFAULT_BUFFER_SIZE, Self::DEFAULT_POOL_SIZE)
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    #[tokio::test]
80    async fn get_returns_allocated_vec() {
81        let pool = BufferPool::new(1024, 4);
82        let buf = pool.get().await;
83        assert!(buf.capacity() >= 1024);
84        assert!(buf.is_empty());
85    }
86
87    #[tokio::test]
88    async fn return_reuses_under_limit() {
89        let pool = BufferPool::new(1024, 2);
90        let buf = pool.get().await;
91        pool.return_buffer(buf).await;
92        pool.return_buffer(Vec::with_capacity(1024)).await;
93        pool.return_buffer(Vec::with_capacity(1024)).await;
94        assert_eq!(pool.buffers.lock().await.len(), 2);
95    }
96
97    #[tokio::test]
98    async fn overgrown_buffer_is_shrunk() {
99        let pool = BufferPool::new(1024, 4);
100        let buf = vec![0u8; 1024 * 8];
101        pool.return_buffer(buf).await;
102        let reclaimed = pool.get().await;
103        assert!(reclaimed.capacity() <= 1024 * 2);
104    }
105
106    #[tokio::test]
107    async fn prewarm_fills_pool() {
108        let pool = BufferPool::new(1024, 8);
109        pool.prewarm(5).await;
110        assert_eq!(pool.buffers.lock().await.len(), 5);
111    }
112}