Skip to main content

ax_codec_bytes/
pool.rs

1use ax_codec_core::{BufferWriter, EncodeError};
2use std::sync::{Arc, Mutex};
3use std::vec::Vec;
4
5#[derive(Debug, Clone)]
6pub struct BufferPool {
7    free: Arc<Mutex<Vec<Vec<u8>>>>,
8    max_capacity: usize,
9    default_buf_size: usize,
10}
11
12impl Default for BufferPool {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl BufferPool {
19    pub fn new() -> Self {
20        Self {
21            free: Arc::new(Mutex::new(Vec::new())),
22            max_capacity: 32,
23            default_buf_size: 1024,
24        }
25    }
26
27    pub fn with_capacity(max_capacity: usize, default_buf_size: usize) -> Self {
28        Self {
29            free: Arc::new(Mutex::new(Vec::with_capacity(max_capacity))),
30            max_capacity,
31            default_buf_size,
32        }
33    }
34
35    pub fn acquire(&self) -> Vec<u8> {
36        let mut free = self.free.lock().unwrap();
37        free.pop()
38            .unwrap_or_else(|| Vec::with_capacity(self.default_buf_size))
39    }
40
41    pub fn release(&self, mut buf: Vec<u8>) {
42        buf.clear();
43        let mut free = self.free.lock().unwrap();
44        if free.len() < self.max_capacity {
45            free.push(buf);
46        }
47    }
48
49    pub fn len(&self) -> usize {
50        self.free.lock().unwrap().len()
51    }
52
53    pub fn is_empty(&self) -> bool {
54        self.len() == 0
55    }
56}
57
58#[derive(Debug)]
59pub struct PooledWriter {
60    buf: Option<Vec<u8>>,
61    pool: Option<Arc<BufferPool>>,
62}
63
64impl Default for PooledWriter {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl PooledWriter {
71    pub fn new() -> Self {
72        Self {
73            buf: Some(Vec::new()),
74            pool: None,
75        }
76    }
77
78    pub fn with_pool(pool: Arc<BufferPool>) -> Self {
79        Self {
80            buf: Some(pool.acquire()),
81            pool: Some(pool),
82        }
83    }
84
85    pub fn finish(mut self) -> Vec<u8> {
86        self.buf.take().unwrap_or_default()
87    }
88
89    pub fn recycle(mut self) {
90        if let (Some(pool), Some(buf)) = (self.pool.take(), self.buf.take()) {
91            pool.release(buf);
92        }
93    }
94
95    pub fn as_slice(&self) -> &[u8] {
96        self.buf.as_deref().unwrap_or_default()
97    }
98}
99
100impl BufferWriter for PooledWriter {
101    #[inline]
102    fn write_all(&mut self, buf: &[u8]) -> Result<(), EncodeError> {
103        self.buf.as_mut().unwrap().extend_from_slice(buf);
104        Ok(())
105    }
106}
107
108impl Drop for PooledWriter {
109    fn drop(&mut self) {
110        if let (Some(pool), Some(buf)) = (self.pool.take(), self.buf.take()) {
111            pool.release(buf);
112        }
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn pool_acquire_release() {
122        let pool = BufferPool::new();
123        let buf = pool.acquire();
124        assert!(buf.capacity() >= 1024);
125        pool.release(buf);
126        assert_eq!(pool.len(), 1);
127    }
128
129    #[test]
130    fn pooled_writer_roundtrip() {
131        let pool = Arc::new(BufferPool::new());
132        let mut w = PooledWriter::with_pool(pool.clone());
133        w.write_all(b"hello").unwrap();
134        assert_eq!(w.as_slice(), b"hello");
135
136        let buf = w.finish();
137        assert_eq!(&buf, b"hello");
138    }
139
140    #[test]
141    fn pooled_writer_auto_recycle_on_drop() {
142        let pool = Arc::new(BufferPool::new());
143        {
144            let mut w = PooledWriter::with_pool(pool.clone());
145            w.write_all(b"temp").unwrap();
146        }
147        assert_eq!(pool.len(), 1);
148
149        let buf = pool.acquire();
150        assert!(buf.is_empty());
151        assert!(buf.capacity() >= 4);
152    }
153
154    #[test]
155    fn pooled_writer_max_capacity() {
156        let pool = Arc::new(BufferPool::with_capacity(2, 64));
157        let b1 = pool.acquire();
158        let b2 = pool.acquire();
159        let b3 = pool.acquire();
160        pool.release(b1);
161        pool.release(b2);
162        pool.release(b3);
163        assert_eq!(pool.len(), 2);
164    }
165}