Skip to main content

forge_runtime/db/
pool.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::postgres::{PgPool, PgPoolOptions};
5
6use forge_core::config::{DatabaseConfig, DatabaseSource};
7use forge_core::error::{ForgeError, Result};
8
9#[cfg(feature = "embedded-db")]
10use tokio::sync::OnceCell;
11
12#[cfg(feature = "embedded-db")]
13use tracing::info;
14
15/// Global embedded PostgreSQL instance (shared across all Database instances).
16#[cfg(feature = "embedded-db")]
17static EMBEDDED_PG: OnceCell<postgresql_embedded::PostgreSQL> = OnceCell::const_new();
18
19/// Database connection wrapper providing connection pooling.
20#[derive(Clone)]
21pub struct Database {
22    /// Primary connection pool.
23    primary: Arc<PgPool>,
24
25    /// Read replica pools (optional).
26    replicas: Vec<Arc<PgPool>>,
27
28    /// Configuration.
29    config: DatabaseConfig,
30
31    /// Counter for round-robin replica selection.
32    replica_counter: Arc<std::sync::atomic::AtomicUsize>,
33
34    /// Whether using embedded PostgreSQL.
35    embedded: bool,
36}
37
38impl Database {
39    /// Create a new database connection from configuration.
40    pub async fn from_config(config: &DatabaseConfig) -> Result<Self> {
41        let (url, embedded) = match &config.source {
42            DatabaseSource::Remote { url } => {
43                if url.is_empty() {
44                    return Err(ForgeError::Database(
45                        "database.url cannot be empty. Provide a PostgreSQL connection URL.".into(),
46                    ));
47                }
48                (url.clone(), false)
49            }
50            DatabaseSource::Embedded { data_dir } => {
51                #[cfg(feature = "embedded-db")]
52                {
53                    let url = Self::start_embedded_postgres(data_dir.as_deref()).await?;
54                    (url, true)
55                }
56                #[cfg(not(feature = "embedded-db"))]
57                {
58                    let _ = data_dir;
59                    return Err(ForgeError::Database(
60                        "Embedded PostgreSQL requires the 'embedded-db' feature. \
61                        Build with: cargo build --features embedded-db"
62                            .to_string(),
63                    ));
64                }
65            }
66        };
67
68        let primary = Self::create_pool(&url, config.pool_size, config.pool_timeout_secs)
69            .await
70            .map_err(|e| ForgeError::Database(format!("Failed to connect to primary: {}", e)))?;
71
72        let mut replicas = Vec::new();
73        for replica_url in &config.replica_urls {
74            let pool =
75                Self::create_pool(replica_url, config.pool_size / 2, config.pool_timeout_secs)
76                    .await
77                    .map_err(|e| {
78                        ForgeError::Database(format!("Failed to connect to replica: {}", e))
79                    })?;
80            replicas.push(Arc::new(pool));
81        }
82
83        Ok(Self {
84            primary: Arc::new(primary),
85            replicas,
86            config: config.clone(),
87            replica_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
88            embedded,
89        })
90    }
91
92    /// Start embedded PostgreSQL and return the connection URL.
93    #[cfg(feature = "embedded-db")]
94    async fn start_embedded_postgres(data_dir: Option<&str>) -> Result<String> {
95        let pg = EMBEDDED_PG
96            .get_or_try_init(|| async {
97                info!("Starting embedded PostgreSQL...");
98
99                // Create settings with custom data directory if specified
100                let settings = if let Some(dir) = data_dir {
101                    postgresql_embedded::Settings {
102                        data_dir: std::path::PathBuf::from(dir),
103                        ..Default::default()
104                    }
105                } else {
106                    postgresql_embedded::Settings::default()
107                };
108
109                let mut pg = postgresql_embedded::PostgreSQL::new(settings);
110                pg.setup().await.map_err(|e| {
111                    ForgeError::Database(format!("Failed to setup embedded Postgres: {}", e))
112                })?;
113                pg.start().await.map_err(|e| {
114                    ForgeError::Database(format!("Failed to start embedded Postgres: {}", e))
115                })?;
116                info!("Embedded PostgreSQL started successfully");
117                Ok::<_, ForgeError>(pg)
118            })
119            .await?;
120
121        Ok(pg.settings().url("forge"))
122    }
123
124    /// Check if using embedded PostgreSQL.
125    pub fn is_embedded(&self) -> bool {
126        self.embedded
127    }
128
129    /// Create a connection pool with the given parameters.
130    async fn create_pool(url: &str, size: u32, timeout_secs: u64) -> sqlx::Result<PgPool> {
131        PgPoolOptions::new()
132            .max_connections(size)
133            .acquire_timeout(Duration::from_secs(timeout_secs))
134            .connect(url)
135            .await
136    }
137
138    /// Get the primary pool for writes.
139    pub fn primary(&self) -> &PgPool {
140        &self.primary
141    }
142
143    /// Get a pool for reads (uses replica if configured, otherwise primary).
144    pub fn read_pool(&self) -> &PgPool {
145        if self.config.read_from_replica && !self.replicas.is_empty() {
146            // Round-robin replica selection
147            let idx = self
148                .replica_counter
149                .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
150                % self.replicas.len();
151            self.replicas.get(idx).unwrap_or(&self.primary)
152        } else {
153            &self.primary
154        }
155    }
156
157    /// Create a Database wrapper from an existing pool (for testing).
158    #[cfg(test)]
159    pub fn from_pool(pool: PgPool) -> Self {
160        Self {
161            primary: Arc::new(pool),
162            replicas: Vec::new(),
163            config: DatabaseConfig::default(),
164            replica_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
165            embedded: false,
166        }
167    }
168
169    /// Check database connectivity.
170    pub async fn health_check(&self) -> Result<()> {
171        sqlx::query("SELECT 1")
172            .execute(self.primary.as_ref())
173            .await
174            .map_err(|e| ForgeError::Database(format!("Health check failed: {}", e)))?;
175        Ok(())
176    }
177
178    /// Close all connections gracefully.
179    pub async fn close(&self) {
180        self.primary.close().await;
181        for replica in &self.replicas {
182            replica.close().await;
183        }
184    }
185}
186
187/// Type alias for the pool type.
188pub type DatabasePool = PgPool;
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    // Integration tests require a real PostgreSQL connection
195    // These are placeholder tests that don't require a database
196
197    #[test]
198    fn test_database_config_clone() {
199        let config = DatabaseConfig::remote("postgres://localhost/test");
200
201        let cloned = config.clone();
202        assert_eq!(cloned.url(), config.url());
203        assert_eq!(cloned.pool_size, config.pool_size);
204    }
205}