1#![allow(
6 clippy::unwrap_used,
7 clippy::indexing_slicing,
8 clippy::disallowed_methods
9)]
10use sqlx::PgPool;
14use std::path::Path;
15#[cfg(feature = "testcontainers")]
16use std::sync::Arc;
17use tracing::{debug, info};
18
19use crate::error::{ForgeError, Result};
20
21#[cfg(feature = "testcontainers")]
22type PgContainer =
23 Arc<Option<testcontainers::ContainerAsync<testcontainers_modules::postgres::Postgres>>>;
24
25pub struct TestDatabase {
27 pool: PgPool,
28 url: String,
29 #[cfg(feature = "testcontainers")]
30 _container: PgContainer,
31}
32
33impl TestDatabase {
34 pub async fn from_url(url: &str) -> Result<Self> {
35 let pool = sqlx::postgres::PgPoolOptions::new()
36 .max_connections(10)
37 .connect(url)
38 .await
39 .map_err(ForgeError::Database)?;
40
41 Ok(Self {
42 pool,
43 url: url.to_string(),
44 #[cfg(feature = "testcontainers")]
45 _container: Arc::new(None),
46 })
47 }
48
49 pub async fn from_env() -> Result<Self> {
52 match std::env::var("TEST_DATABASE_URL") {
53 Ok(url) => Self::from_url(&url).await,
54 Err(_) => {
55 #[cfg(feature = "testcontainers")]
56 {
57 Self::from_container().await
58 }
59 #[cfg(not(feature = "testcontainers"))]
60 {
61 Err(ForgeError::internal(
62 "TEST_DATABASE_URL not set. Set it explicitly for database tests, \
63 or enable the `testcontainers` feature for automatic provisioning.",
64 ))
65 }
66 }
67 }
68 }
69
70 #[cfg(feature = "testcontainers")]
71 async fn from_container() -> Result<Self> {
72 use testcontainers::ImageExt;
73 use testcontainers::runners::AsyncRunner;
74 use testcontainers_modules::postgres::Postgres;
75
76 let container = Postgres::default()
77 .with_tag("18-alpine")
78 .start()
79 .await
80 .map_err(|e| ForgeError::internal_with("Failed to start PG container", e))?;
81
82 let port = container
83 .get_host_port_ipv4(5432)
84 .await
85 .map_err(|e| ForgeError::internal_with("Failed to get container port", e))?;
86
87 let url = format!("postgres://postgres:postgres@localhost:{port}/postgres");
88 let pool = sqlx::postgres::PgPoolOptions::new()
89 .max_connections(10)
90 .acquire_timeout(std::time::Duration::from_secs(30))
91 .connect(&url)
92 .await
93 .map_err(ForgeError::Database)?;
94
95 Ok(Self {
96 pool,
97 url,
98 _container: Arc::new(Some(container)),
99 })
100 }
101
102 pub fn pool(&self) -> &PgPool {
103 &self.pool
104 }
105
106 pub fn url(&self) -> &str {
107 &self.url
108 }
109
110 pub async fn execute(&self, sql: &str) -> Result<()> {
112 sqlx::query(sql)
113 .execute(&self.pool)
114 .await
115 .map_err(ForgeError::Database)?;
116 Ok(())
117 }
118
119 pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
121 let base_url = self.url.clone();
122 let db_name = format!(
123 "forge_test_{}_{}",
124 sanitize_db_name(test_name),
125 uuid::Uuid::new_v4().simple()
126 );
127
128 sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
129 .execute(&self.pool)
130 .await
131 .map_err(ForgeError::Database)?;
132
133 let test_url = replace_db_name(&base_url, &db_name);
134
135 let test_pool = sqlx::postgres::PgPoolOptions::new()
136 .max_connections(5)
137 .connect(&test_url)
138 .await
139 .map_err(ForgeError::Database)?;
140
141 Ok(IsolatedTestDb {
142 pool: test_pool,
143 db_name,
144 base_url,
145 #[cfg(feature = "testcontainers")]
146 _container: self._container.clone(),
147 })
148 }
149}
150
151pub struct IsolatedTestDb {
154 pool: PgPool,
155 db_name: String,
156 base_url: String,
157 #[cfg(feature = "testcontainers")]
158 _container: PgContainer,
159}
160
161impl IsolatedTestDb {
162 pub async fn setup(test_name: &str, internal_sql: &str, migrations_dir: &Path) -> Result<Self> {
164 let base = TestDatabase::from_env().await?;
165 let db = base.isolated(test_name).await?;
166 db.run_sql(internal_sql).await?;
167 db.migrate(migrations_dir).await?;
168 Ok(db)
169 }
170
171 pub fn pool(&self) -> &PgPool {
172 &self.pool
173 }
174
175 pub fn db_name(&self) -> &str {
176 &self.db_name
177 }
178
179 pub async fn execute(&self, sql: &str) -> Result<()> {
181 sqlx::query(sql)
182 .execute(&self.pool)
183 .await
184 .map_err(ForgeError::Database)?;
185 Ok(())
186 }
187
188 pub async fn run_sql(&self, sql: &str) -> Result<()> {
190 for stmt in split_sql_statements(sql) {
191 let stmt = stmt.trim();
192 if is_blank_sql(stmt) {
193 continue;
194 }
195 sqlx::query(stmt)
196 .execute(&self.pool)
197 .await
198 .map_err(|e| ForgeError::internal_with("Failed to execute SQL", e))?;
199 }
200 Ok(())
201 }
202
203 pub async fn cleanup(self) -> Result<()> {
205 self.pool.close().await;
206
207 let pool = sqlx::postgres::PgPoolOptions::new()
208 .max_connections(1)
209 .connect(&self.base_url)
210 .await
211 .map_err(ForgeError::Database)?;
212
213 if let Err(e) =
214 sqlx::query("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = $1")
215 .bind(&self.db_name)
216 .execute(&pool)
217 .await
218 {
219 tracing::warn!(db = %self.db_name, error = %e, "failed to terminate backend connections during test cleanup");
220 }
221
222 sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
223 .execute(&pool)
224 .await
225 .map_err(ForgeError::Database)?;
226
227 Ok(())
228 }
229
230 pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> {
232 if !migrations_dir.exists() {
233 debug!("Migrations directory does not exist: {:?}", migrations_dir);
234 return Ok(());
235 }
236
237 let mut migrations = Vec::new();
238
239 let entries = std::fs::read_dir(migrations_dir).map_err(ForgeError::Io)?;
240
241 for entry in entries {
242 let entry = entry.map_err(ForgeError::Io)?;
243 let path = entry.path();
244
245 if path.extension().map(|e| e == "sql").unwrap_or(false) {
246 let name = path
247 .file_stem()
248 .and_then(|s| s.to_str())
249 .ok_or_else(|| ForgeError::config("Invalid migration filename"))?
250 .to_string();
251
252 let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
253 migrations.push((name, content));
254 }
255 }
256
257 migrations.sort_by(|a, b| a.0.cmp(&b.0));
258
259 debug!("Running {} migrations for test", migrations.len());
260
261 for (name, content) in migrations {
262 info!("Applying test migration: {}", name);
263
264 let up_sql = strip_up_markers(&content);
265
266 for stmt in split_sql_statements(&up_sql) {
267 let stmt = stmt.trim();
268 if is_blank_sql(stmt) {
269 continue;
270 }
271 sqlx::query(stmt).execute(&self.pool).await.map_err(|e| {
272 ForgeError::internal(format!("Failed to apply migration '{name}': {e}"))
273 })?;
274 }
275 }
276
277 Ok(())
278 }
279}
280
281fn is_blank_sql(sql: &str) -> bool {
282 sql.is_empty()
283 || sql
284 .lines()
285 .all(|l| l.trim().is_empty() || l.trim().starts_with("--"))
286}
287
288fn sanitize_db_name(name: &str) -> String {
289 name.chars()
290 .map(|c| if c.is_alphanumeric() { c } else { '_' })
291 .take(32)
292 .collect()
293}
294
295fn replace_db_name(url: &str, new_db: &str) -> String {
296 if let Some(idx) = url.rfind('/') {
297 let base = &url[..=idx];
298 if let Some(query_idx) = url[idx + 1..].find('?') {
300 let query = &url[idx + 1 + query_idx..];
301 format!("{}{}{}", base, new_db, query)
302 } else {
303 format!("{}{}", base, new_db)
304 }
305 } else {
306 format!("{}/{}", url, new_db)
307 }
308}
309
310fn strip_up_markers(sql: &str) -> String {
311 sql.replace("-- @up", "")
312 .replace("--@up", "")
313 .replace("-- @UP", "")
314 .replace("--@UP", "")
315 .trim()
316 .to_string()
317}
318
319fn split_sql_statements(sql: &str) -> Vec<String> {
322 let mut statements = Vec::new();
323 let mut current = String::new();
324 let mut in_dollar_quote = false;
325 let mut dollar_tag = String::new();
326 let mut in_line_comment = false;
327 let mut in_block_comment = false;
328 let mut in_string_literal = false;
329 let mut chars = sql.chars().peekable();
330
331 while let Some(c) = chars.next() {
332 current.push(c);
333
334 if in_line_comment {
335 if c == '\n' {
336 in_line_comment = false;
337 }
338 continue;
339 }
340
341 if in_block_comment {
342 if c == '*' && chars.peek() == Some(&'/') {
343 current.push(chars.next().expect("peeked char"));
344 in_block_comment = false;
345 }
346 continue;
347 }
348
349 if in_string_literal {
350 if c == '\'' {
351 if chars.peek() == Some(&'\'') {
352 current.push(chars.next().expect("peeked char"));
353 } else {
354 in_string_literal = false;
355 }
356 }
357 continue;
358 }
359
360 if in_dollar_quote {
361 if c == '$' {
362 let mut potential_tag = String::from("$");
363 while let Some(&next_c) = chars.peek() {
364 if next_c == '$' {
365 potential_tag.push(chars.next().expect("peeked char"));
366 current.push('$');
367 break;
368 } else if next_c.is_alphanumeric() || next_c == '_' {
369 let ch = chars.next().expect("peeked char");
370 potential_tag.push(ch);
371 current.push(ch);
372 } else {
373 break;
374 }
375 }
376 if potential_tag.len() >= 2
377 && potential_tag.ends_with('$')
378 && potential_tag == dollar_tag
379 {
380 in_dollar_quote = false;
381 dollar_tag.clear();
382 }
383 }
384 continue;
385 }
386
387 if c == '-' && chars.peek() == Some(&'-') {
388 current.push(chars.next().expect("peeked char"));
389 in_line_comment = true;
390 continue;
391 }
392
393 if c == '/' && chars.peek() == Some(&'*') {
394 current.push(chars.next().expect("peeked char"));
395 in_block_comment = true;
396 continue;
397 }
398
399 if c == '\'' {
400 in_string_literal = true;
401 continue;
402 }
403
404 if c == '$' {
405 let mut potential_tag = String::from("$");
406 while let Some(&next_c) = chars.peek() {
407 if next_c == '$' {
408 potential_tag.push(chars.next().expect("peeked char"));
409 current.push('$');
410 break;
411 } else if next_c.is_alphanumeric() || next_c == '_' {
412 let ch = chars.next().expect("peeked char");
413 potential_tag.push(ch);
414 current.push(ch);
415 } else {
416 break;
417 }
418 }
419 if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
420 in_dollar_quote = true;
421 dollar_tag = potential_tag;
422 }
423 continue;
424 }
425
426 if c == ';' {
427 let stmt = current.trim().trim_end_matches(';').trim().to_string();
428 if !stmt.is_empty() {
429 statements.push(stmt);
430 }
431 current.clear();
432 }
433 }
434
435 let stmt = current.trim().trim_end_matches(';').trim().to_string();
436 if !stmt.is_empty() {
437 statements.push(stmt);
438 }
439
440 statements
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn test_sanitize_db_name() {
449 assert_eq!(sanitize_db_name("my_test"), "my_test");
450 assert_eq!(sanitize_db_name("my-test"), "my_test");
451 assert_eq!(sanitize_db_name("my test"), "my_test");
452 assert_eq!(sanitize_db_name("test::function"), "test__function");
453 }
454
455 #[test]
456 fn test_replace_db_name() {
457 assert_eq!(
458 replace_db_name("postgres://localhost/olddb", "newdb"),
459 "postgres://localhost/newdb"
460 );
461 assert_eq!(
462 replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
463 "postgres://user:pass@localhost:5432/newdb"
464 );
465 assert_eq!(
466 replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
467 "postgres://localhost/newdb?sslmode=disable"
468 );
469 }
470
471 #[test]
472 fn split_simple_statements() {
473 let stmts = split_sql_statements("CREATE TABLE a (id int); CREATE TABLE b (id int);");
474 assert_eq!(stmts.len(), 2);
475 assert!(stmts[0].starts_with("CREATE TABLE a"));
476 assert!(stmts[1].starts_with("CREATE TABLE b"));
477 }
478
479 #[test]
480 fn split_preserves_dollar_quoted_content() {
481 let sql = r#"
482 CREATE FUNCTION test() RETURNS void AS $$
483 BEGIN
484 INSERT INTO logs (msg) VALUES ('hello; world');
485 END;
486 $$ LANGUAGE plpgsql;
487 SELECT 1;
488 "#;
489 let stmts = split_sql_statements(sql);
490 assert_eq!(
491 stmts.len(),
492 2,
493 "Should split into function + SELECT, not more"
494 );
495 assert!(
496 stmts[0].contains("$$"),
497 "Function body must include dollar quotes"
498 );
499 }
500
501 #[test]
502 fn split_handles_empty_input() {
503 let stmts = split_sql_statements("");
504 assert!(stmts.is_empty());
505 }
506
507 #[test]
508 fn split_handles_no_trailing_semicolon() {
509 let stmts = split_sql_statements("SELECT 1");
510 assert_eq!(stmts.len(), 1);
511 assert_eq!(stmts[0], "SELECT 1");
512 }
513
514 #[test]
515 fn split_skips_blank_statements() {
516 let stmts = split_sql_statements("; ; SELECT 1; ;");
517 assert_eq!(stmts.len(), 1);
518 assert_eq!(stmts[0], "SELECT 1");
519 }
520
521 #[test]
522 fn split_ignores_semicolons_in_line_comments() {
523 let sql = "CREATE TABLE t (\n id INT,\n -- this; has a semicolon\n name TEXT\n);\nSELECT 1;";
524 let stmts = split_sql_statements(sql);
525 assert_eq!(stmts.len(), 2);
526 assert!(stmts[0].contains("name TEXT"));
527 }
528
529 #[test]
530 fn split_ignores_semicolons_in_block_comments() {
531 let sql = "CREATE TABLE t (id INT /* ; */ );\nSELECT 1;";
532 let stmts = split_sql_statements(sql);
533 assert_eq!(stmts.len(), 2);
534 }
535
536 #[test]
537 fn split_ignores_semicolons_in_string_literals() {
538 let sql = "INSERT INTO t VALUES ('a;b');\nSELECT 1;";
539 let stmts = split_sql_statements(sql);
540 assert_eq!(stmts.len(), 2);
541 assert!(stmts[0].contains("'a;b'"));
542 }
543
544 #[test]
545 fn strip_up_markers_drops_marker() {
546 let content = "-- @up\nCREATE TABLE a (id int);";
547 let up = strip_up_markers(content);
548 assert!(!up.contains("@up"), "Up marker should be stripped");
549 assert!(up.contains("CREATE TABLE"));
550 }
551
552 #[test]
553 fn blank_sql_detection() {
554 assert!(is_blank_sql(""));
555 assert!(is_blank_sql(" "));
556 assert!(is_blank_sql("-- just a comment"));
557 assert!(is_blank_sql("-- comment\n-- another"));
558 assert!(!is_blank_sql("SELECT 1"));
559 assert!(!is_blank_sql("-- comment\nSELECT 1"));
560 }
561
562 #[test]
563 fn sanitize_truncates_long_names() {
564 let long_name = "a".repeat(100);
565 let sanitized = sanitize_db_name(&long_name);
566 assert_eq!(sanitized.len(), 32);
567 }
568
569 #[test]
570 fn sanitize_handles_special_characters() {
571 assert_eq!(
572 sanitize_db_name("test/with:special!chars"),
573 "test_with_special_chars"
574 );
575 assert_eq!(sanitize_db_name(""), "");
576 }
577}