1use crate::error::{Error, Result};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use tokio::net::TcpStream;
16use tokio::sync::Mutex;
17
18const DEFAULT_MAX_CONNECTIONS_PER_HOST: usize = 32;
20
21const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(90);
23
24const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
26
27struct PooledConnection {
29 stream: TcpStream,
30 last_used: Instant,
31}
32
33impl PooledConnection {
34 fn new(stream: TcpStream) -> Self {
35 Self {
36 stream,
37 last_used: Instant::now(),
38 }
39 }
40
41 fn is_expired(&self, idle_timeout: Duration) -> bool {
42 self.last_used.elapsed() > idle_timeout
43 }
44
45 fn update_last_used(&mut self) {
46 self.last_used = Instant::now();
47 }
48}
49
50#[derive(Clone, Debug)]
52pub struct PoolConfig {
53 pub max_connections_per_host: usize,
55 pub idle_timeout: Duration,
57 pub connection_timeout: Duration,
59 pub keep_alive: bool,
61}
62
63impl Default for PoolConfig {
64 fn default() -> Self {
65 Self {
66 max_connections_per_host: DEFAULT_MAX_CONNECTIONS_PER_HOST,
67 idle_timeout: DEFAULT_IDLE_TIMEOUT,
68 connection_timeout: DEFAULT_CONNECTION_TIMEOUT,
69 keep_alive: true,
70 }
71 }
72}
73
74pub struct ConnectionPool {
76 config: PoolConfig,
77 idle_connections: Arc<Mutex<HashMap<String, Vec<PooledConnection>>>>,
79 active_counts: Arc<Mutex<HashMap<String, usize>>>,
81}
82
83impl ConnectionPool {
84 pub fn new() -> Self {
86 Self::with_config(PoolConfig::default())
87 }
88
89 pub fn with_config(config: PoolConfig) -> Self {
91 Self {
92 config,
93 idle_connections: Arc::new(Mutex::new(HashMap::new())),
94 active_counts: Arc::new(Mutex::new(HashMap::new())),
95 }
96 }
97
98 pub async fn get_connection(&self, host: &str, port: u16) -> Result<TcpStream> {
100 let key = format!("{}:{}", host, port);
101
102 {
104 let mut idle = self.idle_connections.lock().await;
105 if let Some(connections) = idle.get_mut(&key) {
106 connections.retain(|conn| !conn.is_expired(self.config.idle_timeout));
108
109 while let Some(mut conn) = connections.pop() {
111 conn.update_last_used();
112
113 let mut active = self.active_counts.lock().await;
115 *active.entry(key.clone()).or_insert(0) += 1;
116
117 return Ok(conn.stream);
118 }
119 }
120 }
121
122 {
124 let active = self.active_counts.lock().await;
125 let count = active.get(&key).copied().unwrap_or(0);
126 if count >= self.config.max_connections_per_host {
127 return Err(Error::TooManyConnections {
128 host: key,
129 max: self.config.max_connections_per_host,
130 });
131 }
132 }
133
134 let addr = format!("{}:{}", host, port);
136 let stream = tokio::time::timeout(
137 self.config.connection_timeout,
138 TcpStream::connect(&addr),
139 )
140 .await
141 .map_err(|_| Error::ConnectionTimeout { host: addr.clone() })?
142 .map_err(|e| Error::ConnectionFailed {
143 host: addr,
144 source: e,
145 })?;
146
147 if self.config.keep_alive {
149 #[cfg(unix)]
150 {
151 use std::os::unix::io::AsRawFd;
152 let fd = stream.as_raw_fd();
153 unsafe {
154 let optval: libc::c_int = 1;
155 libc::setsockopt(
156 fd,
157 libc::SOL_SOCKET,
158 libc::SO_KEEPALIVE,
159 &optval as *const _ as *const libc::c_void,
160 std::mem::size_of_val(&optval) as libc::socklen_t,
161 );
162 }
163 }
164 #[cfg(windows)]
165 {
166 use std::os::windows::io::AsRawSocket;
167 let socket = stream.as_raw_socket();
168 unsafe {
169 let optval: u32 = 1;
170 windows_sys::Win32::Networking::WinSock::setsockopt(
171 socket as usize,
172 windows_sys::Win32::Networking::WinSock::SOL_SOCKET,
173 windows_sys::Win32::Networking::WinSock::SO_KEEPALIVE,
174 &optval as *const _ as *const u8,
175 std::mem::size_of_val(&optval) as i32,
176 );
177 }
178 }
179 }
180
181 let mut active = self.active_counts.lock().await;
183 *active.entry(key).or_insert(0) += 1;
184
185 Ok(stream)
186 }
187
188 pub async fn return_connection(&self, host: &str, port: u16, stream: TcpStream) {
190 let key = format!("{}:{}", host, port);
191
192 {
194 let mut active = self.active_counts.lock().await;
195 if let Some(count) = active.get_mut(&key) {
196 *count = count.saturating_sub(1);
197 }
198 }
199
200 let mut idle = self.idle_connections.lock().await;
202 let connections = idle.entry(key.clone()).or_insert_with(Vec::new);
203
204 if connections.len() < self.config.max_connections_per_host {
206 connections.push(PooledConnection::new(stream));
207 }
208 }
209
210 pub async fn cleanup_expired(&self) {
212 let mut idle = self.idle_connections.lock().await;
213 for connections in idle.values_mut() {
214 connections.retain(|conn| !conn.is_expired(self.config.idle_timeout));
215 }
216 idle.retain(|_, v| !v.is_empty());
218 }
219
220 pub async fn stats(&self) -> PoolStats {
222 let idle = self.idle_connections.lock().await;
223 let active = self.active_counts.lock().await;
224
225 let total_idle: usize = idle.values().map(|v| v.len()).sum();
226 let total_active: usize = active.values().sum();
227
228 PoolStats {
229 idle_connections: total_idle,
230 active_connections: total_active,
231 hosts: idle.len(),
232 }
233 }
234
235 pub async fn close_all(&self) {
237 let mut idle = self.idle_connections.lock().await;
238 idle.clear();
239
240 let mut active = self.active_counts.lock().await;
241 active.clear();
242 }
243}
244
245impl Default for ConnectionPool {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct PoolStats {
254 pub idle_connections: usize,
256 pub active_connections: usize,
258 pub hosts: usize,
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[tokio::test]
267 async fn test_pool_creation() {
268 let pool = ConnectionPool::new();
269 let stats = pool.stats().await;
270 assert_eq!(stats.idle_connections, 0);
271 assert_eq!(stats.active_connections, 0);
272 }
273
274 #[tokio::test]
275 async fn test_pool_config() {
276 let config = PoolConfig {
277 max_connections_per_host: 50,
278 idle_timeout: Duration::from_secs(120),
279 connection_timeout: Duration::from_secs(5),
280 keep_alive: true,
281 };
282 let pool = ConnectionPool::with_config(config);
283 assert_eq!(pool.config.max_connections_per_host, 50);
284 }
285
286 #[tokio::test]
287 async fn test_cleanup_expired() {
288 let pool = ConnectionPool::new();
289 pool.cleanup_expired().await;
290 let stats = pool.stats().await;
291 assert_eq!(stats.idle_connections, 0);
292 }
293}