1use sqlx::PgPool;
13use std::path::Path;
14use tracing::{debug, info};
15
16use crate::error::{ForgeError, Result};
17
18#[cfg(feature = "embedded-db")]
19use tokio::sync::OnceCell;
20
21#[cfg(feature = "embedded-db")]
22static EMBEDDED_PG: OnceCell<postgresql_embedded::PostgreSQL> = OnceCell::const_new();
23
24pub struct TestDatabase {
43 pool: PgPool,
44 url: String,
45}
46
47impl TestDatabase {
48 pub async fn from_url(url: &str) -> Result<Self> {
52 let pool = sqlx::postgres::PgPoolOptions::new()
53 .max_connections(10)
54 .connect(url)
55 .await
56 .map_err(ForgeError::Sql)?;
57
58 Ok(Self {
59 pool,
60 url: url.to_string(),
61 })
62 }
63
64 pub async fn from_env() -> Result<Self> {
69 let url = std::env::var("TEST_DATABASE_URL").map_err(|_| {
70 ForgeError::Database(
71 "TEST_DATABASE_URL not set. Set it explicitly for database tests.".to_string(),
72 )
73 })?;
74 Self::from_url(&url).await
75 }
76
77 #[cfg(feature = "embedded-db")]
82 pub async fn embedded() -> Result<Self> {
83 let pg = EMBEDDED_PG
84 .get_or_try_init(|| async {
85 let mut pg = postgresql_embedded::PostgreSQL::default();
86 pg.setup().await.map_err(|e| {
87 ForgeError::Database(format!("Failed to setup embedded Postgres: {}", e))
88 })?;
89 pg.start().await.map_err(|e| {
90 ForgeError::Database(format!("Failed to start embedded Postgres: {}", e))
91 })?;
92 Ok::<_, ForgeError>(pg)
93 })
94 .await?;
95
96 let url = pg.settings().url("postgres");
97 Self::from_url(&url).await
98 }
99
100 pub fn pool(&self) -> &PgPool {
102 &self.pool
103 }
104
105 pub fn url(&self) -> &str {
107 &self.url
108 }
109
110 pub async fn execute(&self, sql: &str) -> Result<()> {
112 sqlx::query(sql)
113 .execute(&self.pool)
114 .await
115 .map_err(ForgeError::Sql)?;
116 Ok(())
117 }
118
119 pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
124 let base_url = self.url.clone();
125 let db_name = format!(
127 "forge_test_{}_{}",
128 sanitize_db_name(test_name),
129 uuid::Uuid::new_v4().simple()
130 );
131
132 let pool = sqlx::postgres::PgPoolOptions::new()
134 .max_connections(1)
135 .connect(&base_url)
136 .await
137 .map_err(ForgeError::Sql)?;
138
139 sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
141 .execute(&pool)
142 .await
143 .map_err(ForgeError::Sql)?;
144
145 let test_url = replace_db_name(&base_url, &db_name);
147
148 let test_pool = sqlx::postgres::PgPoolOptions::new()
149 .max_connections(5)
150 .connect(&test_url)
151 .await
152 .map_err(ForgeError::Sql)?;
153
154 Ok(IsolatedTestDb {
155 pool: test_pool,
156 db_name,
157 base_url,
158 })
159 }
160}
161
162pub struct IsolatedTestDb {
168 pool: PgPool,
169 db_name: String,
170 base_url: String,
171}
172
173impl IsolatedTestDb {
174 pub fn pool(&self) -> &PgPool {
176 &self.pool
177 }
178
179 pub fn db_name(&self) -> &str {
181 &self.db_name
182 }
183
184 pub async fn execute(&self, sql: &str) -> Result<()> {
186 sqlx::query(sql)
187 .execute(&self.pool)
188 .await
189 .map_err(ForgeError::Sql)?;
190 Ok(())
191 }
192
193 pub async fn run_sql(&self, sql: &str) -> Result<()> {
198 let statements = split_sql_statements(sql);
199 for statement in statements {
200 let statement = statement.trim();
201 if statement.is_empty()
202 || statement.lines().all(|l| {
203 let l = l.trim();
204 l.is_empty() || l.starts_with("--")
205 })
206 {
207 continue;
208 }
209
210 sqlx::query(statement)
211 .execute(&self.pool)
212 .await
213 .map_err(|e| ForgeError::Database(format!("Failed to execute SQL: {}", e)))?;
214 }
215 Ok(())
216 }
217
218 pub async fn cleanup(self) -> Result<()> {
223 self.pool.close().await;
225
226 let pool = sqlx::postgres::PgPoolOptions::new()
228 .max_connections(1)
229 .connect(&self.base_url)
230 .await
231 .map_err(ForgeError::Sql)?;
232
233 let _ = sqlx::query(&format!(
235 "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
236 self.db_name
237 ))
238 .execute(&pool)
239 .await;
240
241 sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
242 .execute(&pool)
243 .await
244 .map_err(ForgeError::Sql)?;
245
246 Ok(())
247 }
248
249 pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> {
262 if !migrations_dir.exists() {
263 debug!("Migrations directory does not exist: {:?}", migrations_dir);
264 return Ok(());
265 }
266
267 let mut migrations = Vec::new();
268
269 let entries = std::fs::read_dir(migrations_dir).map_err(ForgeError::Io)?;
270
271 for entry in entries {
272 let entry = entry.map_err(ForgeError::Io)?;
273 let path = entry.path();
274
275 if path.extension().map(|e| e == "sql").unwrap_or(false) {
276 let name = path
277 .file_stem()
278 .and_then(|s| s.to_str())
279 .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
280 .to_string();
281
282 let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
283 migrations.push((name, content));
284 }
285 }
286
287 migrations.sort_by(|a, b| a.0.cmp(&b.0));
289
290 debug!("Running {} migrations for test", migrations.len());
291
292 for (name, content) in migrations {
293 info!("Applying test migration: {}", name);
294
295 let up_sql = parse_up_sql(&content);
297
298 let statements = split_sql_statements(&up_sql);
300 for statement in statements {
301 let statement = statement.trim();
302 if statement.is_empty()
303 || statement.lines().all(|l| {
304 let l = l.trim();
305 l.is_empty() || l.starts_with("--")
306 })
307 {
308 continue;
309 }
310
311 sqlx::query(statement)
312 .execute(&self.pool)
313 .await
314 .map_err(|e| {
315 ForgeError::Database(format!("Failed to apply migration '{}': {}", name, e))
316 })?;
317 }
318 }
319
320 Ok(())
321 }
322}
323
324fn sanitize_db_name(name: &str) -> String {
326 name.chars()
327 .map(|c| if c.is_alphanumeric() { c } else { '_' })
328 .take(32)
329 .collect()
330}
331
332fn replace_db_name(url: &str, new_db: &str) -> String {
334 if let Some(idx) = url.rfind('/') {
336 let base = &url[..=idx];
337 if let Some(query_idx) = url[idx + 1..].find('?') {
339 let query = &url[idx + 1 + query_idx..];
340 format!("{}{}{}", base, new_db, query)
341 } else {
342 format!("{}{}", base, new_db)
343 }
344 } else {
345 format!("{}/{}", url, new_db)
346 }
347}
348
349fn parse_up_sql(content: &str) -> String {
351 let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
352
353 for pattern in down_marker_patterns {
354 if let Some(idx) = content.find(pattern) {
355 let up_part = &content[..idx];
356 return up_part
357 .replace("-- @up", "")
358 .replace("--@up", "")
359 .replace("-- @UP", "")
360 .replace("--@UP", "")
361 .trim()
362 .to_string();
363 }
364 }
365
366 content
368 .replace("-- @up", "")
369 .replace("--@up", "")
370 .replace("-- @UP", "")
371 .replace("--@UP", "")
372 .trim()
373 .to_string()
374}
375
376fn split_sql_statements(sql: &str) -> Vec<String> {
379 let mut statements = Vec::new();
380 let mut current = String::new();
381 let mut in_dollar_quote = false;
382 let mut dollar_tag = String::new();
383 let mut chars = sql.chars().peekable();
384
385 while let Some(c) = chars.next() {
386 current.push(c);
387
388 if c == '$' {
390 let mut potential_tag = String::from("$");
392
393 while let Some(&next_c) = chars.peek() {
395 if next_c == '$' {
396 potential_tag.push(chars.next().unwrap());
397 current.push('$');
398 break;
399 } else if next_c.is_alphanumeric() || next_c == '_' {
400 potential_tag.push(chars.next().unwrap());
401 current.push(potential_tag.chars().last().unwrap());
402 } else {
403 break;
404 }
405 }
406
407 if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
409 if in_dollar_quote && potential_tag == dollar_tag {
410 in_dollar_quote = false;
412 dollar_tag.clear();
413 } else if !in_dollar_quote {
414 in_dollar_quote = true;
416 dollar_tag = potential_tag;
417 }
418 }
419 }
420
421 if c == ';' && !in_dollar_quote {
423 let stmt = current.trim().trim_end_matches(';').trim().to_string();
424 if !stmt.is_empty() {
425 statements.push(stmt);
426 }
427 current.clear();
428 }
429 }
430
431 let stmt = current.trim().trim_end_matches(';').trim().to_string();
433 if !stmt.is_empty() {
434 statements.push(stmt);
435 }
436
437 statements
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_sanitize_db_name() {
446 assert_eq!(sanitize_db_name("my_test"), "my_test");
447 assert_eq!(sanitize_db_name("my-test"), "my_test");
448 assert_eq!(sanitize_db_name("my test"), "my_test");
449 assert_eq!(sanitize_db_name("test::function"), "test__function");
450 }
451
452 #[test]
453 fn test_replace_db_name() {
454 assert_eq!(
455 replace_db_name("postgres://localhost/olddb", "newdb"),
456 "postgres://localhost/newdb"
457 );
458 assert_eq!(
459 replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
460 "postgres://user:pass@localhost:5432/newdb"
461 );
462 assert_eq!(
463 replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
464 "postgres://localhost/newdb?sslmode=disable"
465 );
466 }
467}