Skip to main content

bsql_core/
test_support.rs

1//! Test infrastructure for `#[bsql::test]`.
2//!
3//! Creates isolated PostgreSQL schemas per test for parallel execution.
4//! Fixtures (SQL files) are applied to the schema before the test runs.
5//! Schema is dropped after the test -- even on panic.
6
7use std::sync::atomic::{AtomicU64, Ordering};
8
9use bsql_driver_postgres::{Config, Connection};
10
11use crate::error::{BsqlError, ConnectError};
12use crate::pool::Pool;
13
14static TEST_COUNTER: AtomicU64 = AtomicU64::new(0);
15
16/// Test context holding the pool and cleanup info.
17/// Drops the schema on cleanup.
18pub struct TestContext {
19    /// The connection pool, scoped to the isolated test schema.
20    pub pool: Pool,
21    schema_name: String,
22    db_url: String,
23}
24
25impl std::fmt::Debug for TestContext {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("TestContext")
28            .field("schema", &self.schema_name)
29            .finish()
30    }
31}
32
33impl Drop for TestContext {
34    fn drop(&mut self) {
35        // Fresh connection for cleanup (pool connection may be broken after panic).
36        // Errors are intentionally ignored -- we are in a destructor.
37        if let Ok(config) = Config::from_url(&self.db_url) {
38            if let Ok(mut conn) = Connection::connect(&config) {
39                let _ = conn.simple_query(&format!(
40                    "DROP SCHEMA IF EXISTS \"{}\" CASCADE",
41                    self.schema_name
42                ));
43            }
44        }
45    }
46}
47
48/// Set up an isolated test schema with fixtures.
49///
50/// Called by generated `#[bsql::test]` code. Not intended for direct use.
51///
52/// `fixtures_sql` contains compile-time embedded SQL strings from fixture files.
53pub async fn setup_test_schema(fixtures_sql: &[&str]) -> Result<TestContext, BsqlError> {
54    let db_url = std::env::var("BSQL_DATABASE_URL")
55        .or_else(|_| std::env::var("DATABASE_URL"))
56        .map_err(|_| {
57            ConnectError::create("BSQL_DATABASE_URL or DATABASE_URL must be set for #[bsql::test]")
58        })?;
59
60    let schema_name = format!(
61        "__bsql_test_{}_{}",
62        std::process::id(),
63        TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
64    );
65
66    // Setup connection: create schema, apply fixtures
67    let config = Config::from_url(&db_url)
68        .map_err(|e| ConnectError::create(format!("invalid database URL: {e}")))?;
69    let mut conn = Connection::connect(&config)
70        .map_err(|e| ConnectError::create(format!("connection failed: {e}")))?;
71
72    // Create isolated schema
73    conn.simple_query(&format!("CREATE SCHEMA \"{}\"", schema_name))
74        .map_err(|e| ConnectError::create(format!("failed to create test schema: {e}")))?;
75
76    // Set search_path to test schema (with public for extensions)
77    conn.simple_query(&format!("SET search_path TO \"{}\", public", schema_name))
78        .map_err(|e| ConnectError::create(format!("failed to set search_path: {e}")))?;
79
80    // Apply fixtures in order
81    for fixture_sql in fixtures_sql {
82        if !fixture_sql.trim().is_empty() {
83            conn.simple_query(fixture_sql)
84                .map_err(|e| ConnectError::create(format!("fixture failed: {e}")))?;
85        }
86    }
87
88    drop(conn); // Release setup connection
89
90    // Build pool. Connections are lazy, so we create the pool first,
91    // then immediately acquire one connection and set search_path on it.
92    let pool = Pool::connect(&db_url).await?;
93
94    // Acquire a connection and set search_path so all subsequent queries
95    // in this test run against the isolated schema.
96    pool.raw_execute(&format!("SET search_path TO \"{}\", public", schema_name))
97        .await?;
98
99    // Set warmup SQL so any *new* connections from this pool also get
100    // the correct search_path (the pool has max_size=10 by default,
101    // but for tests we typically only use 1 connection).
102    let warmup_sql = format!("SET search_path TO \"{}\", public", schema_name);
103    // Pool::set_warmup_sqls takes &[&str] but the string must live long enough.
104    // Since warmup is best-effort and tests are short-lived, we leak the string
105    // to get a 'static lifetime. The leak is bounded (one allocation per test).
106    let leaked: &'static str = Box::leak(warmup_sql.into_boxed_str());
107    pool.set_warmup_sqls(&[leaked]);
108
109    Ok(TestContext {
110        pool,
111        schema_name,
112        db_url,
113    })
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use std::collections::HashSet;
120
121    // ---------------------------------------------------------------
122    // Schema lifecycle
123    // ---------------------------------------------------------------
124
125    #[test]
126    fn schema_name_is_unique() {
127        let name1 = format!(
128            "__bsql_test_{}_{}",
129            std::process::id(),
130            TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
131        );
132        let name2 = format!(
133            "__bsql_test_{}_{}",
134            std::process::id(),
135            TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
136        );
137        assert_ne!(name1, name2);
138    }
139
140    #[test]
141    fn schema_name_contains_pid() {
142        let name = format!(
143            "__bsql_test_{}_{}",
144            std::process::id(),
145            TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
146        );
147        assert!(name.contains(&std::process::id().to_string()));
148    }
149
150    #[test]
151    fn schema_name_starts_with_prefix() {
152        let name = format!(
153            "__bsql_test_{}_{}",
154            std::process::id(),
155            TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
156        );
157        assert!(name.starts_with("__bsql_test_"));
158    }
159
160    #[test]
161    fn schema_names_never_collide_100_sequential() {
162        let mut names = HashSet::new();
163        for _ in 0..100 {
164            let name = format!(
165                "__bsql_test_{}_{}",
166                std::process::id(),
167                TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
168            );
169            assert!(names.insert(name.clone()), "duplicate schema name: {name}");
170        }
171        assert_eq!(names.len(), 100);
172    }
173
174    #[test]
175    fn schema_name_is_valid_sql_identifier() {
176        let name = format!(
177            "__bsql_test_{}_{}",
178            std::process::id(),
179            TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
180        );
181        // Valid SQL identifier: starts with letter or underscore, then alphanumeric/underscore
182        assert!(
183            name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_'),
184            "schema name contains invalid chars: {name}"
185        );
186        assert!(
187            name.starts_with('_') || name.starts_with(|c: char| c.is_ascii_alphabetic()),
188            "schema name must start with letter or underscore: {name}"
189        );
190    }
191
192    // ---------------------------------------------------------------
193    // Counter atomicity
194    // ---------------------------------------------------------------
195
196    #[test]
197    fn test_counter_is_monotonic() {
198        let a = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
199        let b = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
200        let c = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
201        assert!(a < b);
202        assert!(b < c);
203    }
204
205    #[test]
206    fn counter_increments_atomically_across_threads() {
207        use std::sync::Arc;
208        let results: Arc<std::sync::Mutex<Vec<u64>>> = Arc::new(std::sync::Mutex::new(Vec::new()));
209        let mut handles = Vec::new();
210        for _ in 0..10 {
211            let results = Arc::clone(&results);
212            handles.push(std::thread::spawn(move || {
213                for _ in 0..10 {
214                    let val = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
215                    results.lock().unwrap().push(val);
216                }
217            }));
218        }
219        for h in handles {
220            h.join().unwrap();
221        }
222        let mut vals = results.lock().unwrap().clone();
223        assert_eq!(vals.len(), 100, "expected 100 counter values");
224        // All values must be unique (no duplicates from racing threads)
225        let set: HashSet<u64> = vals.iter().copied().collect();
226        assert_eq!(
227            set.len(),
228            100,
229            "counter values must be unique across threads"
230        );
231        // Sorted values must be strictly increasing
232        vals.sort();
233        for window in vals.windows(2) {
234            assert!(window[0] < window[1], "counter must be strictly increasing");
235        }
236    }
237
238    // ---------------------------------------------------------------
239    // Concurrency — multiple TestContexts
240    // ---------------------------------------------------------------
241
242    #[test]
243    fn multiple_schema_names_created_simultaneously_are_different() {
244        // Simulate what happens when multiple tests call setup at the same instant
245        let names: Vec<String> = (0..50)
246            .map(|_| {
247                format!(
248                    "__bsql_test_{}_{}",
249                    std::process::id(),
250                    TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
251                )
252            })
253            .collect();
254        let set: HashSet<&String> = names.iter().collect();
255        assert_eq!(set.len(), names.len(), "all schema names must be unique");
256    }
257
258    // ---------------------------------------------------------------
259    // Setup error paths
260    // ---------------------------------------------------------------
261
262    #[tokio::test]
263    async fn missing_db_url_returns_clear_error() {
264        // Temporarily unset both env vars (if set)
265        let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
266        let orig_db = std::env::var("DATABASE_URL").ok();
267        std::env::remove_var("BSQL_DATABASE_URL");
268        std::env::remove_var("DATABASE_URL");
269
270        let result = setup_test_schema(&[]).await;
271        assert!(result.is_err());
272        let err = result.unwrap_err();
273        let msg = err.to_string();
274        assert!(
275            msg.contains("BSQL_DATABASE_URL") && msg.contains("DATABASE_URL"),
276            "error should mention both env vars, got: {msg}"
277        );
278
279        // Restore
280        if let Some(v) = orig_bsql {
281            std::env::set_var("BSQL_DATABASE_URL", v);
282        }
283        if let Some(v) = orig_db {
284            std::env::set_var("DATABASE_URL", v);
285        }
286    }
287
288    #[tokio::test]
289    async fn missing_bsql_database_url_falls_back_to_database_url() {
290        let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
291        let orig_db = std::env::var("DATABASE_URL").ok();
292        std::env::remove_var("BSQL_DATABASE_URL");
293        // Set DATABASE_URL to something invalid so we get past env-check but fail on connect
294        std::env::set_var("DATABASE_URL", "not-a-url");
295
296        let result = setup_test_schema(&[]).await;
297        // Should fail on URL parse, not on missing env var
298        assert!(result.is_err());
299        let msg = result.unwrap_err().to_string();
300        assert!(
301            msg.contains("invalid database URL"),
302            "should fail on URL parse after falling back to DATABASE_URL, got: {msg}"
303        );
304
305        // Restore
306        std::env::remove_var("DATABASE_URL");
307        if let Some(v) = orig_bsql {
308            std::env::set_var("BSQL_DATABASE_URL", v);
309        }
310        if let Some(v) = orig_db {
311            std::env::set_var("DATABASE_URL", v);
312        }
313    }
314
315    #[tokio::test]
316    async fn invalid_db_url_returns_clear_error() {
317        let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
318        let orig_db = std::env::var("DATABASE_URL").ok();
319        std::env::set_var("BSQL_DATABASE_URL", "not-a-valid-url");
320        std::env::remove_var("DATABASE_URL");
321
322        let result = setup_test_schema(&[]).await;
323        assert!(result.is_err());
324        let err = result.unwrap_err();
325        let msg = err.to_string();
326        assert!(
327            msg.contains("invalid database URL"),
328            "error should mention invalid URL, got: {msg}"
329        );
330
331        // Restore
332        std::env::remove_var("BSQL_DATABASE_URL");
333        if let Some(v) = orig_bsql {
334            std::env::set_var("BSQL_DATABASE_URL", v);
335        }
336        if let Some(v) = orig_db {
337            std::env::set_var("DATABASE_URL", v);
338        }
339    }
340
341    #[tokio::test]
342    async fn invalid_db_url_not_postgres_scheme() {
343        let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
344        let orig_db = std::env::var("DATABASE_URL").ok();
345        std::env::set_var("BSQL_DATABASE_URL", "mysql://user:pass@localhost/db");
346        std::env::remove_var("DATABASE_URL");
347
348        let result = setup_test_schema(&[]).await;
349        assert!(result.is_err());
350        let msg = result.unwrap_err().to_string();
351        assert!(
352            msg.contains("invalid database URL"),
353            "non-postgres scheme should fail with clear error, got: {msg}"
354        );
355
356        std::env::remove_var("BSQL_DATABASE_URL");
357        if let Some(v) = orig_bsql {
358            std::env::set_var("BSQL_DATABASE_URL", v);
359        }
360        if let Some(v) = orig_db {
361            std::env::set_var("DATABASE_URL", v);
362        }
363    }
364
365    #[test]
366    fn connection_refused_unreachable_host() {
367        // Test the connection-refused path directly, bypassing env-var setup
368        // to avoid races with other concurrent async tests that manipulate env.
369        let url = "postgres://user:pass@127.0.0.1:1/testdb";
370        let config = Config::from_url(url).expect("URL should parse");
371        let conn_result = Connection::connect(&config);
372        assert!(conn_result.is_err(), "connection to port 1 should fail");
373        // Verify the error maps to a ConnectError with "connection failed" message
374        // (this is the exact error path that setup_test_schema takes)
375        let err = ConnectError::create(format!("connection failed: {}", conn_result.unwrap_err()));
376        let msg = err.to_string();
377        assert!(
378            msg.contains("connection failed"),
379            "unreachable host should produce 'connection failed' error, got: {msg}"
380        );
381    }
382
383    // ---------------------------------------------------------------
384    // TestContext Debug
385    // ---------------------------------------------------------------
386
387    #[test]
388    fn test_context_has_debug_impl() {
389        // Verify that TestContext implements Debug (compile-time check).
390        fn assert_debug<T: std::fmt::Debug>() {}
391        assert_debug::<TestContext>();
392    }
393
394    #[test]
395    fn test_context_debug_shows_schema_name() {
396        // We can't easily construct a full TestContext without a real DB,
397        // but we can test the Debug format by constructing the expected string.
398        // The Debug impl should show schema field.
399        let schema = "__bsql_test_12345_0";
400        let expected = format!("TestContext {{ schema: {:?} }}", schema);
401        // Just verify the format pattern is correct
402        assert!(expected.contains("TestContext"));
403        assert!(expected.contains("schema"));
404        assert!(expected.contains(schema));
405    }
406
407    // ---------------------------------------------------------------
408    // Drop behavior
409    // ---------------------------------------------------------------
410
411    #[test]
412    fn drop_code_path_with_invalid_url_does_not_panic() {
413        // We can't construct a TestContext without a real Pool (async), so we
414        // exercise the exact Drop code path manually. This is the same logic
415        // that TestContext::drop executes.
416        let db_url = "garbage-url";
417        let schema_name = "__bsql_test_fake_0";
418        // Step 1: Config::from_url — should fail for a garbage URL
419        if let Ok(config) = Config::from_url(db_url) {
420            // Step 2: Connection::connect — would fail but we shouldn't reach here
421            if let Ok(mut conn) = Connection::connect(&config) {
422                let _ = conn.simple_query(&format!(
423                    "DROP SCHEMA IF EXISTS \"{}\" CASCADE",
424                    schema_name
425                ));
426            }
427        }
428        // If we get here without panicking, the drop path is safe.
429    }
430
431    #[test]
432    fn drop_with_garbage_url_does_not_panic() {
433        // Directly exercise the Drop code path with an invalid URL.
434        // This ensures Config::from_url failure doesn't cause a panic in Drop.
435        //
436        // We test the conditional logic in Drop:
437        //   if let Ok(config) = Config::from_url(&self.db_url) { ... }
438        // An invalid URL means Config::from_url returns Err, so drop exits silently.
439        let db_url = "not-a-postgres-url";
440        let config_result = Config::from_url(db_url);
441        assert!(config_result.is_err(), "garbage URL should not parse");
442        // The Drop impl would exit at the first `if let Ok(...)` — no panic.
443    }
444
445    #[test]
446    fn drop_with_valid_url_but_unreachable_host_does_not_panic() {
447        // Even if Config::from_url succeeds, Connection::connect can fail.
448        // Drop should handle this gracefully.
449        let db_url = "postgres://user:pass@127.0.0.1:1/testdb";
450        let config = Config::from_url(db_url);
451        assert!(config.is_ok(), "URL should parse");
452        let conn_result = Connection::connect(&config.unwrap());
453        assert!(conn_result.is_err(), "connection to port 1 should fail");
454        // The Drop impl would exit at the second `if let Ok(...)` — no panic.
455    }
456
457    // ---------------------------------------------------------------
458    // Fixture edge cases (tested via the setup function's logic)
459    // ---------------------------------------------------------------
460
461    #[test]
462    fn empty_fixture_string_is_skipped() {
463        // The setup function skips empty fixtures: `if !fixture_sql.trim().is_empty()`
464        // Verify the logic directly.
465        let fixture = "";
466        assert!(fixture.trim().is_empty(), "empty string should be skipped");
467    }
468
469    #[test]
470    fn whitespace_only_fixture_is_skipped() {
471        let fixture = "   \n\t  \n  ";
472        assert!(
473            fixture.trim().is_empty(),
474            "whitespace-only fixture should be skipped"
475        );
476    }
477
478    #[test]
479    fn fixture_with_only_comments_is_not_empty() {
480        // SQL comments are not whitespace, so they pass the trim check.
481        // PostgreSQL will accept them as valid SQL (no-op).
482        let fixture = "-- just a comment\n/* block comment */";
483        assert!(
484            !fixture.trim().is_empty(),
485            "comment-only fixture should NOT be skipped (PG handles it)"
486        );
487    }
488
489    #[test]
490    fn fixture_with_multiple_statements_passes_trim_check() {
491        let fixture = "CREATE TABLE a (id INT);\nCREATE TABLE b (id INT);";
492        assert!(!fixture.trim().is_empty());
493    }
494
495    // ---------------------------------------------------------------
496    // Error type verification
497    // ---------------------------------------------------------------
498
499    #[test]
500    fn missing_env_error_is_connect_variant() {
501        let err =
502            ConnectError::create("BSQL_DATABASE_URL or DATABASE_URL must be set for #[bsql::test]");
503        match err {
504            BsqlError::Connect(ref ce) => {
505                assert!(ce.message.contains("BSQL_DATABASE_URL"));
506            }
507            _ => panic!("expected Connect variant"),
508        }
509    }
510
511    #[test]
512    fn invalid_url_error_is_connect_variant() {
513        let err = ConnectError::create("invalid database URL: missing postgres:// prefix");
514        match err {
515            BsqlError::Connect(ref ce) => {
516                assert!(ce.message.contains("invalid database URL"));
517            }
518            _ => panic!("expected Connect variant"),
519        }
520    }
521
522    #[test]
523    fn connection_failed_error_is_connect_variant() {
524        let err = ConnectError::create("connection failed: Connection refused");
525        match err {
526            BsqlError::Connect(ref ce) => {
527                assert!(ce.message.contains("connection failed"));
528            }
529            _ => panic!("expected Connect variant"),
530        }
531    }
532
533    #[test]
534    fn fixture_failed_error_is_connect_variant() {
535        let err = ConnectError::create("fixture failed: syntax error at position 5");
536        match err {
537            BsqlError::Connect(ref ce) => {
538                assert!(ce.message.contains("fixture failed"));
539            }
540            _ => panic!("expected Connect variant"),
541        }
542    }
543
544    #[test]
545    fn schema_creation_failed_error_is_connect_variant() {
546        let err = ConnectError::create("failed to create test schema: permission denied");
547        match err {
548            BsqlError::Connect(ref ce) => {
549                assert!(ce.message.contains("failed to create test schema"));
550            }
551            _ => panic!("expected Connect variant"),
552        }
553    }
554
555    // ---------------------------------------------------------------
556    // Schema name format deep verification
557    // ---------------------------------------------------------------
558
559    #[test]
560    fn schema_name_has_three_parts() {
561        let counter = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
562        let pid = std::process::id();
563        let name = format!("__bsql_test_{}_{}", pid, counter);
564        // Parts: prefix "__bsql_test", pid, counter
565        assert!(name.starts_with("__bsql_test_"));
566        let suffix = &name["__bsql_test_".len()..];
567        let parts: Vec<&str> = suffix.split('_').collect();
568        assert_eq!(parts.len(), 2, "expected PID_COUNTER suffix, got: {suffix}");
569        assert_eq!(parts[0], pid.to_string());
570        assert_eq!(parts[1], counter.to_string());
571    }
572
573    #[test]
574    fn schema_name_counter_part_increases() {
575        let c1 = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
576        let c2 = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
577        let pid = std::process::id();
578        let name1 = format!("__bsql_test_{}_{}", pid, c1);
579        let name2 = format!("__bsql_test_{}_{}", pid, c2);
580        // Extract counter from name
581        let counter1: u64 = name1.rsplit('_').next().unwrap().parse().unwrap();
582        let counter2: u64 = name2.rsplit('_').next().unwrap().parse().unwrap();
583        assert!(counter2 > counter1);
584    }
585
586    // ---------------------------------------------------------------
587    // BSQL_DATABASE_URL takes priority over DATABASE_URL
588    // ---------------------------------------------------------------
589
590    #[tokio::test]
591    async fn bsql_database_url_takes_priority_over_database_url() {
592        let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
593        let orig_db = std::env::var("DATABASE_URL").ok();
594
595        // Set both — BSQL_DATABASE_URL should win
596        // Use an invalid URL so we can see which one is used in the error
597        std::env::set_var("BSQL_DATABASE_URL", "not-postgres-bsql");
598        std::env::set_var("DATABASE_URL", "postgres://user:pass@127.0.0.1:1/realdb");
599
600        let result = setup_test_schema(&[]).await;
601        assert!(result.is_err());
602        let msg = result.unwrap_err().to_string();
603        // Should fail because BSQL_DATABASE_URL is not a valid postgres URL
604        assert!(
605            msg.contains("invalid database URL"),
606            "BSQL_DATABASE_URL should take priority, got: {msg}"
607        );
608
609        // Restore
610        std::env::remove_var("BSQL_DATABASE_URL");
611        std::env::remove_var("DATABASE_URL");
612        if let Some(v) = orig_bsql {
613            std::env::set_var("BSQL_DATABASE_URL", v);
614        }
615        if let Some(v) = orig_db {
616            std::env::set_var("DATABASE_URL", v);
617        }
618    }
619}