Skip to main content

ax_codec_bytes/
pool.rs

1use ax_codec_core::{BufferWriter, EncodeError};
2use std::sync::{Arc, Mutex};
3use std::vec::Vec;
4
5#[derive(Debug, Clone)]
6pub struct BufferPool {
7    free: Arc<Mutex<Vec<Vec<u8>>>>,
8    max_capacity: usize,
9    default_buf_size: usize,
10}
11
12impl Default for BufferPool {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl BufferPool {
19    pub fn new() -> Self {
20        Self {
21            free: Arc::new(Mutex::new(Vec::new())),
22            max_capacity: 32,
23            default_buf_size: 1024,
24        }
25    }
26
27    pub fn with_capacity(max_capacity: usize, default_buf_size: usize) -> Self {
28        Self {
29            free: Arc::new(Mutex::new(Vec::with_capacity(max_capacity))),
30            max_capacity,
31            default_buf_size,
32        }
33    }
34
35    pub fn acquire(&self) -> Vec<u8> {
36        let mut free = self.free.lock().unwrap_or_else(|e| {
37            // If mutex is poisoned, we can still access the data by recovering from the poison
38            e.into_inner()
39        });
40        free.pop()
41            .unwrap_or_else(|| Vec::with_capacity(self.default_buf_size))
42    }
43
44    pub fn release(&self, mut buf: Vec<u8>) {
45        buf.clear();
46        let mut free = self.free.lock().unwrap_or_else(|e| {
47            // If mutex is poisoned, we can still access the data by recovering from the poison
48            e.into_inner()
49        });
50        if free.len() < self.max_capacity {
51            free.push(buf);
52        }
53    }
54
55    pub fn len(&self) -> usize {
56        self.free.lock().unwrap_or_else(|e| {
57            // If mutex is poisoned, we can still access the data by recovering from the poison
58            e.into_inner()
59        }).len()
60    }
61
62    pub fn is_empty(&self) -> bool {
63        self.len() == 0
64    }
65}
66
67#[derive(Debug)]
68pub struct PooledWriter {
69    buf: Option<Vec<u8>>,
70    pool: Option<Arc<BufferPool>>,
71}
72
73impl Default for PooledWriter {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl PooledWriter {
80    pub fn new() -> Self {
81        Self {
82            buf: Some(Vec::new()),
83            pool: None,
84        }
85    }
86
87    pub fn with_pool(pool: Arc<BufferPool>) -> Self {
88        Self {
89            buf: Some(pool.acquire()),
90            pool: Some(pool),
91        }
92    }
93
94    pub fn finish(mut self) -> Vec<u8> {
95        self.buf.take().unwrap_or_default()
96    }
97
98    pub fn recycle(mut self) {
99        if let (Some(pool), Some(buf)) = (self.pool.take(), self.buf.take()) {
100            pool.release(buf);
101        }
102    }
103
104    pub fn as_slice(&self) -> &[u8] {
105        self.buf.as_deref().unwrap_or_default()
106    }
107}
108
109impl BufferWriter for PooledWriter {
110    #[inline]
111    fn write_all(&mut self, buf: &[u8]) -> Result<(), EncodeError> {
112        self.buf.as_mut().unwrap().extend_from_slice(buf);
113        Ok(())
114    }
115}
116
117impl Drop for PooledWriter {
118    fn drop(&mut self) {
119        if let (Some(pool), Some(buf)) = (self.pool.take(), self.buf.take()) {
120            pool.release(buf);
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn pool_acquire_release() {
131        let pool = BufferPool::new();
132        let buf = pool.acquire();
133        assert!(buf.capacity() >= 1024);
134        pool.release(buf);
135        assert_eq!(pool.len(), 1);
136    }
137
138    #[test]
139    fn pooled_writer_roundtrip() {
140        let pool = Arc::new(BufferPool::new());
141        let mut w = PooledWriter::with_pool(pool.clone());
142        w.write_all(b"hello").unwrap();
143        assert_eq!(w.as_slice(), b"hello");
144
145        let buf = w.finish();
146        assert_eq!(&buf, b"hello");
147    }
148
149    #[test]
150    fn pooled_writer_auto_recycle_on_drop() {
151        let pool = Arc::new(BufferPool::new());
152        {
153            let mut w = PooledWriter::with_pool(pool.clone());
154            w.write_all(b"temp").unwrap();
155        }
156        assert_eq!(pool.len(), 1);
157
158        let buf = pool.acquire();
159        assert!(buf.is_empty());
160        assert!(buf.capacity() >= 4);
161    }
162
163    #[test]
164    fn pooled_writer_max_capacity() {
165        let pool = Arc::new(BufferPool::with_capacity(2, 64));
166        let b1 = pool.acquire();
167        let b2 = pool.acquire();
168        let b3 = pool.acquire();
169        pool.release(b1);
170        pool.release(b2);
171        pool.release(b3);
172        assert_eq!(pool.len(), 2);
173    }
174}