1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::postgres::{PgPool, PgPoolOptions};
5
6use forge_core::config::DatabaseConfig;
7use forge_core::error::{ForgeError, Result};
8
9#[cfg(feature = "embedded-db")]
10use tokio::sync::OnceCell;
11
12#[cfg(feature = "embedded-db")]
13use tracing::info;
14
15#[cfg(feature = "embedded-db")]
17static EMBEDDED_PG: OnceCell<postgresql_embedded::PostgreSQL> = OnceCell::const_new();
18
19#[derive(Clone)]
21pub struct Database {
22 primary: Arc<PgPool>,
24
25 replicas: Vec<Arc<PgPool>>,
27
28 config: DatabaseConfig,
30
31 replica_counter: Arc<std::sync::atomic::AtomicUsize>,
33
34 embedded: bool,
36}
37
38impl Database {
39 pub async fn from_config(config: &DatabaseConfig) -> Result<Self> {
41 let (url, embedded) = if config.embedded {
42 #[cfg(feature = "embedded-db")]
43 {
44 let url = Self::start_embedded_postgres(config.data_dir.as_deref()).await?;
45 (url, true)
46 }
47 #[cfg(not(feature = "embedded-db"))]
48 {
49 return Err(ForgeError::Database(
50 "Embedded PostgreSQL requires the 'embedded-db' feature. \
51 Build with: cargo build --features embedded-db"
52 .to_string(),
53 ));
54 }
55 } else {
56 if config.url.is_empty() {
57 return Err(ForgeError::Database(
58 "Database URL is required when embedded = false. Set database.url or database.embedded = true".to_string()
59 ));
60 }
61 (config.url.clone(), false)
62 };
63
64 let primary = Self::create_pool(&url, config.pool_size, config.pool_timeout_secs)
65 .await
66 .map_err(|e| ForgeError::Database(format!("Failed to connect to primary: {}", e)))?;
67
68 let mut replicas = Vec::new();
69 for replica_url in &config.replica_urls {
70 let pool =
71 Self::create_pool(replica_url, config.pool_size / 2, config.pool_timeout_secs)
72 .await
73 .map_err(|e| {
74 ForgeError::Database(format!("Failed to connect to replica: {}", e))
75 })?;
76 replicas.push(Arc::new(pool));
77 }
78
79 Ok(Self {
80 primary: Arc::new(primary),
81 replicas,
82 config: config.clone(),
83 replica_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
84 embedded,
85 })
86 }
87
88 #[cfg(feature = "embedded-db")]
90 async fn start_embedded_postgres(data_dir: Option<&str>) -> Result<String> {
91 let pg = EMBEDDED_PG
92 .get_or_try_init(|| async {
93 info!("Starting embedded PostgreSQL...");
94
95 let settings = if let Some(dir) = data_dir {
97 postgresql_embedded::Settings {
98 data_dir: std::path::PathBuf::from(dir),
99 ..Default::default()
100 }
101 } else {
102 postgresql_embedded::Settings::default()
103 };
104
105 let mut pg = postgresql_embedded::PostgreSQL::new(settings);
106 pg.setup().await.map_err(|e| {
107 ForgeError::Database(format!("Failed to setup embedded Postgres: {}", e))
108 })?;
109 pg.start().await.map_err(|e| {
110 ForgeError::Database(format!("Failed to start embedded Postgres: {}", e))
111 })?;
112 info!("Embedded PostgreSQL started successfully");
113 Ok::<_, ForgeError>(pg)
114 })
115 .await?;
116
117 Ok(pg.settings().url("forge"))
118 }
119
120 pub fn is_embedded(&self) -> bool {
122 self.embedded
123 }
124
125 async fn create_pool(url: &str, size: u32, timeout_secs: u64) -> sqlx::Result<PgPool> {
127 PgPoolOptions::new()
128 .max_connections(size)
129 .acquire_timeout(Duration::from_secs(timeout_secs))
130 .connect(url)
131 .await
132 }
133
134 pub fn primary(&self) -> &PgPool {
136 &self.primary
137 }
138
139 pub fn read_pool(&self) -> &PgPool {
141 if self.config.read_from_replica && !self.replicas.is_empty() {
142 let idx = self
144 .replica_counter
145 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
146 % self.replicas.len();
147 &self.replicas[idx]
148 } else {
149 &self.primary
150 }
151 }
152
153 pub async fn health_check(&self) -> Result<()> {
155 sqlx::query("SELECT 1")
156 .execute(self.primary.as_ref())
157 .await
158 .map_err(|e| ForgeError::Database(format!("Health check failed: {}", e)))?;
159 Ok(())
160 }
161
162 pub async fn close(&self) {
164 self.primary.close().await;
165 for replica in &self.replicas {
166 replica.close().await;
167 }
168 }
169}
170
171pub type DatabasePool = PgPool;
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[test]
182 fn test_database_config_clone() {
183 let config = DatabaseConfig {
184 url: "postgres://localhost/test".to_string(),
185 pool_size: 10,
186 ..Default::default()
187 };
188
189 let cloned = config.clone();
190 assert_eq!(cloned.url, config.url);
191 assert_eq!(cloned.pool_size, config.pool_size);
192 }
193}