1use super::connection_pool::ConnectionPool;
11use super::connection_trait::ConnectionProvider;
12use super::deadpool_connection::{Pool, TcpManager};
13use super::health_check::{HealthCheckMetrics, check_date_response};
14use crate::pool::PoolStatus;
15use crate::tls::TlsConfig;
16use anyhow::Result;
17use async_trait::async_trait;
18use deadpool::managed;
19use std::sync::Arc;
20use tokio::sync::{Mutex, broadcast};
21use tracing::{debug, info, warn};
22
23#[derive(Debug, Clone)]
25pub struct DeadpoolConnectionProvider {
26 pool: Pool,
27 name: String,
28 #[allow(dead_code)]
30 keepalive_interval: Option<std::time::Duration>,
31 #[allow(dead_code)]
34 shutdown_tx: Option<broadcast::Sender<()>>,
35 pub health_check_metrics: Arc<Mutex<HealthCheckMetrics>>,
37}
38
39pub struct DeadpoolConnectionProviderBuilder {
72 host: String,
73 port: u16,
74 name: Option<String>,
75 max_size: usize,
76 username: Option<String>,
77 password: Option<String>,
78 tls_config: Option<TlsConfig>,
79}
80
81impl DeadpoolConnectionProviderBuilder {
82 #[must_use]
88 pub fn new(host: impl Into<String>, port: u16) -> Self {
89 Self {
90 host: host.into(),
91 port,
92 name: None,
93 max_size: 10, username: None,
95 password: None,
96 tls_config: None,
97 }
98 }
99
100 #[must_use]
102 pub fn name(mut self, name: impl Into<String>) -> Self {
103 self.name = Some(name.into());
104 self
105 }
106
107 #[must_use]
109 pub fn max_connections(mut self, max_size: usize) -> Self {
110 self.max_size = max_size;
111 self
112 }
113
114 #[must_use]
116 pub fn username(mut self, username: impl Into<String>) -> Self {
117 self.username = Some(username.into());
118 self
119 }
120
121 #[must_use]
123 pub fn password(mut self, password: impl Into<String>) -> Self {
124 self.password = Some(password.into());
125 self
126 }
127
128 #[must_use]
130 pub fn tls_config(mut self, config: TlsConfig) -> Self {
131 self.tls_config = Some(config);
132 self
133 }
134
135 pub fn build(self) -> Result<DeadpoolConnectionProvider> {
141 let name = self
142 .name
143 .unwrap_or_else(|| format!("{}:{}", self.host, self.port));
144
145 if let Some(tls_config) = self.tls_config {
146 let manager = TcpManager::new_with_tls(
148 self.host,
149 self.port,
150 name.clone(),
151 self.username,
152 self.password,
153 tls_config,
154 )?;
155 let pool = Pool::builder(manager)
156 .max_size(self.max_size)
157 .build()
158 .expect("Failed to create connection pool");
159
160 Ok(DeadpoolConnectionProvider {
161 pool,
162 name,
163 keepalive_interval: None,
164 shutdown_tx: None,
165 health_check_metrics: Arc::new(Mutex::new(HealthCheckMetrics::new())),
166 })
167 } else {
168 let manager = TcpManager::new(
170 self.host,
171 self.port,
172 name.clone(),
173 self.username,
174 self.password,
175 );
176 let pool = Pool::builder(manager)
177 .max_size(self.max_size)
178 .build()
179 .expect("Failed to create connection pool");
180
181 Ok(DeadpoolConnectionProvider {
182 pool,
183 name,
184 keepalive_interval: None,
185 shutdown_tx: None,
186 health_check_metrics: Arc::new(Mutex::new(HealthCheckMetrics::new())),
187 })
188 }
189 }
190}
191
192impl DeadpoolConnectionProvider {
193 #[must_use]
207 pub fn builder(host: impl Into<String>, port: u16) -> DeadpoolConnectionProviderBuilder {
208 DeadpoolConnectionProviderBuilder::new(host, port)
209 }
210 pub fn new(
212 host: String,
213 port: u16,
214 name: String,
215 max_size: usize,
216 username: Option<String>,
217 password: Option<String>,
218 ) -> Self {
219 let manager = TcpManager::new(host, port, name.clone(), username, password);
220 let pool = Pool::builder(manager)
221 .max_size(max_size)
222 .build()
223 .expect("Failed to create connection pool");
224
225 Self {
226 pool,
227 name,
228 keepalive_interval: None,
229 shutdown_tx: None,
230 health_check_metrics: Arc::new(Mutex::new(HealthCheckMetrics::new())),
231 }
232 }
233
234 pub fn new_with_tls(
236 host: String,
237 port: u16,
238 name: String,
239 max_size: usize,
240 username: Option<String>,
241 password: Option<String>,
242 tls_config: TlsConfig,
243 ) -> Result<Self> {
244 let manager =
245 TcpManager::new_with_tls(host, port, name.clone(), username, password, tls_config)?;
246 let pool = Pool::builder(manager)
247 .max_size(max_size)
248 .build()
249 .expect("Failed to create connection pool");
250
251 Ok(Self {
252 pool,
253 name,
254 keepalive_interval: None,
255 shutdown_tx: None,
256 health_check_metrics: Arc::new(Mutex::new(HealthCheckMetrics::new())),
257 })
258 }
259
260 pub fn from_server_config(server: &crate::config::ServerConfig) -> Result<Self> {
264 let tls_builder = TlsConfig::builder()
265 .enabled(server.use_tls)
266 .verify_cert(server.tls_verify_cert);
267
268 let tls_builder = server
270 .tls_cert_path
271 .as_ref()
272 .map(|cert_path| tls_builder.clone().cert_path(cert_path.as_str()))
273 .unwrap_or(tls_builder);
274
275 let tls_config = tls_builder.build();
276
277 let manager = TcpManager::new_with_tls(
278 server.host.as_str().to_string(),
279 server.port.get(),
280 server.name.as_str().to_string(),
281 server.username.clone(),
282 server.password.clone(),
283 tls_config,
284 )?;
285 let pool = Pool::builder(manager)
286 .max_size(server.max_connections.get())
287 .build()
288 .expect("Failed to create connection pool");
289
290 let keepalive_interval = server.connection_keepalive;
291
292 let metrics = Arc::new(Mutex::new(HealthCheckMetrics::new()));
294 let shutdown_tx = if let Some(interval) = keepalive_interval {
295 let (tx, rx) = broadcast::channel(1);
296
297 let pool_clone = pool.clone();
299 let name_clone = server.name.as_str().to_string();
300 let metrics_clone = metrics.clone();
301 tokio::spawn(async move {
302 Self::run_periodic_health_checks(
303 pool_clone,
304 name_clone,
305 interval,
306 rx,
307 metrics_clone,
308 )
309 .await;
310 });
311
312 Some(tx)
313 } else {
314 None
315 };
316
317 Ok(Self {
318 pool,
319 name: server.name.as_str().to_string(),
320 keepalive_interval,
321 shutdown_tx,
322 health_check_metrics: metrics,
323 })
324 }
325
326 pub async fn get_pooled_connection(&self) -> Result<managed::Object<TcpManager>> {
328 self.pool
329 .get()
330 .await
331 .map_err(|e| anyhow::anyhow!("Failed to get connection from {}: {}", self.name, e))
332 }
333
334 #[must_use]
336 #[inline]
337 pub fn max_size(&self) -> usize {
338 self.pool.status().max_size
339 }
340
341 pub async fn get_health_check_metrics(&self) -> HealthCheckMetrics {
343 self.health_check_metrics.lock().await.clone()
344 }
345
346 pub fn shutdown(&self) {
351 if let Some(tx) = &self.shutdown_tx {
352 let _ = tx.send(());
353 }
354 }
355
356 async fn run_periodic_health_checks(
362 pool: Pool,
363 name: String,
364 interval: std::time::Duration,
365 mut shutdown_rx: broadcast::Receiver<()>,
366 metrics: Arc<Mutex<HealthCheckMetrics>>,
367 ) {
368 use crate::constants::pool::{
369 HEALTH_CHECK_POOL_TIMEOUT_MS, MAX_CONNECTIONS_PER_HEALTH_CHECK_CYCLE,
370 };
371 use tokio::time::{Duration, sleep};
372
373 info!(
374 pool = %name,
375 interval_secs = interval.as_secs(),
376 "Starting periodic health checks"
377 );
378
379 loop {
380 tokio::select! {
381 _ = sleep(interval) => {
382 }
384 _ = shutdown_rx.recv() => {
385 info!(pool = %name, "Shutting down periodic health check task");
386 break;
387 }
388 }
389
390 let status = pool.status();
391 if status.available == 0 {
392 continue;
393 }
394
395 debug!(
396 pool = %name,
397 available = status.available,
398 max_check = MAX_CONNECTIONS_PER_HEALTH_CHECK_CYCLE,
399 "Running health check cycle"
400 );
401
402 let check_count =
404 std::cmp::min(status.available, MAX_CONNECTIONS_PER_HEALTH_CHECK_CYCLE);
405 let mut checked = 0;
406 let mut failed = 0;
407
408 let mut timeouts = managed::Timeouts::new();
409 timeouts.wait = Some(Duration::from_millis(HEALTH_CHECK_POOL_TIMEOUT_MS));
410
411 for _ in 0..check_count {
412 if let Ok(mut conn_obj) = pool.timeout_get(&timeouts).await {
413 checked += 1;
414
415 if let Err(e) = check_date_response(&mut conn_obj).await {
417 failed += 1;
418 warn!(
419 pool = %name,
420 error = %e,
421 "Health check failed, discarding connection"
422 );
423 drop(managed::Object::take(conn_obj));
425 } else {
426 drop(conn_obj);
428 }
429 } else {
430 break;
431 }
432 }
433
434 if checked > 0 {
435 {
437 let mut m = metrics.lock().await;
438 m.record_cycle(checked, failed);
439 }
440
441 debug!(
442 pool = %name,
443 checked = checked,
444 failed = failed,
445 "Health check cycle complete"
446 );
447 }
448 }
449
450 info!(pool = %name, "Periodic health check task terminated");
451 }
452
453 pub async fn graceful_shutdown(&self) {
455 use deadpool::managed::Object;
456 use tokio::io::AsyncWriteExt;
457
458 let status = self.pool.status();
459 info!(
460 "Shutting down pool '{}' ({} idle connections)",
461 self.name, status.available
462 );
463
464 let mut timeouts = managed::Timeouts::new();
466 timeouts.wait = Some(std::time::Duration::from_millis(1));
467
468 for _ in 0..status.available {
469 if let Ok(conn_obj) = self.pool.timeout_get(&timeouts).await {
470 let mut conn = Object::take(conn_obj);
471 let _ = conn.write_all(b"QUIT\r\n").await;
472 } else {
473 break;
474 }
475 }
476
477 self.pool.close();
478 }
479}
480
481#[async_trait]
482impl ConnectionProvider for DeadpoolConnectionProvider {
483 fn status(&self) -> PoolStatus {
484 let status = self.pool.status();
485 PoolStatus {
486 available: status.available,
487 max_size: status.max_size,
488 created: status.size,
489 }
490 }
491}
492
493#[async_trait]
494impl ConnectionPool for DeadpoolConnectionProvider {
495 async fn get(&self) -> Result<crate::stream::ConnectionStream> {
496 let conn = self.get_pooled_connection().await?;
497
498 let stream = deadpool::managed::Object::take(conn);
508 Ok(stream)
509 }
510 fn name(&self) -> &str {
511 &self.name
512 }
513
514 fn status(&self) -> PoolStatus {
515 let status = self.pool.status();
516 PoolStatus {
517 available: status.available,
518 max_size: status.max_size,
519 created: status.size,
520 }
521 }
522
523 fn host(&self) -> &str {
524 &self.pool.manager().host
525 }
526
527 fn port(&self) -> u16 {
528 self.pool.manager().port
529 }
530}