1use std::sync::Arc;
2use std::time::Duration;
3
4use sqlx::postgres::{PgPool, PgPoolOptions};
5
6use forge_core::config::{DatabaseConfig, DatabaseSource};
7use forge_core::error::{ForgeError, Result};
8
9#[cfg(feature = "embedded-db")]
10use tokio::sync::OnceCell;
11
12#[cfg(feature = "embedded-db")]
13use tracing::info;
14
15#[cfg(feature = "embedded-db")]
17static EMBEDDED_PG: OnceCell<postgresql_embedded::PostgreSQL> = OnceCell::const_new();
18
19#[derive(Clone)]
21pub struct Database {
22 primary: Arc<PgPool>,
24
25 replicas: Vec<Arc<PgPool>>,
27
28 config: DatabaseConfig,
30
31 replica_counter: Arc<std::sync::atomic::AtomicUsize>,
33
34 embedded: bool,
36}
37
38impl Database {
39 pub async fn from_config(config: &DatabaseConfig) -> Result<Self> {
41 let (url, embedded) = match &config.source {
42 DatabaseSource::Remote { url } => {
43 if url.is_empty() {
44 return Err(ForgeError::Database(
45 "database.url cannot be empty. Provide a PostgreSQL connection URL.".into(),
46 ));
47 }
48 (url.clone(), false)
49 }
50 DatabaseSource::Embedded { data_dir } => {
51 #[cfg(feature = "embedded-db")]
52 {
53 let url = Self::start_embedded_postgres(data_dir.as_deref()).await?;
54 (url, true)
55 }
56 #[cfg(not(feature = "embedded-db"))]
57 {
58 let _ = data_dir;
59 return Err(ForgeError::Database(
60 "Embedded PostgreSQL requires the 'embedded-db' feature. \
61 Build with: cargo build --features embedded-db"
62 .to_string(),
63 ));
64 }
65 }
66 };
67
68 let primary = Self::create_pool(&url, config.pool_size, config.pool_timeout_secs)
69 .await
70 .map_err(|e| ForgeError::Database(format!("Failed to connect to primary: {}", e)))?;
71
72 let mut replicas = Vec::new();
73 for replica_url in &config.replica_urls {
74 let pool =
75 Self::create_pool(replica_url, config.pool_size / 2, config.pool_timeout_secs)
76 .await
77 .map_err(|e| {
78 ForgeError::Database(format!("Failed to connect to replica: {}", e))
79 })?;
80 replicas.push(Arc::new(pool));
81 }
82
83 Ok(Self {
84 primary: Arc::new(primary),
85 replicas,
86 config: config.clone(),
87 replica_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
88 embedded,
89 })
90 }
91
92 #[cfg(feature = "embedded-db")]
94 async fn start_embedded_postgres(data_dir: Option<&str>) -> Result<String> {
95 let pg = EMBEDDED_PG
96 .get_or_try_init(|| async {
97 info!("Starting embedded PostgreSQL...");
98
99 let settings = if let Some(dir) = data_dir {
101 postgresql_embedded::Settings {
102 data_dir: std::path::PathBuf::from(dir),
103 ..Default::default()
104 }
105 } else {
106 postgresql_embedded::Settings::default()
107 };
108
109 let mut pg = postgresql_embedded::PostgreSQL::new(settings);
110 pg.setup().await.map_err(|e| {
111 ForgeError::Database(format!("Failed to setup embedded Postgres: {}", e))
112 })?;
113 pg.start().await.map_err(|e| {
114 ForgeError::Database(format!("Failed to start embedded Postgres: {}", e))
115 })?;
116 info!("Embedded PostgreSQL started successfully");
117 Ok::<_, ForgeError>(pg)
118 })
119 .await?;
120
121 Ok(pg.settings().url("forge"))
122 }
123
124 pub fn is_embedded(&self) -> bool {
126 self.embedded
127 }
128
129 async fn create_pool(url: &str, size: u32, timeout_secs: u64) -> sqlx::Result<PgPool> {
131 PgPoolOptions::new()
132 .max_connections(size)
133 .acquire_timeout(Duration::from_secs(timeout_secs))
134 .connect(url)
135 .await
136 }
137
138 pub fn primary(&self) -> &PgPool {
140 &self.primary
141 }
142
143 pub fn read_pool(&self) -> &PgPool {
145 if self.config.read_from_replica && !self.replicas.is_empty() {
146 let idx = self
148 .replica_counter
149 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
150 % self.replicas.len();
151 self.replicas.get(idx).unwrap_or(&self.primary)
152 } else {
153 &self.primary
154 }
155 }
156
157 #[cfg(test)]
159 pub fn from_pool(pool: PgPool) -> Self {
160 Self {
161 primary: Arc::new(pool),
162 replicas: Vec::new(),
163 config: DatabaseConfig::default(),
164 replica_counter: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
165 embedded: false,
166 }
167 }
168
169 pub async fn health_check(&self) -> Result<()> {
171 sqlx::query("SELECT 1")
172 .execute(self.primary.as_ref())
173 .await
174 .map_err(|e| ForgeError::Database(format!("Health check failed: {}", e)))?;
175 Ok(())
176 }
177
178 pub async fn close(&self) {
180 self.primary.close().await;
181 for replica in &self.replicas {
182 replica.close().await;
183 }
184 }
185}
186
187pub type DatabasePool = PgPool;
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
198 fn test_database_config_clone() {
199 let config = DatabaseConfig::remote("postgres://localhost/test");
200
201 let cloned = config.clone();
202 assert_eq!(cloned.url(), config.url());
203 assert_eq!(cloned.pool_size, config.pool_size);
204 }
205}