forge_runtime/db/
pool.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::postgres::{PgPool, PgPoolOptions};
5
6use forge_core::config::DatabaseConfig;
7use forge_core::error::{ForgeError, Result};
8
9#[cfg(feature = "embedded-db")]
10use tokio::sync::OnceCell;
11
12#[cfg(feature = "embedded-db")]
13use tracing::info;
14
15/// Global embedded PostgreSQL instance (shared across all Database instances).
16#[cfg(feature = "embedded-db")]
17static EMBEDDED_PG: OnceCell<postgresql_embedded::PostgreSQL> = OnceCell::const_new();
18
19/// Database connection wrapper providing connection pooling.
20#[derive(Clone)]
21pub struct Database {
22    /// Primary connection pool.
23    primary: Arc<PgPool>,
24
25    /// Read replica pools (optional).
26    replicas: Vec<Arc<PgPool>>,
27
28    /// Configuration.
29    config: DatabaseConfig,
30
31    /// Counter for round-robin replica selection.
32    replica_counter: Arc<std::sync::atomic::AtomicUsize>,
33
34    /// Whether using embedded PostgreSQL.
35    embedded: bool,
36}
37
38impl Database {
39    /// Create a new database connection from configuration.
40    pub async fn from_config(config: &DatabaseConfig) -> Result<Self> {
41        let (url, embedded) = if config.embedded {
42            #[cfg(feature = "embedded-db")]
43            {
44                let url = Self::start_embedded_postgres(config.data_dir.as_deref()).await?;
45                (url, true)
46            }
47            #[cfg(not(feature = "embedded-db"))]
48            {
49                return Err(ForgeError::Database(
50                    "Embedded PostgreSQL requires the 'embedded-db' feature. \
51                    Build with: cargo build --features embedded-db"
52                        .to_string(),
53                ));
54            }
55        } else {
56            if config.url.is_empty() {
57                return Err(ForgeError::Database(
58                    "Database URL is required when embedded = false. Set database.url or database.embedded = true".to_string()
59                ));
60            }
61            (config.url.clone(), false)
62        };
63
64        let primary = Self::create_pool(&url, config.pool_size, config.pool_timeout_secs)
65            .await
66            .map_err(|e| ForgeError::Database(format!("Failed to connect to primary: {}", e)))?;
67
68        let mut replicas = Vec::new();
69        for replica_url in &config.replica_urls {
70            let pool =
71                Self::create_pool(replica_url, config.pool_size / 2, config.pool_timeout_secs)
72                    .await
73                    .map_err(|e| {
74                        ForgeError::Database(format!("Failed to connect to replica: {}", e))
75                    })?;
76            replicas.push(Arc::new(pool));
77        }
78
79        Ok(Self {
80            primary: Arc::new(primary),
81            replicas,
82            config: config.clone(),
83            replica_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
84            embedded,
85        })
86    }
87
88    /// Start embedded PostgreSQL and return the connection URL.
89    #[cfg(feature = "embedded-db")]
90    async fn start_embedded_postgres(data_dir: Option<&str>) -> Result<String> {
91        let pg = EMBEDDED_PG
92            .get_or_try_init(|| async {
93                info!("Starting embedded PostgreSQL...");
94
95                // Create settings with custom data directory if specified
96                let settings = if let Some(dir) = data_dir {
97                    postgresql_embedded::Settings {
98                        data_dir: std::path::PathBuf::from(dir),
99                        ..Default::default()
100                    }
101                } else {
102                    postgresql_embedded::Settings::default()
103                };
104
105                let mut pg = postgresql_embedded::PostgreSQL::new(settings);
106                pg.setup().await.map_err(|e| {
107                    ForgeError::Database(format!("Failed to setup embedded Postgres: {}", e))
108                })?;
109                pg.start().await.map_err(|e| {
110                    ForgeError::Database(format!("Failed to start embedded Postgres: {}", e))
111                })?;
112                info!("Embedded PostgreSQL started successfully");
113                Ok::<_, ForgeError>(pg)
114            })
115            .await?;
116
117        Ok(pg.settings().url("forge"))
118    }
119
120    /// Check if using embedded PostgreSQL.
121    pub fn is_embedded(&self) -> bool {
122        self.embedded
123    }
124
125    /// Create a connection pool with the given parameters.
126    async fn create_pool(url: &str, size: u32, timeout_secs: u64) -> sqlx::Result<PgPool> {
127        PgPoolOptions::new()
128            .max_connections(size)
129            .acquire_timeout(Duration::from_secs(timeout_secs))
130            .connect(url)
131            .await
132    }
133
134    /// Get the primary pool for writes.
135    pub fn primary(&self) -> &PgPool {
136        &self.primary
137    }
138
139    /// Get a pool for reads (uses replica if configured, otherwise primary).
140    pub fn read_pool(&self) -> &PgPool {
141        if self.config.read_from_replica && !self.replicas.is_empty() {
142            // Round-robin replica selection
143            let idx = self
144                .replica_counter
145                .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
146                % self.replicas.len();
147            &self.replicas[idx]
148        } else {
149            &self.primary
150        }
151    }
152
153    /// Check database connectivity.
154    pub async fn health_check(&self) -> Result<()> {
155        sqlx::query("SELECT 1")
156            .execute(self.primary.as_ref())
157            .await
158            .map_err(|e| ForgeError::Database(format!("Health check failed: {}", e)))?;
159        Ok(())
160    }
161
162    /// Close all connections gracefully.
163    pub async fn close(&self) {
164        self.primary.close().await;
165        for replica in &self.replicas {
166            replica.close().await;
167        }
168    }
169}
170
171/// Type alias for the pool type.
172pub type DatabasePool = PgPool;
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    // Integration tests require a real PostgreSQL connection
179    // These are placeholder tests that don't require a database
180
181    #[test]
182    fn test_database_config_clone() {
183        let config = DatabaseConfig {
184            url: "postgres://localhost/test".to_string(),
185            pool_size: 10,
186            ..Default::default()
187        };
188
189        let cloned = config.clone();
190        assert_eq!(cloned.url, config.url);
191        assert_eq!(cloned.pool_size, config.pool_size);
192    }
193}