Skip to main content

forge_core/testing/
db.rs

1//! Database provisioning for tests.
2//!
3//! Provides PostgreSQL access for integration tests. Database configuration
4//! options:
5//! 1. Pass a URL directly via `from_url()`
6//! 2. Use `from_env()` to explicitly read from TEST_DATABASE_URL
7//!
8//! This design prevents accidental use of production databases. The .env file
9//! DATABASE_URL is NEVER automatically read.
10
11#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
12
13use sqlx::PgPool;
14use std::path::Path;
15use tracing::{debug, info};
16
17use crate::error::{ForgeError, Result};
18
19/// Database access for tests.
20///
21/// Test database configuration is intentionally explicit to prevent
22/// accidental use of production databases.
23///
24/// # Examples
25///
26/// ```ignore
27/// // Option 1: Explicit URL
28/// let db = TestDatabase::from_url("postgres://localhost/test_db").await?;
29///
30/// // Option 2: From TEST_DATABASE_URL env var
31/// let db = TestDatabase::from_env().await?;
32/// ```
33pub struct TestDatabase {
34    pool: PgPool,
35    url: String,
36}
37
38impl TestDatabase {
39    /// Connect to database at the given URL.
40    ///
41    /// Use this for explicit database configuration in tests.
42    pub async fn from_url(url: &str) -> Result<Self> {
43        let pool = sqlx::postgres::PgPoolOptions::new()
44            .max_connections(10)
45            .connect(url)
46            .await
47            .map_err(ForgeError::Sql)?;
48
49        Ok(Self {
50            pool,
51            url: url.to_string(),
52        })
53    }
54
55    /// Connect using TEST_DATABASE_URL environment variable.
56    ///
57    /// Note: This reads TEST_DATABASE_URL (not DATABASE_URL) to prevent
58    /// accidental use of production databases in tests.
59    pub async fn from_env() -> Result<Self> {
60        let url = std::env::var("TEST_DATABASE_URL").map_err(|_| {
61            ForgeError::Database(
62                "TEST_DATABASE_URL not set. Set it explicitly for database tests.".to_string(),
63            )
64        })?;
65        Self::from_url(&url).await
66    }
67
68    /// Get the connection pool.
69    pub fn pool(&self) -> &PgPool {
70        &self.pool
71    }
72
73    /// Get the database URL.
74    pub fn url(&self) -> &str {
75        &self.url
76    }
77
78    /// Run raw SQL to set up test data or schema.
79    pub async fn execute(&self, sql: &str) -> Result<()> {
80        sqlx::query(sql)
81            .execute(&self.pool)
82            .await
83            .map_err(ForgeError::Sql)?;
84        Ok(())
85    }
86
87    /// Creates a dedicated database for a single test, providing full isolation.
88    ///
89    /// Each call creates a new database with a unique name. Use this when tests
90    /// modify data and could interfere with each other.
91    pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
92        let base_url = self.url.clone();
93        // UUID suffix prevents collisions when tests run in parallel
94        let db_name = format!(
95            "forge_test_{}_{}",
96            sanitize_db_name(test_name),
97            uuid::Uuid::new_v4().simple()
98        );
99
100        // Connect to default database to create the test database
101        let pool = sqlx::postgres::PgPoolOptions::new()
102            .max_connections(1)
103            .connect(&base_url)
104            .await
105            .map_err(ForgeError::Sql)?;
106
107        // Double-quoted identifier handles special characters in generated name
108        sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
109            .execute(&pool)
110            .await
111            .map_err(ForgeError::Sql)?;
112
113        // Build URL for the new database by replacing the database name component
114        let test_url = replace_db_name(&base_url, &db_name);
115
116        let test_pool = sqlx::postgres::PgPoolOptions::new()
117            .max_connections(5)
118            .connect(&test_url)
119            .await
120            .map_err(ForgeError::Sql)?;
121
122        Ok(IsolatedTestDb {
123            pool: test_pool,
124            db_name,
125            base_url,
126        })
127    }
128}
129
130/// A test database that exists for the lifetime of a single test.
131///
132/// The database is automatically created on construction. Cleanup happens
133/// when `cleanup()` is called or when the database is reused in subsequent
134/// test runs (orphaned databases are cleaned up automatically).
135pub struct IsolatedTestDb {
136    pool: PgPool,
137    db_name: String,
138    base_url: String,
139}
140
141impl IsolatedTestDb {
142    /// Get the connection pool for this isolated database.
143    pub fn pool(&self) -> &PgPool {
144        &self.pool
145    }
146
147    /// Get the database name.
148    pub fn db_name(&self) -> &str {
149        &self.db_name
150    }
151
152    /// Run raw SQL to set up test data or schema.
153    pub async fn execute(&self, sql: &str) -> Result<()> {
154        sqlx::query(sql)
155            .execute(&self.pool)
156            .await
157            .map_err(ForgeError::Sql)?;
158        Ok(())
159    }
160
161    /// Run multi-statement SQL for setup.
162    ///
163    /// This handles SQL with multiple statements separated by semicolons,
164    /// including PL/pgSQL functions with dollar-quoted strings.
165    pub async fn run_sql(&self, sql: &str) -> Result<()> {
166        let statements = split_sql_statements(sql);
167        for statement in statements {
168            let statement = statement.trim();
169            if statement.is_empty()
170                || statement.lines().all(|l| {
171                    let l = l.trim();
172                    l.is_empty() || l.starts_with("--")
173                })
174            {
175                continue;
176            }
177
178            sqlx::query(statement)
179                .execute(&self.pool)
180                .await
181                .map_err(|e| ForgeError::Database(format!("Failed to execute SQL: {}", e)))?;
182        }
183        Ok(())
184    }
185
186    /// Cleanup the test database by dropping it.
187    ///
188    /// Call this at the end of your test if you want immediate cleanup.
189    /// Otherwise, orphaned databases will be cleaned up on subsequent test runs.
190    pub async fn cleanup(self) -> Result<()> {
191        // Close all connections first
192        self.pool.close().await;
193
194        // Connect to default database to drop the test database
195        let pool = sqlx::postgres::PgPoolOptions::new()
196            .max_connections(1)
197            .connect(&self.base_url)
198            .await
199            .map_err(ForgeError::Sql)?;
200
201        // Force disconnect other connections and drop
202        let _ = sqlx::query(&format!(
203            "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
204            self.db_name
205        ))
206        .execute(&pool)
207        .await;
208
209        sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
210            .execute(&pool)
211            .await
212            .map_err(ForgeError::Sql)?;
213
214        Ok(())
215    }
216
217    /// Run migrations from a directory.
218    ///
219    /// Loads all `.sql` files from the directory, sorts them alphabetically,
220    /// and executes them in order. This is intended for test setup.
221    ///
222    /// # Example
223    ///
224    /// ```ignore
225    /// let base = TestDatabase::from_env().await?;
226    /// let db = base.isolated("my_test").await?;
227    /// db.migrate(Path::new("migrations")).await?;
228    /// ```
229    pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> {
230        if !migrations_dir.exists() {
231            debug!("Migrations directory does not exist: {:?}", migrations_dir);
232            return Ok(());
233        }
234
235        let mut migrations = Vec::new();
236
237        let entries = std::fs::read_dir(migrations_dir).map_err(ForgeError::Io)?;
238
239        for entry in entries {
240            let entry = entry.map_err(ForgeError::Io)?;
241            let path = entry.path();
242
243            if path.extension().map(|e| e == "sql").unwrap_or(false) {
244                let name = path
245                    .file_stem()
246                    .and_then(|s| s.to_str())
247                    .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
248                    .to_string();
249
250                let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
251                migrations.push((name, content));
252            }
253        }
254
255        // Sort by name (which includes the numeric prefix)
256        migrations.sort_by(|a, b| a.0.cmp(&b.0));
257
258        debug!("Running {} migrations for test", migrations.len());
259
260        for (name, content) in migrations {
261            info!("Applying test migration: {}", name);
262
263            // Parse content to extract up SQL (everything before -- @down marker)
264            let up_sql = parse_up_sql(&content);
265
266            // Split into individual statements and execute
267            let statements = split_sql_statements(&up_sql);
268            for statement in statements {
269                let statement = statement.trim();
270                if statement.is_empty()
271                    || statement.lines().all(|l| {
272                        let l = l.trim();
273                        l.is_empty() || l.starts_with("--")
274                    })
275                {
276                    continue;
277                }
278
279                sqlx::query(statement)
280                    .execute(&self.pool)
281                    .await
282                    .map_err(|e| {
283                        ForgeError::Database(format!("Failed to apply migration '{}': {}", name, e))
284                    })?;
285            }
286        }
287
288        Ok(())
289    }
290}
291
292/// Sanitize a test name for use in a database name.
293fn sanitize_db_name(name: &str) -> String {
294    name.chars()
295        .map(|c| if c.is_alphanumeric() { c } else { '_' })
296        .take(32)
297        .collect()
298}
299
300/// Replace the database name in a connection URL.
301fn replace_db_name(url: &str, new_db: &str) -> String {
302    // Handle both postgres://.../ and postgres://...? formats
303    if let Some(idx) = url.rfind('/') {
304        let base = &url[..=idx];
305        // Check if there are query params
306        if let Some(query_idx) = url[idx + 1..].find('?') {
307            let query = &url[idx + 1 + query_idx..];
308            format!("{}{}{}", base, new_db, query)
309        } else {
310            format!("{}{}", base, new_db)
311        }
312    } else {
313        format!("{}/{}", url, new_db)
314    }
315}
316
317/// Parse migration content, extracting only the up SQL (before -- @down marker).
318fn parse_up_sql(content: &str) -> String {
319    let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
320
321    for pattern in down_marker_patterns {
322        if let Some(idx) = content.find(pattern) {
323            let up_part = &content[..idx];
324            return up_part
325                .replace("-- @up", "")
326                .replace("--@up", "")
327                .replace("-- @UP", "")
328                .replace("--@UP", "")
329                .trim()
330                .to_string();
331        }
332    }
333
334    // No @down marker found - treat entire content as up SQL
335    content
336        .replace("-- @up", "")
337        .replace("--@up", "")
338        .replace("-- @UP", "")
339        .replace("--@UP", "")
340        .trim()
341        .to_string()
342}
343
344/// Split SQL into individual statements, respecting dollar-quoted strings.
345/// This handles PL/pgSQL functions that contain semicolons inside $$ delimiters.
346fn split_sql_statements(sql: &str) -> Vec<String> {
347    let mut statements = Vec::new();
348    let mut current = String::new();
349    let mut in_dollar_quote = false;
350    let mut dollar_tag = String::new();
351    let mut chars = sql.chars().peekable();
352
353    while let Some(c) = chars.next() {
354        current.push(c);
355
356        // Check for dollar-quoting start/end
357        if c == '$' {
358            // Look for a dollar-quote tag like $$ or $tag$
359            let mut potential_tag = String::from("$");
360
361            // Collect characters until we hit another $ or non-identifier char
362            while let Some(&next_c) = chars.peek() {
363                if next_c == '$' {
364                    potential_tag.push(chars.next().unwrap());
365                    current.push('$');
366                    break;
367                } else if next_c.is_alphanumeric() || next_c == '_' {
368                    potential_tag.push(chars.next().unwrap());
369                    current.push(potential_tag.chars().last().unwrap());
370                } else {
371                    break;
372                }
373            }
374
375            // Check if this is a valid dollar-quote delimiter (ends with $)
376            if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
377                if in_dollar_quote && potential_tag == dollar_tag {
378                    // End of dollar-quoted string
379                    in_dollar_quote = false;
380                    dollar_tag.clear();
381                } else if !in_dollar_quote {
382                    // Start of dollar-quoted string
383                    in_dollar_quote = true;
384                    dollar_tag = potential_tag;
385                }
386            }
387        }
388
389        // Split on semicolon only if not inside a dollar-quoted string
390        if c == ';' && !in_dollar_quote {
391            let stmt = current.trim().trim_end_matches(';').trim().to_string();
392            if !stmt.is_empty() {
393                statements.push(stmt);
394            }
395            current.clear();
396        }
397    }
398
399    // Don't forget the last statement (might not end with ;)
400    let stmt = current.trim().trim_end_matches(';').trim().to_string();
401    if !stmt.is_empty() {
402        statements.push(stmt);
403    }
404
405    statements
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_sanitize_db_name() {
414        assert_eq!(sanitize_db_name("my_test"), "my_test");
415        assert_eq!(sanitize_db_name("my-test"), "my_test");
416        assert_eq!(sanitize_db_name("my test"), "my_test");
417        assert_eq!(sanitize_db_name("test::function"), "test__function");
418    }
419
420    #[test]
421    fn test_replace_db_name() {
422        assert_eq!(
423            replace_db_name("postgres://localhost/olddb", "newdb"),
424            "postgres://localhost/newdb"
425        );
426        assert_eq!(
427            replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
428            "postgres://user:pass@localhost:5432/newdb"
429        );
430        assert_eq!(
431            replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
432            "postgres://localhost/newdb?sslmode=disable"
433        );
434    }
435}