1use ax_codec_core::{BufferWriter, EncodeError};
2use std::sync::{Arc, Mutex};
3use std::vec::Vec;
4
5#[derive(Debug, Clone)]
6pub struct BufferPool {
7 free: Arc<Mutex<Vec<Vec<u8>>>>,
8 max_capacity: usize,
9 default_buf_size: usize,
10}
11
12impl Default for BufferPool {
13 fn default() -> Self {
14 Self::new()
15 }
16}
17
18impl BufferPool {
19 pub fn new() -> Self {
20 Self {
21 free: Arc::new(Mutex::new(Vec::new())),
22 max_capacity: 32,
23 default_buf_size: 1024,
24 }
25 }
26
27 pub fn with_capacity(max_capacity: usize, default_buf_size: usize) -> Self {
28 Self {
29 free: Arc::new(Mutex::new(Vec::with_capacity(max_capacity))),
30 max_capacity,
31 default_buf_size,
32 }
33 }
34
35 pub fn acquire(&self) -> Vec<u8> {
36 let mut free = self.free.lock().unwrap_or_else(|e| {
37 e.into_inner()
39 });
40 free.pop()
41 .unwrap_or_else(|| Vec::with_capacity(self.default_buf_size))
42 }
43
44 pub fn release(&self, mut buf: Vec<u8>) {
45 buf.clear();
46 let mut free = self.free.lock().unwrap_or_else(|e| {
47 e.into_inner()
49 });
50 if free.len() < self.max_capacity {
51 free.push(buf);
52 }
53 }
54
55 pub fn len(&self) -> usize {
56 self.free.lock().unwrap_or_else(|e| {
57 e.into_inner()
59 }).len()
60 }
61
62 pub fn is_empty(&self) -> bool {
63 self.len() == 0
64 }
65}
66
67#[derive(Debug)]
68pub struct PooledWriter {
69 buf: Option<Vec<u8>>,
70 pool: Option<Arc<BufferPool>>,
71}
72
73impl Default for PooledWriter {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl PooledWriter {
80 pub fn new() -> Self {
81 Self {
82 buf: Some(Vec::new()),
83 pool: None,
84 }
85 }
86
87 pub fn with_pool(pool: Arc<BufferPool>) -> Self {
88 Self {
89 buf: Some(pool.acquire()),
90 pool: Some(pool),
91 }
92 }
93
94 pub fn finish(mut self) -> Vec<u8> {
95 self.buf.take().unwrap_or_default()
96 }
97
98 pub fn recycle(mut self) {
99 if let (Some(pool), Some(buf)) = (self.pool.take(), self.buf.take()) {
100 pool.release(buf);
101 }
102 }
103
104 pub fn as_slice(&self) -> &[u8] {
105 self.buf.as_deref().unwrap_or_default()
106 }
107}
108
109impl BufferWriter for PooledWriter {
110 #[inline]
111 fn write_all(&mut self, buf: &[u8]) -> Result<(), EncodeError> {
112 self.buf.as_mut().unwrap().extend_from_slice(buf);
113 Ok(())
114 }
115}
116
117impl Drop for PooledWriter {
118 fn drop(&mut self) {
119 if let (Some(pool), Some(buf)) = (self.pool.take(), self.buf.take()) {
120 pool.release(buf);
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn pool_acquire_release() {
131 let pool = BufferPool::new();
132 let buf = pool.acquire();
133 assert!(buf.capacity() >= 1024);
134 pool.release(buf);
135 assert_eq!(pool.len(), 1);
136 }
137
138 #[test]
139 fn pooled_writer_roundtrip() {
140 let pool = Arc::new(BufferPool::new());
141 let mut w = PooledWriter::with_pool(pool.clone());
142 w.write_all(b"hello").unwrap();
143 assert_eq!(w.as_slice(), b"hello");
144
145 let buf = w.finish();
146 assert_eq!(&buf, b"hello");
147 }
148
149 #[test]
150 fn pooled_writer_auto_recycle_on_drop() {
151 let pool = Arc::new(BufferPool::new());
152 {
153 let mut w = PooledWriter::with_pool(pool.clone());
154 w.write_all(b"temp").unwrap();
155 }
156 assert_eq!(pool.len(), 1);
157
158 let buf = pool.acquire();
159 assert!(buf.is_empty());
160 assert!(buf.capacity() >= 4);
161 }
162
163 #[test]
164 fn pooled_writer_max_capacity() {
165 let pool = Arc::new(BufferPool::with_capacity(2, 64));
166 let b1 = pool.acquire();
167 let b2 = pool.acquire();
168 let b3 = pool.acquire();
169 pool.release(b1);
170 pool.release(b2);
171 pool.release(b3);
172 assert_eq!(pool.len(), 2);
173 }
174}