1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::time::Duration;
4
5use sqlx::ConnectOptions;
6use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
7use tokio::task::JoinHandle;
8use tracing::log::LevelFilter;
9
10use forge_core::config::{DatabaseConfig, PoolConfig};
11use forge_core::error::{ForgeError, Result};
12
13struct ReplicaEntry {
14 pool: Arc<PgPool>,
15 healthy: Arc<AtomicBool>,
16}
17
18#[derive(Clone)]
20pub struct Database {
21 primary: Arc<PgPool>,
22 replicas: Arc<Vec<ReplicaEntry>>,
23 config: DatabaseConfig,
24 replica_counter: Arc<AtomicUsize>,
25 jobs_pool: Option<Arc<PgPool>>,
27 observability_pool: Option<Arc<PgPool>>,
29 analytics_pool: Option<Arc<PgPool>>,
31}
32
33impl Database {
34 pub async fn from_config(config: &DatabaseConfig) -> Result<Self> {
36 Self::from_config_with_service(config, "forge").await
37 }
38
39 pub async fn from_config_with_service(
44 config: &DatabaseConfig,
45 service_name: &str,
46 ) -> Result<Self> {
47 if config.url.is_empty() {
48 return Err(ForgeError::Database(
49 "database.url cannot be empty. Provide a PostgreSQL connection URL.".into(),
50 ));
51 }
52
53 let primary_size = config
55 .pools
56 .default
57 .as_ref()
58 .map(|p| p.size)
59 .unwrap_or(config.pool_size);
60 let primary_timeout = config
61 .pools
62 .default
63 .as_ref()
64 .map(|p| p.timeout_secs)
65 .unwrap_or(config.pool_timeout_secs);
66
67 let primary_min = config
68 .pools
69 .default
70 .as_ref()
71 .map(|p| p.min_size)
72 .unwrap_or(config.min_pool_size);
73 let primary_test = config
74 .pools
75 .default
76 .as_ref()
77 .map(|p| p.test_before_acquire)
78 .unwrap_or(config.test_before_acquire);
79
80 let statement_timeout = config
81 .pools
82 .default
83 .as_ref()
84 .and_then(|p| p.statement_timeout_secs)
85 .unwrap_or(config.statement_timeout_secs);
86
87 let primary = Self::create_pool_with_statement_timeout(
88 &config.url,
89 primary_size,
90 primary_min,
91 primary_timeout,
92 statement_timeout,
93 primary_test,
94 service_name,
95 )
96 .await
97 .map_err(|e| ForgeError::Database(format!("Failed to connect to primary: {}", e)))?;
98
99 let mut replicas = Vec::new();
100 for replica_url in &config.replica_urls {
101 let pool = Self::create_pool(
102 replica_url,
103 config.pool_size / 2,
104 config.pool_timeout_secs,
105 service_name,
106 )
107 .await
108 .map_err(|e| ForgeError::Database(format!("Failed to connect to replica: {}", e)))?;
109 replicas.push(ReplicaEntry {
110 pool: Arc::new(pool),
111 healthy: Arc::new(AtomicBool::new(true)),
112 });
113 }
114
115 let jobs_pool =
116 Self::create_isolated_pool(&config.url, config.pools.jobs.as_ref(), service_name)
117 .await?;
118 let observability_pool = Self::create_isolated_pool(
119 &config.url,
120 config.pools.observability.as_ref(),
121 service_name,
122 )
123 .await?;
124 let analytics_pool =
125 Self::create_isolated_pool(&config.url, config.pools.analytics.as_ref(), service_name)
126 .await?;
127
128 Ok(Self {
129 primary: Arc::new(primary),
130 replicas: Arc::new(replicas),
131 config: config.clone(),
132 replica_counter: Arc::new(AtomicUsize::new(0)),
133 jobs_pool,
134 observability_pool,
135 analytics_pool,
136 })
137 }
138
139 fn connect_options(url: &str, service_name: &str) -> sqlx::Result<PgConnectOptions> {
140 let options: PgConnectOptions = url.parse()?;
141 Ok(options
142 .application_name(service_name)
143 .log_statements(LevelFilter::Off)
144 .log_slow_statements(LevelFilter::Warn, Duration::from_millis(500)))
145 }
146
147 fn connect_options_with_timeout(
148 url: &str,
149 service_name: &str,
150 statement_timeout_secs: u64,
151 ) -> sqlx::Result<PgConnectOptions> {
152 let options: PgConnectOptions = url.parse()?;
153 let mut opts = options
154 .application_name(service_name)
155 .log_statements(LevelFilter::Off)
156 .log_slow_statements(LevelFilter::Warn, Duration::from_millis(500));
157 if statement_timeout_secs > 0 {
158 opts = opts.options([("statement_timeout", &format!("{}s", statement_timeout_secs))]);
160 }
161 Ok(opts)
162 }
163
164 async fn create_pool(
165 url: &str,
166 size: u32,
167 timeout_secs: u64,
168 service_name: &str,
169 ) -> sqlx::Result<PgPool> {
170 Self::create_pool_with_opts(url, size, 0, timeout_secs, true, service_name).await
171 }
172
173 async fn create_pool_with_opts(
174 url: &str,
175 size: u32,
176 min_size: u32,
177 timeout_secs: u64,
178 test_before_acquire: bool,
179 service_name: &str,
180 ) -> sqlx::Result<PgPool> {
181 Self::create_pool_with_statement_timeout(
182 url,
183 size,
184 min_size,
185 timeout_secs,
186 0,
187 test_before_acquire,
188 service_name,
189 )
190 .await
191 }
192
193 async fn create_pool_with_statement_timeout(
194 url: &str,
195 size: u32,
196 min_size: u32,
197 timeout_secs: u64,
198 statement_timeout_secs: u64,
199 test_before_acquire: bool,
200 service_name: &str,
201 ) -> sqlx::Result<PgPool> {
202 let options = if statement_timeout_secs > 0 {
203 Self::connect_options_with_timeout(url, service_name, statement_timeout_secs)?
204 } else {
205 Self::connect_options(url, service_name)?
206 };
207 PgPoolOptions::new()
208 .max_connections(size)
209 .min_connections(min_size)
210 .acquire_timeout(Duration::from_secs(timeout_secs))
211 .test_before_acquire(test_before_acquire)
212 .connect_with(options)
213 .await
214 }
215
216 async fn create_isolated_pool(
217 url: &str,
218 config: Option<&PoolConfig>,
219 service_name: &str,
220 ) -> Result<Option<Arc<PgPool>>> {
221 let Some(cfg) = config else {
222 return Ok(None);
223 };
224 let pool = Self::create_pool_with_opts(
225 url,
226 cfg.size,
227 cfg.min_size,
228 cfg.timeout_secs,
229 cfg.test_before_acquire,
230 service_name,
231 )
232 .await
233 .map_err(|e| ForgeError::Database(format!("Failed to create isolated pool: {}", e)))?;
234 Ok(Some(Arc::new(pool)))
235 }
236
237 pub fn primary(&self) -> &PgPool {
239 &self.primary
240 }
241
242 pub fn read_pool(&self) -> &PgPool {
244 if !self.config.read_from_replica || self.replicas.is_empty() {
245 return &self.primary;
246 }
247
248 let len = self.replicas.len();
249 let start = self.replica_counter.fetch_add(1, Ordering::Relaxed) % len;
250
251 for offset in 0..len {
253 let idx = (start + offset) % len;
254 if let Some(entry) = self.replicas.get(idx)
255 && entry.healthy.load(Ordering::Relaxed)
256 {
257 return &entry.pool;
258 }
259 }
260
261 &self.primary
263 }
264
265 pub fn jobs_pool(&self) -> &PgPool {
268 self.jobs_pool.as_deref().unwrap_or(&self.primary)
269 }
270
271 pub fn observability_pool(&self) -> &PgPool {
274 self.observability_pool.as_deref().unwrap_or(&self.primary)
275 }
276
277 pub fn analytics_pool(&self) -> &PgPool {
280 self.analytics_pool.as_deref().unwrap_or(&self.primary)
281 }
282
283 pub fn start_health_monitor(&self) -> Option<JoinHandle<()>> {
285 if self.replicas.is_empty() {
286 return None;
287 }
288
289 let replicas = Arc::clone(&self.replicas);
290 let handle = tokio::spawn(async move {
291 let mut interval = tokio::time::interval(Duration::from_secs(15));
292 loop {
293 interval.tick().await;
294 for entry in replicas.iter() {
295 let ok = sqlx::query_scalar!("SELECT 1 as \"v!\"")
296 .fetch_one(entry.pool.as_ref())
297 .await
298 .is_ok();
299 let was_healthy = entry.healthy.swap(ok, Ordering::Relaxed);
300 if was_healthy && !ok {
301 tracing::warn!("Replica marked unhealthy");
302 } else if !was_healthy && ok {
303 tracing::info!("Replica recovered");
304 }
305 }
306 }
307 });
308
309 Some(handle)
310 }
311
312 #[cfg(test)]
314 pub fn from_pool(pool: PgPool) -> Self {
315 Self {
316 primary: Arc::new(pool),
317 replicas: Arc::new(Vec::new()),
318 config: DatabaseConfig::default(),
319 replica_counter: Arc::new(AtomicUsize::new(0)),
320 jobs_pool: None,
321 observability_pool: None,
322 analytics_pool: None,
323 }
324 }
325
326 pub async fn health_check(&self) -> Result<()> {
328 sqlx::query_scalar!("SELECT 1 as \"v!\"")
329 .fetch_one(self.primary.as_ref())
330 .await
331 .map_err(|e| ForgeError::Database(format!("Health check failed: {}", e)))?;
332 Ok(())
333 }
334
335 pub async fn close(&self) {
337 self.primary.close().await;
338 for entry in self.replicas.iter() {
339 entry.pool.close().await;
340 }
341 if let Some(ref p) = self.jobs_pool {
342 p.close().await;
343 }
344 if let Some(ref p) = self.observability_pool {
345 p.close().await;
346 }
347 if let Some(ref p) = self.analytics_pool {
348 p.close().await;
349 }
350 }
351}
352
353pub type DatabasePool = PgPool;
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_database_config_clone() {
362 let config = DatabaseConfig::new("postgres://localhost/test");
363
364 let cloned = config.clone();
365 assert_eq!(cloned.url(), config.url());
366 assert_eq!(cloned.pool_size, config.pool_size);
367 }
368}