mitoxide_ssh/
pool.rs

1//! Connection pool and management
2
3use crate::{Transport, Connection, TransportError, SshConfig, StdioTransport};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use tokio::sync::RwLock;
8use tokio::time::{sleep, timeout};
9use tracing::{debug, info, warn};
10use uuid::Uuid;
11
12/// Connection pool configuration
13#[derive(Debug, Clone)]
14pub struct PoolConfig {
15    /// Maximum number of connections per host
16    pub max_connections_per_host: usize,
17    /// Maximum idle time before connection is closed
18    pub max_idle_time: Duration,
19    /// Connection timeout
20    pub connection_timeout: Duration,
21    /// Health check interval
22    pub health_check_interval: Duration,
23    /// Maximum number of connection retries
24    pub max_retries: u32,
25    /// Retry delay
26    pub retry_delay: Duration,
27}
28
29impl Default for PoolConfig {
30    fn default() -> Self {
31        Self {
32            max_connections_per_host: 10,
33            max_idle_time: Duration::from_secs(300), // 5 minutes
34            connection_timeout: Duration::from_secs(30),
35            health_check_interval: Duration::from_secs(60), // 1 minute
36            max_retries: 3,
37            retry_delay: Duration::from_secs(1),
38        }
39    }
40}
41
42/// Connection pool entry
43#[derive(Debug)]
44struct PoolEntry {
45    /// Connection instance
46    connection: Connection,
47    /// Last used timestamp
48    last_used: Instant,
49    /// Connection health status
50    healthy: bool,
51    /// Number of times this connection has been used
52    use_count: u64,
53}
54
55/// Connection pool for managing SSH connections
56pub struct ConnectionPool {
57    /// Pool configuration
58    config: PoolConfig,
59    /// Active connections grouped by host
60    connections: Arc<RwLock<HashMap<String, Vec<PoolEntry>>>>,
61    /// Connection configurations
62    ssh_configs: Arc<RwLock<HashMap<String, SshConfig>>>,
63    /// Health check task handle
64    health_check_handle: Option<tokio::task::JoinHandle<()>>,
65}
66
67/// A pooled connection wrapper
68pub struct PooledConnection {
69    /// Connection ID
70    id: Uuid,
71    /// Host key
72    host_key: String,
73    /// Underlying connection
74    connection: Option<Connection>,
75    /// Reference to the pool for returning the connection
76    pool: Arc<ConnectionPool>,
77}
78
79impl ConnectionPool {
80    /// Create a new connection pool
81    pub fn new(config: PoolConfig) -> Self {
82        let pool = Self {
83            config,
84            connections: Arc::new(RwLock::new(HashMap::new())),
85            ssh_configs: Arc::new(RwLock::new(HashMap::new())),
86            health_check_handle: None,
87        };
88        
89        pool
90    }
91    
92    /// Start the connection pool with health checking
93    pub async fn start(&mut self) -> Result<(), TransportError> {
94        info!("Starting connection pool");
95        
96        // Start health check task
97        let connections = Arc::clone(&self.connections);
98        let config = self.config.clone();
99        
100        let handle = tokio::spawn(async move {
101            Self::health_check_loop(connections, config).await;
102        });
103        
104        self.health_check_handle = Some(handle);
105        Ok(())
106    }
107    
108    /// Stop the connection pool
109    pub async fn stop(&mut self) -> Result<(), TransportError> {
110        info!("Stopping connection pool");
111        
112        // Stop health check task
113        if let Some(handle) = self.health_check_handle.take() {
114            handle.abort();
115        }
116        
117        // Close all connections
118        let mut connections = self.connections.write().await;
119        for (host, entries) in connections.drain() {
120            info!("Closing {} connections for host: {}", entries.len(), host);
121            for mut entry in entries {
122                if let Err(e) = entry.connection.close().await {
123                    warn!("Error closing connection to {}: {}", host, e);
124                }
125            }
126        }
127        
128        Ok(())
129    }
130    
131    /// Add SSH configuration for a host
132    pub async fn add_host(&self, host: String, config: SshConfig) {
133        let mut configs = self.ssh_configs.write().await;
134        configs.insert(host.clone(), config);
135        debug!("Added SSH configuration for host: {}", host);
136    }
137    
138    /// Get a connection from the pool
139    pub async fn get_connection(&self, host: &str) -> Result<PooledConnection, TransportError> {
140        let host_key = host.to_string();
141        
142        // Try to get an existing connection
143        if let Some(connection) = self.get_existing_connection(&host_key).await? {
144            return Ok(connection);
145        }
146        
147        // Create a new connection
148        self.create_new_connection(&host_key).await
149    }
150    
151    /// Get an existing connection from the pool
152    async fn get_existing_connection(&self, host_key: &str) -> Result<Option<PooledConnection>, TransportError> {
153        let mut connections = self.connections.write().await;
154        
155        if let Some(entries) = connections.get_mut(host_key) {
156            // Find a healthy, idle connection
157            for (i, entry) in entries.iter().enumerate() {
158                if entry.healthy && entry.connection.is_connected() {
159                    let mut entry = entries.remove(i);
160                    entry.last_used = Instant::now();
161                    entry.use_count += 1;
162                    
163                    debug!("Reusing existing connection to {}", host_key);
164                    
165                    return Ok(Some(PooledConnection {
166                        id: Uuid::new_v4(),
167                        host_key: host_key.to_string(),
168                        connection: Some(entry.connection),
169                        pool: Arc::new(self.clone()),
170                    }));
171                }
172            }
173        }
174        
175        Ok(None)
176    }
177    
178    /// Create a new connection
179    async fn create_new_connection(&self, host_key: &str) -> Result<PooledConnection, TransportError> {
180        // Check if we've reached the connection limit
181        {
182            let connections = self.connections.read().await;
183            if let Some(entries) = connections.get(host_key) {
184                if entries.len() >= self.config.max_connections_per_host {
185                    return Err(TransportError::Configuration(
186                        format!("Maximum connections reached for host: {}", host_key)
187                    ));
188                }
189            }
190        }
191        
192        // Get SSH configuration
193        let ssh_config = {
194            let configs = self.ssh_configs.read().await;
195            configs.get(host_key).cloned()
196                .ok_or_else(|| TransportError::Configuration(
197                    format!("No SSH configuration found for host: {}", host_key)
198                ))?
199        };
200        
201        debug!("Creating new connection to {}", host_key);
202        
203        // Create transport and connect with retries
204        let connection = self.connect_with_retries(ssh_config).await?;
205        
206        info!("Successfully created new connection to {}", host_key);
207        
208        Ok(PooledConnection {
209            id: Uuid::new_v4(),
210            host_key: host_key.to_string(),
211            connection: Some(connection),
212            pool: Arc::new(self.clone()),
213        })
214    }
215    
216    /// Connect with retries
217    async fn connect_with_retries(&self, ssh_config: SshConfig) -> Result<Connection, TransportError> {
218        let mut last_error = None;
219        
220        for attempt in 1..=self.config.max_retries {
221            debug!("Connection attempt {} of {}", attempt, self.config.max_retries);
222            
223            let mut transport = StdioTransport::new(ssh_config.clone());
224            
225            match timeout(self.config.connection_timeout, transport.connect()).await {
226                Ok(Ok(connection)) => {
227                    debug!("Connection successful on attempt {}", attempt);
228                    return Ok(connection);
229                }
230                Ok(Err(e)) => {
231                    warn!("Connection attempt {} failed: {}", attempt, e);
232                    last_error = Some(e);
233                }
234                Err(_) => {
235                    let timeout_error = TransportError::Timeout;
236                    warn!("Connection attempt {} timed out", attempt);
237                    last_error = Some(timeout_error);
238                }
239            }
240            
241            if attempt < self.config.max_retries {
242                sleep(self.config.retry_delay).await;
243            }
244        }
245        
246        Err(last_error.unwrap_or_else(|| {
247            TransportError::Connection("All connection attempts failed".to_string())
248        }))
249    }
250    
251    /// Return a connection to the pool
252    async fn return_connection(&self, host_key: String, connection: Connection) -> Result<(), TransportError> {
253        if !connection.is_connected() {
254            debug!("Not returning disconnected connection to pool");
255            return Ok(());
256        }
257        
258        let entry = PoolEntry {
259            connection,
260            last_used: Instant::now(),
261            healthy: true,
262            use_count: 1,
263        };
264        
265        let mut connections = self.connections.write().await;
266        let entries = connections.entry(host_key.clone()).or_insert_with(Vec::new);
267        
268        // Check if we're under the limit
269        if entries.len() < self.config.max_connections_per_host {
270            entries.push(entry);
271            debug!("Returned connection to pool for host: {}", host_key);
272        } else {
273            debug!("Pool full, closing connection for host: {}", host_key);
274            // Pool is full, close the connection
275            drop(entry);
276        }
277        
278        Ok(())
279    }
280    
281    /// Health check loop
282    async fn health_check_loop(
283        connections: Arc<RwLock<HashMap<String, Vec<PoolEntry>>>>,
284        config: PoolConfig,
285    ) {
286        let mut interval = tokio::time::interval(config.health_check_interval);
287        
288        loop {
289            interval.tick().await;
290            
291            debug!("Running connection health check");
292            
293            let mut connections_guard = connections.write().await;
294            let now = Instant::now();
295            
296            for (host, entries) in connections_guard.iter_mut() {
297                entries.retain_mut(|entry| {
298                    // Check if connection is too old
299                    if now.duration_since(entry.last_used) > config.max_idle_time {
300                        debug!("Closing idle connection to {}", host);
301                        let _ = entry.connection.close();
302                        return false;
303                    }
304                    
305                    // Check if connection is still healthy
306                    if !entry.connection.is_connected() {
307                        debug!("Removing unhealthy connection to {}", host);
308                        entry.healthy = false;
309                        return false;
310                    }
311                    
312                    true
313                });
314            }
315            
316            // Remove empty host entries
317            connections_guard.retain(|_, entries| !entries.is_empty());
318        }
319    }
320    
321    /// Get pool statistics
322    pub async fn stats(&self) -> PoolStats {
323        let connections = self.connections.read().await;
324        let mut total_connections = 0;
325        let mut healthy_connections = 0;
326        let mut hosts = 0;
327        
328        for (_, entries) in connections.iter() {
329            hosts += 1;
330            for entry in entries {
331                total_connections += 1;
332                if entry.healthy {
333                    healthy_connections += 1;
334                }
335            }
336        }
337        
338        PoolStats {
339            total_connections,
340            healthy_connections,
341            hosts,
342        }
343    }
344}
345
346impl Clone for ConnectionPool {
347    fn clone(&self) -> Self {
348        Self {
349            config: self.config.clone(),
350            connections: Arc::clone(&self.connections),
351            ssh_configs: Arc::clone(&self.ssh_configs),
352            health_check_handle: None, // Don't clone the handle
353        }
354    }
355}
356
357impl Drop for ConnectionPool {
358    fn drop(&mut self) {
359        if let Some(handle) = self.health_check_handle.take() {
360            handle.abort();
361        }
362    }
363}
364
365/// Pool statistics
366#[derive(Debug, Clone)]
367pub struct PoolStats {
368    /// Total number of connections
369    pub total_connections: usize,
370    /// Number of healthy connections
371    pub healthy_connections: usize,
372    /// Number of hosts
373    pub hosts: usize,
374}
375
376impl PooledConnection {
377    /// Get the connection ID
378    pub fn id(&self) -> Uuid {
379        self.id
380    }
381    
382    /// Get the host key
383    pub fn host_key(&self) -> &str {
384        &self.host_key
385    }
386    
387    /// Get mutable reference to the underlying connection
388    pub fn connection_mut(&mut self) -> Option<&mut Connection> {
389        self.connection.as_mut()
390    }
391    
392    /// Take ownership of the underlying connection
393    pub fn take_connection(&mut self) -> Option<Connection> {
394        self.connection.take()
395    }
396    
397    /// Check if the connection is still active
398    pub fn is_connected(&self) -> bool {
399        self.connection.as_ref().map_or(false, |c| c.is_connected())
400    }
401}
402
403impl Drop for PooledConnection {
404    fn drop(&mut self) {
405        if let Some(connection) = self.connection.take() {
406            let pool = Arc::clone(&self.pool);
407            let host_key = self.host_key.clone();
408            
409            // Return connection to pool in background
410            tokio::spawn(async move {
411                if let Err(e) = pool.return_connection(host_key, connection).await {
412                    warn!("Failed to return connection to pool: {}", e);
413                }
414            });
415        }
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use crate::SshConfig;
423    
424    #[test]
425    fn test_pool_config_default() {
426        let config = PoolConfig::default();
427        assert_eq!(config.max_connections_per_host, 10);
428        assert_eq!(config.max_idle_time, Duration::from_secs(300));
429        assert_eq!(config.connection_timeout, Duration::from_secs(30));
430    }
431    
432    #[tokio::test]
433    async fn test_pool_creation() {
434        let config = PoolConfig::default();
435        let pool = ConnectionPool::new(config);
436        
437        let stats = pool.stats().await;
438        assert_eq!(stats.total_connections, 0);
439        assert_eq!(stats.healthy_connections, 0);
440        assert_eq!(stats.hosts, 0);
441    }
442    
443    #[tokio::test]
444    async fn test_add_host() {
445        let config = PoolConfig::default();
446        let pool = ConnectionPool::new(config);
447        
448        let ssh_config = SshConfig::default();
449        pool.add_host("test.example.com".to_string(), ssh_config).await;
450        
451        // Verify the configuration was added
452        let configs = pool.ssh_configs.read().await;
453        assert!(configs.contains_key("test.example.com"));
454    }
455    
456    #[tokio::test]
457    async fn test_pool_start_stop() {
458        let config = PoolConfig::default();
459        let mut pool = ConnectionPool::new(config);
460        
461        // Start the pool
462        pool.start().await.unwrap();
463        assert!(pool.health_check_handle.is_some());
464        
465        // Stop the pool
466        pool.stop().await.unwrap();
467        assert!(pool.health_check_handle.is_none());
468    }
469    
470    #[tokio::test]
471    async fn test_pooled_connection_properties() {
472        let config = PoolConfig::default();
473        let pool = Arc::new(ConnectionPool::new(config));
474        
475        let pooled_conn = PooledConnection {
476            id: Uuid::new_v4(),
477            host_key: "test.example.com".to_string(),
478            connection: Some(Connection::new(None)),
479            pool,
480        };
481        
482        assert_eq!(pooled_conn.host_key(), "test.example.com");
483        assert!(!pooled_conn.is_connected()); // No actual SSH process
484    }
485    
486    #[test]
487    fn test_pool_stats() {
488        let stats = PoolStats {
489            total_connections: 5,
490            healthy_connections: 4,
491            hosts: 2,
492        };
493        
494        assert_eq!(stats.total_connections, 5);
495        assert_eq!(stats.healthy_connections, 4);
496        assert_eq!(stats.hosts, 2);
497    }
498}