Skip to main content

geode_client/
pool.rs

1//! Connection pooling for Geode connections.
2
3use std::sync::Arc;
4use tokio::sync::{Mutex, Semaphore};
5
6use crate::client::{Client, Connection};
7use crate::error::Result;
8
9/// Connection pool for managing QUIC connections
10pub struct ConnectionPool {
11    client: Client,
12    connections: Arc<Mutex<Vec<Connection>>>,
13    semaphore: Arc<Semaphore>,
14    max_size: usize,
15}
16
17impl ConnectionPool {
18    /// Create a new connection pool
19    pub fn new(host: impl Into<String>, port: u16, max_size: usize) -> Self {
20        Self {
21            client: Client::new(host, port),
22            connections: Arc::new(Mutex::new(Vec::new())),
23            semaphore: Arc::new(Semaphore::new(max_size)),
24            max_size,
25        }
26    }
27
28    /// Configure to skip TLS verification
29    pub fn skip_verify(mut self, skip: bool) -> Self {
30        self.client = self.client.skip_verify(skip);
31        self
32    }
33
34    /// Set page size for queries
35    pub fn page_size(mut self, size: usize) -> Self {
36        self.client = self.client.page_size(size);
37        self
38    }
39
40    /// Acquire a connection from the pool
41    pub async fn acquire(&self) -> Result<PooledConnection> {
42        let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
43
44        // Try to get an existing connection
45        let conn = {
46            let mut connections = self.connections.lock().await;
47            connections.pop()
48        };
49
50        let connection = match conn {
51            Some(c) => c,
52            None => {
53                let client = self.client.clone();
54                // Quinn is async, directly await
55                client.connect().await?
56            }
57        };
58
59        Ok(PooledConnection {
60            connection: Some(connection),
61            pool: self.connections.clone(),
62            _permit: permit,
63        })
64    }
65
66    /// Get current pool size
67    pub async fn size(&self) -> usize {
68        self.connections.lock().await.len()
69    }
70
71    /// Get the maximum pool size
72    pub fn max_size(&self) -> usize {
73        self.max_size
74    }
75}
76
77/// A pooled connection that returns to the pool when dropped
78pub struct PooledConnection {
79    connection: Option<Connection>,
80    pool: Arc<Mutex<Vec<Connection>>>,
81    _permit: tokio::sync::OwnedSemaphorePermit,
82}
83
84impl PooledConnection {
85    /// Get a reference to the underlying connection
86    pub fn inner(&self) -> &Connection {
87        self.connection.as_ref().unwrap()
88    }
89}
90
91impl Drop for PooledConnection {
92    fn drop(&mut self) {
93        if let Some(conn) = self.connection.take() {
94            let pool = self.pool.clone();
95            tokio::spawn(async move {
96                let mut connections = pool.lock().await;
97                connections.push(conn);
98            });
99        }
100    }
101}
102
103impl std::ops::Deref for PooledConnection {
104    type Target = Connection;
105
106    fn deref(&self) -> &Self::Target {
107        self.connection.as_ref().unwrap()
108    }
109}
110
111impl std::ops::DerefMut for PooledConnection {
112    fn deref_mut(&mut self) -> &mut Self::Target {
113        self.connection.as_mut().unwrap()
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn test_connection_pool_new() {
123        let pool = ConnectionPool::new("localhost", 3141, 10);
124        assert_eq!(pool.max_size(), 10);
125    }
126
127    #[test]
128    fn test_connection_pool_new_different_host() {
129        let pool = ConnectionPool::new("192.168.1.100", 8443, 5);
130        assert_eq!(pool.max_size(), 5);
131    }
132
133    #[test]
134    fn test_connection_pool_new_string_host() {
135        let host = String::from("geode.example.com");
136        let pool = ConnectionPool::new(host, 3141, 20);
137        assert_eq!(pool.max_size(), 20);
138    }
139
140    #[test]
141    fn test_connection_pool_skip_verify() {
142        let pool = ConnectionPool::new("localhost", 3141, 10).skip_verify(true);
143        // Configuration is passed through to client
144        assert_eq!(pool.max_size(), 10);
145    }
146
147    #[test]
148    fn test_connection_pool_skip_verify_false() {
149        let pool = ConnectionPool::new("localhost", 3141, 10).skip_verify(false);
150        assert_eq!(pool.max_size(), 10);
151    }
152
153    #[test]
154    fn test_connection_pool_page_size() {
155        let pool = ConnectionPool::new("localhost", 3141, 10).page_size(500);
156        assert_eq!(pool.max_size(), 10);
157    }
158
159    #[test]
160    fn test_connection_pool_chained_config() {
161        let pool = ConnectionPool::new("localhost", 3141, 10)
162            .skip_verify(true)
163            .page_size(1000);
164        assert_eq!(pool.max_size(), 10);
165    }
166
167    #[tokio::test]
168    async fn test_connection_pool_initial_size() {
169        let pool = ConnectionPool::new("localhost", 3141, 10);
170        // Pool starts empty
171        assert_eq!(pool.size().await, 0);
172    }
173
174    #[test]
175    fn test_connection_pool_max_size_zero() {
176        // Edge case: pool with zero max size
177        let pool = ConnectionPool::new("localhost", 3141, 0);
178        assert_eq!(pool.max_size(), 0);
179    }
180
181    #[test]
182    fn test_connection_pool_max_size_one() {
183        let pool = ConnectionPool::new("localhost", 3141, 1);
184        assert_eq!(pool.max_size(), 1);
185    }
186
187    #[test]
188    fn test_connection_pool_max_size_large() {
189        let pool = ConnectionPool::new("localhost", 3141, 1000);
190        assert_eq!(pool.max_size(), 1000);
191    }
192
193    // Note: Full integration tests for acquire() require a running Geode server
194    // and are covered in the integration test suite.
195
196    // The following tests verify the structural aspects of PooledConnection
197    // without actually establishing connections.
198
199    #[test]
200    fn test_semaphore_permits_match_max_size() {
201        let pool = ConnectionPool::new("localhost", 3141, 5);
202        // Semaphore should have permits equal to max_size
203        assert_eq!(pool.semaphore.available_permits(), 5);
204    }
205
206    #[test]
207    fn test_connections_vec_initially_empty() {
208        let pool = ConnectionPool::new("localhost", 3141, 10);
209        // We can't directly access the mutex contents in a sync test,
210        // but we verified size() returns 0 in the async test above
211        assert_eq!(pool.max_size(), 10);
212    }
213}