Skip to main content

forge_core/testing/
db.rs

1//! Database provisioning for tests.
2//!
3//! Deliberately avoids reading DATABASE_URL to prevent accidental production use.
4
5#![allow(
6    clippy::unwrap_used,
7    clippy::indexing_slicing,
8    clippy::disallowed_methods
9)]
10// Test harness runs dynamic DDL (CREATE/DROP DATABASE, schema reset) where the
11// query macros can't see the table at compile time.
12
13use sqlx::PgPool;
14use std::path::Path;
15#[cfg(feature = "testcontainers")]
16use std::sync::Arc;
17use tracing::{debug, info};
18
19use crate::error::{ForgeError, Result};
20
21#[cfg(feature = "testcontainers")]
22type PgContainer =
23    Arc<Option<testcontainers::ContainerAsync<testcontainers_modules::postgres::Postgres>>>;
24
25/// Database access for tests.
26pub struct TestDatabase {
27    pool: PgPool,
28    url: String,
29    #[cfg(feature = "testcontainers")]
30    _container: PgContainer,
31}
32
33impl TestDatabase {
34    pub async fn from_url(url: &str) -> Result<Self> {
35        let pool = sqlx::postgres::PgPoolOptions::new()
36            .max_connections(10)
37            .connect(url)
38            .await
39            .map_err(ForgeError::Database)?;
40
41        Ok(Self {
42            pool,
43            url: url.to_string(),
44            #[cfg(feature = "testcontainers")]
45            _container: Arc::new(None),
46        })
47    }
48
49    /// Connect using `TEST_DATABASE_URL`, or start a container if the
50    /// `testcontainers` feature is enabled and the var is unset.
51    pub async fn from_env() -> Result<Self> {
52        match std::env::var("TEST_DATABASE_URL") {
53            Ok(url) => Self::from_url(&url).await,
54            Err(_) => {
55                #[cfg(feature = "testcontainers")]
56                {
57                    Self::from_container().await
58                }
59                #[cfg(not(feature = "testcontainers"))]
60                {
61                    Err(ForgeError::internal(
62                        "TEST_DATABASE_URL not set. Set it explicitly for database tests, \
63                         or enable the `testcontainers` feature for automatic provisioning.",
64                    ))
65                }
66            }
67        }
68    }
69
70    #[cfg(feature = "testcontainers")]
71    async fn from_container() -> Result<Self> {
72        use testcontainers::ImageExt;
73        use testcontainers::runners::AsyncRunner;
74        use testcontainers_modules::postgres::Postgres;
75
76        let container = Postgres::default()
77            .with_tag("18-alpine")
78            .start()
79            .await
80            .map_err(|e| ForgeError::internal_with("Failed to start PG container", e))?;
81
82        let port = container
83            .get_host_port_ipv4(5432)
84            .await
85            .map_err(|e| ForgeError::internal_with("Failed to get container port", e))?;
86
87        let url = format!("postgres://postgres:postgres@localhost:{port}/postgres");
88        let pool = sqlx::postgres::PgPoolOptions::new()
89            .max_connections(10)
90            .acquire_timeout(std::time::Duration::from_secs(30))
91            .connect(&url)
92            .await
93            .map_err(ForgeError::Database)?;
94
95        Ok(Self {
96            pool,
97            url,
98            _container: Arc::new(Some(container)),
99        })
100    }
101
102    pub fn pool(&self) -> &PgPool {
103        &self.pool
104    }
105
106    pub fn url(&self) -> &str {
107        &self.url
108    }
109
110    /// Run raw SQL for test setup.
111    pub async fn execute(&self, sql: &str) -> Result<()> {
112        sqlx::query(sql)
113            .execute(&self.pool)
114            .await
115            .map_err(ForgeError::Database)?;
116        Ok(())
117    }
118
119    /// Create a dedicated database for a single test, providing full isolation.
120    pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
121        let base_url = self.url.clone();
122        let db_name = format!(
123            "forge_test_{}_{}",
124            sanitize_db_name(test_name),
125            uuid::Uuid::new_v4().simple()
126        );
127
128        sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
129            .execute(&self.pool)
130            .await
131            .map_err(ForgeError::Database)?;
132
133        let test_url = replace_db_name(&base_url, &db_name);
134
135        let test_pool = sqlx::postgres::PgPoolOptions::new()
136            .max_connections(5)
137            .connect(&test_url)
138            .await
139            .map_err(ForgeError::Database)?;
140
141        Ok(IsolatedTestDb {
142            pool: test_pool,
143            db_name,
144            base_url,
145            #[cfg(feature = "testcontainers")]
146            _container: self._container.clone(),
147        })
148    }
149}
150
151/// A test database scoped to a single test. Call `cleanup()` to drop it immediately,
152/// or rely on future test runs to clean up orphaned databases.
153pub struct IsolatedTestDb {
154    pool: PgPool,
155    db_name: String,
156    base_url: String,
157    #[cfg(feature = "testcontainers")]
158    _container: PgContainer,
159}
160
161impl IsolatedTestDb {
162    /// Convenience: `from_env()` → `isolated()` → `run_sql(internal_sql)` → `migrate()`.
163    pub async fn setup(test_name: &str, internal_sql: &str, migrations_dir: &Path) -> Result<Self> {
164        let base = TestDatabase::from_env().await?;
165        let db = base.isolated(test_name).await?;
166        db.run_sql(internal_sql).await?;
167        db.migrate(migrations_dir).await?;
168        Ok(db)
169    }
170
171    pub fn pool(&self) -> &PgPool {
172        &self.pool
173    }
174
175    pub fn db_name(&self) -> &str {
176        &self.db_name
177    }
178
179    /// Run raw SQL for test setup.
180    pub async fn execute(&self, sql: &str) -> Result<()> {
181        sqlx::query(sql)
182            .execute(&self.pool)
183            .await
184            .map_err(ForgeError::Database)?;
185        Ok(())
186    }
187
188    /// Run multi-statement SQL, handling PL/pgSQL dollar-quoted strings.
189    pub async fn run_sql(&self, sql: &str) -> Result<()> {
190        for stmt in split_sql_statements(sql) {
191            let stmt = stmt.trim();
192            if is_blank_sql(stmt) {
193                continue;
194            }
195            sqlx::query(stmt)
196                .execute(&self.pool)
197                .await
198                .map_err(|e| ForgeError::internal_with("Failed to execute SQL", e))?;
199        }
200        Ok(())
201    }
202
203    /// Drop the isolated database and close all connections.
204    pub async fn cleanup(self) -> Result<()> {
205        self.pool.close().await;
206
207        let pool = sqlx::postgres::PgPoolOptions::new()
208            .max_connections(1)
209            .connect(&self.base_url)
210            .await
211            .map_err(ForgeError::Database)?;
212
213        if let Err(e) =
214            sqlx::query("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = $1")
215                .bind(&self.db_name)
216                .execute(&pool)
217                .await
218        {
219            tracing::warn!(db = %self.db_name, error = %e, "failed to terminate backend connections during test cleanup");
220        }
221
222        sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
223            .execute(&pool)
224            .await
225            .map_err(ForgeError::Database)?;
226
227        Ok(())
228    }
229
230    /// Run migrations: loads all `.sql` files from the directory, sorts alphabetically, executes in order.
231    pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> {
232        if !migrations_dir.exists() {
233            debug!("Migrations directory does not exist: {:?}", migrations_dir);
234            return Ok(());
235        }
236
237        let mut migrations = Vec::new();
238
239        let entries = std::fs::read_dir(migrations_dir).map_err(ForgeError::Io)?;
240
241        for entry in entries {
242            let entry = entry.map_err(ForgeError::Io)?;
243            let path = entry.path();
244
245            if path.extension().map(|e| e == "sql").unwrap_or(false) {
246                let name = path
247                    .file_stem()
248                    .and_then(|s| s.to_str())
249                    .ok_or_else(|| ForgeError::config("Invalid migration filename"))?
250                    .to_string();
251
252                let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
253                migrations.push((name, content));
254            }
255        }
256
257        migrations.sort_by(|a, b| a.0.cmp(&b.0));
258
259        debug!("Running {} migrations for test", migrations.len());
260
261        for (name, content) in migrations {
262            info!("Applying test migration: {}", name);
263
264            let up_sql = strip_up_markers(&content);
265
266            for stmt in split_sql_statements(&up_sql) {
267                let stmt = stmt.trim();
268                if is_blank_sql(stmt) {
269                    continue;
270                }
271                sqlx::query(stmt).execute(&self.pool).await.map_err(|e| {
272                    ForgeError::internal(format!("Failed to apply migration '{name}': {e}"))
273                })?;
274            }
275        }
276
277        Ok(())
278    }
279}
280
281fn is_blank_sql(sql: &str) -> bool {
282    sql.is_empty()
283        || sql
284            .lines()
285            .all(|l| l.trim().is_empty() || l.trim().starts_with("--"))
286}
287
288fn sanitize_db_name(name: &str) -> String {
289    name.chars()
290        .map(|c| if c.is_alphanumeric() { c } else { '_' })
291        .take(32)
292        .collect()
293}
294
295fn replace_db_name(url: &str, new_db: &str) -> String {
296    if let Some(idx) = url.rfind('/') {
297        let base = &url[..=idx];
298        // Check if there are query params
299        if let Some(query_idx) = url[idx + 1..].find('?') {
300            let query = &url[idx + 1 + query_idx..];
301            format!("{}{}{}", base, new_db, query)
302        } else {
303            format!("{}{}", base, new_db)
304        }
305    } else {
306        format!("{}/{}", url, new_db)
307    }
308}
309
310fn strip_up_markers(sql: &str) -> String {
311    sql.replace("-- @up", "")
312        .replace("--@up", "")
313        .replace("-- @UP", "")
314        .replace("--@UP", "")
315        .trim()
316        .to_string()
317}
318
319/// Split SQL into individual statements, respecting dollar-quoted strings,
320/// line comments, block comments, and string literals.
321fn split_sql_statements(sql: &str) -> Vec<String> {
322    let mut statements = Vec::new();
323    let mut current = String::new();
324    let mut in_dollar_quote = false;
325    let mut dollar_tag = String::new();
326    let mut in_line_comment = false;
327    let mut in_block_comment = false;
328    let mut in_string_literal = false;
329    let mut chars = sql.chars().peekable();
330
331    while let Some(c) = chars.next() {
332        current.push(c);
333
334        if in_line_comment {
335            if c == '\n' {
336                in_line_comment = false;
337            }
338            continue;
339        }
340
341        if in_block_comment {
342            if c == '*' && chars.peek() == Some(&'/') {
343                current.push(chars.next().expect("peeked char"));
344                in_block_comment = false;
345            }
346            continue;
347        }
348
349        if in_string_literal {
350            if c == '\'' {
351                if chars.peek() == Some(&'\'') {
352                    current.push(chars.next().expect("peeked char"));
353                } else {
354                    in_string_literal = false;
355                }
356            }
357            continue;
358        }
359
360        if in_dollar_quote {
361            if c == '$' {
362                let mut potential_tag = String::from("$");
363                while let Some(&next_c) = chars.peek() {
364                    if next_c == '$' {
365                        potential_tag.push(chars.next().expect("peeked char"));
366                        current.push('$');
367                        break;
368                    } else if next_c.is_alphanumeric() || next_c == '_' {
369                        let ch = chars.next().expect("peeked char");
370                        potential_tag.push(ch);
371                        current.push(ch);
372                    } else {
373                        break;
374                    }
375                }
376                if potential_tag.len() >= 2
377                    && potential_tag.ends_with('$')
378                    && potential_tag == dollar_tag
379                {
380                    in_dollar_quote = false;
381                    dollar_tag.clear();
382                }
383            }
384            continue;
385        }
386
387        if c == '-' && chars.peek() == Some(&'-') {
388            current.push(chars.next().expect("peeked char"));
389            in_line_comment = true;
390            continue;
391        }
392
393        if c == '/' && chars.peek() == Some(&'*') {
394            current.push(chars.next().expect("peeked char"));
395            in_block_comment = true;
396            continue;
397        }
398
399        if c == '\'' {
400            in_string_literal = true;
401            continue;
402        }
403
404        if c == '$' {
405            let mut potential_tag = String::from("$");
406            while let Some(&next_c) = chars.peek() {
407                if next_c == '$' {
408                    potential_tag.push(chars.next().expect("peeked char"));
409                    current.push('$');
410                    break;
411                } else if next_c.is_alphanumeric() || next_c == '_' {
412                    let ch = chars.next().expect("peeked char");
413                    potential_tag.push(ch);
414                    current.push(ch);
415                } else {
416                    break;
417                }
418            }
419            if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
420                in_dollar_quote = true;
421                dollar_tag = potential_tag;
422            }
423            continue;
424        }
425
426        if c == ';' {
427            let stmt = current.trim().trim_end_matches(';').trim().to_string();
428            if !stmt.is_empty() {
429                statements.push(stmt);
430            }
431            current.clear();
432        }
433    }
434
435    let stmt = current.trim().trim_end_matches(';').trim().to_string();
436    if !stmt.is_empty() {
437        statements.push(stmt);
438    }
439
440    statements
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn test_sanitize_db_name() {
449        assert_eq!(sanitize_db_name("my_test"), "my_test");
450        assert_eq!(sanitize_db_name("my-test"), "my_test");
451        assert_eq!(sanitize_db_name("my test"), "my_test");
452        assert_eq!(sanitize_db_name("test::function"), "test__function");
453    }
454
455    #[test]
456    fn test_replace_db_name() {
457        assert_eq!(
458            replace_db_name("postgres://localhost/olddb", "newdb"),
459            "postgres://localhost/newdb"
460        );
461        assert_eq!(
462            replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
463            "postgres://user:pass@localhost:5432/newdb"
464        );
465        assert_eq!(
466            replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
467            "postgres://localhost/newdb?sslmode=disable"
468        );
469    }
470
471    #[test]
472    fn split_simple_statements() {
473        let stmts = split_sql_statements("CREATE TABLE a (id int); CREATE TABLE b (id int);");
474        assert_eq!(stmts.len(), 2);
475        assert!(stmts[0].starts_with("CREATE TABLE a"));
476        assert!(stmts[1].starts_with("CREATE TABLE b"));
477    }
478
479    #[test]
480    fn split_preserves_dollar_quoted_content() {
481        let sql = r#"
482            CREATE FUNCTION test() RETURNS void AS $$
483            BEGIN
484                INSERT INTO logs (msg) VALUES ('hello; world');
485            END;
486            $$ LANGUAGE plpgsql;
487            SELECT 1;
488        "#;
489        let stmts = split_sql_statements(sql);
490        assert_eq!(
491            stmts.len(),
492            2,
493            "Should split into function + SELECT, not more"
494        );
495        assert!(
496            stmts[0].contains("$$"),
497            "Function body must include dollar quotes"
498        );
499    }
500
501    #[test]
502    fn split_handles_empty_input() {
503        let stmts = split_sql_statements("");
504        assert!(stmts.is_empty());
505    }
506
507    #[test]
508    fn split_handles_no_trailing_semicolon() {
509        let stmts = split_sql_statements("SELECT 1");
510        assert_eq!(stmts.len(), 1);
511        assert_eq!(stmts[0], "SELECT 1");
512    }
513
514    #[test]
515    fn split_skips_blank_statements() {
516        let stmts = split_sql_statements("; ; SELECT 1; ;");
517        assert_eq!(stmts.len(), 1);
518        assert_eq!(stmts[0], "SELECT 1");
519    }
520
521    #[test]
522    fn split_ignores_semicolons_in_line_comments() {
523        let sql = "CREATE TABLE t (\n    id INT,\n    -- this; has a semicolon\n    name TEXT\n);\nSELECT 1;";
524        let stmts = split_sql_statements(sql);
525        assert_eq!(stmts.len(), 2);
526        assert!(stmts[0].contains("name TEXT"));
527    }
528
529    #[test]
530    fn split_ignores_semicolons_in_block_comments() {
531        let sql = "CREATE TABLE t (id INT /* ; */ );\nSELECT 1;";
532        let stmts = split_sql_statements(sql);
533        assert_eq!(stmts.len(), 2);
534    }
535
536    #[test]
537    fn split_ignores_semicolons_in_string_literals() {
538        let sql = "INSERT INTO t VALUES ('a;b');\nSELECT 1;";
539        let stmts = split_sql_statements(sql);
540        assert_eq!(stmts.len(), 2);
541        assert!(stmts[0].contains("'a;b'"));
542    }
543
544    #[test]
545    fn strip_up_markers_drops_marker() {
546        let content = "-- @up\nCREATE TABLE a (id int);";
547        let up = strip_up_markers(content);
548        assert!(!up.contains("@up"), "Up marker should be stripped");
549        assert!(up.contains("CREATE TABLE"));
550    }
551
552    #[test]
553    fn blank_sql_detection() {
554        assert!(is_blank_sql(""));
555        assert!(is_blank_sql("   "));
556        assert!(is_blank_sql("-- just a comment"));
557        assert!(is_blank_sql("-- comment\n-- another"));
558        assert!(!is_blank_sql("SELECT 1"));
559        assert!(!is_blank_sql("-- comment\nSELECT 1"));
560    }
561
562    #[test]
563    fn sanitize_truncates_long_names() {
564        let long_name = "a".repeat(100);
565        let sanitized = sanitize_db_name(&long_name);
566        assert_eq!(sanitized.len(), 32);
567    }
568
569    #[test]
570    fn sanitize_handles_special_characters() {
571        assert_eq!(
572            sanitize_db_name("test/with:special!chars"),
573            "test_with_special_chars"
574        );
575        assert_eq!(sanitize_db_name(""), "");
576    }
577}