1use ax_codec_core::{BufferWriter, EncodeError};
2use std::sync::{Arc, Mutex};
3use std::vec::Vec;
4
5#[derive(Debug, Clone)]
6pub struct BufferPool {
7 free: Arc<Mutex<Vec<Vec<u8>>>>,
8 max_capacity: usize,
9 default_buf_size: usize,
10}
11
12impl Default for BufferPool {
13 fn default() -> Self {
14 Self::new()
15 }
16}
17
18impl BufferPool {
19 pub fn new() -> Self {
20 Self {
21 free: Arc::new(Mutex::new(Vec::new())),
22 max_capacity: 32,
23 default_buf_size: 1024,
24 }
25 }
26
27 pub fn with_capacity(max_capacity: usize, default_buf_size: usize) -> Self {
28 Self {
29 free: Arc::new(Mutex::new(Vec::with_capacity(max_capacity))),
30 max_capacity,
31 default_buf_size,
32 }
33 }
34
35 pub fn acquire(&self) -> Vec<u8> {
36 let mut free = self.free.lock().unwrap();
37 free.pop()
38 .unwrap_or_else(|| Vec::with_capacity(self.default_buf_size))
39 }
40
41 pub fn release(&self, mut buf: Vec<u8>) {
42 buf.clear();
43 let mut free = self.free.lock().unwrap();
44 if free.len() < self.max_capacity {
45 free.push(buf);
46 }
47 }
48
49 pub fn len(&self) -> usize {
50 self.free.lock().unwrap().len()
51 }
52
53 pub fn is_empty(&self) -> bool {
54 self.len() == 0
55 }
56}
57
58#[derive(Debug)]
59pub struct PooledWriter {
60 buf: Option<Vec<u8>>,
61 pool: Option<Arc<BufferPool>>,
62}
63
64impl Default for PooledWriter {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl PooledWriter {
71 pub fn new() -> Self {
72 Self {
73 buf: Some(Vec::new()),
74 pool: None,
75 }
76 }
77
78 pub fn with_pool(pool: Arc<BufferPool>) -> Self {
79 Self {
80 buf: Some(pool.acquire()),
81 pool: Some(pool),
82 }
83 }
84
85 pub fn finish(mut self) -> Vec<u8> {
86 self.buf.take().unwrap_or_default()
87 }
88
89 pub fn recycle(mut self) {
90 if let (Some(pool), Some(buf)) = (self.pool.take(), self.buf.take()) {
91 pool.release(buf);
92 }
93 }
94
95 pub fn as_slice(&self) -> &[u8] {
96 self.buf.as_deref().unwrap_or_default()
97 }
98}
99
100impl BufferWriter for PooledWriter {
101 #[inline]
102 fn write_all(&mut self, buf: &[u8]) -> Result<(), EncodeError> {
103 self.buf.as_mut().unwrap().extend_from_slice(buf);
104 Ok(())
105 }
106}
107
108impl Drop for PooledWriter {
109 fn drop(&mut self) {
110 if let (Some(pool), Some(buf)) = (self.pool.take(), self.buf.take()) {
111 pool.release(buf);
112 }
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn pool_acquire_release() {
122 let pool = BufferPool::new();
123 let buf = pool.acquire();
124 assert!(buf.capacity() >= 1024);
125 pool.release(buf);
126 assert_eq!(pool.len(), 1);
127 }
128
129 #[test]
130 fn pooled_writer_roundtrip() {
131 let pool = Arc::new(BufferPool::new());
132 let mut w = PooledWriter::with_pool(pool.clone());
133 w.write_all(b"hello").unwrap();
134 assert_eq!(w.as_slice(), b"hello");
135
136 let buf = w.finish();
137 assert_eq!(&buf, b"hello");
138 }
139
140 #[test]
141 fn pooled_writer_auto_recycle_on_drop() {
142 let pool = Arc::new(BufferPool::new());
143 {
144 let mut w = PooledWriter::with_pool(pool.clone());
145 w.write_all(b"temp").unwrap();
146 }
147 assert_eq!(pool.len(), 1);
148
149 let buf = pool.acquire();
150 assert!(buf.is_empty());
151 assert!(buf.capacity() >= 4);
152 }
153
154 #[test]
155 fn pooled_writer_max_capacity() {
156 let pool = Arc::new(BufferPool::with_capacity(2, 64));
157 let b1 = pool.acquire();
158 let b2 = pool.acquire();
159 let b3 = pool.acquire();
160 pool.release(b1);
161 pool.release(b2);
162 pool.release(b3);
163 assert_eq!(pool.len(), 2);
164 }
165}