Skip to main content

forge_runtime/db/
pool.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::time::Duration;
4
5use sqlx::ConnectOptions;
6use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
7use tokio::task::JoinHandle;
8use tracing::log::LevelFilter;
9
10use forge_core::config::{DatabaseConfig, PoolConfig};
11use forge_core::error::{ForgeError, Result};
12
13struct ReplicaEntry {
14    pool: Arc<PgPool>,
15    healthy: Arc<AtomicBool>,
16}
17
18/// Database connection wrapper with health-aware replica routing and workload isolation.
19#[derive(Clone)]
20pub struct Database {
21    primary: Arc<PgPool>,
22    replicas: Arc<Vec<ReplicaEntry>>,
23    config: DatabaseConfig,
24    replica_counter: Arc<AtomicUsize>,
25    /// Isolated pool for background jobs, cron, daemons, workflows.
26    jobs_pool: Option<Arc<PgPool>>,
27    /// Isolated pool for observability writes.
28    observability_pool: Option<Arc<PgPool>>,
29    /// Isolated pool for long-running analytics queries.
30    analytics_pool: Option<Arc<PgPool>>,
31}
32
33impl Database {
34    /// Create a new database connection from configuration.
35    pub async fn from_config(config: &DatabaseConfig) -> Result<Self> {
36        Self::from_config_with_service(config, "forge").await
37    }
38
39    /// Create a new database connection with a service name for tracing.
40    ///
41    /// The service name is set as PostgreSQL's `application_name`, visible in
42    /// `pg_stat_activity` for correlating queries to the originating service.
43    pub async fn from_config_with_service(
44        config: &DatabaseConfig,
45        service_name: &str,
46    ) -> Result<Self> {
47        if config.url.is_empty() {
48            return Err(ForgeError::Database(
49                "database.url cannot be empty. Provide a PostgreSQL connection URL.".into(),
50            ));
51        }
52
53        // If pools.default overrides the primary pool size, use it
54        let primary_size = config
55            .pools
56            .default
57            .as_ref()
58            .map(|p| p.size)
59            .unwrap_or(config.pool_size);
60        let primary_timeout = config
61            .pools
62            .default
63            .as_ref()
64            .map(|p| p.timeout_secs)
65            .unwrap_or(config.pool_timeout_secs);
66
67        let primary_min = config
68            .pools
69            .default
70            .as_ref()
71            .map(|p| p.min_size)
72            .unwrap_or(config.min_pool_size);
73        let primary_test = config
74            .pools
75            .default
76            .as_ref()
77            .map(|p| p.test_before_acquire)
78            .unwrap_or(config.test_before_acquire);
79
80        let statement_timeout = config
81            .pools
82            .default
83            .as_ref()
84            .and_then(|p| p.statement_timeout_secs)
85            .unwrap_or(config.statement_timeout_secs);
86
87        let primary = Self::create_pool_with_statement_timeout(
88            &config.url,
89            primary_size,
90            primary_min,
91            primary_timeout,
92            statement_timeout,
93            primary_test,
94            service_name,
95        )
96        .await
97        .map_err(|e| ForgeError::Database(format!("Failed to connect to primary: {}", e)))?;
98
99        let mut replicas = Vec::new();
100        for replica_url in &config.replica_urls {
101            let pool = Self::create_pool(
102                replica_url,
103                config.pool_size / 2,
104                config.pool_timeout_secs,
105                service_name,
106            )
107            .await
108            .map_err(|e| ForgeError::Database(format!("Failed to connect to replica: {}", e)))?;
109            replicas.push(ReplicaEntry {
110                pool: Arc::new(pool),
111                healthy: Arc::new(AtomicBool::new(true)),
112            });
113        }
114
115        let jobs_pool =
116            Self::create_isolated_pool(&config.url, config.pools.jobs.as_ref(), service_name)
117                .await?;
118        let observability_pool = Self::create_isolated_pool(
119            &config.url,
120            config.pools.observability.as_ref(),
121            service_name,
122        )
123        .await?;
124        let analytics_pool =
125            Self::create_isolated_pool(&config.url, config.pools.analytics.as_ref(), service_name)
126                .await?;
127
128        Ok(Self {
129            primary: Arc::new(primary),
130            replicas: Arc::new(replicas),
131            config: config.clone(),
132            replica_counter: Arc::new(AtomicUsize::new(0)),
133            jobs_pool,
134            observability_pool,
135            analytics_pool,
136        })
137    }
138
139    fn connect_options(url: &str, service_name: &str) -> sqlx::Result<PgConnectOptions> {
140        let options: PgConnectOptions = url.parse()?;
141        Ok(options
142            .application_name(service_name)
143            .log_statements(LevelFilter::Off)
144            .log_slow_statements(LevelFilter::Warn, Duration::from_millis(500)))
145    }
146
147    fn connect_options_with_timeout(
148        url: &str,
149        service_name: &str,
150        statement_timeout_secs: u64,
151    ) -> sqlx::Result<PgConnectOptions> {
152        let options: PgConnectOptions = url.parse()?;
153        let mut opts = options
154            .application_name(service_name)
155            .log_statements(LevelFilter::Off)
156            .log_slow_statements(LevelFilter::Warn, Duration::from_millis(500));
157        if statement_timeout_secs > 0 {
158            // Set PostgreSQL statement_timeout to prevent unbounded query execution
159            opts = opts.options([("statement_timeout", &format!("{}s", statement_timeout_secs))]);
160        }
161        Ok(opts)
162    }
163
164    async fn create_pool(
165        url: &str,
166        size: u32,
167        timeout_secs: u64,
168        service_name: &str,
169    ) -> sqlx::Result<PgPool> {
170        Self::create_pool_with_opts(url, size, 0, timeout_secs, true, service_name).await
171    }
172
173    async fn create_pool_with_opts(
174        url: &str,
175        size: u32,
176        min_size: u32,
177        timeout_secs: u64,
178        test_before_acquire: bool,
179        service_name: &str,
180    ) -> sqlx::Result<PgPool> {
181        Self::create_pool_with_statement_timeout(
182            url,
183            size,
184            min_size,
185            timeout_secs,
186            0,
187            test_before_acquire,
188            service_name,
189        )
190        .await
191    }
192
193    async fn create_pool_with_statement_timeout(
194        url: &str,
195        size: u32,
196        min_size: u32,
197        timeout_secs: u64,
198        statement_timeout_secs: u64,
199        test_before_acquire: bool,
200        service_name: &str,
201    ) -> sqlx::Result<PgPool> {
202        let options = if statement_timeout_secs > 0 {
203            Self::connect_options_with_timeout(url, service_name, statement_timeout_secs)?
204        } else {
205            Self::connect_options(url, service_name)?
206        };
207        PgPoolOptions::new()
208            .max_connections(size)
209            .min_connections(min_size)
210            .acquire_timeout(Duration::from_secs(timeout_secs))
211            .test_before_acquire(test_before_acquire)
212            .connect_with(options)
213            .await
214    }
215
216    async fn create_isolated_pool(
217        url: &str,
218        config: Option<&PoolConfig>,
219        service_name: &str,
220    ) -> Result<Option<Arc<PgPool>>> {
221        let Some(cfg) = config else {
222            return Ok(None);
223        };
224        let pool = Self::create_pool_with_opts(
225            url,
226            cfg.size,
227            cfg.min_size,
228            cfg.timeout_secs,
229            cfg.test_before_acquire,
230            service_name,
231        )
232        .await
233        .map_err(|e| ForgeError::Database(format!("Failed to create isolated pool: {}", e)))?;
234        Ok(Some(Arc::new(pool)))
235    }
236
237    /// Get the primary pool for writes.
238    pub fn primary(&self) -> &PgPool {
239        &self.primary
240    }
241
242    /// Get a pool for reads. Skips unhealthy replicas, falls back to primary.
243    pub fn read_pool(&self) -> &PgPool {
244        if !self.config.read_from_replica || self.replicas.is_empty() {
245            return &self.primary;
246        }
247
248        let len = self.replicas.len();
249        let start = self.replica_counter.fetch_add(1, Ordering::Relaxed) % len;
250
251        // Try each replica starting from round-robin position
252        for offset in 0..len {
253            let idx = (start + offset) % len;
254            if let Some(entry) = self.replicas.get(idx)
255                && entry.healthy.load(Ordering::Relaxed)
256            {
257                return &entry.pool;
258            }
259        }
260
261        // All replicas unhealthy, fall back to primary
262        &self.primary
263    }
264
265    /// Pool for background jobs, cron, daemons, and workflows.
266    /// Falls back to primary if no isolated pool is configured.
267    pub fn jobs_pool(&self) -> &PgPool {
268        self.jobs_pool.as_deref().unwrap_or(&self.primary)
269    }
270
271    /// Pool for observability writes (metrics, slow query logs).
272    /// Falls back to primary if no isolated pool is configured.
273    pub fn observability_pool(&self) -> &PgPool {
274        self.observability_pool.as_deref().unwrap_or(&self.primary)
275    }
276
277    /// Pool for long-running analytics queries.
278    /// Falls back to primary if no isolated pool is configured.
279    pub fn analytics_pool(&self) -> &PgPool {
280        self.analytics_pool.as_deref().unwrap_or(&self.primary)
281    }
282
283    /// Start background health monitoring for replicas. Returns None if no replicas configured.
284    pub fn start_health_monitor(&self) -> Option<JoinHandle<()>> {
285        if self.replicas.is_empty() {
286            return None;
287        }
288
289        let replicas = Arc::clone(&self.replicas);
290        let handle = tokio::spawn(async move {
291            let mut interval = tokio::time::interval(Duration::from_secs(15));
292            loop {
293                interval.tick().await;
294                for entry in replicas.iter() {
295                    let ok = sqlx::query_scalar!("SELECT 1 as \"v!\"")
296                        .fetch_one(entry.pool.as_ref())
297                        .await
298                        .is_ok();
299                    let was_healthy = entry.healthy.swap(ok, Ordering::Relaxed);
300                    if was_healthy && !ok {
301                        tracing::warn!("Replica marked unhealthy");
302                    } else if !was_healthy && ok {
303                        tracing::info!("Replica recovered");
304                    }
305                }
306            }
307        });
308
309        Some(handle)
310    }
311
312    /// Create a Database wrapper from an existing pool (for testing).
313    #[cfg(test)]
314    pub fn from_pool(pool: PgPool) -> Self {
315        Self {
316            primary: Arc::new(pool),
317            replicas: Arc::new(Vec::new()),
318            config: DatabaseConfig::default(),
319            replica_counter: Arc::new(AtomicUsize::new(0)),
320            jobs_pool: None,
321            observability_pool: None,
322            analytics_pool: None,
323        }
324    }
325
326    /// Check database connectivity.
327    pub async fn health_check(&self) -> Result<()> {
328        sqlx::query_scalar!("SELECT 1 as \"v!\"")
329            .fetch_one(self.primary.as_ref())
330            .await
331            .map_err(|e| ForgeError::Database(format!("Health check failed: {}", e)))?;
332        Ok(())
333    }
334
335    /// Close all connections gracefully.
336    pub async fn close(&self) {
337        self.primary.close().await;
338        for entry in self.replicas.iter() {
339            entry.pool.close().await;
340        }
341        if let Some(ref p) = self.jobs_pool {
342            p.close().await;
343        }
344        if let Some(ref p) = self.observability_pool {
345            p.close().await;
346        }
347        if let Some(ref p) = self.analytics_pool {
348            p.close().await;
349        }
350    }
351}
352
353/// Type alias for the pool type.
354pub type DatabasePool = PgPool;
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_database_config_clone() {
362        let config = DatabaseConfig::new("postgres://localhost/test");
363
364        let cloned = config.clone();
365        assert_eq!(cloned.url(), config.url());
366        assert_eq!(cloned.pool_size, config.pool_size);
367    }
368}