Skip to main content

forge_core/testing/
db.rs

1//! Database provisioning for tests.
2//!
3//! Provides PostgreSQL access for integration tests. Database configuration
4//! options:
5//! 1. Pass a URL directly via `from_url()`
6//! 2. Use `from_env()` to explicitly read from TEST_DATABASE_URL
7//! 3. Use `embedded()` for automatic embedded PostgreSQL (requires `embedded-db` feature)
8//!
9//! This design prevents accidental use of production databases. The .env file
10//! DATABASE_URL is NEVER automatically read.
11
12#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
13
14use sqlx::PgPool;
15use std::path::Path;
16use tracing::{debug, info};
17
18use crate::error::{ForgeError, Result};
19
20#[cfg(feature = "embedded-db")]
21use tokio::sync::OnceCell;
22
23#[cfg(feature = "embedded-db")]
24static EMBEDDED_PG: OnceCell<postgresql_embedded::PostgreSQL> = OnceCell::const_new();
25
26/// Database access for tests.
27///
28/// Test database configuration is intentionally explicit to prevent
29/// accidental use of production databases.
30///
31/// # Examples
32///
33/// ```ignore
34/// // Option 1: Embedded Postgres (requires embedded-db feature)
35/// // Run with: cargo test --features embedded-db
36/// let db = TestDatabase::embedded().await?;
37///
38/// // Option 2: Explicit URL
39/// let db = TestDatabase::from_url("postgres://localhost/test_db").await?;
40///
41/// // Option 3: From TEST_DATABASE_URL env var
42/// let db = TestDatabase::from_env().await?;
43/// ```
44pub struct TestDatabase {
45    pool: PgPool,
46    url: String,
47}
48
49impl TestDatabase {
50    /// Connect to database at the given URL.
51    ///
52    /// Use this for explicit database configuration in tests.
53    pub async fn from_url(url: &str) -> Result<Self> {
54        let pool = sqlx::postgres::PgPoolOptions::new()
55            .max_connections(10)
56            .connect(url)
57            .await
58            .map_err(ForgeError::Sql)?;
59
60        Ok(Self {
61            pool,
62            url: url.to_string(),
63        })
64    }
65
66    /// Connect using TEST_DATABASE_URL environment variable.
67    ///
68    /// Note: This reads TEST_DATABASE_URL (not DATABASE_URL) to prevent
69    /// accidental use of production databases in tests.
70    pub async fn from_env() -> Result<Self> {
71        let url = std::env::var("TEST_DATABASE_URL").map_err(|_| {
72            ForgeError::Database(
73                "TEST_DATABASE_URL not set. Set it explicitly for database tests.".to_string(),
74            )
75        })?;
76        Self::from_url(&url).await
77    }
78
79    /// Start an embedded PostgreSQL instance.
80    ///
81    /// Downloads and starts a real PostgreSQL instance automatically.
82    /// Requires the `embedded-db` feature: `cargo test --features embedded-db`
83    #[cfg(feature = "embedded-db")]
84    pub async fn embedded() -> Result<Self> {
85        let pg = EMBEDDED_PG
86            .get_or_try_init(|| async {
87                let mut pg = postgresql_embedded::PostgreSQL::default();
88                pg.setup().await.map_err(|e| {
89                    ForgeError::Database(format!("Failed to setup embedded Postgres: {}", e))
90                })?;
91                pg.start().await.map_err(|e| {
92                    ForgeError::Database(format!("Failed to start embedded Postgres: {}", e))
93                })?;
94                Ok::<_, ForgeError>(pg)
95            })
96            .await?;
97
98        let url = pg.settings().url("postgres");
99        Self::from_url(&url).await
100    }
101
102    /// Get the connection pool.
103    pub fn pool(&self) -> &PgPool {
104        &self.pool
105    }
106
107    /// Get the database URL.
108    pub fn url(&self) -> &str {
109        &self.url
110    }
111
112    /// Run raw SQL to set up test data or schema.
113    pub async fn execute(&self, sql: &str) -> Result<()> {
114        sqlx::query(sql)
115            .execute(&self.pool)
116            .await
117            .map_err(ForgeError::Sql)?;
118        Ok(())
119    }
120
121    /// Creates a dedicated database for a single test, providing full isolation.
122    ///
123    /// Each call creates a new database with a unique name. Use this when tests
124    /// modify data and could interfere with each other.
125    pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
126        let base_url = self.url.clone();
127        // UUID suffix prevents collisions when tests run in parallel
128        let db_name = format!(
129            "forge_test_{}_{}",
130            sanitize_db_name(test_name),
131            uuid::Uuid::new_v4().simple()
132        );
133
134        // Connect to default database to create the test database
135        let pool = sqlx::postgres::PgPoolOptions::new()
136            .max_connections(1)
137            .connect(&base_url)
138            .await
139            .map_err(ForgeError::Sql)?;
140
141        // Double-quoted identifier handles special characters in generated name
142        sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
143            .execute(&pool)
144            .await
145            .map_err(ForgeError::Sql)?;
146
147        // Build URL for the new database by replacing the database name component
148        let test_url = replace_db_name(&base_url, &db_name);
149
150        let test_pool = sqlx::postgres::PgPoolOptions::new()
151            .max_connections(5)
152            .connect(&test_url)
153            .await
154            .map_err(ForgeError::Sql)?;
155
156        Ok(IsolatedTestDb {
157            pool: test_pool,
158            db_name,
159            base_url,
160        })
161    }
162}
163
164/// A test database that exists for the lifetime of a single test.
165///
166/// The database is automatically created on construction. Cleanup happens
167/// when `cleanup()` is called or when the database is reused in subsequent
168/// test runs (orphaned databases are cleaned up automatically).
169pub struct IsolatedTestDb {
170    pool: PgPool,
171    db_name: String,
172    base_url: String,
173}
174
175impl IsolatedTestDb {
176    /// Get the connection pool for this isolated database.
177    pub fn pool(&self) -> &PgPool {
178        &self.pool
179    }
180
181    /// Get the database name.
182    pub fn db_name(&self) -> &str {
183        &self.db_name
184    }
185
186    /// Run raw SQL to set up test data or schema.
187    pub async fn execute(&self, sql: &str) -> Result<()> {
188        sqlx::query(sql)
189            .execute(&self.pool)
190            .await
191            .map_err(ForgeError::Sql)?;
192        Ok(())
193    }
194
195    /// Run multi-statement SQL for setup.
196    ///
197    /// This handles SQL with multiple statements separated by semicolons,
198    /// including PL/pgSQL functions with dollar-quoted strings.
199    pub async fn run_sql(&self, sql: &str) -> Result<()> {
200        let statements = split_sql_statements(sql);
201        for statement in statements {
202            let statement = statement.trim();
203            if statement.is_empty()
204                || statement.lines().all(|l| {
205                    let l = l.trim();
206                    l.is_empty() || l.starts_with("--")
207                })
208            {
209                continue;
210            }
211
212            sqlx::query(statement)
213                .execute(&self.pool)
214                .await
215                .map_err(|e| ForgeError::Database(format!("Failed to execute SQL: {}", e)))?;
216        }
217        Ok(())
218    }
219
220    /// Cleanup the test database by dropping it.
221    ///
222    /// Call this at the end of your test if you want immediate cleanup.
223    /// Otherwise, orphaned databases will be cleaned up on subsequent test runs.
224    pub async fn cleanup(self) -> Result<()> {
225        // Close all connections first
226        self.pool.close().await;
227
228        // Connect to default database to drop the test database
229        let pool = sqlx::postgres::PgPoolOptions::new()
230            .max_connections(1)
231            .connect(&self.base_url)
232            .await
233            .map_err(ForgeError::Sql)?;
234
235        // Force disconnect other connections and drop
236        let _ = sqlx::query(&format!(
237            "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
238            self.db_name
239        ))
240        .execute(&pool)
241        .await;
242
243        sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
244            .execute(&pool)
245            .await
246            .map_err(ForgeError::Sql)?;
247
248        Ok(())
249    }
250
251    /// Run migrations from a directory.
252    ///
253    /// Loads all `.sql` files from the directory, sorts them alphabetically,
254    /// and executes them in order. This is intended for test setup.
255    ///
256    /// # Example
257    ///
258    /// ```ignore
259    /// let base = TestDatabase::embedded().await?;
260    /// let db = base.isolated("my_test").await?;
261    /// db.migrate(Path::new("migrations")).await?;
262    /// ```
263    pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> {
264        if !migrations_dir.exists() {
265            debug!("Migrations directory does not exist: {:?}", migrations_dir);
266            return Ok(());
267        }
268
269        let mut migrations = Vec::new();
270
271        let entries = std::fs::read_dir(migrations_dir).map_err(ForgeError::Io)?;
272
273        for entry in entries {
274            let entry = entry.map_err(ForgeError::Io)?;
275            let path = entry.path();
276
277            if path.extension().map(|e| e == "sql").unwrap_or(false) {
278                let name = path
279                    .file_stem()
280                    .and_then(|s| s.to_str())
281                    .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
282                    .to_string();
283
284                let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
285                migrations.push((name, content));
286            }
287        }
288
289        // Sort by name (which includes the numeric prefix)
290        migrations.sort_by(|a, b| a.0.cmp(&b.0));
291
292        debug!("Running {} migrations for test", migrations.len());
293
294        for (name, content) in migrations {
295            info!("Applying test migration: {}", name);
296
297            // Parse content to extract up SQL (everything before -- @down marker)
298            let up_sql = parse_up_sql(&content);
299
300            // Split into individual statements and execute
301            let statements = split_sql_statements(&up_sql);
302            for statement in statements {
303                let statement = statement.trim();
304                if statement.is_empty()
305                    || statement.lines().all(|l| {
306                        let l = l.trim();
307                        l.is_empty() || l.starts_with("--")
308                    })
309                {
310                    continue;
311                }
312
313                sqlx::query(statement)
314                    .execute(&self.pool)
315                    .await
316                    .map_err(|e| {
317                        ForgeError::Database(format!("Failed to apply migration '{}': {}", name, e))
318                    })?;
319            }
320        }
321
322        Ok(())
323    }
324}
325
326/// Sanitize a test name for use in a database name.
327fn sanitize_db_name(name: &str) -> String {
328    name.chars()
329        .map(|c| if c.is_alphanumeric() { c } else { '_' })
330        .take(32)
331        .collect()
332}
333
334/// Replace the database name in a connection URL.
335fn replace_db_name(url: &str, new_db: &str) -> String {
336    // Handle both postgres://.../ and postgres://...? formats
337    if let Some(idx) = url.rfind('/') {
338        let base = &url[..=idx];
339        // Check if there are query params
340        if let Some(query_idx) = url[idx + 1..].find('?') {
341            let query = &url[idx + 1 + query_idx..];
342            format!("{}{}{}", base, new_db, query)
343        } else {
344            format!("{}{}", base, new_db)
345        }
346    } else {
347        format!("{}/{}", url, new_db)
348    }
349}
350
351/// Parse migration content, extracting only the up SQL (before -- @down marker).
352fn parse_up_sql(content: &str) -> String {
353    let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
354
355    for pattern in down_marker_patterns {
356        if let Some(idx) = content.find(pattern) {
357            let up_part = &content[..idx];
358            return up_part
359                .replace("-- @up", "")
360                .replace("--@up", "")
361                .replace("-- @UP", "")
362                .replace("--@UP", "")
363                .trim()
364                .to_string();
365        }
366    }
367
368    // No @down marker found - treat entire content as up SQL
369    content
370        .replace("-- @up", "")
371        .replace("--@up", "")
372        .replace("-- @UP", "")
373        .replace("--@UP", "")
374        .trim()
375        .to_string()
376}
377
378/// Split SQL into individual statements, respecting dollar-quoted strings.
379/// This handles PL/pgSQL functions that contain semicolons inside $$ delimiters.
380fn split_sql_statements(sql: &str) -> Vec<String> {
381    let mut statements = Vec::new();
382    let mut current = String::new();
383    let mut in_dollar_quote = false;
384    let mut dollar_tag = String::new();
385    let mut chars = sql.chars().peekable();
386
387    while let Some(c) = chars.next() {
388        current.push(c);
389
390        // Check for dollar-quoting start/end
391        if c == '$' {
392            // Look for a dollar-quote tag like $$ or $tag$
393            let mut potential_tag = String::from("$");
394
395            // Collect characters until we hit another $ or non-identifier char
396            while let Some(&next_c) = chars.peek() {
397                if next_c == '$' {
398                    potential_tag.push(chars.next().unwrap());
399                    current.push('$');
400                    break;
401                } else if next_c.is_alphanumeric() || next_c == '_' {
402                    potential_tag.push(chars.next().unwrap());
403                    current.push(potential_tag.chars().last().unwrap());
404                } else {
405                    break;
406                }
407            }
408
409            // Check if this is a valid dollar-quote delimiter (ends with $)
410            if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
411                if in_dollar_quote && potential_tag == dollar_tag {
412                    // End of dollar-quoted string
413                    in_dollar_quote = false;
414                    dollar_tag.clear();
415                } else if !in_dollar_quote {
416                    // Start of dollar-quoted string
417                    in_dollar_quote = true;
418                    dollar_tag = potential_tag;
419                }
420            }
421        }
422
423        // Split on semicolon only if not inside a dollar-quoted string
424        if c == ';' && !in_dollar_quote {
425            let stmt = current.trim().trim_end_matches(';').trim().to_string();
426            if !stmt.is_empty() {
427                statements.push(stmt);
428            }
429            current.clear();
430        }
431    }
432
433    // Don't forget the last statement (might not end with ;)
434    let stmt = current.trim().trim_end_matches(';').trim().to_string();
435    if !stmt.is_empty() {
436        statements.push(stmt);
437    }
438
439    statements
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_sanitize_db_name() {
448        assert_eq!(sanitize_db_name("my_test"), "my_test");
449        assert_eq!(sanitize_db_name("my-test"), "my_test");
450        assert_eq!(sanitize_db_name("my test"), "my_test");
451        assert_eq!(sanitize_db_name("test::function"), "test__function");
452    }
453
454    #[test]
455    fn test_replace_db_name() {
456        assert_eq!(
457            replace_db_name("postgres://localhost/olddb", "newdb"),
458            "postgres://localhost/newdb"
459        );
460        assert_eq!(
461            replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
462            "postgres://user:pass@localhost:5432/newdb"
463        );
464        assert_eq!(
465            replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
466            "postgres://localhost/newdb?sslmode=disable"
467        );
468    }
469}