Skip to main content

pylon_storage/
pool.rs

1use std::collections::VecDeque;
2use std::sync::{Condvar, Mutex};
3use std::time::Duration;
4
5// ---------------------------------------------------------------------------
6// Connection pool
7// ---------------------------------------------------------------------------
8
9/// A minimal database connection pool.
10///
11/// Maintains a bounded set of connections and hands them out on request.
12/// When all connections are in use, callers block until one is returned
13/// (or a timeout expires).
14///
15/// Connections are returned automatically when the [`PooledConnection`] guard
16/// is dropped, so callers cannot accidentally leak a slot.
17pub struct ConnectionPool<T> {
18    inner: Mutex<VecDeque<T>>,
19    available: Condvar,
20    max_size: usize,
21}
22
23/// RAII guard that returns the connection to the pool on drop.
24pub struct PooledConnection<'a, T> {
25    pool: &'a ConnectionPool<T>,
26    conn: Option<T>,
27}
28
29impl<T> std::fmt::Debug for PooledConnection<'_, T> {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("PooledConnection")
32            .field("has_conn", &self.conn.is_some())
33            .finish()
34    }
35}
36
37impl<T> ConnectionPool<T> {
38    /// Create a new, empty pool with the given capacity.
39    ///
40    /// Connections must be added via [`add`] before they can be acquired.
41    pub fn new(max_size: usize) -> Self {
42        assert!(max_size > 0, "pool max_size must be at least 1");
43        Self {
44            inner: Mutex::new(VecDeque::with_capacity(max_size)),
45            available: Condvar::new(),
46            max_size,
47        }
48    }
49
50    /// Add a connection to the pool.
51    ///
52    /// Panics if the pool is already at capacity.
53    pub fn add(&self, conn: T) {
54        let mut queue = self.inner.lock().expect("pool lock poisoned");
55        assert!(
56            queue.len() < self.max_size,
57            "cannot add connection: pool is at capacity ({})",
58            self.max_size,
59        );
60        queue.push_back(conn);
61        self.available.notify_one();
62    }
63
64    /// Acquire a connection, blocking up to `timeout`.
65    ///
66    /// Returns `Err` if the timeout expires before a connection becomes
67    /// available.
68    pub fn get(&self, timeout: Duration) -> Result<PooledConnection<'_, T>, PoolError> {
69        let mut queue = self.inner.lock().expect("pool lock poisoned");
70
71        // Fast path: a connection is already available.
72        if let Some(conn) = queue.pop_front() {
73            return Ok(PooledConnection {
74                pool: self,
75                conn: Some(conn),
76            });
77        }
78
79        // Slow path: wait for a connection to be returned.
80        let (mut queue, wait_result) = self
81            .available
82            .wait_timeout_while(queue, timeout, |q| q.is_empty())
83            .expect("pool lock poisoned");
84
85        if wait_result.timed_out() && queue.is_empty() {
86            return Err(PoolError::Timeout);
87        }
88
89        match queue.pop_front() {
90            Some(conn) => Ok(PooledConnection {
91                pool: self,
92                conn: Some(conn),
93            }),
94            None => Err(PoolError::Unavailable),
95        }
96    }
97
98    /// Number of connections currently idle in the pool.
99    pub fn available_count(&self) -> usize {
100        self.inner.lock().expect("pool lock poisoned").len()
101    }
102
103    /// Maximum number of connections this pool can hold.
104    pub fn max_size(&self) -> usize {
105        self.max_size
106    }
107}
108
109// ---------------------------------------------------------------------------
110// Pool errors
111// ---------------------------------------------------------------------------
112
113/// Error returned when a connection cannot be acquired from the pool.
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub enum PoolError {
116    /// The timeout expired before a connection became available.
117    Timeout,
118    /// No connection was available after waiting (spurious wakeup).
119    Unavailable,
120}
121
122impl std::fmt::Display for PoolError {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        match self {
125            PoolError::Timeout => write!(f, "connection pool: timed out waiting for a connection"),
126            PoolError::Unavailable => {
127                write!(f, "connection pool: no connection available after wait")
128            }
129        }
130    }
131}
132
133impl std::error::Error for PoolError {}
134
135// ---------------------------------------------------------------------------
136// PooledConnection — RAII guard
137// ---------------------------------------------------------------------------
138
139impl<T> Drop for PooledConnection<'_, T> {
140    fn drop(&mut self) {
141        if let Some(conn) = self.conn.take() {
142            self.pool.add(conn);
143        }
144    }
145}
146
147impl<T> std::ops::Deref for PooledConnection<'_, T> {
148    type Target = T;
149    fn deref(&self) -> &T {
150        self.conn
151            .as_ref()
152            .expect("PooledConnection used after take (bug)")
153    }
154}
155
156impl<T> std::ops::DerefMut for PooledConnection<'_, T> {
157    fn deref_mut(&mut self) -> &mut T {
158        self.conn
159            .as_mut()
160            .expect("PooledConnection used after take (bug)")
161    }
162}
163
164// ---------------------------------------------------------------------------
165// Tests
166// ---------------------------------------------------------------------------
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use std::sync::Arc;
172    use std::thread;
173
174    #[test]
175    fn basic_get_and_return() {
176        let pool = ConnectionPool::new(2);
177        pool.add("conn1");
178        pool.add("conn2");
179
180        assert_eq!(pool.available_count(), 2);
181        assert_eq!(pool.max_size(), 2);
182
183        {
184            let c = pool.get(Duration::from_millis(100)).unwrap();
185            assert_eq!(*c, "conn1");
186            assert_eq!(pool.available_count(), 1);
187        }
188
189        // Guard dropped, connection returned.
190        assert_eq!(pool.available_count(), 2);
191    }
192
193    #[test]
194    fn pool_exhaustion_blocks_then_succeeds() {
195        let pool = Arc::new(ConnectionPool::new(1));
196        pool.add(42u32);
197
198        // Spawn a thread that grabs the connection, holds it briefly, then
199        // releases it.
200        let pool2 = Arc::clone(&pool);
201        let holder = thread::spawn(move || {
202            let _conn = pool2.get(Duration::from_millis(100)).unwrap();
203            assert_eq!(*_conn, 42);
204            thread::sleep(Duration::from_millis(100));
205            // _conn drops here, returning connection to pool.
206        });
207
208        // Give the holder thread time to acquire the connection.
209        thread::sleep(Duration::from_millis(20));
210
211        // This thread blocks until the holder releases.
212        let c = pool.get(Duration::from_secs(2)).unwrap();
213        assert_eq!(*c, 42);
214
215        holder.join().expect("holder thread panicked");
216    }
217
218    #[test]
219    fn pool_exhaustion_timeout() {
220        let pool = ConnectionPool::new(1);
221        pool.add("only");
222
223        let _held = pool.get(Duration::from_millis(100)).unwrap();
224        let result = pool.get(Duration::from_millis(50));
225        assert!(result.is_err());
226        assert_eq!(result.unwrap_err(), PoolError::Timeout);
227    }
228
229    #[test]
230    fn dropped_guard_returns_connection() {
231        let pool = ConnectionPool::new(1);
232        pool.add(99u32);
233
234        assert_eq!(pool.available_count(), 1);
235        {
236            let _c = pool.get(Duration::from_millis(100)).unwrap();
237            assert_eq!(pool.available_count(), 0);
238        }
239        assert_eq!(pool.available_count(), 1);
240    }
241
242    #[test]
243    fn multiple_concurrent_gets() {
244        let pool = Arc::new(ConnectionPool::new(4));
245        for i in 0..4u32 {
246            pool.add(i);
247        }
248
249        let mut handles = Vec::new();
250        for _ in 0..8 {
251            let pool = Arc::clone(&pool);
252            handles.push(thread::spawn(move || {
253                let c = pool.get(Duration::from_secs(2)).unwrap();
254                // Simulate work.
255                thread::sleep(Duration::from_millis(10));
256                let _val = *c;
257            }));
258        }
259
260        for h in handles {
261            h.join().expect("thread panicked");
262        }
263
264        assert_eq!(pool.available_count(), 4);
265    }
266
267    #[test]
268    fn deref_mut_works() {
269        let pool = ConnectionPool::new(1);
270        pool.add(vec![1, 2, 3]);
271
272        let mut c = pool.get(Duration::from_millis(100)).unwrap();
273        c.push(4);
274        assert_eq!(*c, vec![1, 2, 3, 4]);
275    }
276
277    #[test]
278    #[should_panic(expected = "pool max_size must be at least 1")]
279    fn zero_size_panics() {
280        let _pool = ConnectionPool::<u32>::new(0);
281    }
282
283    #[test]
284    #[should_panic(expected = "pool is at capacity")]
285    fn add_beyond_capacity_panics() {
286        let pool = ConnectionPool::new(1);
287        pool.add(1);
288        pool.add(2);
289    }
290
291    #[test]
292    fn pool_error_display() {
293        assert!(format!("{}", PoolError::Timeout).contains("timed out"));
294        assert!(format!("{}", PoolError::Unavailable).contains("no connection"));
295    }
296}