use std::sync::atomic::{AtomicU64, Ordering};
use bsql_driver_postgres::{Config, Connection};
use crate::error::{BsqlError, ConnectError};
use crate::pool::Pool;
static TEST_COUNTER: AtomicU64 = AtomicU64::new(0);
pub struct TestContext {
pub pool: Pool,
schema_name: String,
db_url: String,
}
impl std::fmt::Debug for TestContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestContext")
.field("schema", &self.schema_name)
.finish()
}
}
impl Drop for TestContext {
fn drop(&mut self) {
if let Ok(config) = Config::from_url(&self.db_url) {
if let Ok(mut conn) = Connection::connect(&config) {
let _ = conn.simple_query(&format!(
"DROP SCHEMA IF EXISTS \"{}\" CASCADE",
self.schema_name
));
}
}
}
}
pub async fn setup_test_schema(fixtures_sql: &[&str]) -> Result<TestContext, BsqlError> {
let db_url = std::env::var("BSQL_DATABASE_URL")
.or_else(|_| std::env::var("DATABASE_URL"))
.map_err(|_| {
ConnectError::create("BSQL_DATABASE_URL or DATABASE_URL must be set for #[bsql::test]")
})?;
let schema_name = format!(
"__bsql_test_{}_{}",
std::process::id(),
TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
);
let config = Config::from_url(&db_url)
.map_err(|e| ConnectError::create(format!("invalid database URL: {e}")))?;
let mut conn = Connection::connect(&config)
.map_err(|e| ConnectError::create(format!("connection failed: {e}")))?;
conn.simple_query(&format!("CREATE SCHEMA \"{}\"", schema_name))
.map_err(|e| ConnectError::create(format!("failed to create test schema: {e}")))?;
conn.simple_query(&format!("SET search_path TO \"{}\", public", schema_name))
.map_err(|e| ConnectError::create(format!("failed to set search_path: {e}")))?;
for fixture_sql in fixtures_sql {
if !fixture_sql.trim().is_empty() {
conn.simple_query(fixture_sql)
.map_err(|e| ConnectError::create(format!("fixture failed: {e}")))?;
}
}
drop(conn);
let pool = Pool::connect(&db_url).await?;
pool.raw_execute(&format!("SET search_path TO \"{}\", public", schema_name))
.await?;
let warmup_sql = format!("SET search_path TO \"{}\", public", schema_name);
pool.set_warmup_sqls([warmup_sql]);
Ok(TestContext {
pool,
schema_name,
db_url,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn schema_name_is_unique() {
let name1 = format!(
"__bsql_test_{}_{}",
std::process::id(),
TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
);
let name2 = format!(
"__bsql_test_{}_{}",
std::process::id(),
TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
);
assert_ne!(name1, name2);
}
#[test]
fn schema_name_contains_pid() {
let name = format!(
"__bsql_test_{}_{}",
std::process::id(),
TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
);
assert!(name.contains(&std::process::id().to_string()));
}
#[test]
fn schema_name_starts_with_prefix() {
let name = format!(
"__bsql_test_{}_{}",
std::process::id(),
TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
);
assert!(name.starts_with("__bsql_test_"));
}
#[test]
fn schema_names_never_collide_100_sequential() {
let mut names = HashSet::new();
for _ in 0..100 {
let name = format!(
"__bsql_test_{}_{}",
std::process::id(),
TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
);
assert!(names.insert(name.clone()), "duplicate schema name: {name}");
}
assert_eq!(names.len(), 100);
}
#[test]
fn schema_name_is_valid_sql_identifier() {
let name = format!(
"__bsql_test_{}_{}",
std::process::id(),
TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
);
assert!(
name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_'),
"schema name contains invalid chars: {name}"
);
assert!(
name.starts_with('_') || name.starts_with(|c: char| c.is_ascii_alphabetic()),
"schema name must start with letter or underscore: {name}"
);
}
#[test]
fn test_counter_is_monotonic() {
let a = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
let b = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
let c = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
assert!(a < b);
assert!(b < c);
}
#[test]
fn counter_increments_atomically_across_threads() {
use std::sync::Arc;
let results: Arc<std::sync::Mutex<Vec<u64>>> = Arc::new(std::sync::Mutex::new(Vec::new()));
let mut handles = Vec::new();
for _ in 0..10 {
let results = Arc::clone(&results);
handles.push(std::thread::spawn(move || {
for _ in 0..10 {
let val = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
results.lock().unwrap().push(val);
}
}));
}
for h in handles {
h.join().unwrap();
}
let mut vals = results.lock().unwrap().clone();
assert_eq!(vals.len(), 100, "expected 100 counter values");
let set: HashSet<u64> = vals.iter().copied().collect();
assert_eq!(
set.len(),
100,
"counter values must be unique across threads"
);
vals.sort();
for window in vals.windows(2) {
assert!(window[0] < window[1], "counter must be strictly increasing");
}
}
#[test]
fn multiple_schema_names_created_simultaneously_are_different() {
let names: Vec<String> = (0..50)
.map(|_| {
format!(
"__bsql_test_{}_{}",
std::process::id(),
TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
)
})
.collect();
let set: HashSet<&String> = names.iter().collect();
assert_eq!(set.len(), names.len(), "all schema names must be unique");
}
#[tokio::test]
async fn missing_db_url_returns_clear_error() {
let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
let orig_db = std::env::var("DATABASE_URL").ok();
std::env::remove_var("BSQL_DATABASE_URL");
std::env::remove_var("DATABASE_URL");
let result = setup_test_schema(&[]).await;
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("BSQL_DATABASE_URL") && msg.contains("DATABASE_URL"),
"error should mention both env vars, got: {msg}"
);
if let Some(v) = orig_bsql {
std::env::set_var("BSQL_DATABASE_URL", v);
}
if let Some(v) = orig_db {
std::env::set_var("DATABASE_URL", v);
}
}
#[tokio::test]
async fn missing_bsql_database_url_falls_back_to_database_url() {
let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
let orig_db = std::env::var("DATABASE_URL").ok();
std::env::remove_var("BSQL_DATABASE_URL");
std::env::set_var("DATABASE_URL", "not-a-url");
let result = setup_test_schema(&[]).await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("invalid database URL"),
"should fail on URL parse after falling back to DATABASE_URL, got: {msg}"
);
std::env::remove_var("DATABASE_URL");
if let Some(v) = orig_bsql {
std::env::set_var("BSQL_DATABASE_URL", v);
}
if let Some(v) = orig_db {
std::env::set_var("DATABASE_URL", v);
}
}
#[tokio::test]
async fn invalid_db_url_returns_clear_error() {
let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
let orig_db = std::env::var("DATABASE_URL").ok();
std::env::set_var("BSQL_DATABASE_URL", "not-a-valid-url");
std::env::remove_var("DATABASE_URL");
let result = setup_test_schema(&[]).await;
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("invalid database URL"),
"error should mention invalid URL, got: {msg}"
);
std::env::remove_var("BSQL_DATABASE_URL");
if let Some(v) = orig_bsql {
std::env::set_var("BSQL_DATABASE_URL", v);
}
if let Some(v) = orig_db {
std::env::set_var("DATABASE_URL", v);
}
}
#[tokio::test]
async fn invalid_db_url_not_postgres_scheme() {
let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
let orig_db = std::env::var("DATABASE_URL").ok();
std::env::set_var("BSQL_DATABASE_URL", "mysql://user:pass@localhost/db");
std::env::remove_var("DATABASE_URL");
let result = setup_test_schema(&[]).await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("invalid database URL"),
"non-postgres scheme should fail with clear error, got: {msg}"
);
std::env::remove_var("BSQL_DATABASE_URL");
if let Some(v) = orig_bsql {
std::env::set_var("BSQL_DATABASE_URL", v);
}
if let Some(v) = orig_db {
std::env::set_var("DATABASE_URL", v);
}
}
#[test]
fn connection_refused_unreachable_host() {
let url = "postgres://user:pass@127.0.0.1:1/testdb";
let config = Config::from_url(url).expect("URL should parse");
let conn_result = Connection::connect(&config);
assert!(conn_result.is_err(), "connection to port 1 should fail");
let err = ConnectError::create(format!("connection failed: {}", conn_result.unwrap_err()));
let msg = err.to_string();
assert!(
msg.contains("connection failed"),
"unreachable host should produce 'connection failed' error, got: {msg}"
);
}
#[test]
fn test_context_has_debug_impl() {
fn assert_debug<T: std::fmt::Debug>() {}
assert_debug::<TestContext>();
}
#[test]
fn test_context_debug_shows_schema_name() {
let schema = "__bsql_test_12345_0";
let expected = format!("TestContext {{ schema: {:?} }}", schema);
assert!(expected.contains("TestContext"));
assert!(expected.contains("schema"));
assert!(expected.contains(schema));
}
#[test]
fn drop_code_path_with_invalid_url_does_not_panic() {
let db_url = "garbage-url";
let schema_name = "__bsql_test_fake_0";
if let Ok(config) = Config::from_url(db_url) {
if let Ok(mut conn) = Connection::connect(&config) {
let _ = conn.simple_query(&format!(
"DROP SCHEMA IF EXISTS \"{}\" CASCADE",
schema_name
));
}
}
}
#[test]
fn drop_with_garbage_url_does_not_panic() {
let db_url = "not-a-postgres-url";
let config_result = Config::from_url(db_url);
assert!(config_result.is_err(), "garbage URL should not parse");
}
#[test]
fn drop_with_valid_url_but_unreachable_host_does_not_panic() {
let db_url = "postgres://user:pass@127.0.0.1:1/testdb";
let config = Config::from_url(db_url);
assert!(config.is_ok(), "URL should parse");
let conn_result = Connection::connect(&config.unwrap());
assert!(conn_result.is_err(), "connection to port 1 should fail");
}
#[test]
fn empty_fixture_string_is_skipped() {
let fixture = "";
assert!(fixture.trim().is_empty(), "empty string should be skipped");
}
#[test]
fn whitespace_only_fixture_is_skipped() {
let fixture = " \n\t \n ";
assert!(
fixture.trim().is_empty(),
"whitespace-only fixture should be skipped"
);
}
#[test]
fn fixture_with_only_comments_is_not_empty() {
let fixture = "-- just a comment\n/* block comment */";
assert!(
!fixture.trim().is_empty(),
"comment-only fixture should NOT be skipped (PG handles it)"
);
}
#[test]
fn fixture_with_multiple_statements_passes_trim_check() {
let fixture = "CREATE TABLE a (id INT);\nCREATE TABLE b (id INT);";
assert!(!fixture.trim().is_empty());
}
#[test]
fn missing_env_error_is_connect_variant() {
let err =
ConnectError::create("BSQL_DATABASE_URL or DATABASE_URL must be set for #[bsql::test]");
match err {
BsqlError::Connect(ref ce) => {
assert!(ce.message.contains("BSQL_DATABASE_URL"));
}
_ => panic!("expected Connect variant"),
}
}
#[test]
fn invalid_url_error_is_connect_variant() {
let err = ConnectError::create("invalid database URL: missing postgres:// prefix");
match err {
BsqlError::Connect(ref ce) => {
assert!(ce.message.contains("invalid database URL"));
}
_ => panic!("expected Connect variant"),
}
}
#[test]
fn connection_failed_error_is_connect_variant() {
let err = ConnectError::create("connection failed: Connection refused");
match err {
BsqlError::Connect(ref ce) => {
assert!(ce.message.contains("connection failed"));
}
_ => panic!("expected Connect variant"),
}
}
#[test]
fn fixture_failed_error_is_connect_variant() {
let err = ConnectError::create("fixture failed: syntax error at position 5");
match err {
BsqlError::Connect(ref ce) => {
assert!(ce.message.contains("fixture failed"));
}
_ => panic!("expected Connect variant"),
}
}
#[test]
fn schema_creation_failed_error_is_connect_variant() {
let err = ConnectError::create("failed to create test schema: permission denied");
match err {
BsqlError::Connect(ref ce) => {
assert!(ce.message.contains("failed to create test schema"));
}
_ => panic!("expected Connect variant"),
}
}
#[test]
fn schema_name_has_three_parts() {
let counter = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
let name = format!("__bsql_test_{}_{}", pid, counter);
assert!(name.starts_with("__bsql_test_"));
let suffix = &name["__bsql_test_".len()..];
let parts: Vec<&str> = suffix.split('_').collect();
assert_eq!(parts.len(), 2, "expected PID_COUNTER suffix, got: {suffix}");
assert_eq!(parts[0], pid.to_string());
assert_eq!(parts[1], counter.to_string());
}
#[test]
fn schema_name_counter_part_increases() {
let c1 = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
let c2 = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
let name1 = format!("__bsql_test_{}_{}", pid, c1);
let name2 = format!("__bsql_test_{}_{}", pid, c2);
let counter1: u64 = name1.rsplit('_').next().unwrap().parse().unwrap();
let counter2: u64 = name2.rsplit('_').next().unwrap().parse().unwrap();
assert!(counter2 > counter1);
}
#[tokio::test]
async fn bsql_database_url_takes_priority_over_database_url() {
let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
let orig_db = std::env::var("DATABASE_URL").ok();
std::env::set_var("BSQL_DATABASE_URL", "not-postgres-bsql");
std::env::set_var("DATABASE_URL", "postgres://user:pass@127.0.0.1:1/realdb");
let result = setup_test_schema(&[]).await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("invalid database URL"),
"BSQL_DATABASE_URL should take priority, got: {msg}"
);
std::env::remove_var("BSQL_DATABASE_URL");
std::env::remove_var("DATABASE_URL");
if let Some(v) = orig_bsql {
std::env::set_var("BSQL_DATABASE_URL", v);
}
if let Some(v) = orig_db {
std::env::set_var("DATABASE_URL", v);
}
}
}