1use super::connection_trait::ConnectionProvider;
11use super::deadpool_connection::{Pool, TcpManager};
12use super::health_check::{HealthCheckMetrics, check_date_response};
13use crate::pool::PoolStatus;
14use crate::tls::TlsConfig;
15use anyhow::Result;
16use async_trait::async_trait;
17use deadpool::managed;
18use std::sync::Arc;
19use tokio::sync::broadcast;
20use tracing::{debug, info, warn};
21
22#[derive(Debug, Clone)]
24pub struct DeadpoolConnectionProvider {
25 pool: Pool,
26 name: String,
27 shutdown_tx: Option<broadcast::Sender<()>>,
31 pub health_check_metrics: Arc<HealthCheckMetrics>,
33}
34
35pub struct Builder {
68 host: String,
69 port: u16,
70 name: Option<String>,
71 max_size: usize,
72 username: Option<String>,
73 password: Option<String>,
74 tls_config: Option<TlsConfig>,
75}
76
77impl Builder {
78 #[must_use]
84 pub fn new(host: impl Into<String>, port: u16) -> Self {
85 Self {
86 host: host.into(),
87 port,
88 name: None,
89 max_size: 10, username: None,
91 password: None,
92 tls_config: None,
93 }
94 }
95
96 #[must_use]
98 pub fn name(mut self, name: impl Into<String>) -> Self {
99 self.name = Some(name.into());
100 self
101 }
102
103 #[must_use]
105 pub fn max_connections(mut self, max_size: usize) -> Self {
106 self.max_size = max_size;
107 self
108 }
109
110 #[must_use]
112 pub fn username(mut self, username: impl Into<String>) -> Self {
113 self.username = Some(username.into());
114 self
115 }
116
117 #[must_use]
119 pub fn password(mut self, password: impl Into<String>) -> Self {
120 self.password = Some(password.into());
121 self
122 }
123
124 #[must_use]
126 pub fn tls_config(mut self, config: TlsConfig) -> Self {
127 self.tls_config = Some(config);
128 self
129 }
130
131 pub fn build(self) -> Result<DeadpoolConnectionProvider> {
137 let name = self
138 .name
139 .unwrap_or_else(|| format!("{}:{}", self.host, self.port));
140
141 if let Some(tls_config) = self.tls_config {
142 let manager = TcpManager::new_with_tls(
144 self.host,
145 self.port,
146 name.clone(),
147 self.username,
148 self.password,
149 tls_config,
150 )?;
151 let pool = Pool::builder(manager)
152 .max_size(self.max_size)
153 .build()
154 .expect("Failed to create connection pool");
155
156 Ok(DeadpoolConnectionProvider {
157 pool,
158 name,
159 shutdown_tx: None,
160 health_check_metrics: Arc::new(HealthCheckMetrics::new()),
161 })
162 } else {
163 let manager = TcpManager::new(
165 self.host,
166 self.port,
167 name.clone(),
168 self.username,
169 self.password,
170 );
171 let pool = Pool::builder(manager)
172 .max_size(self.max_size)
173 .build()
174 .expect("Failed to create connection pool");
175
176 Ok(DeadpoolConnectionProvider {
177 pool,
178 name,
179 shutdown_tx: None,
180 health_check_metrics: Arc::new(HealthCheckMetrics::new()),
181 })
182 }
183 }
184}
185
186impl DeadpoolConnectionProvider {
187 #[must_use]
201 pub fn builder(host: impl Into<String>, port: u16) -> Builder {
202 Builder::new(host, port)
203 }
204
205 pub fn simple(host: impl Into<String>, port: u16) -> Result<Self> {
218 Self::builder(host, port).build()
219 }
220
221 pub fn with_auth(
239 host: impl Into<String>,
240 port: u16,
241 username: impl Into<String>,
242 password: impl Into<String>,
243 ) -> Result<Self> {
244 Self::builder(host, port)
245 .username(username)
246 .password(password)
247 .build()
248 }
249
250 pub fn with_tls(host: impl Into<String>, port: u16) -> Result<Self> {
264 Self::builder(host, port)
265 .tls_config(TlsConfig::default())
266 .build()
267 }
268
269 pub fn with_tls_auth(
288 host: impl Into<String>,
289 port: u16,
290 username: impl Into<String>,
291 password: impl Into<String>,
292 ) -> Result<Self> {
293 Self::builder(host, port)
294 .username(username)
295 .password(password)
296 .tls_config(TlsConfig::default())
297 .build()
298 }
299
300 pub fn new(
302 host: String,
303 port: u16,
304 name: String,
305 max_size: usize,
306 username: Option<String>,
307 password: Option<String>,
308 ) -> Self {
309 let manager = TcpManager::new(host, port, name.clone(), username, password);
310 let pool = Pool::builder(manager)
311 .max_size(max_size)
312 .build()
313 .expect("Failed to create connection pool");
314
315 Self {
316 pool,
317 name,
318 shutdown_tx: None,
319 health_check_metrics: Arc::new(HealthCheckMetrics::new()),
320 }
321 }
322
323 pub fn new_with_tls(
325 host: String,
326 port: u16,
327 name: String,
328 max_size: usize,
329 username: Option<String>,
330 password: Option<String>,
331 tls_config: TlsConfig,
332 ) -> Result<Self> {
333 let manager =
334 TcpManager::new_with_tls(host, port, name.clone(), username, password, tls_config)?;
335 let pool = Pool::builder(manager)
336 .max_size(max_size)
337 .build()
338 .expect("Failed to create connection pool");
339
340 Ok(Self {
341 pool,
342 name,
343 shutdown_tx: None,
344 health_check_metrics: Arc::new(HealthCheckMetrics::new()),
345 })
346 }
347
348 pub fn from_server_config(server: &crate::config::Server) -> Result<Self> {
352 let tls_builder = TlsConfig::builder()
353 .enabled(server.use_tls)
354 .verify_cert(server.tls_verify_cert);
355
356 let tls_builder = server
358 .tls_cert_path
359 .as_ref()
360 .map(|cert_path| tls_builder.clone().cert_path(cert_path.as_str()))
361 .unwrap_or(tls_builder);
362
363 let tls_config = tls_builder.build();
364
365 let manager = TcpManager::new_with_tls(
366 server.host.to_string(),
367 server.port.get(),
368 server.name.to_string(),
369 server.username.clone(),
370 server.password.clone(),
371 tls_config,
372 )?;
373 let pool = Pool::builder(manager)
374 .max_size(server.max_connections.get())
375 .build()
376 .expect("Failed to create connection pool");
377
378 let keepalive_interval = server.connection_keepalive;
379
380 let metrics = Arc::new(HealthCheckMetrics::new());
382 let shutdown_tx = if let Some(interval) = keepalive_interval {
383 let (tx, rx) = broadcast::channel(1);
384
385 let pool_clone = pool.clone();
387 let name_clone = server.name.to_string();
388 let metrics_clone = metrics.clone();
389 tokio::spawn(async move {
390 Self::run_periodic_health_checks(
391 pool_clone,
392 name_clone,
393 interval,
394 rx,
395 metrics_clone,
396 )
397 .await;
398 });
399
400 Some(tx)
401 } else {
402 None
403 };
404
405 Ok(Self {
406 pool,
407 name: server.name.to_string(),
408 shutdown_tx,
409 health_check_metrics: metrics,
410 })
411 }
412
413 pub async fn get_pooled_connection(&self) -> Result<managed::Object<TcpManager>> {
415 self.pool
416 .get()
417 .await
418 .map_err(|e| anyhow::anyhow!("Failed to get connection from {}: {}", self.name, e))
419 }
420
421 pub fn clear_idle_connections(&self) {
429 let max_size = self.pool.status().max_size;
430 let available = self.pool.status().available;
431
432 if available > 0 {
433 debug!(
434 pool = %self.name,
435 available = available,
436 "Clearing idle connections from pool"
437 );
438
439 self.pool.resize(0);
441 self.pool.resize(max_size);
443 }
444 }
445
446 #[must_use]
448 #[inline]
449 pub fn max_size(&self) -> usize {
450 self.pool.status().max_size
451 }
452
453 #[must_use]
455 #[inline]
456 pub fn name(&self) -> &str {
457 &self.name
458 }
459
460 #[must_use]
462 #[inline]
463 pub fn host(&self) -> &str {
464 &self.pool.manager().host
465 }
466
467 #[must_use]
469 #[inline]
470 pub fn port(&self) -> u16 {
471 self.pool.manager().port
472 }
473
474 pub fn health_check_metrics(&self) -> &HealthCheckMetrics {
476 &self.health_check_metrics
477 }
478
479 pub fn shutdown(&self) {
484 if let Some(tx) = &self.shutdown_tx {
485 let _ = tx.send(());
486 }
487 }
488
489 async fn run_periodic_health_checks(
495 pool: Pool,
496 name: String,
497 interval: std::time::Duration,
498 mut shutdown_rx: broadcast::Receiver<()>,
499 metrics: Arc<HealthCheckMetrics>,
500 ) {
501 use crate::constants::pool::{
502 HEALTH_CHECK_POOL_TIMEOUT_MS, MAX_CONNECTIONS_PER_HEALTH_CHECK_CYCLE,
503 };
504 use tokio::time::{Duration, sleep};
505
506 info!(
507 pool = %name,
508 interval_secs = interval.as_secs(),
509 "Starting periodic health checks"
510 );
511
512 loop {
513 tokio::select! {
514 _ = sleep(interval) => {
515 }
517 _ = shutdown_rx.recv() => {
518 info!(pool = %name, "Shutting down periodic health check task");
519 break;
520 }
521 }
522
523 let status = pool.status();
524 if status.available == 0 {
525 continue;
526 }
527
528 debug!(
529 pool = %name,
530 available = status.available,
531 max_check = MAX_CONNECTIONS_PER_HEALTH_CHECK_CYCLE,
532 "Running health check cycle"
533 );
534
535 let check_count =
537 std::cmp::min(status.available, MAX_CONNECTIONS_PER_HEALTH_CHECK_CYCLE);
538 let mut checked = 0;
539 let mut failed = 0;
540
541 let mut timeouts = managed::Timeouts::new();
542 timeouts.wait = Some(Duration::from_millis(HEALTH_CHECK_POOL_TIMEOUT_MS));
543
544 for _ in 0..check_count {
545 if let Ok(mut conn_obj) = pool.timeout_get(&timeouts).await {
546 checked += 1;
547
548 if let Err(e) = check_date_response(&mut conn_obj).await {
550 failed += 1;
551 warn!(
552 pool = %name,
553 error = %e,
554 "Health check failed, discarding connection"
555 );
556 drop(managed::Object::take(conn_obj));
558 } else {
559 drop(conn_obj);
561 }
562 } else {
563 break;
564 }
565 }
566
567 if checked > 0 {
568 metrics.record_cycle(checked, failed);
570
571 debug!(
572 pool = %name,
573 checked = checked,
574 failed = failed,
575 "Health check cycle complete"
576 );
577 }
578 }
579
580 info!(pool = %name, "Periodic health check task terminated");
581 }
582
583 pub async fn graceful_shutdown(&self) {
585 use deadpool::managed::Object;
586 use tokio::io::AsyncWriteExt;
587
588 let status = self.pool.status();
589 info!(
590 "Shutting down pool '{}' ({} idle connections)",
591 self.name, status.available
592 );
593
594 let mut timeouts = managed::Timeouts::new();
596 timeouts.wait = Some(std::time::Duration::from_millis(1));
597
598 for _ in 0..status.available {
599 if let Ok(conn_obj) = self.pool.timeout_get(&timeouts).await {
600 let mut conn = Object::take(conn_obj);
601 let _ = conn.write_all(b"QUIT\r\n").await;
602 } else {
603 break;
604 }
605 }
606
607 self.pool.close();
608 }
609}
610
611#[async_trait]
612impl ConnectionProvider for DeadpoolConnectionProvider {
613 fn status(&self) -> PoolStatus {
614 use crate::types::{AvailableConnections, CreatedConnections, MaxPoolSize};
615 let status = self.pool.status();
616 PoolStatus {
617 available: AvailableConnections::new(status.available),
618 max_size: MaxPoolSize::new(status.max_size),
619 created: CreatedConnections::new(status.size),
620 }
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[test]
629 fn test_builder_new() {
630 let builder = Builder::new("news.example.com", 119);
631 assert_eq!(builder.host, "news.example.com");
632 assert_eq!(builder.port, 119);
633 assert_eq!(builder.max_size, 10); assert!(builder.name.is_none());
635 assert!(builder.username.is_none());
636 assert!(builder.password.is_none());
637 assert!(builder.tls_config.is_none());
638 }
639
640 #[test]
641 fn test_builder_with_name() {
642 let builder = Builder::new("example.com", 119).name("Test Server");
643 assert_eq!(builder.name, Some("Test Server".to_string()));
644 }
645
646 #[test]
647 fn test_builder_with_max_connections() {
648 let builder = Builder::new("example.com", 119).max_connections(25);
649 assert_eq!(builder.max_size, 25);
650 }
651
652 #[test]
653 fn test_builder_with_username() {
654 let builder = Builder::new("example.com", 119).username("testuser");
655 assert_eq!(builder.username, Some("testuser".to_string()));
656 }
657
658 #[test]
659 fn test_builder_with_password() {
660 let builder = Builder::new("example.com", 119).password("testpass");
661 assert_eq!(builder.password, Some("testpass".to_string()));
662 }
663
664 #[test]
665 fn test_builder_with_tls_config() {
666 let tls_config = TlsConfig::builder().enabled(true).build();
667 let builder = Builder::new("example.com", 563).tls_config(tls_config.clone());
668 assert!(builder.tls_config.is_some());
669 }
670
671 #[test]
672 fn test_builder_chaining() {
673 let builder = Builder::new("news.example.com", 119)
674 .name("Chained Server")
675 .max_connections(30)
676 .username("user")
677 .password("pass");
678
679 assert_eq!(builder.name, Some("Chained Server".to_string()));
680 assert_eq!(builder.max_size, 30);
681 assert_eq!(builder.username, Some("user".to_string()));
682 assert_eq!(builder.password, Some("pass".to_string()));
683 }
684
685 #[test]
686 fn test_builder_default_name_from_host_port() {
687 let provider = Builder::new("test.example.com", 8119)
688 .max_connections(5)
689 .build()
690 .unwrap();
691
692 assert_eq!(provider.name(), "test.example.com:8119");
694 }
695
696 #[test]
697 fn test_builder_custom_name_used() {
698 let provider = Builder::new("test.example.com", 8119)
699 .name("Custom Name")
700 .build()
701 .unwrap();
702
703 assert_eq!(provider.name(), "Custom Name");
704 }
705
706 #[test]
707 fn test_provider_builder_method() {
708 let builder = DeadpoolConnectionProvider::builder("example.com", 119);
709 assert_eq!(builder.host, "example.com");
710 assert_eq!(builder.port, 119);
711 }
712
713 #[test]
714 fn test_provider_status_conversion() {
715 let provider = DeadpoolConnectionProvider::builder("localhost", 119)
716 .max_connections(15)
717 .build()
718 .unwrap();
719
720 let status = ConnectionProvider::status(&provider);
721 assert_eq!(status.max_size.get(), 15);
722 assert_eq!(status.created.get(), 0);
724 }
725
726 #[test]
727 fn test_provider_inherent_methods() {
728 let provider = DeadpoolConnectionProvider::builder("localhost", 119)
729 .build()
730 .unwrap();
731
732 assert_eq!(provider.name(), "localhost:119");
734 assert_eq!(provider.host(), "localhost");
735 assert_eq!(provider.port(), 119);
736 }
737
738 #[test]
739 fn test_provider_with_all_builder_options() {
740 let tls_config = TlsConfig::builder().enabled(false).build();
741
742 let provider = DeadpoolConnectionProvider::builder("news.test.com", 563)
743 .name("Full Test")
744 .max_connections(42)
745 .username("testuser")
746 .password("testpass")
747 .tls_config(tls_config)
748 .build()
749 .unwrap();
750
751 assert_eq!(provider.name(), "Full Test");
752 assert_eq!(provider.host(), "news.test.com");
753 assert_eq!(provider.port(), 563);
754
755 let status = ConnectionProvider::status(&provider);
756 assert_eq!(status.max_size.get(), 42);
757 }
758
759 #[test]
760 fn test_health_check_metrics_initialization() {
761 let provider = DeadpoolConnectionProvider::builder("localhost", 119)
762 .build()
763 .unwrap();
764
765 let metrics = &provider.health_check_metrics;
766 assert_eq!(metrics.cycles_run(), 0);
767 assert_eq!(metrics.connections_checked(), 0);
768 assert_eq!(metrics.connections_failed(), 0);
769 assert_eq!(metrics.failure_rate(), 0.0);
770 }
771
772 #[test]
773 fn test_builder_accepts_string_types() {
774 let _ = Builder::new("example.com", 119);
776
777 let _ = Builder::new(String::from("example.com"), 119);
779
780 let _ = Builder::new("example.com", 119).name("test");
782
783 let _ = Builder::new("example.com", 119).name(String::from("test"));
785 }
786
787 #[test]
788 fn test_builder_zero_max_connections() {
789 let provider = Builder::new("localhost", 119)
791 .max_connections(0)
792 .build()
793 .unwrap();
794
795 let status = ConnectionProvider::status(&provider);
796 assert_eq!(status.max_size.get(), 0);
797 }
798
799 #[test]
800 fn test_builder_large_max_connections() {
801 let provider = Builder::new("localhost", 119)
802 .max_connections(1000)
803 .build()
804 .unwrap();
805
806 let status = ConnectionProvider::status(&provider);
807 assert_eq!(status.max_size.get(), 1000);
808 }
809
810 #[test]
811 fn test_provider_name_special_characters() {
812 let provider = Builder::new("example.com", 119)
813 .name("Server-123_Test.Name")
814 .build()
815 .unwrap();
816
817 assert_eq!(provider.name(), "Server-123_Test.Name");
818 }
819
820 #[test]
821 fn test_provider_name_unicode() {
822 let provider = Builder::new("example.com", 119)
823 .name("测试服务器")
824 .build()
825 .unwrap();
826
827 assert_eq!(provider.name(), "测试服务器");
828 }
829
830 #[test]
831 fn test_provider_empty_name() {
832 let provider = Builder::new("example.com", 119).name("").build().unwrap();
833
834 assert_eq!(provider.name(), "");
835 }
836
837 #[test]
838 fn test_builder_idempotent_chaining() {
839 let builder = Builder::new("example.com", 119)
841 .name("First")
842 .name("Second")
843 .max_connections(10)
844 .max_connections(20);
845
846 assert_eq!(builder.name, Some("Second".to_string()));
847 assert_eq!(builder.max_size, 20);
848 }
849}