Skip to main content

forge_runtime/db/
pool.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::postgres::{PgPool, PgPoolOptions};
5
6use forge_core::config::DatabaseConfig;
7use forge_core::error::{ForgeError, Result};
8
9/// Database connection wrapper providing connection pooling.
10#[derive(Clone)]
11pub struct Database {
12    /// Primary connection pool.
13    primary: Arc<PgPool>,
14
15    /// Read replica pools (optional).
16    replicas: Vec<Arc<PgPool>>,
17
18    /// Configuration.
19    config: DatabaseConfig,
20
21    /// Counter for round-robin replica selection.
22    replica_counter: Arc<std::sync::atomic::AtomicUsize>,
23}
24
25impl Database {
26    /// Create a new database connection from configuration.
27    pub async fn from_config(config: &DatabaseConfig) -> Result<Self> {
28        if config.url.is_empty() {
29            return Err(ForgeError::Database(
30                "database.url cannot be empty. Provide a PostgreSQL connection URL.".into(),
31            ));
32        }
33
34        let primary = Self::create_pool(&config.url, config.pool_size, config.pool_timeout_secs)
35            .await
36            .map_err(|e| ForgeError::Database(format!("Failed to connect to primary: {}", e)))?;
37
38        let mut replicas = Vec::new();
39        for replica_url in &config.replica_urls {
40            let pool =
41                Self::create_pool(replica_url, config.pool_size / 2, config.pool_timeout_secs)
42                    .await
43                    .map_err(|e| {
44                        ForgeError::Database(format!("Failed to connect to replica: {}", e))
45                    })?;
46            replicas.push(Arc::new(pool));
47        }
48
49        Ok(Self {
50            primary: Arc::new(primary),
51            replicas,
52            config: config.clone(),
53            replica_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
54        })
55    }
56
57    /// Create a connection pool with the given parameters.
58    async fn create_pool(url: &str, size: u32, timeout_secs: u64) -> sqlx::Result<PgPool> {
59        PgPoolOptions::new()
60            .max_connections(size)
61            .acquire_timeout(Duration::from_secs(timeout_secs))
62            .connect(url)
63            .await
64    }
65
66    /// Get the primary pool for writes.
67    pub fn primary(&self) -> &PgPool {
68        &self.primary
69    }
70
71    /// Get a pool for reads (uses replica if configured, otherwise primary).
72    pub fn read_pool(&self) -> &PgPool {
73        if self.config.read_from_replica && !self.replicas.is_empty() {
74            let idx = self
75                .replica_counter
76                .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
77                % self.replicas.len();
78            self.replicas.get(idx).unwrap_or(&self.primary)
79        } else {
80            &self.primary
81        }
82    }
83
84    /// Create a Database wrapper from an existing pool (for testing).
85    #[cfg(test)]
86    pub fn from_pool(pool: PgPool) -> Self {
87        Self {
88            primary: Arc::new(pool),
89            replicas: Vec::new(),
90            config: DatabaseConfig::default(),
91            replica_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
92        }
93    }
94
95    /// Check database connectivity.
96    pub async fn health_check(&self) -> Result<()> {
97        sqlx::query("SELECT 1")
98            .execute(self.primary.as_ref())
99            .await
100            .map_err(|e| ForgeError::Database(format!("Health check failed: {}", e)))?;
101        Ok(())
102    }
103
104    /// Close all connections gracefully.
105    pub async fn close(&self) {
106        self.primary.close().await;
107        for replica in &self.replicas {
108            replica.close().await;
109        }
110    }
111}
112
113/// Type alias for the pool type.
114pub type DatabasePool = PgPool;
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_database_config_clone() {
122        let config = DatabaseConfig::new("postgres://localhost/test");
123
124        let cloned = config.clone();
125        assert_eq!(cloned.url(), config.url());
126        assert_eq!(cloned.pool_size, config.pool_size);
127    }
128}