1use std::sync::Arc;
4use tokio::sync::{Mutex, Semaphore};
5
6use crate::client::{Client, Connection};
7use crate::error::{Error, Result};
8
9pub struct ConnectionPool {
11 client: Client,
12 connections: Arc<Mutex<Vec<Connection>>>,
13 semaphore: Arc<Semaphore>,
14 max_size: usize,
15}
16
17impl ConnectionPool {
18 pub fn new(host: impl Into<String>, port: u16, max_size: usize) -> Self {
25 assert!(
26 max_size > 0,
27 "ConnectionPool max_size must be at least 1 (was 0). \
28 A pool with 0 connections would deadlock on acquire()."
29 );
30 Self {
31 client: Client::new(host, port),
32 connections: Arc::new(Mutex::new(Vec::new())),
33 semaphore: Arc::new(Semaphore::new(max_size)),
34 max_size,
35 }
36 }
37
38 pub fn skip_verify(mut self, skip: bool) -> Self {
40 self.client = self.client.skip_verify(skip);
41 self
42 }
43
44 pub fn page_size(mut self, size: usize) -> Self {
46 self.client = self.client.page_size(size);
47 self
48 }
49
50 pub async fn acquire(&self) -> Result<PooledConnection> {
56 let permit = Arc::clone(&self.semaphore)
57 .acquire_owned()
58 .await
59 .map_err(|_| Error::pool("Connection pool has been closed"))?;
60
61 let connection = loop {
63 let conn = {
64 let mut connections = self.connections.lock().await;
65 connections.pop()
66 };
67
68 match conn {
69 Some(c) if c.is_healthy() => {
70 break c;
72 }
73 Some(_) => {
74 continue;
77 }
78 None => {
79 let client = self.client.clone();
81 break client.connect().await?;
82 }
83 }
84 };
85
86 Ok(PooledConnection {
87 connection: Some(connection),
88 pool: self.connections.clone(),
89 _permit: permit,
90 })
91 }
92
93 pub async fn size(&self) -> usize {
95 self.connections.lock().await.len()
96 }
97
98 pub fn max_size(&self) -> usize {
100 self.max_size
101 }
102}
103
104pub struct PooledConnection {
106 connection: Option<Connection>,
107 pool: Arc<Mutex<Vec<Connection>>>,
108 _permit: tokio::sync::OwnedSemaphorePermit,
109}
110
111impl PooledConnection {
112 pub fn inner(&self) -> &Connection {
120 self.connection
121 .as_ref()
122 .expect("PooledConnection invariant violated: connection was None")
123 }
124}
125
126impl Drop for PooledConnection {
127 fn drop(&mut self) {
128 if let Some(conn) = self.connection.take() {
129 if conn.is_healthy() {
132 let pool = self.pool.clone();
133 tokio::spawn(async move {
134 let mut connections = pool.lock().await;
135 connections.push(conn);
136 });
137 }
138 }
140 }
141}
142
143impl std::ops::Deref for PooledConnection {
144 type Target = Connection;
145
146 fn deref(&self) -> &Self::Target {
147 self.connection
148 .as_ref()
149 .expect("PooledConnection invariant violated: connection was None")
150 }
151}
152
153impl std::ops::DerefMut for PooledConnection {
154 fn deref_mut(&mut self) -> &mut Self::Target {
155 self.connection
156 .as_mut()
157 .expect("PooledConnection invariant violated: connection was None")
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn test_connection_pool_new() {
167 let pool = ConnectionPool::new("localhost", 3141, 10);
168 assert_eq!(pool.max_size(), 10);
169 }
170
171 #[test]
172 fn test_connection_pool_new_different_host() {
173 let pool = ConnectionPool::new("192.168.1.100", 8443, 5);
174 assert_eq!(pool.max_size(), 5);
175 }
176
177 #[test]
178 fn test_connection_pool_new_string_host() {
179 let host = String::from("geode.example.com");
180 let pool = ConnectionPool::new(host, 3141, 20);
181 assert_eq!(pool.max_size(), 20);
182 }
183
184 #[test]
185 fn test_connection_pool_skip_verify() {
186 let pool = ConnectionPool::new("localhost", 3141, 10).skip_verify(true);
187 assert_eq!(pool.max_size(), 10);
189 }
190
191 #[test]
192 fn test_connection_pool_skip_verify_false() {
193 let pool = ConnectionPool::new("localhost", 3141, 10).skip_verify(false);
194 assert_eq!(pool.max_size(), 10);
195 }
196
197 #[test]
198 fn test_connection_pool_page_size() {
199 let pool = ConnectionPool::new("localhost", 3141, 10).page_size(500);
200 assert_eq!(pool.max_size(), 10);
201 }
202
203 #[test]
204 fn test_connection_pool_chained_config() {
205 let pool = ConnectionPool::new("localhost", 3141, 10)
206 .skip_verify(true)
207 .page_size(1000);
208 assert_eq!(pool.max_size(), 10);
209 }
210
211 #[tokio::test]
212 async fn test_connection_pool_initial_size() {
213 let pool = ConnectionPool::new("localhost", 3141, 10);
214 assert_eq!(pool.size().await, 0);
216 }
217
218 #[test]
219 #[should_panic(expected = "ConnectionPool max_size must be at least 1")]
220 fn test_connection_pool_max_size_zero_panics() {
221 let _pool = ConnectionPool::new("localhost", 3141, 0);
224 }
225
226 #[test]
227 fn test_connection_pool_max_size_one() {
228 let pool = ConnectionPool::new("localhost", 3141, 1);
229 assert_eq!(pool.max_size(), 1);
230 }
231
232 #[test]
233 fn test_connection_pool_max_size_large() {
234 let pool = ConnectionPool::new("localhost", 3141, 1000);
235 assert_eq!(pool.max_size(), 1000);
236 }
237
238 #[test]
247 fn test_semaphore_permits_match_max_size() {
248 let pool = ConnectionPool::new("localhost", 3141, 5);
249 assert_eq!(pool.semaphore.available_permits(), 5);
251 }
252
253 #[test]
254 fn test_connections_vec_initially_empty() {
255 let pool = ConnectionPool::new("localhost", 3141, 10);
256 assert_eq!(pool.max_size(), 10);
259 }
260}