Skip to main content

bsql_core/
test_support.rs

1//! Test infrastructure for `#[bsql::test]`.
2//!
3//! Creates isolated PostgreSQL schemas per test for parallel execution.
4//! Fixtures (SQL files) are applied to the schema before the test runs.
5//! Schema is dropped after the test -- even on panic.
6
7use std::sync::atomic::{AtomicU64, Ordering};
8
9use bsql_driver_postgres::{Config, Connection};
10
11use crate::error::{BsqlError, ConnectError};
12use crate::pool::Pool;
13
14static TEST_COUNTER: AtomicU64 = AtomicU64::new(0);
15
16/// Test context holding the pool and cleanup info.
17/// Drops the schema on cleanup.
18pub struct TestContext {
19    /// The connection pool, scoped to the isolated test schema.
20    pub pool: Pool,
21    schema_name: String,
22    db_url: String,
23}
24
25impl std::fmt::Debug for TestContext {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("TestContext")
28            .field("schema", &self.schema_name)
29            .finish()
30    }
31}
32
33impl Drop for TestContext {
34    fn drop(&mut self) {
35        // Fresh connection for cleanup (pool connection may be broken after panic).
36        // Errors are intentionally ignored -- we are in a destructor.
37        if let Ok(config) = Config::from_url(&self.db_url) {
38            if let Ok(mut conn) = Connection::connect(&config) {
39                let _ = conn.simple_query(&format!(
40                    "DROP SCHEMA IF EXISTS \"{}\" CASCADE",
41                    self.schema_name
42                ));
43            }
44        }
45    }
46}
47
48/// Set up an isolated test schema with fixtures.
49///
50/// Called by generated `#[bsql::test]` code. Not intended for direct use.
51///
52/// `fixtures_sql` contains compile-time embedded SQL strings from fixture files.
53pub async fn setup_test_schema(fixtures_sql: &[&str]) -> Result<TestContext, BsqlError> {
54    let db_url = std::env::var("BSQL_DATABASE_URL")
55        .or_else(|_| std::env::var("DATABASE_URL"))
56        .map_err(|_| {
57            ConnectError::create("BSQL_DATABASE_URL or DATABASE_URL must be set for #[bsql::test]")
58        })?;
59
60    let schema_name = format!(
61        "__bsql_test_{}_{}",
62        std::process::id(),
63        TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
64    );
65
66    // Setup connection: create schema, apply fixtures
67    let config = Config::from_url(&db_url)
68        .map_err(|e| ConnectError::create(format!("invalid database URL: {e}")))?;
69    let mut conn = Connection::connect(&config)
70        .map_err(|e| ConnectError::create(format!("connection failed: {e}")))?;
71
72    // Create isolated schema
73    conn.simple_query(&format!("CREATE SCHEMA \"{}\"", schema_name))
74        .map_err(|e| ConnectError::create(format!("failed to create test schema: {e}")))?;
75
76    // Set search_path to test schema (with public for extensions)
77    conn.simple_query(&format!("SET search_path TO \"{}\", public", schema_name))
78        .map_err(|e| ConnectError::create(format!("failed to set search_path: {e}")))?;
79
80    // Apply fixtures in order
81    for fixture_sql in fixtures_sql {
82        if !fixture_sql.trim().is_empty() {
83            conn.simple_query(fixture_sql)
84                .map_err(|e| ConnectError::create(format!("fixture failed: {e}")))?;
85        }
86    }
87
88    drop(conn); // Release setup connection
89
90    // Build pool. Connections are lazy, so we create the pool first,
91    // then immediately acquire one connection and set search_path on it.
92    let pool = Pool::connect(&db_url).await?;
93
94    // Acquire a connection and set search_path so all subsequent queries
95    // in this test run against the isolated schema.
96    pool.raw_execute(&format!("SET search_path TO \"{}\", public", schema_name))
97        .await?;
98
99    // Set warmup SQL so any *new* connections from this pool also get
100    // the correct search_path (the pool has max_size=10 by default,
101    // but for tests we typically only use 1 connection).
102    let warmup_sql = format!("SET search_path TO \"{}\", public", schema_name);
103    // set_warmup_sqls copies strings internally (into Box<str>), so &str
104    // only needs to live for the duration of this call. No leak needed.
105    pool.set_warmup_sqls([warmup_sql]);
106
107    Ok(TestContext {
108        pool,
109        schema_name,
110        db_url,
111    })
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use std::collections::HashSet;
118
119    // ---------------------------------------------------------------
120    // Schema lifecycle
121    // ---------------------------------------------------------------
122
123    #[test]
124    fn schema_name_is_unique() {
125        let name1 = format!(
126            "__bsql_test_{}_{}",
127            std::process::id(),
128            TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
129        );
130        let name2 = format!(
131            "__bsql_test_{}_{}",
132            std::process::id(),
133            TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
134        );
135        assert_ne!(name1, name2);
136    }
137
138    #[test]
139    fn schema_name_contains_pid() {
140        let name = format!(
141            "__bsql_test_{}_{}",
142            std::process::id(),
143            TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
144        );
145        assert!(name.contains(&std::process::id().to_string()));
146    }
147
148    #[test]
149    fn schema_name_starts_with_prefix() {
150        let name = format!(
151            "__bsql_test_{}_{}",
152            std::process::id(),
153            TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
154        );
155        assert!(name.starts_with("__bsql_test_"));
156    }
157
158    #[test]
159    fn schema_names_never_collide_100_sequential() {
160        let mut names = HashSet::new();
161        for _ in 0..100 {
162            let name = format!(
163                "__bsql_test_{}_{}",
164                std::process::id(),
165                TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
166            );
167            assert!(names.insert(name.clone()), "duplicate schema name: {name}");
168        }
169        assert_eq!(names.len(), 100);
170    }
171
172    #[test]
173    fn schema_name_is_valid_sql_identifier() {
174        let name = format!(
175            "__bsql_test_{}_{}",
176            std::process::id(),
177            TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
178        );
179        // Valid SQL identifier: starts with letter or underscore, then alphanumeric/underscore
180        assert!(
181            name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_'),
182            "schema name contains invalid chars: {name}"
183        );
184        assert!(
185            name.starts_with('_') || name.starts_with(|c: char| c.is_ascii_alphabetic()),
186            "schema name must start with letter or underscore: {name}"
187        );
188    }
189
190    // ---------------------------------------------------------------
191    // Counter atomicity
192    // ---------------------------------------------------------------
193
194    #[test]
195    fn test_counter_is_monotonic() {
196        let a = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
197        let b = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
198        let c = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
199        assert!(a < b);
200        assert!(b < c);
201    }
202
203    #[test]
204    fn counter_increments_atomically_across_threads() {
205        use std::sync::Arc;
206        let results: Arc<std::sync::Mutex<Vec<u64>>> = Arc::new(std::sync::Mutex::new(Vec::new()));
207        let mut handles = Vec::new();
208        for _ in 0..10 {
209            let results = Arc::clone(&results);
210            handles.push(std::thread::spawn(move || {
211                for _ in 0..10 {
212                    let val = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
213                    results.lock().unwrap().push(val);
214                }
215            }));
216        }
217        for h in handles {
218            h.join().unwrap();
219        }
220        let mut vals = results.lock().unwrap().clone();
221        assert_eq!(vals.len(), 100, "expected 100 counter values");
222        // All values must be unique (no duplicates from racing threads)
223        let set: HashSet<u64> = vals.iter().copied().collect();
224        assert_eq!(
225            set.len(),
226            100,
227            "counter values must be unique across threads"
228        );
229        // Sorted values must be strictly increasing
230        vals.sort();
231        for window in vals.windows(2) {
232            assert!(window[0] < window[1], "counter must be strictly increasing");
233        }
234    }
235
236    // ---------------------------------------------------------------
237    // Concurrency — multiple TestContexts
238    // ---------------------------------------------------------------
239
240    #[test]
241    fn multiple_schema_names_created_simultaneously_are_different() {
242        // Simulate what happens when multiple tests call setup at the same instant
243        let names: Vec<String> = (0..50)
244            .map(|_| {
245                format!(
246                    "__bsql_test_{}_{}",
247                    std::process::id(),
248                    TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
249                )
250            })
251            .collect();
252        let set: HashSet<&String> = names.iter().collect();
253        assert_eq!(set.len(), names.len(), "all schema names must be unique");
254    }
255
256    // ---------------------------------------------------------------
257    // Setup error paths
258    // ---------------------------------------------------------------
259
260    #[tokio::test]
261    async fn missing_db_url_returns_clear_error() {
262        // Temporarily unset both env vars (if set)
263        let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
264        let orig_db = std::env::var("DATABASE_URL").ok();
265        std::env::remove_var("BSQL_DATABASE_URL");
266        std::env::remove_var("DATABASE_URL");
267
268        let result = setup_test_schema(&[]).await;
269        assert!(result.is_err());
270        let err = result.unwrap_err();
271        let msg = err.to_string();
272        assert!(
273            msg.contains("BSQL_DATABASE_URL") && msg.contains("DATABASE_URL"),
274            "error should mention both env vars, got: {msg}"
275        );
276
277        // Restore
278        if let Some(v) = orig_bsql {
279            std::env::set_var("BSQL_DATABASE_URL", v);
280        }
281        if let Some(v) = orig_db {
282            std::env::set_var("DATABASE_URL", v);
283        }
284    }
285
286    #[tokio::test]
287    async fn missing_bsql_database_url_falls_back_to_database_url() {
288        let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
289        let orig_db = std::env::var("DATABASE_URL").ok();
290        std::env::remove_var("BSQL_DATABASE_URL");
291        // Set DATABASE_URL to something invalid so we get past env-check but fail on connect
292        std::env::set_var("DATABASE_URL", "not-a-url");
293
294        let result = setup_test_schema(&[]).await;
295        // Should fail on URL parse, not on missing env var
296        assert!(result.is_err());
297        let msg = result.unwrap_err().to_string();
298        assert!(
299            msg.contains("invalid database URL"),
300            "should fail on URL parse after falling back to DATABASE_URL, got: {msg}"
301        );
302
303        // Restore
304        std::env::remove_var("DATABASE_URL");
305        if let Some(v) = orig_bsql {
306            std::env::set_var("BSQL_DATABASE_URL", v);
307        }
308        if let Some(v) = orig_db {
309            std::env::set_var("DATABASE_URL", v);
310        }
311    }
312
313    #[tokio::test]
314    async fn invalid_db_url_returns_clear_error() {
315        let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
316        let orig_db = std::env::var("DATABASE_URL").ok();
317        std::env::set_var("BSQL_DATABASE_URL", "not-a-valid-url");
318        std::env::remove_var("DATABASE_URL");
319
320        let result = setup_test_schema(&[]).await;
321        assert!(result.is_err());
322        let err = result.unwrap_err();
323        let msg = err.to_string();
324        assert!(
325            msg.contains("invalid database URL"),
326            "error should mention invalid URL, got: {msg}"
327        );
328
329        // Restore
330        std::env::remove_var("BSQL_DATABASE_URL");
331        if let Some(v) = orig_bsql {
332            std::env::set_var("BSQL_DATABASE_URL", v);
333        }
334        if let Some(v) = orig_db {
335            std::env::set_var("DATABASE_URL", v);
336        }
337    }
338
339    #[tokio::test]
340    async fn invalid_db_url_not_postgres_scheme() {
341        let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
342        let orig_db = std::env::var("DATABASE_URL").ok();
343        std::env::set_var("BSQL_DATABASE_URL", "mysql://user:pass@localhost/db");
344        std::env::remove_var("DATABASE_URL");
345
346        let result = setup_test_schema(&[]).await;
347        assert!(result.is_err());
348        let msg = result.unwrap_err().to_string();
349        assert!(
350            msg.contains("invalid database URL"),
351            "non-postgres scheme should fail with clear error, got: {msg}"
352        );
353
354        std::env::remove_var("BSQL_DATABASE_URL");
355        if let Some(v) = orig_bsql {
356            std::env::set_var("BSQL_DATABASE_URL", v);
357        }
358        if let Some(v) = orig_db {
359            std::env::set_var("DATABASE_URL", v);
360        }
361    }
362
363    #[test]
364    fn connection_refused_unreachable_host() {
365        // Test the connection-refused path directly, bypassing env-var setup
366        // to avoid races with other concurrent async tests that manipulate env.
367        let url = "postgres://user:pass@127.0.0.1:1/testdb";
368        let config = Config::from_url(url).expect("URL should parse");
369        let conn_result = Connection::connect(&config);
370        assert!(conn_result.is_err(), "connection to port 1 should fail");
371        // Verify the error maps to a ConnectError with "connection failed" message
372        // (this is the exact error path that setup_test_schema takes)
373        let err = ConnectError::create(format!("connection failed: {}", conn_result.unwrap_err()));
374        let msg = err.to_string();
375        assert!(
376            msg.contains("connection failed"),
377            "unreachable host should produce 'connection failed' error, got: {msg}"
378        );
379    }
380
381    // ---------------------------------------------------------------
382    // TestContext Debug
383    // ---------------------------------------------------------------
384
385    #[test]
386    fn test_context_has_debug_impl() {
387        // Verify that TestContext implements Debug (compile-time check).
388        fn assert_debug<T: std::fmt::Debug>() {}
389        assert_debug::<TestContext>();
390    }
391
392    #[test]
393    fn test_context_debug_shows_schema_name() {
394        // We can't easily construct a full TestContext without a real DB,
395        // but we can test the Debug format by constructing the expected string.
396        // The Debug impl should show schema field.
397        let schema = "__bsql_test_12345_0";
398        let expected = format!("TestContext {{ schema: {:?} }}", schema);
399        // Just verify the format pattern is correct
400        assert!(expected.contains("TestContext"));
401        assert!(expected.contains("schema"));
402        assert!(expected.contains(schema));
403    }
404
405    // ---------------------------------------------------------------
406    // Drop behavior
407    // ---------------------------------------------------------------
408
409    #[test]
410    fn drop_code_path_with_invalid_url_does_not_panic() {
411        // We can't construct a TestContext without a real Pool (async), so we
412        // exercise the exact Drop code path manually. This is the same logic
413        // that TestContext::drop executes.
414        let db_url = "garbage-url";
415        let schema_name = "__bsql_test_fake_0";
416        // Step 1: Config::from_url — should fail for a garbage URL
417        if let Ok(config) = Config::from_url(db_url) {
418            // Step 2: Connection::connect — would fail but we shouldn't reach here
419            if let Ok(mut conn) = Connection::connect(&config) {
420                let _ = conn.simple_query(&format!(
421                    "DROP SCHEMA IF EXISTS \"{}\" CASCADE",
422                    schema_name
423                ));
424            }
425        }
426        // If we get here without panicking, the drop path is safe.
427    }
428
429    #[test]
430    fn drop_with_garbage_url_does_not_panic() {
431        // Directly exercise the Drop code path with an invalid URL.
432        // This ensures Config::from_url failure doesn't cause a panic in Drop.
433        //
434        // We test the conditional logic in Drop:
435        //   if let Ok(config) = Config::from_url(&self.db_url) { ... }
436        // An invalid URL means Config::from_url returns Err, so drop exits silently.
437        let db_url = "not-a-postgres-url";
438        let config_result = Config::from_url(db_url);
439        assert!(config_result.is_err(), "garbage URL should not parse");
440        // The Drop impl would exit at the first `if let Ok(...)` — no panic.
441    }
442
443    #[test]
444    fn drop_with_valid_url_but_unreachable_host_does_not_panic() {
445        // Even if Config::from_url succeeds, Connection::connect can fail.
446        // Drop should handle this gracefully.
447        let db_url = "postgres://user:pass@127.0.0.1:1/testdb";
448        let config = Config::from_url(db_url);
449        assert!(config.is_ok(), "URL should parse");
450        let conn_result = Connection::connect(&config.unwrap());
451        assert!(conn_result.is_err(), "connection to port 1 should fail");
452        // The Drop impl would exit at the second `if let Ok(...)` — no panic.
453    }
454
455    // ---------------------------------------------------------------
456    // Fixture edge cases (tested via the setup function's logic)
457    // ---------------------------------------------------------------
458
459    #[test]
460    fn empty_fixture_string_is_skipped() {
461        // The setup function skips empty fixtures: `if !fixture_sql.trim().is_empty()`
462        // Verify the logic directly.
463        let fixture = "";
464        assert!(fixture.trim().is_empty(), "empty string should be skipped");
465    }
466
467    #[test]
468    fn whitespace_only_fixture_is_skipped() {
469        let fixture = "   \n\t  \n  ";
470        assert!(
471            fixture.trim().is_empty(),
472            "whitespace-only fixture should be skipped"
473        );
474    }
475
476    #[test]
477    fn fixture_with_only_comments_is_not_empty() {
478        // SQL comments are not whitespace, so they pass the trim check.
479        // PostgreSQL will accept them as valid SQL (no-op).
480        let fixture = "-- just a comment\n/* block comment */";
481        assert!(
482            !fixture.trim().is_empty(),
483            "comment-only fixture should NOT be skipped (PG handles it)"
484        );
485    }
486
487    #[test]
488    fn fixture_with_multiple_statements_passes_trim_check() {
489        let fixture = "CREATE TABLE a (id INT);\nCREATE TABLE b (id INT);";
490        assert!(!fixture.trim().is_empty());
491    }
492
493    // ---------------------------------------------------------------
494    // Error type verification
495    // ---------------------------------------------------------------
496
497    #[test]
498    fn missing_env_error_is_connect_variant() {
499        let err =
500            ConnectError::create("BSQL_DATABASE_URL or DATABASE_URL must be set for #[bsql::test]");
501        match err {
502            BsqlError::Connect(ref ce) => {
503                assert!(ce.message.contains("BSQL_DATABASE_URL"));
504            }
505            _ => panic!("expected Connect variant"),
506        }
507    }
508
509    #[test]
510    fn invalid_url_error_is_connect_variant() {
511        let err = ConnectError::create("invalid database URL: missing postgres:// prefix");
512        match err {
513            BsqlError::Connect(ref ce) => {
514                assert!(ce.message.contains("invalid database URL"));
515            }
516            _ => panic!("expected Connect variant"),
517        }
518    }
519
520    #[test]
521    fn connection_failed_error_is_connect_variant() {
522        let err = ConnectError::create("connection failed: Connection refused");
523        match err {
524            BsqlError::Connect(ref ce) => {
525                assert!(ce.message.contains("connection failed"));
526            }
527            _ => panic!("expected Connect variant"),
528        }
529    }
530
531    #[test]
532    fn fixture_failed_error_is_connect_variant() {
533        let err = ConnectError::create("fixture failed: syntax error at position 5");
534        match err {
535            BsqlError::Connect(ref ce) => {
536                assert!(ce.message.contains("fixture failed"));
537            }
538            _ => panic!("expected Connect variant"),
539        }
540    }
541
542    #[test]
543    fn schema_creation_failed_error_is_connect_variant() {
544        let err = ConnectError::create("failed to create test schema: permission denied");
545        match err {
546            BsqlError::Connect(ref ce) => {
547                assert!(ce.message.contains("failed to create test schema"));
548            }
549            _ => panic!("expected Connect variant"),
550        }
551    }
552
553    // ---------------------------------------------------------------
554    // Schema name format deep verification
555    // ---------------------------------------------------------------
556
557    #[test]
558    fn schema_name_has_three_parts() {
559        let counter = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
560        let pid = std::process::id();
561        let name = format!("__bsql_test_{}_{}", pid, counter);
562        // Parts: prefix "__bsql_test", pid, counter
563        assert!(name.starts_with("__bsql_test_"));
564        let suffix = &name["__bsql_test_".len()..];
565        let parts: Vec<&str> = suffix.split('_').collect();
566        assert_eq!(parts.len(), 2, "expected PID_COUNTER suffix, got: {suffix}");
567        assert_eq!(parts[0], pid.to_string());
568        assert_eq!(parts[1], counter.to_string());
569    }
570
571    #[test]
572    fn schema_name_counter_part_increases() {
573        let c1 = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
574        let c2 = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
575        let pid = std::process::id();
576        let name1 = format!("__bsql_test_{}_{}", pid, c1);
577        let name2 = format!("__bsql_test_{}_{}", pid, c2);
578        // Extract counter from name
579        let counter1: u64 = name1.rsplit('_').next().unwrap().parse().unwrap();
580        let counter2: u64 = name2.rsplit('_').next().unwrap().parse().unwrap();
581        assert!(counter2 > counter1);
582    }
583
584    // ---------------------------------------------------------------
585    // BSQL_DATABASE_URL takes priority over DATABASE_URL
586    // ---------------------------------------------------------------
587
588    #[tokio::test]
589    async fn bsql_database_url_takes_priority_over_database_url() {
590        let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
591        let orig_db = std::env::var("DATABASE_URL").ok();
592
593        // Set both — BSQL_DATABASE_URL should win
594        // Use an invalid URL so we can see which one is used in the error
595        std::env::set_var("BSQL_DATABASE_URL", "not-postgres-bsql");
596        std::env::set_var("DATABASE_URL", "postgres://user:pass@127.0.0.1:1/realdb");
597
598        let result = setup_test_schema(&[]).await;
599        assert!(result.is_err());
600        let msg = result.unwrap_err().to_string();
601        // Should fail because BSQL_DATABASE_URL is not a valid postgres URL
602        assert!(
603            msg.contains("invalid database URL"),
604            "BSQL_DATABASE_URL should take priority, got: {msg}"
605        );
606
607        // Restore
608        std::env::remove_var("BSQL_DATABASE_URL");
609        std::env::remove_var("DATABASE_URL");
610        if let Some(v) = orig_bsql {
611            std::env::set_var("BSQL_DATABASE_URL", v);
612        }
613        if let Some(v) = orig_db {
614            std::env::set_var("DATABASE_URL", v);
615        }
616    }
617}