Skip to main content

forge_core/testing/
db.rs

1//! Database provisioning for tests.
2//!
3//! Provides PostgreSQL access for integration tests. Database configuration
4//! options:
5//! 1. Pass a URL directly via `from_url()`
6//! 2. Use `from_env()` to explicitly read from TEST_DATABASE_URL
7//! 3. Use `embedded()` for automatic embedded PostgreSQL (requires `embedded-db` feature)
8//!
9//! This design prevents accidental use of production databases. The .env file
10//! DATABASE_URL is NEVER automatically read.
11
12use sqlx::PgPool;
13use std::path::Path;
14use tracing::{debug, info};
15
16use crate::error::{ForgeError, Result};
17
18#[cfg(feature = "embedded-db")]
19use tokio::sync::OnceCell;
20
21#[cfg(feature = "embedded-db")]
22static EMBEDDED_PG: OnceCell<postgresql_embedded::PostgreSQL> = OnceCell::const_new();
23
24/// Database access for tests.
25///
26/// Test database configuration is intentionally explicit to prevent
27/// accidental use of production databases.
28///
29/// # Examples
30///
31/// ```ignore
32/// // Option 1: Embedded Postgres (requires embedded-db feature)
33/// // Run with: cargo test --features embedded-db
34/// let db = TestDatabase::embedded().await?;
35///
36/// // Option 2: Explicit URL
37/// let db = TestDatabase::from_url("postgres://localhost/test_db").await?;
38///
39/// // Option 3: From TEST_DATABASE_URL env var
40/// let db = TestDatabase::from_env().await?;
41/// ```
42pub struct TestDatabase {
43    pool: PgPool,
44    url: String,
45}
46
47impl TestDatabase {
48    /// Connect to database at the given URL.
49    ///
50    /// Use this for explicit database configuration in tests.
51    pub async fn from_url(url: &str) -> Result<Self> {
52        let pool = sqlx::postgres::PgPoolOptions::new()
53            .max_connections(10)
54            .connect(url)
55            .await
56            .map_err(ForgeError::Sql)?;
57
58        Ok(Self {
59            pool,
60            url: url.to_string(),
61        })
62    }
63
64    /// Connect using TEST_DATABASE_URL environment variable.
65    ///
66    /// Note: This reads TEST_DATABASE_URL (not DATABASE_URL) to prevent
67    /// accidental use of production databases in tests.
68    pub async fn from_env() -> Result<Self> {
69        let url = std::env::var("TEST_DATABASE_URL").map_err(|_| {
70            ForgeError::Database(
71                "TEST_DATABASE_URL not set. Set it explicitly for database tests.".to_string(),
72            )
73        })?;
74        Self::from_url(&url).await
75    }
76
77    /// Start an embedded PostgreSQL instance.
78    ///
79    /// Downloads and starts a real PostgreSQL instance automatically.
80    /// Requires the `embedded-db` feature: `cargo test --features embedded-db`
81    #[cfg(feature = "embedded-db")]
82    pub async fn embedded() -> Result<Self> {
83        let pg = EMBEDDED_PG
84            .get_or_try_init(|| async {
85                let mut pg = postgresql_embedded::PostgreSQL::default();
86                pg.setup().await.map_err(|e| {
87                    ForgeError::Database(format!("Failed to setup embedded Postgres: {}", e))
88                })?;
89                pg.start().await.map_err(|e| {
90                    ForgeError::Database(format!("Failed to start embedded Postgres: {}", e))
91                })?;
92                Ok::<_, ForgeError>(pg)
93            })
94            .await?;
95
96        let url = pg.settings().url("postgres");
97        Self::from_url(&url).await
98    }
99
100    /// Get the connection pool.
101    pub fn pool(&self) -> &PgPool {
102        &self.pool
103    }
104
105    /// Get the database URL.
106    pub fn url(&self) -> &str {
107        &self.url
108    }
109
110    /// Run raw SQL to set up test data or schema.
111    pub async fn execute(&self, sql: &str) -> Result<()> {
112        sqlx::query(sql)
113            .execute(&self.pool)
114            .await
115            .map_err(ForgeError::Sql)?;
116        Ok(())
117    }
118
119    /// Creates a dedicated database for a single test, providing full isolation.
120    ///
121    /// Each call creates a new database with a unique name. Use this when tests
122    /// modify data and could interfere with each other.
123    pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
124        let base_url = self.url.clone();
125        // UUID suffix prevents collisions when tests run in parallel
126        let db_name = format!(
127            "forge_test_{}_{}",
128            sanitize_db_name(test_name),
129            uuid::Uuid::new_v4().simple()
130        );
131
132        // Connect to default database to create the test database
133        let pool = sqlx::postgres::PgPoolOptions::new()
134            .max_connections(1)
135            .connect(&base_url)
136            .await
137            .map_err(ForgeError::Sql)?;
138
139        // Double-quoted identifier handles special characters in generated name
140        sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
141            .execute(&pool)
142            .await
143            .map_err(ForgeError::Sql)?;
144
145        // Build URL for the new database by replacing the database name component
146        let test_url = replace_db_name(&base_url, &db_name);
147
148        let test_pool = sqlx::postgres::PgPoolOptions::new()
149            .max_connections(5)
150            .connect(&test_url)
151            .await
152            .map_err(ForgeError::Sql)?;
153
154        Ok(IsolatedTestDb {
155            pool: test_pool,
156            db_name,
157            base_url,
158        })
159    }
160}
161
162/// A test database that exists for the lifetime of a single test.
163///
164/// The database is automatically created on construction. Cleanup happens
165/// when `cleanup()` is called or when the database is reused in subsequent
166/// test runs (orphaned databases are cleaned up automatically).
167pub struct IsolatedTestDb {
168    pool: PgPool,
169    db_name: String,
170    base_url: String,
171}
172
173impl IsolatedTestDb {
174    /// Get the connection pool for this isolated database.
175    pub fn pool(&self) -> &PgPool {
176        &self.pool
177    }
178
179    /// Get the database name.
180    pub fn db_name(&self) -> &str {
181        &self.db_name
182    }
183
184    /// Run raw SQL to set up test data or schema.
185    pub async fn execute(&self, sql: &str) -> Result<()> {
186        sqlx::query(sql)
187            .execute(&self.pool)
188            .await
189            .map_err(ForgeError::Sql)?;
190        Ok(())
191    }
192
193    /// Run multi-statement SQL for setup.
194    ///
195    /// This handles SQL with multiple statements separated by semicolons,
196    /// including PL/pgSQL functions with dollar-quoted strings.
197    pub async fn run_sql(&self, sql: &str) -> Result<()> {
198        let statements = split_sql_statements(sql);
199        for statement in statements {
200            let statement = statement.trim();
201            if statement.is_empty()
202                || statement.lines().all(|l| {
203                    let l = l.trim();
204                    l.is_empty() || l.starts_with("--")
205                })
206            {
207                continue;
208            }
209
210            sqlx::query(statement)
211                .execute(&self.pool)
212                .await
213                .map_err(|e| ForgeError::Database(format!("Failed to execute SQL: {}", e)))?;
214        }
215        Ok(())
216    }
217
218    /// Cleanup the test database by dropping it.
219    ///
220    /// Call this at the end of your test if you want immediate cleanup.
221    /// Otherwise, orphaned databases will be cleaned up on subsequent test runs.
222    pub async fn cleanup(self) -> Result<()> {
223        // Close all connections first
224        self.pool.close().await;
225
226        // Connect to default database to drop the test database
227        let pool = sqlx::postgres::PgPoolOptions::new()
228            .max_connections(1)
229            .connect(&self.base_url)
230            .await
231            .map_err(ForgeError::Sql)?;
232
233        // Force disconnect other connections and drop
234        let _ = sqlx::query(&format!(
235            "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
236            self.db_name
237        ))
238        .execute(&pool)
239        .await;
240
241        sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
242            .execute(&pool)
243            .await
244            .map_err(ForgeError::Sql)?;
245
246        Ok(())
247    }
248
249    /// Run migrations from a directory.
250    ///
251    /// Loads all `.sql` files from the directory, sorts them alphabetically,
252    /// and executes them in order. This is intended for test setup.
253    ///
254    /// # Example
255    ///
256    /// ```ignore
257    /// let base = TestDatabase::embedded().await?;
258    /// let db = base.isolated("my_test").await?;
259    /// db.migrate(Path::new("migrations")).await?;
260    /// ```
261    pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> {
262        if !migrations_dir.exists() {
263            debug!("Migrations directory does not exist: {:?}", migrations_dir);
264            return Ok(());
265        }
266
267        let mut migrations = Vec::new();
268
269        let entries = std::fs::read_dir(migrations_dir).map_err(ForgeError::Io)?;
270
271        for entry in entries {
272            let entry = entry.map_err(ForgeError::Io)?;
273            let path = entry.path();
274
275            if path.extension().map(|e| e == "sql").unwrap_or(false) {
276                let name = path
277                    .file_stem()
278                    .and_then(|s| s.to_str())
279                    .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
280                    .to_string();
281
282                let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
283                migrations.push((name, content));
284            }
285        }
286
287        // Sort by name (which includes the numeric prefix)
288        migrations.sort_by(|a, b| a.0.cmp(&b.0));
289
290        debug!("Running {} migrations for test", migrations.len());
291
292        for (name, content) in migrations {
293            info!("Applying test migration: {}", name);
294
295            // Parse content to extract up SQL (everything before -- @down marker)
296            let up_sql = parse_up_sql(&content);
297
298            // Split into individual statements and execute
299            let statements = split_sql_statements(&up_sql);
300            for statement in statements {
301                let statement = statement.trim();
302                if statement.is_empty()
303                    || statement.lines().all(|l| {
304                        let l = l.trim();
305                        l.is_empty() || l.starts_with("--")
306                    })
307                {
308                    continue;
309                }
310
311                sqlx::query(statement)
312                    .execute(&self.pool)
313                    .await
314                    .map_err(|e| {
315                        ForgeError::Database(format!("Failed to apply migration '{}': {}", name, e))
316                    })?;
317            }
318        }
319
320        Ok(())
321    }
322}
323
324/// Sanitize a test name for use in a database name.
325fn sanitize_db_name(name: &str) -> String {
326    name.chars()
327        .map(|c| if c.is_alphanumeric() { c } else { '_' })
328        .take(32)
329        .collect()
330}
331
332/// Replace the database name in a connection URL.
333fn replace_db_name(url: &str, new_db: &str) -> String {
334    // Handle both postgres://.../ and postgres://...? formats
335    if let Some(idx) = url.rfind('/') {
336        let base = &url[..=idx];
337        // Check if there are query params
338        if let Some(query_idx) = url[idx + 1..].find('?') {
339            let query = &url[idx + 1 + query_idx..];
340            format!("{}{}{}", base, new_db, query)
341        } else {
342            format!("{}{}", base, new_db)
343        }
344    } else {
345        format!("{}/{}", url, new_db)
346    }
347}
348
349/// Parse migration content, extracting only the up SQL (before -- @down marker).
350fn parse_up_sql(content: &str) -> String {
351    let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
352
353    for pattern in down_marker_patterns {
354        if let Some(idx) = content.find(pattern) {
355            let up_part = &content[..idx];
356            return up_part
357                .replace("-- @up", "")
358                .replace("--@up", "")
359                .replace("-- @UP", "")
360                .replace("--@UP", "")
361                .trim()
362                .to_string();
363        }
364    }
365
366    // No @down marker found - treat entire content as up SQL
367    content
368        .replace("-- @up", "")
369        .replace("--@up", "")
370        .replace("-- @UP", "")
371        .replace("--@UP", "")
372        .trim()
373        .to_string()
374}
375
376/// Split SQL into individual statements, respecting dollar-quoted strings.
377/// This handles PL/pgSQL functions that contain semicolons inside $$ delimiters.
378fn split_sql_statements(sql: &str) -> Vec<String> {
379    let mut statements = Vec::new();
380    let mut current = String::new();
381    let mut in_dollar_quote = false;
382    let mut dollar_tag = String::new();
383    let mut chars = sql.chars().peekable();
384
385    while let Some(c) = chars.next() {
386        current.push(c);
387
388        // Check for dollar-quoting start/end
389        if c == '$' {
390            // Look for a dollar-quote tag like $$ or $tag$
391            let mut potential_tag = String::from("$");
392
393            // Collect characters until we hit another $ or non-identifier char
394            while let Some(&next_c) = chars.peek() {
395                if next_c == '$' {
396                    potential_tag.push(chars.next().unwrap());
397                    current.push('$');
398                    break;
399                } else if next_c.is_alphanumeric() || next_c == '_' {
400                    potential_tag.push(chars.next().unwrap());
401                    current.push(potential_tag.chars().last().unwrap());
402                } else {
403                    break;
404                }
405            }
406
407            // Check if this is a valid dollar-quote delimiter (ends with $)
408            if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
409                if in_dollar_quote && potential_tag == dollar_tag {
410                    // End of dollar-quoted string
411                    in_dollar_quote = false;
412                    dollar_tag.clear();
413                } else if !in_dollar_quote {
414                    // Start of dollar-quoted string
415                    in_dollar_quote = true;
416                    dollar_tag = potential_tag;
417                }
418            }
419        }
420
421        // Split on semicolon only if not inside a dollar-quoted string
422        if c == ';' && !in_dollar_quote {
423            let stmt = current.trim().trim_end_matches(';').trim().to_string();
424            if !stmt.is_empty() {
425                statements.push(stmt);
426            }
427            current.clear();
428        }
429    }
430
431    // Don't forget the last statement (might not end with ;)
432    let stmt = current.trim().trim_end_matches(';').trim().to_string();
433    if !stmt.is_empty() {
434        statements.push(stmt);
435    }
436
437    statements
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_sanitize_db_name() {
446        assert_eq!(sanitize_db_name("my_test"), "my_test");
447        assert_eq!(sanitize_db_name("my-test"), "my_test");
448        assert_eq!(sanitize_db_name("my test"), "my_test");
449        assert_eq!(sanitize_db_name("test::function"), "test__function");
450    }
451
452    #[test]
453    fn test_replace_db_name() {
454        assert_eq!(
455            replace_db_name("postgres://localhost/olddb", "newdb"),
456            "postgres://localhost/newdb"
457        );
458        assert_eq!(
459            replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
460            "postgres://user:pass@localhost:5432/newdb"
461        );
462        assert_eq!(
463            replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
464            "postgres://localhost/newdb?sslmode=disable"
465        );
466    }
467}