Skip to main content

forge_core/testing/
db.rs

1//! Database provisioning for tests.
2//!
3//! Deliberately avoids reading DATABASE_URL to prevent accidental production use.
4
5#![allow(clippy::unwrap_used, clippy::indexing_slicing)]
6
7use sqlx::PgPool;
8use std::path::Path;
9#[cfg(feature = "testcontainers")]
10use std::sync::Arc;
11use tracing::{debug, info};
12
13use crate::error::{ForgeError, Result};
14
15#[cfg(feature = "testcontainers")]
16type PgContainer =
17    Arc<Option<testcontainers::ContainerAsync<testcontainers_modules::postgres::Postgres>>>;
18
19/// Database access for tests.
20///
21/// # Examples
22///
23/// ```ignore
24/// let db = TestDatabase::from_url("postgres://localhost/test_db").await?;
25/// let db = TestDatabase::from_env().await?;
26/// ```
27pub struct TestDatabase {
28    pool: PgPool,
29    url: String,
30    #[cfg(feature = "testcontainers")]
31    _container: PgContainer,
32}
33
34impl TestDatabase {
35    /// Connect to database at the given URL.
36    pub async fn from_url(url: &str) -> Result<Self> {
37        let pool = sqlx::postgres::PgPoolOptions::new()
38            .max_connections(10)
39            .connect(url)
40            .await
41            .map_err(ForgeError::Sql)?;
42
43        Ok(Self {
44            pool,
45            url: url.to_string(),
46            #[cfg(feature = "testcontainers")]
47            _container: Arc::new(None),
48        })
49    }
50
51    /// Connect using `TEST_DATABASE_URL`, or start a container if the
52    /// `testcontainers` feature is enabled and the var is unset.
53    pub async fn from_env() -> Result<Self> {
54        match std::env::var("TEST_DATABASE_URL") {
55            Ok(url) => Self::from_url(&url).await,
56            Err(_) => {
57                #[cfg(feature = "testcontainers")]
58                {
59                    Self::from_container().await
60                }
61                #[cfg(not(feature = "testcontainers"))]
62                {
63                    Err(ForgeError::Database(
64                        "TEST_DATABASE_URL not set. Set it explicitly for database tests, \
65                         or enable the `testcontainers` feature for automatic provisioning."
66                            .to_string(),
67                    ))
68                }
69            }
70        }
71    }
72
73    #[cfg(feature = "testcontainers")]
74    async fn from_container() -> Result<Self> {
75        use testcontainers::ImageExt;
76        use testcontainers::runners::AsyncRunner;
77        use testcontainers_modules::postgres::Postgres;
78
79        // PG 13+ required for gen_random_uuid() without pgcrypto
80        let container = Postgres::default()
81            .with_tag("18-alpine")
82            .start()
83            .await
84            .map_err(|e| ForgeError::Database(format!("Failed to start PG container: {e}")))?;
85
86        let port = container
87            .get_host_port_ipv4(5432)
88            .await
89            .map_err(|e| ForgeError::Database(format!("Failed to get container port: {e}")))?;
90
91        let url = format!("postgres://postgres:postgres@localhost:{port}/postgres");
92        let pool = sqlx::postgres::PgPoolOptions::new()
93            .max_connections(10)
94            .acquire_timeout(std::time::Duration::from_secs(30))
95            .connect(&url)
96            .await
97            .map_err(ForgeError::Sql)?;
98
99        Ok(Self {
100            pool,
101            url,
102            _container: Arc::new(Some(container)),
103        })
104    }
105
106    /// Get the connection pool.
107    pub fn pool(&self) -> &PgPool {
108        &self.pool
109    }
110
111    /// Get the database URL.
112    pub fn url(&self) -> &str {
113        &self.url
114    }
115
116    /// Run raw SQL to set up test data or schema.
117    pub async fn execute(&self, sql: &str) -> Result<()> {
118        sqlx::query(sql)
119            .execute(&self.pool)
120            .await
121            .map_err(ForgeError::Sql)?;
122        Ok(())
123    }
124
125    /// Creates a dedicated database for a single test, providing full isolation.
126    ///
127    /// Each call creates a new database with a unique name. Use this when tests
128    /// modify data and could interfere with each other.
129    pub async fn isolated(&self, test_name: &str) -> Result<IsolatedTestDb> {
130        let base_url = self.url.clone();
131        // UUID suffix prevents collisions when tests run in parallel
132        let db_name = format!(
133            "forge_test_{}_{}",
134            sanitize_db_name(test_name),
135            uuid::Uuid::new_v4().simple()
136        );
137
138        // Connect to default database to create the test database
139        let pool = sqlx::postgres::PgPoolOptions::new()
140            .max_connections(1)
141            .connect(&base_url)
142            .await
143            .map_err(ForgeError::Sql)?;
144
145        // Double-quoted identifier handles special characters in generated name
146        sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name))
147            .execute(&pool)
148            .await
149            .map_err(ForgeError::Sql)?;
150
151        // Build URL for the new database by replacing the database name component
152        let test_url = replace_db_name(&base_url, &db_name);
153
154        let test_pool = sqlx::postgres::PgPoolOptions::new()
155            .max_connections(5)
156            .connect(&test_url)
157            .await
158            .map_err(ForgeError::Sql)?;
159
160        Ok(IsolatedTestDb {
161            pool: test_pool,
162            db_name,
163            base_url,
164            #[cfg(feature = "testcontainers")]
165            _container: self._container.clone(),
166        })
167    }
168}
169
170/// A test database that exists for the lifetime of a single test.
171///
172/// The database is automatically created on construction. Cleanup happens
173/// when `cleanup()` is called or when the database is reused in subsequent
174/// test runs (orphaned databases are cleaned up automatically).
175pub struct IsolatedTestDb {
176    pool: PgPool,
177    db_name: String,
178    base_url: String,
179    #[cfg(feature = "testcontainers")]
180    _container: PgContainer,
181}
182
183impl IsolatedTestDb {
184    /// Create a fully initialized test database in one call.
185    ///
186    /// Combines `TestDatabase::from_env()`, `isolated()`, `run_sql(internal_sql)`,
187    /// and `migrate()` into a single convenience method.
188    ///
189    /// ```ignore
190    /// let db = IsolatedTestDb::setup(
191    ///     "my_test",
192    ///     &forge::get_internal_sql(),
193    ///     Path::new("migrations"),
194    /// ).await?;
195    /// let pool = db.pool().clone();
196    /// ```
197    pub async fn setup(test_name: &str, internal_sql: &str, migrations_dir: &Path) -> Result<Self> {
198        let base = TestDatabase::from_env().await?;
199        let db = base.isolated(test_name).await?;
200        db.run_sql(internal_sql).await?;
201        db.migrate(migrations_dir).await?;
202        Ok(db)
203    }
204
205    /// Get the connection pool for this isolated database.
206    pub fn pool(&self) -> &PgPool {
207        &self.pool
208    }
209
210    /// Get the database name.
211    pub fn db_name(&self) -> &str {
212        &self.db_name
213    }
214
215    /// Run raw SQL to set up test data or schema.
216    pub async fn execute(&self, sql: &str) -> Result<()> {
217        sqlx::query(sql)
218            .execute(&self.pool)
219            .await
220            .map_err(ForgeError::Sql)?;
221        Ok(())
222    }
223
224    /// Run multi-statement SQL for setup.
225    ///
226    /// This handles SQL with multiple statements separated by semicolons,
227    /// including PL/pgSQL functions with dollar-quoted strings.
228    pub async fn run_sql(&self, sql: &str) -> Result<()> {
229        for stmt in split_sql_statements(sql) {
230            let stmt = stmt.trim();
231            if is_blank_sql(stmt) {
232                continue;
233            }
234            sqlx::query(stmt)
235                .execute(&self.pool)
236                .await
237                .map_err(|e| ForgeError::Database(format!("Failed to execute SQL: {e}")))?;
238        }
239        Ok(())
240    }
241
242    /// Cleanup the test database by dropping it.
243    ///
244    /// Call this at the end of your test if you want immediate cleanup.
245    /// Otherwise, orphaned databases will be cleaned up on subsequent test runs.
246    pub async fn cleanup(self) -> Result<()> {
247        // Close all connections first
248        self.pool.close().await;
249
250        // Connect to default database to drop the test database
251        let pool = sqlx::postgres::PgPoolOptions::new()
252            .max_connections(1)
253            .connect(&self.base_url)
254            .await
255            .map_err(ForgeError::Sql)?;
256
257        // Force disconnect other connections and drop
258        let _ = sqlx::query(&format!(
259            "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
260            self.db_name
261        ))
262        .execute(&pool)
263        .await;
264
265        sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name))
266            .execute(&pool)
267            .await
268            .map_err(ForgeError::Sql)?;
269
270        Ok(())
271    }
272
273    /// Run migrations from a directory.
274    ///
275    /// Loads all `.sql` files from the directory, sorts them alphabetically,
276    /// and executes them in order. This is intended for test setup.
277    ///
278    /// # Example
279    ///
280    /// ```ignore
281    /// let base = TestDatabase::from_env().await?;
282    /// let db = base.isolated("my_test").await?;
283    /// db.migrate(Path::new("migrations")).await?;
284    /// ```
285    pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> {
286        if !migrations_dir.exists() {
287            debug!("Migrations directory does not exist: {:?}", migrations_dir);
288            return Ok(());
289        }
290
291        let mut migrations = Vec::new();
292
293        let entries = std::fs::read_dir(migrations_dir).map_err(ForgeError::Io)?;
294
295        for entry in entries {
296            let entry = entry.map_err(ForgeError::Io)?;
297            let path = entry.path();
298
299            if path.extension().map(|e| e == "sql").unwrap_or(false) {
300                let name = path
301                    .file_stem()
302                    .and_then(|s| s.to_str())
303                    .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
304                    .to_string();
305
306                let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
307                migrations.push((name, content));
308            }
309        }
310
311        // Sort by name (which includes the numeric prefix)
312        migrations.sort_by(|a, b| a.0.cmp(&b.0));
313
314        debug!("Running {} migrations for test", migrations.len());
315
316        for (name, content) in migrations {
317            info!("Applying test migration: {}", name);
318
319            // Parse content to extract up SQL (everything before -- @down marker)
320            let up_sql = parse_up_sql(&content);
321
322            // Split into individual statements and execute
323            for stmt in split_sql_statements(&up_sql) {
324                let stmt = stmt.trim();
325                if is_blank_sql(stmt) {
326                    continue;
327                }
328                sqlx::query(stmt).execute(&self.pool).await.map_err(|e| {
329                    ForgeError::Database(format!("Failed to apply migration '{name}': {e}"))
330                })?;
331            }
332        }
333
334        Ok(())
335    }
336}
337
338fn is_blank_sql(sql: &str) -> bool {
339    sql.is_empty()
340        || sql
341            .lines()
342            .all(|l| l.trim().is_empty() || l.trim().starts_with("--"))
343}
344
345/// Sanitize a test name for use in a database name.
346fn sanitize_db_name(name: &str) -> String {
347    name.chars()
348        .map(|c| if c.is_alphanumeric() { c } else { '_' })
349        .take(32)
350        .collect()
351}
352
353/// Replace the database name in a connection URL.
354fn replace_db_name(url: &str, new_db: &str) -> String {
355    // Handle both postgres://.../ and postgres://...? formats
356    if let Some(idx) = url.rfind('/') {
357        let base = &url[..=idx];
358        // Check if there are query params
359        if let Some(query_idx) = url[idx + 1..].find('?') {
360            let query = &url[idx + 1 + query_idx..];
361            format!("{}{}{}", base, new_db, query)
362        } else {
363            format!("{}{}", base, new_db)
364        }
365    } else {
366        format!("{}/{}", url, new_db)
367    }
368}
369
370/// Parse migration content, extracting only the up SQL (before -- @down marker).
371fn parse_up_sql(content: &str) -> String {
372    let down_markers = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
373    let up_part = down_markers
374        .iter()
375        .find_map(|m| content.find(m).map(|idx| &content[..idx]))
376        .unwrap_or(content);
377
378    strip_up_markers(up_part)
379}
380
381fn strip_up_markers(sql: &str) -> String {
382    sql.replace("-- @up", "")
383        .replace("--@up", "")
384        .replace("-- @UP", "")
385        .replace("--@UP", "")
386        .trim()
387        .to_string()
388}
389
390/// Split SQL into individual statements, respecting dollar-quoted strings.
391/// This handles PL/pgSQL functions that contain semicolons inside $$ delimiters.
392fn split_sql_statements(sql: &str) -> Vec<String> {
393    let mut statements = Vec::new();
394    let mut current = String::new();
395    let mut in_dollar_quote = false;
396    let mut dollar_tag = String::new();
397    let mut chars = sql.chars().peekable();
398
399    while let Some(c) = chars.next() {
400        current.push(c);
401
402        // Check for dollar-quoting start/end
403        if c == '$' {
404            // Look for a dollar-quote tag like $$ or $tag$
405            let mut potential_tag = String::from("$");
406
407            // Collect characters until we hit another $ or non-identifier char
408            while let Some(&next_c) = chars.peek() {
409                if next_c == '$' {
410                    potential_tag.push(chars.next().unwrap());
411                    current.push('$');
412                    break;
413                } else if next_c.is_alphanumeric() || next_c == '_' {
414                    potential_tag.push(chars.next().unwrap());
415                    current.push(potential_tag.chars().last().unwrap());
416                } else {
417                    break;
418                }
419            }
420
421            // Check if this is a valid dollar-quote delimiter (ends with $)
422            if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
423                if in_dollar_quote && potential_tag == dollar_tag {
424                    // End of dollar-quoted string
425                    in_dollar_quote = false;
426                    dollar_tag.clear();
427                } else if !in_dollar_quote {
428                    // Start of dollar-quoted string
429                    in_dollar_quote = true;
430                    dollar_tag = potential_tag;
431                }
432            }
433        }
434
435        // Split on semicolon only if not inside a dollar-quoted string
436        if c == ';' && !in_dollar_quote {
437            let stmt = current.trim().trim_end_matches(';').trim().to_string();
438            if !stmt.is_empty() {
439                statements.push(stmt);
440            }
441            current.clear();
442        }
443    }
444
445    // Don't forget the last statement (might not end with ;)
446    let stmt = current.trim().trim_end_matches(';').trim().to_string();
447    if !stmt.is_empty() {
448        statements.push(stmt);
449    }
450
451    statements
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[test]
459    fn test_sanitize_db_name() {
460        assert_eq!(sanitize_db_name("my_test"), "my_test");
461        assert_eq!(sanitize_db_name("my-test"), "my_test");
462        assert_eq!(sanitize_db_name("my test"), "my_test");
463        assert_eq!(sanitize_db_name("test::function"), "test__function");
464    }
465
466    #[test]
467    fn test_replace_db_name() {
468        assert_eq!(
469            replace_db_name("postgres://localhost/olddb", "newdb"),
470            "postgres://localhost/newdb"
471        );
472        assert_eq!(
473            replace_db_name("postgres://user:pass@localhost:5432/olddb", "newdb"),
474            "postgres://user:pass@localhost:5432/newdb"
475        );
476        assert_eq!(
477            replace_db_name("postgres://localhost/olddb?sslmode=disable", "newdb"),
478            "postgres://localhost/newdb?sslmode=disable"
479        );
480    }
481}