1use std::sync::Arc;
2
3use crossbeam_queue::ArrayQueue;
4
5pub struct BufferPool {
25 pool: Arc<ArrayQueue<bytes::BytesMut>>,
26 buf_size: usize,
27}
28
29impl BufferPool {
30 #[must_use]
35 pub fn new(capacity: usize, buf_size: usize) -> Self {
36 Self {
37 pool: Arc::new(ArrayQueue::new(capacity)),
38 buf_size,
39 }
40 }
41
42 #[must_use]
47 pub fn get(&self) -> bytes::BytesMut {
48 match self.pool.pop() {
49 Some(mut b) => {
50 b.clear();
51 b
52 }
53 None => bytes::BytesMut::zeroed(self.buf_size),
54 }
55 }
56
57 pub fn put(&self, buf: bytes::BytesMut) {
62 let _ = self.pool.push(buf);
63 }
64
65 #[must_use]
67 pub fn available(&self) -> usize {
68 self.pool.len()
69 }
70
71 #[must_use]
73 pub fn capacity(&self) -> usize {
74 self.pool.capacity()
75 }
76}
77
78impl Clone for BufferPool {
79 fn clone(&self) -> Self {
80 Self {
81 pool: Arc::clone(&self.pool),
82 buf_size: self.buf_size,
83 }
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn new_pool_is_empty() {
93 let pool = BufferPool::new(4, 64);
94 assert_eq!(pool.available(), 0);
95 assert_eq!(pool.capacity(), 4);
96 }
97
98 #[test]
99 fn get_returns_zeroed_buffer_when_empty() {
100 let pool = BufferPool::new(4, 64);
101 let buf = pool.get();
102 assert_eq!(buf.len(), 64);
103 assert!(buf.iter().all(|&b| b == 0));
104 }
105
106 #[test]
107 fn put_and_get_roundtrip() {
108 let pool = BufferPool::new(4, 64);
109 let mut buf = pool.get();
110 buf.extend_from_slice(b"hello");
111 pool.put(buf);
112
113 assert_eq!(pool.available(), 1);
114
115 let buf2 = pool.get();
116 assert_eq!(buf2.len(), 0, "buffer should be cleared on get()");
118 assert!(
119 buf2.capacity() >= 5,
120 "buffer should retain underlying capacity"
121 );
122 }
123
124 #[test]
125 fn put_does_not_exceed_capacity() {
126 let pool = BufferPool::new(2, 64);
127 let b1 = pool.get();
128 let b2 = pool.get();
129 let b3 = pool.get();
130
131 pool.put(b1);
132 pool.put(b2);
133 pool.put(b3); assert_eq!(pool.available(), 2);
136 }
137
138 #[test]
139 fn clone_shares_pool() {
140 let pool = BufferPool::new(4, 64);
141 let cloned = pool.clone();
142
143 let buf = pool.get();
144 assert_eq!(buf.len(), 64);
145
146 assert_eq!(cloned.available(), 0);
148 assert_eq!(cloned.capacity(), 4);
149 }
150
151 #[test]
152 fn concurrent_access() {
153 let pool = Arc::new(BufferPool::new(16, 64));
154 let mut handles = Vec::new();
155
156 for _ in 0..8 {
157 let p = Arc::clone(&pool);
158 handles.push(std::thread::spawn(move || {
159 for _ in 0..100 {
160 let buf = p.get();
161 p.put(buf);
162 }
163 }));
164 }
165
166 for h in handles {
167 h.join().unwrap();
168 }
169
170 assert!(pool.available() <= pool.capacity());
171 }
172}