Skip to main content

amaters_sdk_rust/
connection.rs

1//! Connection management and pooling
2
3use crate::config::ClientConfig;
4use crate::error::{Result, SdkError};
5use dashmap::DashMap;
6use parking_lot::RwLock;
7use std::sync::Arc;
8use std::time::Instant;
9use tokio::time::timeout;
10use tonic::transport::{Channel, Endpoint};
11use tracing::{debug, info, warn};
12
13/// Connection wrapper with metadata
14#[derive(Clone)]
15pub struct Connection {
16    channel: Channel,
17    created_at: Instant,
18    last_used: Arc<RwLock<Instant>>,
19}
20
21impl Connection {
22    /// Create a new connection
23    fn new(channel: Channel) -> Self {
24        let now = Instant::now();
25        Self {
26            channel,
27            created_at: now,
28            last_used: Arc::new(RwLock::new(now)),
29        }
30    }
31
32    /// Get the underlying channel
33    pub fn channel(&self) -> &Channel {
34        *self.last_used.write() = Instant::now();
35        &self.channel
36    }
37
38    /// Check if connection is idle for too long
39    fn is_idle(&self, idle_timeout: std::time::Duration) -> bool {
40        self.last_used.read().elapsed() > idle_timeout
41    }
42
43    /// Get age of connection
44    fn age(&self) -> std::time::Duration {
45        self.created_at.elapsed()
46    }
47}
48
49/// Connection pool for managing multiple connections
50pub struct ConnectionPool {
51    config: Arc<ClientConfig>,
52    connections: DashMap<usize, Connection>,
53    next_id: Arc<parking_lot::Mutex<usize>>,
54}
55
56impl ConnectionPool {
57    /// Create a new connection pool
58    pub fn new(config: ClientConfig) -> Self {
59        Self {
60            config: Arc::new(config),
61            connections: DashMap::new(),
62            next_id: Arc::new(parking_lot::Mutex::new(0)),
63        }
64    }
65
66    /// Get a connection from the pool or create a new one
67    pub async fn get(&self) -> Result<Connection> {
68        // Try to find a healthy connection
69        for entry in self.connections.iter() {
70            let conn = entry.value();
71            if !conn.is_idle(self.config.idle_timeout) {
72                debug!("Reusing connection {}", entry.key());
73                return Ok(conn.clone());
74            }
75        }
76
77        // Clean up idle connections
78        self.cleanup_idle();
79
80        // Check if we can create a new connection
81        if self.connections.len() >= self.config.max_connections {
82            // Wait and retry
83            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
84
85            // Try again
86            for entry in self.connections.iter() {
87                let conn = entry.value();
88                if !conn.is_idle(self.config.idle_timeout) {
89                    return Ok(conn.clone());
90                }
91            }
92
93            return Err(SdkError::Connection(
94                "connection pool exhausted".to_string(),
95            ));
96        }
97
98        // Create new connection
99        self.create_connection().await
100    }
101
102    /// Create a new connection
103    async fn create_connection(&self) -> Result<Connection> {
104        info!("Creating new connection to {}", self.config.server_addr);
105
106        let mut endpoint = Endpoint::from_shared(self.config.server_addr.clone())
107            .map_err(|e| SdkError::Configuration(format!("invalid server address: {}", e)))?;
108
109        // Configure timeouts
110        endpoint = endpoint
111            .timeout(self.config.request_timeout)
112            .connect_timeout(self.config.connect_timeout);
113
114        // Configure keep-alive
115        if self.config.keep_alive {
116            endpoint = endpoint
117                .keep_alive_timeout(self.config.keep_alive_timeout)
118                .http2_keep_alive_interval(self.config.keep_alive_interval);
119        }
120
121        // Configure TLS if enabled
122        if self.config.tls_enabled {
123            if let Some(tls_config) = &self.config.tls_config {
124                let mut client_tls = tonic::transport::ClientTlsConfig::new();
125
126                // Set domain name for SNI if provided
127                if let Some(domain) = &tls_config.domain_name {
128                    client_tls = client_tls.domain_name(domain.clone());
129                }
130
131                // Load CA certificate if provided
132                if let Some(ca_path) = &tls_config.ca_cert_path {
133                    let ca_pem = std::fs::read(ca_path).map_err(|e| {
134                        SdkError::Configuration(format!(
135                            "failed to read CA certificate at {}: {}",
136                            ca_path.display(),
137                            e
138                        ))
139                    })?;
140                    let ca_cert = tonic::transport::Certificate::from_pem(ca_pem);
141                    client_tls = client_tls.ca_certificate(ca_cert);
142                }
143
144                // Load client certificate and key for mTLS if provided
145                if let (Some(cert_path), Some(key_path)) =
146                    (&tls_config.client_cert_path, &tls_config.client_key_path)
147                {
148                    let cert_pem = std::fs::read(cert_path).map_err(|e| {
149                        SdkError::Configuration(format!(
150                            "failed to read client certificate at {}: {}",
151                            cert_path.display(),
152                            e
153                        ))
154                    })?;
155                    let key_pem = std::fs::read(key_path).map_err(|e| {
156                        SdkError::Configuration(format!(
157                            "failed to read client key at {}: {}",
158                            key_path.display(),
159                            e
160                        ))
161                    })?;
162                    let identity = tonic::transport::Identity::from_pem(cert_pem, key_pem);
163                    client_tls = client_tls.identity(identity);
164                }
165
166                endpoint = endpoint.tls_config(client_tls).map_err(|e| {
167                    SdkError::Configuration(format!("failed to configure TLS: {}", e))
168                })?;
169                debug!("TLS configured successfully");
170            }
171        }
172
173        // Connect with timeout
174        let channel = timeout(self.config.connect_timeout, endpoint.connect())
175            .await
176            .map_err(|_| {
177                SdkError::Timeout(format!(
178                    "connection timeout after {:?}",
179                    self.config.connect_timeout
180                ))
181            })?
182            .map_err(SdkError::Transport)?;
183
184        let conn = Connection::new(channel);
185
186        // Store in pool
187        let id = {
188            let mut next = self.next_id.lock();
189            let id = *next;
190            *next += 1;
191            id
192        };
193
194        self.connections.insert(id, conn.clone());
195        info!("Connection {} created successfully", id);
196
197        Ok(conn)
198    }
199
200    /// Clean up idle connections
201    fn cleanup_idle(&self) {
202        let mut to_remove = Vec::new();
203
204        for entry in self.connections.iter() {
205            if entry.value().is_idle(self.config.idle_timeout) {
206                to_remove.push(*entry.key());
207            }
208        }
209
210        for id in to_remove {
211            if let Some((_, conn)) = self.connections.remove(&id) {
212                warn!("Removing idle connection {} (age: {:?})", id, conn.age());
213            }
214        }
215    }
216
217    /// Close all connections
218    pub fn close_all(&self) {
219        info!("Closing all connections ({})", self.connections.len());
220        self.connections.clear();
221    }
222
223    /// Get pool statistics
224    pub fn stats(&self) -> PoolStats {
225        let total = self.connections.len();
226        let mut idle = 0;
227
228        for entry in self.connections.iter() {
229            if entry.value().is_idle(self.config.idle_timeout) {
230                idle += 1;
231            }
232        }
233
234        PoolStats {
235            total_connections: total,
236            active_connections: total - idle,
237            idle_connections: idle,
238            max_connections: self.config.max_connections,
239        }
240    }
241}
242
243/// Connection pool statistics
244#[derive(Debug, Clone)]
245pub struct PoolStats {
246    pub total_connections: usize,
247    pub active_connections: usize,
248    pub idle_connections: usize,
249    pub max_connections: usize,
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use std::time::Duration;
256
257    #[test]
258    fn test_connection_idle() {
259        // Create a mock channel (we can't easily test this without a real server)
260        // Just test the idle logic with timing
261        let now = Instant::now();
262        let last_used = Arc::new(RwLock::new(now));
263
264        // Sleep a bit
265        std::thread::sleep(Duration::from_millis(10));
266
267        // Check if idle
268        let elapsed = last_used.read().elapsed();
269        assert!(elapsed >= Duration::from_millis(10));
270    }
271
272    #[test]
273    fn test_pool_stats() {
274        let config = ClientConfig::default();
275        let pool = ConnectionPool::new(config);
276
277        let stats = pool.stats();
278        assert_eq!(stats.total_connections, 0);
279        assert_eq!(stats.active_connections, 0);
280        assert_eq!(stats.max_connections, 10);
281    }
282
283    #[test]
284    fn test_tls_config_construction() {
285        use crate::config::TlsConfig;
286
287        // Test building a config with TLS enabled but no cert files
288        // (cannot actually connect, but verifies config construction)
289        let tls = TlsConfig::new().with_domain_name("example.com");
290
291        let config = ClientConfig::new("https://example.com:50051").with_tls(tls);
292
293        assert!(config.tls_enabled);
294        assert!(config.tls_config.is_some());
295
296        let tls_cfg = config
297            .tls_config
298            .as_ref()
299            .expect("tls_config should be Some");
300        assert_eq!(tls_cfg.domain_name, Some("example.com".to_string()));
301        assert!(tls_cfg.ca_cert_path.is_none());
302        assert!(tls_cfg.client_cert_path.is_none());
303        assert!(tls_cfg.client_key_path.is_none());
304    }
305
306    #[test]
307    fn test_tls_config_with_mtls_paths() {
308        use crate::config::TlsConfig;
309
310        let tls = TlsConfig::new()
311            .with_ca_cert("/path/to/ca.pem")
312            .with_client_cert("/path/to/client.pem", "/path/to/client.key")
313            .with_domain_name("db.example.com");
314
315        let config = ClientConfig::new("https://db.example.com:50051").with_tls(tls);
316
317        assert!(config.tls_enabled);
318        let tls_cfg = config
319            .tls_config
320            .as_ref()
321            .expect("tls_config should be Some");
322        assert_eq!(
323            tls_cfg.ca_cert_path.as_ref().map(|p| p.to_str()),
324            Some(Some("/path/to/ca.pem"))
325        );
326        assert_eq!(
327            tls_cfg.client_cert_path.as_ref().map(|p| p.to_str()),
328            Some(Some("/path/to/client.pem"))
329        );
330        assert_eq!(
331            tls_cfg.client_key_path.as_ref().map(|p| p.to_str()),
332            Some(Some("/path/to/client.key"))
333        );
334    }
335}