1use crate::backends::{
7 DatabaseBackendRegistry, DatabaseBackendType, DatabasePool as DatabasePoolTrait,
8 DatabasePoolConfig, DatabasePoolStats,
9};
10use crate::error::ModelError;
11use elif_core::providers::ProviderError;
12use elif_core::{Container, ContainerBuilder, ServiceProvider};
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17#[derive(Debug, thiserror::Error)]
19pub enum PoolError {
20 #[error("Connection acquisition failed: {0}")]
21 AcquisitionFailed(String),
22
23 #[error("Pool is closed")]
24 PoolClosed,
25
26 #[error("Connection timeout after {timeout}s")]
27 ConnectionTimeout { timeout: u64 },
28
29 #[error("Pool exhausted: all {max_connections} connections in use")]
30 PoolExhausted { max_connections: u32 },
31
32 #[error("Health check failed: {reason}")]
33 HealthCheckFailed { reason: String },
34
35 #[error("Configuration error: {message}")]
36 ConfigurationError { message: String },
37}
38
39impl From<PoolError> for ModelError {
41 fn from(err: PoolError) -> Self {
42 match err {
43 PoolError::AcquisitionFailed(err_msg) => {
44 ModelError::Connection(format!("Database connection failed: {}", err_msg))
45 }
46 PoolError::PoolClosed => ModelError::Connection("Database pool is closed".to_string()),
47 PoolError::ConnectionTimeout { timeout } => {
48 ModelError::Connection(format!("Database connection timeout after {}s", timeout))
49 }
50 PoolError::PoolExhausted { max_connections } => ModelError::Connection(format!(
51 "Database pool exhausted: {} connections in use",
52 max_connections
53 )),
54 PoolError::HealthCheckFailed { reason } => {
55 ModelError::Connection(format!("Database health check failed: {}", reason))
56 }
57 PoolError::ConfigurationError { message } => {
58 ModelError::Connection(format!("Database configuration error: {}", message))
59 }
60 }
61 }
62}
63
64pub type PoolConfig = DatabasePoolConfig;
66
67pub type PoolStats = DatabasePoolStats;
69
70#[derive(Debug, Clone)]
72pub struct ExtendedPoolStats {
73 pub pool_stats: DatabasePoolStats,
74 pub acquire_count: u64,
75 pub acquire_errors: u64,
76 pub created_at: Instant,
77}
78
79#[derive(Debug, Clone)]
81pub struct PoolHealthReport {
82 pub check_duration: Duration,
83 pub total_check_time: Duration,
84 pub pool_size: u32,
85 pub idle_connections: u32,
86 pub active_connections: u32,
87 pub total_acquires: u64,
88 pub total_errors: u64,
89 pub error_rate: f64,
90 pub created_at: Instant,
91}
92
93pub struct ManagedPool {
95 pool: Arc<dyn DatabasePoolTrait>,
96 config: DatabasePoolConfig,
97 acquire_count: AtomicU64,
98 acquire_errors: AtomicU64,
99 created_at: Instant,
100}
101
102impl ManagedPool {
103 pub fn new(pool: Arc<dyn DatabasePoolTrait>, config: DatabasePoolConfig) -> Self {
104 Self {
105 pool,
106 config,
107 acquire_count: AtomicU64::new(0),
108 acquire_errors: AtomicU64::new(0),
109 created_at: Instant::now(),
110 }
111 }
112
113 pub fn pool(&self) -> &dyn DatabasePoolTrait {
115 &*self.pool
116 }
117
118 pub async fn acquire(&self) -> Result<Box<dyn crate::backends::DatabaseConnection>, PoolError> {
120 self.acquire_count.fetch_add(1, Ordering::Relaxed);
121
122 match self.pool.acquire().await {
123 Ok(conn) => {
124 let stats = self.pool.stats();
125 tracing::debug!(
126 "Database connection acquired successfully (total: {}, idle: {})",
127 stats.total_connections,
128 stats.idle_connections
129 );
130 Ok(conn)
131 }
132 Err(e) => {
133 self.acquire_errors.fetch_add(1, Ordering::Relaxed);
134 let pool_error = PoolError::AcquisitionFailed(e.to_string());
135 tracing::error!("Failed to acquire database connection: {}", pool_error);
136 Err(pool_error)
137 }
138 }
139 }
140
141 pub async fn execute(
143 &self,
144 sql: &str,
145 params: &[crate::backends::DatabaseValue],
146 ) -> Result<u64, PoolError> {
147 self.pool
148 .execute(sql, params)
149 .await
150 .map_err(|e| PoolError::AcquisitionFailed(e.to_string()))
151 }
152
153 pub async fn begin_transaction(
155 &self,
156 ) -> Result<Box<dyn crate::backends::DatabaseTransaction>, PoolError> {
157 self.acquire_count.fetch_add(1, Ordering::Relaxed);
158
159 match self.pool.begin_transaction().await {
160 Ok(tx) => {
161 tracing::debug!("Database transaction started successfully");
162 Ok(tx)
163 }
164 Err(e) => {
165 self.acquire_errors.fetch_add(1, Ordering::Relaxed);
166 let pool_error = PoolError::AcquisitionFailed(e.to_string());
167 tracing::error!("Failed to begin database transaction: {}", pool_error);
168 Err(pool_error)
169 }
170 }
171 }
172
173 pub fn extended_stats(&self) -> ExtendedPoolStats {
175 ExtendedPoolStats {
176 pool_stats: self.pool.stats(),
177 acquire_count: self.acquire_count.load(Ordering::Relaxed),
178 acquire_errors: self.acquire_errors.load(Ordering::Relaxed),
179 created_at: self.created_at,
180 }
181 }
182
183 pub fn stats(&self) -> DatabasePoolStats {
185 self.pool.stats()
186 }
187
188 pub async fn health_check(&self) -> Result<Duration, PoolError> {
190 match self.pool.health_check().await {
191 Ok(duration) => {
192 tracing::debug!("Database health check passed in {:?}", duration);
193 Ok(duration)
194 }
195 Err(e) => {
196 let pool_error = PoolError::HealthCheckFailed {
197 reason: e.to_string(),
198 };
199 tracing::error!("Database health check failed: {}", pool_error);
200 Err(pool_error)
201 }
202 }
203 }
204
205 pub async fn detailed_health_check(&self) -> Result<PoolHealthReport, PoolError> {
207 let start = Instant::now();
208 let _initial_stats = self.extended_stats();
209
210 let check_duration = self.health_check().await?;
212
213 let final_stats = self.extended_stats();
215
216 let report = PoolHealthReport {
217 check_duration,
218 total_check_time: start.elapsed(),
219 pool_size: final_stats.pool_stats.total_connections,
220 idle_connections: final_stats.pool_stats.idle_connections,
221 active_connections: final_stats.pool_stats.active_connections,
222 total_acquires: final_stats.acquire_count,
223 total_errors: final_stats.acquire_errors,
224 error_rate: if final_stats.acquire_count > 0 {
225 (final_stats.acquire_errors as f64 / final_stats.acquire_count as f64) * 100.0
226 } else {
227 0.0
228 },
229 created_at: final_stats.created_at,
230 };
231
232 tracing::info!("Database pool health report: {:?}", report);
233 Ok(report)
234 }
235
236 pub fn config(&self) -> &DatabasePoolConfig {
238 &self.config
239 }
240
241 pub async fn close(&self) -> Result<(), PoolError> {
243 self.pool
244 .close()
245 .await
246 .map_err(|e| PoolError::ConfigurationError {
247 message: e.to_string(),
248 })
249 }
250}
251
252pub struct DatabaseServiceProvider {
254 database_url: String,
255 config: DatabasePoolConfig,
256 service_name: String,
257 backend_registry: Arc<DatabaseBackendRegistry>,
258}
259
260impl DatabaseServiceProvider {
261 pub fn new(database_url: String) -> Self {
262 let mut registry = DatabaseBackendRegistry::new();
263 registry.register(
264 DatabaseBackendType::PostgreSQL,
265 Arc::new(crate::backends::PostgresBackend::new()),
266 );
267
268 Self {
269 database_url,
270 config: DatabasePoolConfig::default(),
271 service_name: "database_pool".to_string(),
272 backend_registry: Arc::new(registry),
273 }
274 }
275
276 pub fn with_registry(mut self, registry: Arc<DatabaseBackendRegistry>) -> Self {
277 self.backend_registry = registry;
278 self
279 }
280
281 pub fn with_config(mut self, config: DatabasePoolConfig) -> Self {
282 self.config = config;
283 self
284 }
285
286 pub fn with_max_connections(mut self, max_connections: u32) -> Self {
287 self.config.max_connections = max_connections;
288 self
289 }
290
291 pub fn with_min_connections(mut self, min_connections: u32) -> Self {
292 self.config.min_connections = min_connections;
293 self
294 }
295
296 pub fn with_acquire_timeout(mut self, timeout_seconds: u64) -> Self {
297 self.config.acquire_timeout_seconds = timeout_seconds;
298 self
299 }
300
301 pub fn with_idle_timeout(mut self, timeout_seconds: Option<u64>) -> Self {
302 self.config.idle_timeout_seconds = timeout_seconds;
303 self
304 }
305
306 pub fn with_max_lifetime(mut self, lifetime_seconds: Option<u64>) -> Self {
307 self.config.max_lifetime_seconds = lifetime_seconds;
308 self
309 }
310
311 pub fn with_test_before_acquire(mut self, enabled: bool) -> Self {
312 self.config.test_before_acquire = enabled;
313 self
314 }
315
316 pub fn with_service_name(mut self, service_name: String) -> Self {
317 self.service_name = service_name;
318 self
319 }
320
321 pub async fn create_pool(&self) -> Result<Arc<dyn DatabasePoolTrait>, ModelError> {
323 self.backend_registry
324 .create_pool(&self.database_url, self.config.clone())
325 .await
326 .map_err(|e| ModelError::Connection(e.to_string()))
327 }
328
329 pub async fn create_managed_pool(&self) -> Result<ManagedPool, ModelError> {
331 let pool = self.create_pool().await?;
332 Ok(ManagedPool::new(pool, self.config.clone()))
333 }
334
335 pub fn database_url(&self) -> &str {
337 &self.database_url
338 }
339
340 pub fn service_name(&self) -> &str {
342 &self.service_name
343 }
344
345 pub fn config(&self) -> &DatabasePoolConfig {
347 &self.config
348 }
349}
350
351impl ServiceProvider for DatabaseServiceProvider {
352 fn name(&self) -> &'static str {
353 "DatabaseServiceProvider"
354 }
355
356 fn register(&self, builder: ContainerBuilder) -> Result<ContainerBuilder, ProviderError> {
357 tracing::debug!(
360 "Registering database service with URL: {}",
361 self.database_url
362 .split('@')
363 .next_back()
364 .unwrap_or("unknown")
365 );
366 Ok(builder)
367 }
368
369 fn boot(&self, _container: &Container) -> Result<(), ProviderError> {
370 tracing::info!("✅ Database service provider booted successfully");
371 tracing::debug!("Database pool configuration: max_connections={}, min_connections={}, acquire_timeout={}s, idle_timeout={:?}s, max_lifetime={:?}s, test_before_acquire={}",
372 self.config.max_connections, self.config.min_connections, self.config.acquire_timeout_seconds,
373 self.config.idle_timeout_seconds, self.config.max_lifetime_seconds, self.config.test_before_acquire);
374 Ok(())
375 }
376}
377
378pub async fn create_database_pool(
380 database_url: &str,
381) -> Result<Arc<dyn DatabasePoolTrait>, ModelError> {
382 create_database_pool_with_config(database_url, &DatabasePoolConfig::default()).await
383}
384
385pub async fn create_database_pool_with_config(
387 database_url: &str,
388 config: &DatabasePoolConfig,
389) -> Result<Arc<dyn DatabasePoolTrait>, ModelError> {
390 tracing::debug!("Creating database pool with config: max={}, min={}, timeout={}s, idle_timeout={:?}s, max_lifetime={:?}s, test_before_acquire={}",
391 config.max_connections, config.min_connections, config.acquire_timeout_seconds,
392 config.idle_timeout_seconds, config.max_lifetime_seconds, config.test_before_acquire);
393
394 let mut registry = DatabaseBackendRegistry::new();
395 registry.register(
396 DatabaseBackendType::PostgreSQL,
397 Arc::new(crate::backends::PostgresBackend::new()),
398 );
399
400 let pool = registry
401 .create_pool(database_url, config.clone())
402 .await
403 .map_err(|e| {
404 tracing::error!("Failed to create database pool: {}", e);
405 ModelError::Connection(format!("Failed to create database pool: {}", e))
406 })?;
407
408 tracing::info!(
409 "✅ Database pool created successfully with {} max connections",
410 config.max_connections
411 );
412 Ok(pool)
413}
414
415pub struct PoolRegistry {
417 pools: std::collections::HashMap<String, Arc<ManagedPool>>,
418}
419
420pub type DatabasePool = ManagedPool;
422
423impl PoolRegistry {
424 pub fn new() -> Self {
425 Self {
426 pools: std::collections::HashMap::new(),
427 }
428 }
429
430 pub fn register(&mut self, name: String, pool: Arc<ManagedPool>) {
432 tracing::info!("Registering database pool: {}", name);
433 self.pools.insert(name, pool);
434 }
435
436 pub fn get(&self, name: &str) -> Option<Arc<ManagedPool>> {
438 self.pools.get(name).cloned()
439 }
440
441 pub fn get_default(&self) -> Option<Arc<ManagedPool>> {
443 self.get("database_pool")
444 }
445
446 pub fn pool_names(&self) -> Vec<&String> {
448 self.pools.keys().collect()
449 }
450
451 pub fn get_all_stats(&self) -> std::collections::HashMap<String, DatabasePoolStats> {
453 self.pools
454 .iter()
455 .map(|(name, pool)| (name.clone(), pool.stats()))
456 .collect()
457 }
458
459 pub async fn health_check_all(
461 &self,
462 ) -> std::collections::HashMap<String, Result<Duration, PoolError>> {
463 let mut results = std::collections::HashMap::new();
464
465 for (name, pool) in &self.pools {
466 let result = pool.health_check().await;
467 results.insert(name.clone(), result);
468 }
469
470 results
471 }
472}
473
474impl Default for PoolRegistry {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480pub async fn get_database_pool(
482 _container: &Container,
483) -> Result<Arc<dyn DatabasePoolTrait>, String> {
484 Err("Database pool not yet integrated with current Container implementation - use PoolRegistry for now".to_string())
487}
488
489pub async fn get_named_database_pool(
491 _container: &Container,
492 service_name: &str,
493) -> Result<Arc<dyn DatabasePoolTrait>, String> {
494 Err(format!("Database pool '{}' not yet integrated with current Container implementation - use PoolRegistry for now", service_name))
497}
498
499pub async fn create_default_pool_registry(database_url: &str) -> Result<PoolRegistry, ModelError> {
501 let mut registry = PoolRegistry::new();
502
503 let provider = DatabaseServiceProvider::new(database_url.to_string());
504 let managed_pool = provider.create_managed_pool().await?;
505
506 registry.register("database_pool".to_string(), Arc::new(managed_pool));
507
508 tracing::info!("Created default pool registry with database_pool");
509 Ok(registry)
510}
511
512pub async fn create_custom_pool_registry(
514 pools: Vec<(String, String, DatabasePoolConfig)>,
515) -> Result<PoolRegistry, ModelError> {
516 let mut registry = PoolRegistry::new();
517
518 for (name, database_url, config) in pools {
519 let provider = DatabaseServiceProvider::new(database_url).with_config(config);
520 let managed_pool = provider.create_managed_pool().await?;
521
522 registry.register(name, Arc::new(managed_pool));
523 }
524
525 tracing::info!(
526 "Created custom pool registry with {} pools",
527 registry.pool_names().len()
528 );
529 Ok(registry)
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
537 fn test_pool_config_defaults() {
538 let config = DatabasePoolConfig::default();
539 assert_eq!(config.max_connections, 10);
540 assert_eq!(config.min_connections, 1);
541 assert_eq!(config.acquire_timeout_seconds, 30);
542 assert_eq!(config.idle_timeout_seconds, Some(600));
543 assert_eq!(config.max_lifetime_seconds, Some(1800));
544 assert!(config.test_before_acquire);
545 }
546
547 #[test]
548 fn test_database_service_provider_creation() {
549 let provider = DatabaseServiceProvider::new("postgresql://test".to_string());
550 assert_eq!(provider.database_url(), "postgresql://test");
551 assert_eq!(provider.config().max_connections, 10);
552 assert_eq!(provider.config().min_connections, 1);
553 assert_eq!(provider.config().acquire_timeout_seconds, 30);
554 assert_eq!(provider.service_name(), "database_pool");
555 }
556
557 #[test]
558 fn test_database_service_provider_configuration() {
559 let provider = DatabaseServiceProvider::new("postgresql://test".to_string())
560 .with_max_connections(20)
561 .with_min_connections(5)
562 .with_acquire_timeout(60)
563 .with_idle_timeout(Some(300))
564 .with_max_lifetime(Some(900))
565 .with_test_before_acquire(false)
566 .with_service_name("custom_db".to_string());
567
568 assert_eq!(provider.config().max_connections, 20);
569 assert_eq!(provider.config().min_connections, 5);
570 assert_eq!(provider.config().acquire_timeout_seconds, 60);
571 assert_eq!(provider.config().idle_timeout_seconds, Some(300));
572 assert_eq!(provider.config().max_lifetime_seconds, Some(900));
573 assert!(!provider.config().test_before_acquire);
574 assert_eq!(provider.service_name(), "custom_db");
575 }
576
577 #[test]
578 fn test_provider_name() {
579 let provider = DatabaseServiceProvider::new("postgresql://test".to_string());
580 assert_eq!(provider.name(), "DatabaseServiceProvider");
581 }
582
583 #[test]
584 fn test_database_service_provider_accessors() {
585 let provider = DatabaseServiceProvider::new("postgresql://test_db".to_string())
586 .with_service_name("custom_service".to_string());
587
588 assert_eq!(provider.database_url(), "postgresql://test_db");
589 assert_eq!(provider.service_name(), "custom_service");
590 }
591
592 #[test]
593 fn test_database_service_provider_defaults() {
594 let provider = DatabaseServiceProvider::new("postgresql://test".to_string());
595
596 assert_eq!(provider.config().max_connections, 10);
597 assert_eq!(provider.config().min_connections, 1);
598 assert_eq!(provider.config().acquire_timeout_seconds, 30);
599 assert_eq!(provider.config().idle_timeout_seconds, Some(600));
600 assert_eq!(provider.config().max_lifetime_seconds, Some(1800));
601 assert!(provider.config().test_before_acquire);
602 assert_eq!(provider.service_name(), "database_pool");
603 }
604
605 #[test]
606 fn test_database_service_provider_fluent_configuration() {
607 let provider = DatabaseServiceProvider::new("postgresql://test".to_string())
608 .with_max_connections(50)
609 .with_min_connections(10)
610 .with_acquire_timeout(120)
611 .with_idle_timeout(None)
612 .with_max_lifetime(Some(3600))
613 .with_service_name("production_db".to_string());
614
615 assert_eq!(provider.config().max_connections, 50);
616 assert_eq!(provider.config().min_connections, 10);
617 assert_eq!(provider.config().acquire_timeout_seconds, 120);
618 assert_eq!(provider.config().idle_timeout_seconds, None);
619 assert_eq!(provider.config().max_lifetime_seconds, Some(3600));
620 assert_eq!(provider.service_name(), "production_db");
621 assert_eq!(provider.database_url(), "postgresql://test");
622 }
623
624 #[test]
625 fn test_pool_config_creation() {
626 let config = PoolConfig::default();
627 assert_eq!(config.max_connections, 10);
628 assert_eq!(config.min_connections, 1);
629 assert_eq!(config.acquire_timeout_seconds, 30);
630 assert_eq!(config.idle_timeout_seconds, Some(600));
631 assert_eq!(config.max_lifetime_seconds, Some(1800));
632 assert!(config.test_before_acquire);
633 }
634
635 #[test]
636 fn test_managed_pool_config_access() {
637 let config = DatabasePoolConfig {
638 max_connections: 5,
639 min_connections: 2,
640 acquire_timeout_seconds: 60,
641 idle_timeout_seconds: None,
642 max_lifetime_seconds: Some(3600),
643 test_before_acquire: false,
644 };
645
646 assert_eq!(config.max_connections, 5);
648 assert_eq!(config.min_connections, 2);
649 assert_eq!(config.acquire_timeout_seconds, 60);
650 assert_eq!(config.idle_timeout_seconds, None);
651 assert_eq!(config.max_lifetime_seconds, Some(3600));
652 assert!(!config.test_before_acquire);
653 }
654
655 #[test]
656 fn test_pool_config_builder() {
657 let config = DatabasePoolConfig {
658 max_connections: 20,
659 min_connections: 2,
660 acquire_timeout_seconds: 45,
661 idle_timeout_seconds: Some(300),
662 max_lifetime_seconds: Some(1200),
663 test_before_acquire: false,
664 };
665
666 let provider = DatabaseServiceProvider::new("postgresql://test".to_string())
667 .with_config(config.clone());
668
669 assert_eq!(provider.config().max_connections, 20);
670 assert_eq!(provider.config().min_connections, 2);
671 assert_eq!(provider.config().acquire_timeout_seconds, 45);
672 assert_eq!(provider.config().idle_timeout_seconds, Some(300));
673 assert_eq!(provider.config().max_lifetime_seconds, Some(1200));
674 assert!(!provider.config().test_before_acquire);
675 }
676
677 #[test]
678 fn test_pool_registry_creation() {
679 let registry = PoolRegistry::new();
680 assert!(registry.get_default().is_none());
681 assert!(registry.pool_names().is_empty());
682
683 let stats = registry.get_all_stats();
685 assert!(stats.is_empty());
686 }
687
688 #[test]
689 fn test_pool_error_types() {
690 let timeout_error = PoolError::ConnectionTimeout { timeout: 30 };
691 let pool_closed_error = PoolError::PoolClosed;
692 let exhausted_error = PoolError::PoolExhausted {
693 max_connections: 10,
694 };
695
696 assert!(timeout_error.to_string().contains("timeout"));
698 assert!(pool_closed_error.to_string().contains("closed"));
699 assert!(exhausted_error.to_string().contains("exhausted"));
700 }
701
702 #[test]
703 fn test_pool_error_model_conversion() {
704 let pool_error = PoolError::PoolExhausted { max_connections: 5 };
705 let model_error: ModelError = pool_error.into();
706
707 assert!(matches!(model_error, ModelError::Connection(_)));
709 }
710}