1use std::sync::Arc;
4use tokio::sync::{Mutex, Semaphore};
5
6use crate::client::{Client, Connection};
7use crate::error::Result;
8
9pub struct ConnectionPool {
11 client: Client,
12 connections: Arc<Mutex<Vec<Connection>>>,
13 semaphore: Arc<Semaphore>,
14 max_size: usize,
15}
16
17impl ConnectionPool {
18 pub fn new(host: impl Into<String>, port: u16, max_size: usize) -> Self {
20 Self {
21 client: Client::new(host, port),
22 connections: Arc::new(Mutex::new(Vec::new())),
23 semaphore: Arc::new(Semaphore::new(max_size)),
24 max_size,
25 }
26 }
27
28 pub fn skip_verify(mut self, skip: bool) -> Self {
30 self.client = self.client.skip_verify(skip);
31 self
32 }
33
34 pub fn page_size(mut self, size: usize) -> Self {
36 self.client = self.client.page_size(size);
37 self
38 }
39
40 pub async fn acquire(&self) -> Result<PooledConnection> {
42 let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap();
43
44 let conn = {
46 let mut connections = self.connections.lock().await;
47 connections.pop()
48 };
49
50 let connection = match conn {
51 Some(c) => c,
52 None => {
53 let client = self.client.clone();
54 client.connect().await?
56 }
57 };
58
59 Ok(PooledConnection {
60 connection: Some(connection),
61 pool: self.connections.clone(),
62 _permit: permit,
63 })
64 }
65
66 pub async fn size(&self) -> usize {
68 self.connections.lock().await.len()
69 }
70
71 pub fn max_size(&self) -> usize {
73 self.max_size
74 }
75}
76
77pub struct PooledConnection {
79 connection: Option<Connection>,
80 pool: Arc<Mutex<Vec<Connection>>>,
81 _permit: tokio::sync::OwnedSemaphorePermit,
82}
83
84impl PooledConnection {
85 pub fn inner(&self) -> &Connection {
87 self.connection.as_ref().unwrap()
88 }
89}
90
91impl Drop for PooledConnection {
92 fn drop(&mut self) {
93 if let Some(conn) = self.connection.take() {
94 let pool = self.pool.clone();
95 tokio::spawn(async move {
96 let mut connections = pool.lock().await;
97 connections.push(conn);
98 });
99 }
100 }
101}
102
103impl std::ops::Deref for PooledConnection {
104 type Target = Connection;
105
106 fn deref(&self) -> &Self::Target {
107 self.connection.as_ref().unwrap()
108 }
109}
110
111impl std::ops::DerefMut for PooledConnection {
112 fn deref_mut(&mut self) -> &mut Self::Target {
113 self.connection.as_mut().unwrap()
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn test_connection_pool_new() {
123 let pool = ConnectionPool::new("localhost", 3141, 10);
124 assert_eq!(pool.max_size(), 10);
125 }
126
127 #[test]
128 fn test_connection_pool_new_different_host() {
129 let pool = ConnectionPool::new("192.168.1.100", 8443, 5);
130 assert_eq!(pool.max_size(), 5);
131 }
132
133 #[test]
134 fn test_connection_pool_new_string_host() {
135 let host = String::from("geode.example.com");
136 let pool = ConnectionPool::new(host, 3141, 20);
137 assert_eq!(pool.max_size(), 20);
138 }
139
140 #[test]
141 fn test_connection_pool_skip_verify() {
142 let pool = ConnectionPool::new("localhost", 3141, 10).skip_verify(true);
143 assert_eq!(pool.max_size(), 10);
145 }
146
147 #[test]
148 fn test_connection_pool_skip_verify_false() {
149 let pool = ConnectionPool::new("localhost", 3141, 10).skip_verify(false);
150 assert_eq!(pool.max_size(), 10);
151 }
152
153 #[test]
154 fn test_connection_pool_page_size() {
155 let pool = ConnectionPool::new("localhost", 3141, 10).page_size(500);
156 assert_eq!(pool.max_size(), 10);
157 }
158
159 #[test]
160 fn test_connection_pool_chained_config() {
161 let pool = ConnectionPool::new("localhost", 3141, 10)
162 .skip_verify(true)
163 .page_size(1000);
164 assert_eq!(pool.max_size(), 10);
165 }
166
167 #[tokio::test]
168 async fn test_connection_pool_initial_size() {
169 let pool = ConnectionPool::new("localhost", 3141, 10);
170 assert_eq!(pool.size().await, 0);
172 }
173
174 #[test]
175 fn test_connection_pool_max_size_zero() {
176 let pool = ConnectionPool::new("localhost", 3141, 0);
178 assert_eq!(pool.max_size(), 0);
179 }
180
181 #[test]
182 fn test_connection_pool_max_size_one() {
183 let pool = ConnectionPool::new("localhost", 3141, 1);
184 assert_eq!(pool.max_size(), 1);
185 }
186
187 #[test]
188 fn test_connection_pool_max_size_large() {
189 let pool = ConnectionPool::new("localhost", 3141, 1000);
190 assert_eq!(pool.max_size(), 1000);
191 }
192
193 #[test]
200 fn test_semaphore_permits_match_max_size() {
201 let pool = ConnectionPool::new("localhost", 3141, 5);
202 assert_eq!(pool.semaphore.available_permits(), 5);
204 }
205
206 #[test]
207 fn test_connections_vec_initially_empty() {
208 let pool = ConnectionPool::new("localhost", 3141, 10);
209 assert_eq!(pool.max_size(), 10);
212 }
213}