1#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
13
14use sqlx::PgPool;
15use std::path::Path;
16use tracing::{debug, info};
17
18use crate::error::{ForgeError, Result};
19
20#[cfg(feature = "embedded-db")]
21use tokio::sync::OnceCell;
22
23#[cfg(feature = "embedded-db")]
24static EMBEDDED_PG: OnceCell<postgresql_embedded::PostgreSQL> = OnceCell::const_new();
25
26pub struct TestDatabase {
45 pool: PgPool,
46 url: String,
47}
48
49impl TestDatabase {
50 pub async fn from_url(url: &str) -> Result<Self> {
54 let pool = sqlx::postgres::PgPoolOptions::new()
55 .max_connections(10)
56 .connect(url)
57 .await
58 .map_err(ForgeError::Sql)?;
59
60 Ok(Self {
61 pool,
62 url: url.to_string(),
63 })
64 }
65
66 pub async fn from_env() -> Result<Self> {
71 let url = std::env::var("TEST_DATABASE_URL").map_err(|_| {
72 ForgeError::Database(
73 "TEST_DATABASE_URL not set. Set it explicitly for database tests.".to_string(),
74 )
75 })?;
76 Self::from_url(&url).await
77 }
78
79 #[cfg(feature = "embedded-db")]
84 pub async fn embedded() -> Result<Self> {
85 let pg = EMBEDDED_PG
86 .get_or_try_init(|| async {
87 let mut pg = postgresql_embedded::PostgreSQL::default();
88 pg.setup().await.map_err(|e| {
89 ForgeError::Database(format!("Failed to setup embedded Postgres: {}", e))
90 })?;
91 pg.start().await.map_err(|e| {
92 ForgeError::Database(format!("Failed to start embedded Postgres: {}", e))
93 })?;
94 Ok::<_, ForgeError>(pg)
95 })
96 .await?;
97
98 let url = pg.settings().url("postgres");
99 Self::from_url(&url).await
100 }
101
102 pub fn pool(&self) -> &PgPool {
104 &self.pool
105 }
106
107 pub fn url(&self) -> &str {
109 &self.url
110 }
111
112 pub async fn execute(&self, sql: &str) -> Result<()> {
114 sqlx::query(sql)
115 .execute(&self.pool)
116 .await
117 .map_err(ForgeError::Sql)?;
118 Ok(())
119 }
120
121 pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
126 let base_url = self.url.clone();
127 let db_name = format!(
129 "forge_test_{}_{}",
130 sanitize_db_name(test_name),
131 uuid::Uuid::new_v4().simple()
132 );
133
134 let pool = sqlx::postgres::PgPoolOptions::new()
136 .max_connections(1)
137 .connect(&base_url)
138 .await
139 .map_err(ForgeError::Sql)?;
140
141 sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
143 .execute(&pool)
144 .await
145 .map_err(ForgeError::Sql)?;
146
147 let test_url = replace_db_name(&base_url, &db_name);
149
150 let test_pool = sqlx::postgres::PgPoolOptions::new()
151 .max_connections(5)
152 .connect(&test_url)
153 .await
154 .map_err(ForgeError::Sql)?;
155
156 Ok(IsolatedTestDb {
157 pool: test_pool,
158 db_name,
159 base_url,
160 })
161 }
162}
163
164pub struct IsolatedTestDb {
170 pool: PgPool,
171 db_name: String,
172 base_url: String,
173}
174
175impl IsolatedTestDb {
176 pub fn pool(&self) -> &PgPool {
178 &self.pool
179 }
180
181 pub fn db_name(&self) -> &str {
183 &self.db_name
184 }
185
186 pub async fn execute(&self, sql: &str) -> Result<()> {
188 sqlx::query(sql)
189 .execute(&self.pool)
190 .await
191 .map_err(ForgeError::Sql)?;
192 Ok(())
193 }
194
195 pub async fn run_sql(&self, sql: &str) -> Result<()> {
200 let statements = split_sql_statements(sql);
201 for statement in statements {
202 let statement = statement.trim();
203 if statement.is_empty()
204 || statement.lines().all(|l| {
205 let l = l.trim();
206 l.is_empty() || l.starts_with("--")
207 })
208 {
209 continue;
210 }
211
212 sqlx::query(statement)
213 .execute(&self.pool)
214 .await
215 .map_err(|e| ForgeError::Database(format!("Failed to execute SQL: {}", e)))?;
216 }
217 Ok(())
218 }
219
220 pub async fn cleanup(self) -> Result<()> {
225 self.pool.close().await;
227
228 let pool = sqlx::postgres::PgPoolOptions::new()
230 .max_connections(1)
231 .connect(&self.base_url)
232 .await
233 .map_err(ForgeError::Sql)?;
234
235 let _ = sqlx::query(&format!(
237 "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
238 self.db_name
239 ))
240 .execute(&pool)
241 .await;
242
243 sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
244 .execute(&pool)
245 .await
246 .map_err(ForgeError::Sql)?;
247
248 Ok(())
249 }
250
251 pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> {
264 if !migrations_dir.exists() {
265 debug!("Migrations directory does not exist: {:?}", migrations_dir);
266 return Ok(());
267 }
268
269 let mut migrations = Vec::new();
270
271 let entries = std::fs::read_dir(migrations_dir).map_err(ForgeError::Io)?;
272
273 for entry in entries {
274 let entry = entry.map_err(ForgeError::Io)?;
275 let path = entry.path();
276
277 if path.extension().map(|e| e == "sql").unwrap_or(false) {
278 let name = path
279 .file_stem()
280 .and_then(|s| s.to_str())
281 .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
282 .to_string();
283
284 let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
285 migrations.push((name, content));
286 }
287 }
288
289 migrations.sort_by(|a, b| a.0.cmp(&b.0));
291
292 debug!("Running {} migrations for test", migrations.len());
293
294 for (name, content) in migrations {
295 info!("Applying test migration: {}", name);
296
297 let up_sql = parse_up_sql(&content);
299
300 let statements = split_sql_statements(&up_sql);
302 for statement in statements {
303 let statement = statement.trim();
304 if statement.is_empty()
305 || statement.lines().all(|l| {
306 let l = l.trim();
307 l.is_empty() || l.starts_with("--")
308 })
309 {
310 continue;
311 }
312
313 sqlx::query(statement)
314 .execute(&self.pool)
315 .await
316 .map_err(|e| {
317 ForgeError::Database(format!("Failed to apply migration '{}': {}", name, e))
318 })?;
319 }
320 }
321
322 Ok(())
323 }
324}
325
326fn sanitize_db_name(name: &str) -> String {
328 name.chars()
329 .map(|c| if c.is_alphanumeric() { c } else { '_' })
330 .take(32)
331 .collect()
332}
333
334fn replace_db_name(url: &str, new_db: &str) -> String {
336 if let Some(idx) = url.rfind('/') {
338 let base = &url[..=idx];
339 if let Some(query_idx) = url[idx + 1..].find('?') {
341 let query = &url[idx + 1 + query_idx..];
342 format!("{}{}{}", base, new_db, query)
343 } else {
344 format!("{}{}", base, new_db)
345 }
346 } else {
347 format!("{}/{}", url, new_db)
348 }
349}
350
351fn parse_up_sql(content: &str) -> String {
353 let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
354
355 for pattern in down_marker_patterns {
356 if let Some(idx) = content.find(pattern) {
357 let up_part = &content[..idx];
358 return up_part
359 .replace("-- @up", "")
360 .replace("--@up", "")
361 .replace("-- @UP", "")
362 .replace("--@UP", "")
363 .trim()
364 .to_string();
365 }
366 }
367
368 content
370 .replace("-- @up", "")
371 .replace("--@up", "")
372 .replace("-- @UP", "")
373 .replace("--@UP", "")
374 .trim()
375 .to_string()
376}
377
378fn split_sql_statements(sql: &str) -> Vec<String> {
381 let mut statements = Vec::new();
382 let mut current = String::new();
383 let mut in_dollar_quote = false;
384 let mut dollar_tag = String::new();
385 let mut chars = sql.chars().peekable();
386
387 while let Some(c) = chars.next() {
388 current.push(c);
389
390 if c == '$' {
392 let mut potential_tag = String::from("$");
394
395 while let Some(&next_c) = chars.peek() {
397 if next_c == '$' {
398 potential_tag.push(chars.next().unwrap());
399 current.push('$');
400 break;
401 } else if next_c.is_alphanumeric() || next_c == '_' {
402 potential_tag.push(chars.next().unwrap());
403 current.push(potential_tag.chars().last().unwrap());
404 } else {
405 break;
406 }
407 }
408
409 if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
411 if in_dollar_quote && potential_tag == dollar_tag {
412 in_dollar_quote = false;
414 dollar_tag.clear();
415 } else if !in_dollar_quote {
416 in_dollar_quote = true;
418 dollar_tag = potential_tag;
419 }
420 }
421 }
422
423 if c == ';' && !in_dollar_quote {
425 let stmt = current.trim().trim_end_matches(';').trim().to_string();
426 if !stmt.is_empty() {
427 statements.push(stmt);
428 }
429 current.clear();
430 }
431 }
432
433 let stmt = current.trim().trim_end_matches(';').trim().to_string();
435 if !stmt.is_empty() {
436 statements.push(stmt);
437 }
438
439 statements
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_sanitize_db_name() {
448 assert_eq!(sanitize_db_name("my_test"), "my_test");
449 assert_eq!(sanitize_db_name("my-test"), "my_test");
450 assert_eq!(sanitize_db_name("my test"), "my_test");
451 assert_eq!(sanitize_db_name("test::function"), "test__function");
452 }
453
454 #[test]
455 fn test_replace_db_name() {
456 assert_eq!(
457 replace_db_name("postgres://localhost/olddb", "newdb"),
458 "postgres://localhost/newdb"
459 );
460 assert_eq!(
461 replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
462 "postgres://user:pass@localhost:5432/newdb"
463 );
464 assert_eq!(
465 replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
466 "postgres://localhost/newdb?sslmode=disable"
467 );
468 }
469}