Skip to main content

geode_client/
pool.rs

1//! Connection pooling for Geode connections.
2
3use std::sync::Arc;
4use tokio::sync::{Mutex, Semaphore};
5
6use crate::client::{Client, Connection};
7use crate::error::{Error, Result};
8
9/// Connection pool for managing QUIC connections
10pub struct ConnectionPool {
11    client: Client,
12    connections: Arc<Mutex<Vec<Connection>>>,
13    semaphore: Arc<Semaphore>,
14    max_size: usize,
15}
16
17impl ConnectionPool {
18    /// Create a new connection pool.
19    ///
20    /// # Panics
21    ///
22    /// Panics if `max_size` is 0. A connection pool must have at least one
23    /// connection slot to function properly. (Gap #18: CWE-400, CWE-835)
24    pub fn new(host: impl Into<String>, port: u16, max_size: usize) -> Self {
25        assert!(
26            max_size > 0,
27            "ConnectionPool max_size must be at least 1 (was 0). \
28             A pool with 0 connections would deadlock on acquire()."
29        );
30        Self {
31            client: Client::new(host, port),
32            connections: Arc::new(Mutex::new(Vec::new())),
33            semaphore: Arc::new(Semaphore::new(max_size)),
34            max_size,
35        }
36    }
37
38    /// Configure to skip TLS verification
39    pub fn skip_verify(mut self, skip: bool) -> Self {
40        self.client = self.client.skip_verify(skip);
41        self
42    }
43
44    /// Set page size for queries
45    pub fn page_size(mut self, size: usize) -> Self {
46        self.client = self.client.page_size(size);
47        self
48    }
49
50    /// Acquire a connection from the pool
51    ///
52    /// Returns a healthy connection from the pool, or creates a new one if needed.
53    /// Stale connections (those where the underlying QUIC connection has been closed)
54    /// are automatically discarded during acquisition.
55    pub async fn acquire(&self) -> Result<PooledConnection> {
56        let permit = Arc::clone(&self.semaphore)
57            .acquire_owned()
58            .await
59            .map_err(|_| Error::pool("Connection pool has been closed"))?;
60
61        // Try to get a healthy existing connection
62        let connection = loop {
63            let conn = {
64                let mut connections = self.connections.lock().await;
65                connections.pop()
66            };
67
68            match conn {
69                Some(c) if c.is_healthy() => {
70                    // Connection is healthy, use it
71                    break c;
72                }
73                Some(_) => {
74                    // Connection is stale, discard it and try another
75                    // (the connection is dropped here, cleaning up resources)
76                    continue;
77                }
78                None => {
79                    // No pooled connections available, create a new one
80                    let client = self.client.clone();
81                    break client.connect().await?;
82                }
83            }
84        };
85
86        Ok(PooledConnection {
87            connection: Some(connection),
88            pool: self.connections.clone(),
89            _permit: permit,
90        })
91    }
92
93    /// Get current pool size
94    pub async fn size(&self) -> usize {
95        self.connections.lock().await.len()
96    }
97
98    /// Get the maximum pool size
99    pub fn max_size(&self) -> usize {
100        self.max_size
101    }
102}
103
104/// A pooled connection that returns to the pool when dropped
105pub struct PooledConnection {
106    connection: Option<Connection>,
107    pool: Arc<Mutex<Vec<Connection>>>,
108    _permit: tokio::sync::OwnedSemaphorePermit,
109}
110
111impl PooledConnection {
112    /// Get a reference to the underlying connection.
113    ///
114    /// # Panics
115    ///
116    /// Panics if called after the connection has been dropped or taken.
117    /// This should never happen in normal usage as the connection is only
118    /// taken during Drop.
119    pub fn inner(&self) -> &Connection {
120        self.connection
121            .as_ref()
122            .expect("PooledConnection invariant violated: connection was None")
123    }
124}
125
126impl Drop for PooledConnection {
127    fn drop(&mut self) {
128        if let Some(conn) = self.connection.take() {
129            // Only return healthy connections to the pool
130            // Stale connections are dropped, cleaning up their resources
131            if conn.is_healthy() {
132                let pool = self.pool.clone();
133                tokio::spawn(async move {
134                    let mut connections = pool.lock().await;
135                    connections.push(conn);
136                });
137            }
138            // If unhealthy, conn is dropped here and not returned to pool
139        }
140    }
141}
142
143impl std::ops::Deref for PooledConnection {
144    type Target = Connection;
145
146    fn deref(&self) -> &Self::Target {
147        self.connection
148            .as_ref()
149            .expect("PooledConnection invariant violated: connection was None")
150    }
151}
152
153impl std::ops::DerefMut for PooledConnection {
154    fn deref_mut(&mut self) -> &mut Self::Target {
155        self.connection
156            .as_mut()
157            .expect("PooledConnection invariant violated: connection was None")
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_connection_pool_new() {
167        let pool = ConnectionPool::new("localhost", 3141, 10);
168        assert_eq!(pool.max_size(), 10);
169    }
170
171    #[test]
172    fn test_connection_pool_new_different_host() {
173        let pool = ConnectionPool::new("192.168.1.100", 8443, 5);
174        assert_eq!(pool.max_size(), 5);
175    }
176
177    #[test]
178    fn test_connection_pool_new_string_host() {
179        let host = String::from("geode.example.com");
180        let pool = ConnectionPool::new(host, 3141, 20);
181        assert_eq!(pool.max_size(), 20);
182    }
183
184    #[test]
185    fn test_connection_pool_skip_verify() {
186        let pool = ConnectionPool::new("localhost", 3141, 10).skip_verify(true);
187        // Configuration is passed through to client
188        assert_eq!(pool.max_size(), 10);
189    }
190
191    #[test]
192    fn test_connection_pool_skip_verify_false() {
193        let pool = ConnectionPool::new("localhost", 3141, 10).skip_verify(false);
194        assert_eq!(pool.max_size(), 10);
195    }
196
197    #[test]
198    fn test_connection_pool_page_size() {
199        let pool = ConnectionPool::new("localhost", 3141, 10).page_size(500);
200        assert_eq!(pool.max_size(), 10);
201    }
202
203    #[test]
204    fn test_connection_pool_chained_config() {
205        let pool = ConnectionPool::new("localhost", 3141, 10)
206            .skip_verify(true)
207            .page_size(1000);
208        assert_eq!(pool.max_size(), 10);
209    }
210
211    #[tokio::test]
212    async fn test_connection_pool_initial_size() {
213        let pool = ConnectionPool::new("localhost", 3141, 10);
214        // Pool starts empty
215        assert_eq!(pool.size().await, 0);
216    }
217
218    #[test]
219    #[should_panic(expected = "ConnectionPool max_size must be at least 1")]
220    fn test_connection_pool_max_size_zero_panics() {
221        // Gap #18: max_size=0 would cause deadlock on acquire() since semaphore
222        // would have 0 permits. Now properly panics at construction time.
223        let _pool = ConnectionPool::new("localhost", 3141, 0);
224    }
225
226    #[test]
227    fn test_connection_pool_max_size_one() {
228        let pool = ConnectionPool::new("localhost", 3141, 1);
229        assert_eq!(pool.max_size(), 1);
230    }
231
232    #[test]
233    fn test_connection_pool_max_size_large() {
234        let pool = ConnectionPool::new("localhost", 3141, 1000);
235        assert_eq!(pool.max_size(), 1000);
236    }
237
238    // Note: Full integration tests for acquire() and health checking require
239    // a running Geode server and are covered in the integration test suite.
240    // The health check functionality (is_healthy(), stale connection discard)
241    // is verified in integration tests with real connections (Gap #16).
242
243    // The following tests verify the structural aspects of PooledConnection
244    // without actually establishing connections.
245
246    #[test]
247    fn test_semaphore_permits_match_max_size() {
248        let pool = ConnectionPool::new("localhost", 3141, 5);
249        // Semaphore should have permits equal to max_size
250        assert_eq!(pool.semaphore.available_permits(), 5);
251    }
252
253    #[test]
254    fn test_connections_vec_initially_empty() {
255        let pool = ConnectionPool::new("localhost", 3141, 10);
256        // We can't directly access the mutex contents in a sync test,
257        // but we verified size() returns 0 in the async test above
258        assert_eq!(pool.max_size(), 10);
259    }
260}