amaters_sdk_rust/
connection.rs1use crate::config::ClientConfig;
4use crate::error::{Result, SdkError};
5use dashmap::DashMap;
6use parking_lot::RwLock;
7use std::sync::Arc;
8use std::time::Instant;
9use tokio::time::timeout;
10use tonic::transport::{Channel, Endpoint};
11use tracing::{debug, info, warn};
12
13#[derive(Clone)]
15pub struct Connection {
16 channel: Channel,
17 created_at: Instant,
18 last_used: Arc<RwLock<Instant>>,
19}
20
21impl Connection {
22 fn new(channel: Channel) -> Self {
24 let now = Instant::now();
25 Self {
26 channel,
27 created_at: now,
28 last_used: Arc::new(RwLock::new(now)),
29 }
30 }
31
32 pub fn channel(&self) -> &Channel {
34 *self.last_used.write() = Instant::now();
35 &self.channel
36 }
37
38 fn is_idle(&self, idle_timeout: std::time::Duration) -> bool {
40 self.last_used.read().elapsed() > idle_timeout
41 }
42
43 fn age(&self) -> std::time::Duration {
45 self.created_at.elapsed()
46 }
47}
48
49pub struct ConnectionPool {
51 config: Arc<ClientConfig>,
52 connections: DashMap<usize, Connection>,
53 next_id: Arc<parking_lot::Mutex<usize>>,
54}
55
56impl ConnectionPool {
57 pub fn new(config: ClientConfig) -> Self {
59 Self {
60 config: Arc::new(config),
61 connections: DashMap::new(),
62 next_id: Arc::new(parking_lot::Mutex::new(0)),
63 }
64 }
65
66 pub async fn get(&self) -> Result<Connection> {
68 for entry in self.connections.iter() {
70 let conn = entry.value();
71 if !conn.is_idle(self.config.idle_timeout) {
72 debug!("Reusing connection {}", entry.key());
73 return Ok(conn.clone());
74 }
75 }
76
77 self.cleanup_idle();
79
80 if self.connections.len() >= self.config.max_connections {
82 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
84
85 for entry in self.connections.iter() {
87 let conn = entry.value();
88 if !conn.is_idle(self.config.idle_timeout) {
89 return Ok(conn.clone());
90 }
91 }
92
93 return Err(SdkError::Connection(
94 "connection pool exhausted".to_string(),
95 ));
96 }
97
98 self.create_connection().await
100 }
101
102 async fn create_connection(&self) -> Result<Connection> {
104 info!("Creating new connection to {}", self.config.server_addr);
105
106 let mut endpoint = Endpoint::from_shared(self.config.server_addr.clone())
107 .map_err(|e| SdkError::Configuration(format!("invalid server address: {}", e)))?;
108
109 endpoint = endpoint
111 .timeout(self.config.request_timeout)
112 .connect_timeout(self.config.connect_timeout);
113
114 if self.config.keep_alive {
116 endpoint = endpoint
117 .keep_alive_timeout(self.config.keep_alive_timeout)
118 .http2_keep_alive_interval(self.config.keep_alive_interval);
119 }
120
121 if self.config.tls_enabled {
123 if let Some(tls_config) = &self.config.tls_config {
124 let mut client_tls = tonic::transport::ClientTlsConfig::new();
125
126 if let Some(domain) = &tls_config.domain_name {
128 client_tls = client_tls.domain_name(domain.clone());
129 }
130
131 if let Some(ca_path) = &tls_config.ca_cert_path {
133 let ca_pem = std::fs::read(ca_path).map_err(|e| {
134 SdkError::Configuration(format!(
135 "failed to read CA certificate at {}: {}",
136 ca_path.display(),
137 e
138 ))
139 })?;
140 let ca_cert = tonic::transport::Certificate::from_pem(ca_pem);
141 client_tls = client_tls.ca_certificate(ca_cert);
142 }
143
144 if let (Some(cert_path), Some(key_path)) =
146 (&tls_config.client_cert_path, &tls_config.client_key_path)
147 {
148 let cert_pem = std::fs::read(cert_path).map_err(|e| {
149 SdkError::Configuration(format!(
150 "failed to read client certificate at {}: {}",
151 cert_path.display(),
152 e
153 ))
154 })?;
155 let key_pem = std::fs::read(key_path).map_err(|e| {
156 SdkError::Configuration(format!(
157 "failed to read client key at {}: {}",
158 key_path.display(),
159 e
160 ))
161 })?;
162 let identity = tonic::transport::Identity::from_pem(cert_pem, key_pem);
163 client_tls = client_tls.identity(identity);
164 }
165
166 endpoint = endpoint.tls_config(client_tls).map_err(|e| {
167 SdkError::Configuration(format!("failed to configure TLS: {}", e))
168 })?;
169 debug!("TLS configured successfully");
170 }
171 }
172
173 let channel = timeout(self.config.connect_timeout, endpoint.connect())
175 .await
176 .map_err(|_| {
177 SdkError::Timeout(format!(
178 "connection timeout after {:?}",
179 self.config.connect_timeout
180 ))
181 })?
182 .map_err(SdkError::Transport)?;
183
184 let conn = Connection::new(channel);
185
186 let id = {
188 let mut next = self.next_id.lock();
189 let id = *next;
190 *next += 1;
191 id
192 };
193
194 self.connections.insert(id, conn.clone());
195 info!("Connection {} created successfully", id);
196
197 Ok(conn)
198 }
199
200 fn cleanup_idle(&self) {
202 let mut to_remove = Vec::new();
203
204 for entry in self.connections.iter() {
205 if entry.value().is_idle(self.config.idle_timeout) {
206 to_remove.push(*entry.key());
207 }
208 }
209
210 for id in to_remove {
211 if let Some((_, conn)) = self.connections.remove(&id) {
212 warn!("Removing idle connection {} (age: {:?})", id, conn.age());
213 }
214 }
215 }
216
217 pub fn close_all(&self) {
219 info!("Closing all connections ({})", self.connections.len());
220 self.connections.clear();
221 }
222
223 pub fn stats(&self) -> PoolStats {
225 let total = self.connections.len();
226 let mut idle = 0;
227
228 for entry in self.connections.iter() {
229 if entry.value().is_idle(self.config.idle_timeout) {
230 idle += 1;
231 }
232 }
233
234 PoolStats {
235 total_connections: total,
236 active_connections: total - idle,
237 idle_connections: idle,
238 max_connections: self.config.max_connections,
239 }
240 }
241}
242
243#[derive(Debug, Clone)]
245pub struct PoolStats {
246 pub total_connections: usize,
247 pub active_connections: usize,
248 pub idle_connections: usize,
249 pub max_connections: usize,
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use std::time::Duration;
256
257 #[test]
258 fn test_connection_idle() {
259 let now = Instant::now();
262 let last_used = Arc::new(RwLock::new(now));
263
264 std::thread::sleep(Duration::from_millis(10));
266
267 let elapsed = last_used.read().elapsed();
269 assert!(elapsed >= Duration::from_millis(10));
270 }
271
272 #[test]
273 fn test_pool_stats() {
274 let config = ClientConfig::default();
275 let pool = ConnectionPool::new(config);
276
277 let stats = pool.stats();
278 assert_eq!(stats.total_connections, 0);
279 assert_eq!(stats.active_connections, 0);
280 assert_eq!(stats.max_connections, 10);
281 }
282
283 #[test]
284 fn test_tls_config_construction() {
285 use crate::config::TlsConfig;
286
287 let tls = TlsConfig::new().with_domain_name("example.com");
290
291 let config = ClientConfig::new("https://example.com:50051").with_tls(tls);
292
293 assert!(config.tls_enabled);
294 assert!(config.tls_config.is_some());
295
296 let tls_cfg = config
297 .tls_config
298 .as_ref()
299 .expect("tls_config should be Some");
300 assert_eq!(tls_cfg.domain_name, Some("example.com".to_string()));
301 assert!(tls_cfg.ca_cert_path.is_none());
302 assert!(tls_cfg.client_cert_path.is_none());
303 assert!(tls_cfg.client_key_path.is_none());
304 }
305
306 #[test]
307 fn test_tls_config_with_mtls_paths() {
308 use crate::config::TlsConfig;
309
310 let tls = TlsConfig::new()
311 .with_ca_cert("/path/to/ca.pem")
312 .with_client_cert("/path/to/client.pem", "/path/to/client.key")
313 .with_domain_name("db.example.com");
314
315 let config = ClientConfig::new("https://db.example.com:50051").with_tls(tls);
316
317 assert!(config.tls_enabled);
318 let tls_cfg = config
319 .tls_config
320 .as_ref()
321 .expect("tls_config should be Some");
322 assert_eq!(
323 tls_cfg.ca_cert_path.as_ref().map(|p| p.to_str()),
324 Some(Some("/path/to/ca.pem"))
325 );
326 assert_eq!(
327 tls_cfg.client_cert_path.as_ref().map(|p| p.to_str()),
328 Some(Some("/path/to/client.pem"))
329 );
330 assert_eq!(
331 tls_cfg.client_key_path.as_ref().map(|p| p.to_str()),
332 Some(Some("/path/to/client.key"))
333 );
334 }
335}