1use crate::{Transport, Connection, TransportError, SshConfig, StdioTransport};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8use tokio::time::{sleep, timeout};
9use tracing::{debug, info, warn};
10use uuid::Uuid;
11
12#[derive(Debug, Clone)]
14pub struct PoolConfig {
15 pub max_connections_per_host: usize,
17 pub max_idle_time: Duration,
19 pub connection_timeout: Duration,
21 pub health_check_interval: Duration,
23 pub max_retries: u32,
25 pub retry_delay: Duration,
27}
28
29impl Default for PoolConfig {
30 fn default() -> Self {
31 Self {
32 max_connections_per_host: 10,
33 max_idle_time: Duration::from_secs(300), connection_timeout: Duration::from_secs(30),
35 health_check_interval: Duration::from_secs(60), max_retries: 3,
37 retry_delay: Duration::from_secs(1),
38 }
39 }
40}
41
42#[derive(Debug)]
44struct PoolEntry {
45 connection: Connection,
47 last_used: Instant,
49 healthy: bool,
51 use_count: u64,
53}
54
55pub struct ConnectionPool {
57 config: PoolConfig,
59 connections: Arc<RwLock<HashMap<String, Vec<PoolEntry>>>>,
61 ssh_configs: Arc<RwLock<HashMap<String, SshConfig>>>,
63 health_check_handle: Option<tokio::task::JoinHandle<()>>,
65}
66
67pub struct PooledConnection {
69 id: Uuid,
71 host_key: String,
73 connection: Option<Connection>,
75 pool: Arc<ConnectionPool>,
77}
78
79impl ConnectionPool {
80 pub fn new(config: PoolConfig) -> Self {
82 let pool = Self {
83 config,
84 connections: Arc::new(RwLock::new(HashMap::new())),
85 ssh_configs: Arc::new(RwLock::new(HashMap::new())),
86 health_check_handle: None,
87 };
88
89 pool
90 }
91
92 pub async fn start(&mut self) -> Result<(), TransportError> {
94 info!("Starting connection pool");
95
96 let connections = Arc::clone(&self.connections);
98 let config = self.config.clone();
99
100 let handle = tokio::spawn(async move {
101 Self::health_check_loop(connections, config).await;
102 });
103
104 self.health_check_handle = Some(handle);
105 Ok(())
106 }
107
108 pub async fn stop(&mut self) -> Result<(), TransportError> {
110 info!("Stopping connection pool");
111
112 if let Some(handle) = self.health_check_handle.take() {
114 handle.abort();
115 }
116
117 let mut connections = self.connections.write().await;
119 for (host, entries) in connections.drain() {
120 info!("Closing {} connections for host: {}", entries.len(), host);
121 for mut entry in entries {
122 if let Err(e) = entry.connection.close().await {
123 warn!("Error closing connection to {}: {}", host, e);
124 }
125 }
126 }
127
128 Ok(())
129 }
130
131 pub async fn add_host(&self, host: String, config: SshConfig) {
133 let mut configs = self.ssh_configs.write().await;
134 configs.insert(host.clone(), config);
135 debug!("Added SSH configuration for host: {}", host);
136 }
137
138 pub async fn get_connection(&self, host: &str) -> Result<PooledConnection, TransportError> {
140 let host_key = host.to_string();
141
142 if let Some(connection) = self.get_existing_connection(&host_key).await? {
144 return Ok(connection);
145 }
146
147 self.create_new_connection(&host_key).await
149 }
150
151 async fn get_existing_connection(&self, host_key: &str) -> Result<Option<PooledConnection>, TransportError> {
153 let mut connections = self.connections.write().await;
154
155 if let Some(entries) = connections.get_mut(host_key) {
156 for (i, entry) in entries.iter().enumerate() {
158 if entry.healthy && entry.connection.is_connected() {
159 let mut entry = entries.remove(i);
160 entry.last_used = Instant::now();
161 entry.use_count += 1;
162
163 debug!("Reusing existing connection to {}", host_key);
164
165 return Ok(Some(PooledConnection {
166 id: Uuid::new_v4(),
167 host_key: host_key.to_string(),
168 connection: Some(entry.connection),
169 pool: Arc::new(self.clone()),
170 }));
171 }
172 }
173 }
174
175 Ok(None)
176 }
177
178 async fn create_new_connection(&self, host_key: &str) -> Result<PooledConnection, TransportError> {
180 {
182 let connections = self.connections.read().await;
183 if let Some(entries) = connections.get(host_key) {
184 if entries.len() >= self.config.max_connections_per_host {
185 return Err(TransportError::Configuration(
186 format!("Maximum connections reached for host: {}", host_key)
187 ));
188 }
189 }
190 }
191
192 let ssh_config = {
194 let configs = self.ssh_configs.read().await;
195 configs.get(host_key).cloned()
196 .ok_or_else(|| TransportError::Configuration(
197 format!("No SSH configuration found for host: {}", host_key)
198 ))?
199 };
200
201 debug!("Creating new connection to {}", host_key);
202
203 let connection = self.connect_with_retries(ssh_config).await?;
205
206 info!("Successfully created new connection to {}", host_key);
207
208 Ok(PooledConnection {
209 id: Uuid::new_v4(),
210 host_key: host_key.to_string(),
211 connection: Some(connection),
212 pool: Arc::new(self.clone()),
213 })
214 }
215
216 async fn connect_with_retries(&self, ssh_config: SshConfig) -> Result<Connection, TransportError> {
218 let mut last_error = None;
219
220 for attempt in 1..=self.config.max_retries {
221 debug!("Connection attempt {} of {}", attempt, self.config.max_retries);
222
223 let mut transport = StdioTransport::new(ssh_config.clone());
224
225 match timeout(self.config.connection_timeout, transport.connect()).await {
226 Ok(Ok(connection)) => {
227 debug!("Connection successful on attempt {}", attempt);
228 return Ok(connection);
229 }
230 Ok(Err(e)) => {
231 warn!("Connection attempt {} failed: {}", attempt, e);
232 last_error = Some(e);
233 }
234 Err(_) => {
235 let timeout_error = TransportError::Timeout;
236 warn!("Connection attempt {} timed out", attempt);
237 last_error = Some(timeout_error);
238 }
239 }
240
241 if attempt < self.config.max_retries {
242 sleep(self.config.retry_delay).await;
243 }
244 }
245
246 Err(last_error.unwrap_or_else(|| {
247 TransportError::Connection("All connection attempts failed".to_string())
248 }))
249 }
250
251 async fn return_connection(&self, host_key: String, connection: Connection) -> Result<(), TransportError> {
253 if !connection.is_connected() {
254 debug!("Not returning disconnected connection to pool");
255 return Ok(());
256 }
257
258 let entry = PoolEntry {
259 connection,
260 last_used: Instant::now(),
261 healthy: true,
262 use_count: 1,
263 };
264
265 let mut connections = self.connections.write().await;
266 let entries = connections.entry(host_key.clone()).or_insert_with(Vec::new);
267
268 if entries.len() < self.config.max_connections_per_host {
270 entries.push(entry);
271 debug!("Returned connection to pool for host: {}", host_key);
272 } else {
273 debug!("Pool full, closing connection for host: {}", host_key);
274 drop(entry);
276 }
277
278 Ok(())
279 }
280
281 async fn health_check_loop(
283 connections: Arc<RwLock<HashMap<String, Vec<PoolEntry>>>>,
284 config: PoolConfig,
285 ) {
286 let mut interval = tokio::time::interval(config.health_check_interval);
287
288 loop {
289 interval.tick().await;
290
291 debug!("Running connection health check");
292
293 let mut connections_guard = connections.write().await;
294 let now = Instant::now();
295
296 for (host, entries) in connections_guard.iter_mut() {
297 entries.retain_mut(|entry| {
298 if now.duration_since(entry.last_used) > config.max_idle_time {
300 debug!("Closing idle connection to {}", host);
301 let _ = entry.connection.close();
302 return false;
303 }
304
305 if !entry.connection.is_connected() {
307 debug!("Removing unhealthy connection to {}", host);
308 entry.healthy = false;
309 return false;
310 }
311
312 true
313 });
314 }
315
316 connections_guard.retain(|_, entries| !entries.is_empty());
318 }
319 }
320
321 pub async fn stats(&self) -> PoolStats {
323 let connections = self.connections.read().await;
324 let mut total_connections = 0;
325 let mut healthy_connections = 0;
326 let mut hosts = 0;
327
328 for (_, entries) in connections.iter() {
329 hosts += 1;
330 for entry in entries {
331 total_connections += 1;
332 if entry.healthy {
333 healthy_connections += 1;
334 }
335 }
336 }
337
338 PoolStats {
339 total_connections,
340 healthy_connections,
341 hosts,
342 }
343 }
344}
345
346impl Clone for ConnectionPool {
347 fn clone(&self) -> Self {
348 Self {
349 config: self.config.clone(),
350 connections: Arc::clone(&self.connections),
351 ssh_configs: Arc::clone(&self.ssh_configs),
352 health_check_handle: None, }
354 }
355}
356
357impl Drop for ConnectionPool {
358 fn drop(&mut self) {
359 if let Some(handle) = self.health_check_handle.take() {
360 handle.abort();
361 }
362 }
363}
364
365#[derive(Debug, Clone)]
367pub struct PoolStats {
368 pub total_connections: usize,
370 pub healthy_connections: usize,
372 pub hosts: usize,
374}
375
376impl PooledConnection {
377 pub fn id(&self) -> Uuid {
379 self.id
380 }
381
382 pub fn host_key(&self) -> &str {
384 &self.host_key
385 }
386
387 pub fn connection_mut(&mut self) -> Option<&mut Connection> {
389 self.connection.as_mut()
390 }
391
392 pub fn take_connection(&mut self) -> Option<Connection> {
394 self.connection.take()
395 }
396
397 pub fn is_connected(&self) -> bool {
399 self.connection.as_ref().map_or(false, |c| c.is_connected())
400 }
401}
402
403impl Drop for PooledConnection {
404 fn drop(&mut self) {
405 if let Some(connection) = self.connection.take() {
406 let pool = Arc::clone(&self.pool);
407 let host_key = self.host_key.clone();
408
409 tokio::spawn(async move {
411 if let Err(e) = pool.return_connection(host_key, connection).await {
412 warn!("Failed to return connection to pool: {}", e);
413 }
414 });
415 }
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use crate::SshConfig;
423
424 #[test]
425 fn test_pool_config_default() {
426 let config = PoolConfig::default();
427 assert_eq!(config.max_connections_per_host, 10);
428 assert_eq!(config.max_idle_time, Duration::from_secs(300));
429 assert_eq!(config.connection_timeout, Duration::from_secs(30));
430 }
431
432 #[tokio::test]
433 async fn test_pool_creation() {
434 let config = PoolConfig::default();
435 let pool = ConnectionPool::new(config);
436
437 let stats = pool.stats().await;
438 assert_eq!(stats.total_connections, 0);
439 assert_eq!(stats.healthy_connections, 0);
440 assert_eq!(stats.hosts, 0);
441 }
442
443 #[tokio::test]
444 async fn test_add_host() {
445 let config = PoolConfig::default();
446 let pool = ConnectionPool::new(config);
447
448 let ssh_config = SshConfig::default();
449 pool.add_host("test.example.com".to_string(), ssh_config).await;
450
451 let configs = pool.ssh_configs.read().await;
453 assert!(configs.contains_key("test.example.com"));
454 }
455
456 #[tokio::test]
457 async fn test_pool_start_stop() {
458 let config = PoolConfig::default();
459 let mut pool = ConnectionPool::new(config);
460
461 pool.start().await.unwrap();
463 assert!(pool.health_check_handle.is_some());
464
465 pool.stop().await.unwrap();
467 assert!(pool.health_check_handle.is_none());
468 }
469
470 #[tokio::test]
471 async fn test_pooled_connection_properties() {
472 let config = PoolConfig::default();
473 let pool = Arc::new(ConnectionPool::new(config));
474
475 let pooled_conn = PooledConnection {
476 id: Uuid::new_v4(),
477 host_key: "test.example.com".to_string(),
478 connection: Some(Connection::new(None)),
479 pool,
480 };
481
482 assert_eq!(pooled_conn.host_key(), "test.example.com");
483 assert!(!pooled_conn.is_connected()); }
485
486 #[test]
487 fn test_pool_stats() {
488 let stats = PoolStats {
489 total_connections: 5,
490 healthy_connections: 4,
491 hosts: 2,
492 };
493
494 assert_eq!(stats.total_connections, 5);
495 assert_eq!(stats.healthy_connections, 4);
496 assert_eq!(stats.hosts, 2);
497 }
498}