1use std::collections::VecDeque;
2use std::sync::{Condvar, Mutex};
3use std::time::Duration;
4
5pub struct ConnectionPool<T> {
18 inner: Mutex<VecDeque<T>>,
19 available: Condvar,
20 max_size: usize,
21}
22
23pub struct PooledConnection<'a, T> {
25 pool: &'a ConnectionPool<T>,
26 conn: Option<T>,
27}
28
29impl<T> std::fmt::Debug for PooledConnection<'_, T> {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("PooledConnection")
32 .field("has_conn", &self.conn.is_some())
33 .finish()
34 }
35}
36
37impl<T> ConnectionPool<T> {
38 pub fn new(max_size: usize) -> Self {
42 assert!(max_size > 0, "pool max_size must be at least 1");
43 Self {
44 inner: Mutex::new(VecDeque::with_capacity(max_size)),
45 available: Condvar::new(),
46 max_size,
47 }
48 }
49
50 pub fn add(&self, conn: T) {
54 let mut queue = self.inner.lock().expect("pool lock poisoned");
55 assert!(
56 queue.len() < self.max_size,
57 "cannot add connection: pool is at capacity ({})",
58 self.max_size,
59 );
60 queue.push_back(conn);
61 self.available.notify_one();
62 }
63
64 pub fn get(&self, timeout: Duration) -> Result<PooledConnection<'_, T>, PoolError> {
69 let mut queue = self.inner.lock().expect("pool lock poisoned");
70
71 if let Some(conn) = queue.pop_front() {
73 return Ok(PooledConnection {
74 pool: self,
75 conn: Some(conn),
76 });
77 }
78
79 let (mut queue, wait_result) = self
81 .available
82 .wait_timeout_while(queue, timeout, |q| q.is_empty())
83 .expect("pool lock poisoned");
84
85 if wait_result.timed_out() && queue.is_empty() {
86 return Err(PoolError::Timeout);
87 }
88
89 match queue.pop_front() {
90 Some(conn) => Ok(PooledConnection {
91 pool: self,
92 conn: Some(conn),
93 }),
94 None => Err(PoolError::Unavailable),
95 }
96 }
97
98 pub fn available_count(&self) -> usize {
100 self.inner.lock().expect("pool lock poisoned").len()
101 }
102
103 pub fn max_size(&self) -> usize {
105 self.max_size
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq)]
115pub enum PoolError {
116 Timeout,
118 Unavailable,
120}
121
122impl std::fmt::Display for PoolError {
123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 match self {
125 PoolError::Timeout => write!(f, "connection pool: timed out waiting for a connection"),
126 PoolError::Unavailable => {
127 write!(f, "connection pool: no connection available after wait")
128 }
129 }
130 }
131}
132
133impl std::error::Error for PoolError {}
134
135impl<T> Drop for PooledConnection<'_, T> {
140 fn drop(&mut self) {
141 if let Some(conn) = self.conn.take() {
142 self.pool.add(conn);
143 }
144 }
145}
146
147impl<T> std::ops::Deref for PooledConnection<'_, T> {
148 type Target = T;
149 fn deref(&self) -> &T {
150 self.conn
151 .as_ref()
152 .expect("PooledConnection used after take (bug)")
153 }
154}
155
156impl<T> std::ops::DerefMut for PooledConnection<'_, T> {
157 fn deref_mut(&mut self) -> &mut T {
158 self.conn
159 .as_mut()
160 .expect("PooledConnection used after take (bug)")
161 }
162}
163
164#[cfg(test)]
169mod tests {
170 use super::*;
171 use std::sync::Arc;
172 use std::thread;
173
174 #[test]
175 fn basic_get_and_return() {
176 let pool = ConnectionPool::new(2);
177 pool.add("conn1");
178 pool.add("conn2");
179
180 assert_eq!(pool.available_count(), 2);
181 assert_eq!(pool.max_size(), 2);
182
183 {
184 let c = pool.get(Duration::from_millis(100)).unwrap();
185 assert_eq!(*c, "conn1");
186 assert_eq!(pool.available_count(), 1);
187 }
188
189 assert_eq!(pool.available_count(), 2);
191 }
192
193 #[test]
194 fn pool_exhaustion_blocks_then_succeeds() {
195 let pool = Arc::new(ConnectionPool::new(1));
196 pool.add(42u32);
197
198 let pool2 = Arc::clone(&pool);
201 let holder = thread::spawn(move || {
202 let _conn = pool2.get(Duration::from_millis(100)).unwrap();
203 assert_eq!(*_conn, 42);
204 thread::sleep(Duration::from_millis(100));
205 });
207
208 thread::sleep(Duration::from_millis(20));
210
211 let c = pool.get(Duration::from_secs(2)).unwrap();
213 assert_eq!(*c, 42);
214
215 holder.join().expect("holder thread panicked");
216 }
217
218 #[test]
219 fn pool_exhaustion_timeout() {
220 let pool = ConnectionPool::new(1);
221 pool.add("only");
222
223 let _held = pool.get(Duration::from_millis(100)).unwrap();
224 let result = pool.get(Duration::from_millis(50));
225 assert!(result.is_err());
226 assert_eq!(result.unwrap_err(), PoolError::Timeout);
227 }
228
229 #[test]
230 fn dropped_guard_returns_connection() {
231 let pool = ConnectionPool::new(1);
232 pool.add(99u32);
233
234 assert_eq!(pool.available_count(), 1);
235 {
236 let _c = pool.get(Duration::from_millis(100)).unwrap();
237 assert_eq!(pool.available_count(), 0);
238 }
239 assert_eq!(pool.available_count(), 1);
240 }
241
242 #[test]
243 fn multiple_concurrent_gets() {
244 let pool = Arc::new(ConnectionPool::new(4));
245 for i in 0..4u32 {
246 pool.add(i);
247 }
248
249 let mut handles = Vec::new();
250 for _ in 0..8 {
251 let pool = Arc::clone(&pool);
252 handles.push(thread::spawn(move || {
253 let c = pool.get(Duration::from_secs(2)).unwrap();
254 thread::sleep(Duration::from_millis(10));
256 let _val = *c;
257 }));
258 }
259
260 for h in handles {
261 h.join().expect("thread panicked");
262 }
263
264 assert_eq!(pool.available_count(), 4);
265 }
266
267 #[test]
268 fn deref_mut_works() {
269 let pool = ConnectionPool::new(1);
270 pool.add(vec![1, 2, 3]);
271
272 let mut c = pool.get(Duration::from_millis(100)).unwrap();
273 c.push(4);
274 assert_eq!(*c, vec![1, 2, 3, 4]);
275 }
276
277 #[test]
278 #[should_panic(expected = "pool max_size must be at least 1")]
279 fn zero_size_panics() {
280 let _pool = ConnectionPool::<u32>::new(0);
281 }
282
283 #[test]
284 #[should_panic(expected = "pool is at capacity")]
285 fn add_beyond_capacity_panics() {
286 let pool = ConnectionPool::new(1);
287 pool.add(1);
288 pool.add(2);
289 }
290
291 #[test]
292 fn pool_error_display() {
293 assert!(format!("{}", PoolError::Timeout).contains("timed out"));
294 assert!(format!("{}", PoolError::Unavailable).contains("no connection"));
295 }
296}