1use std::sync::atomic::{AtomicU64, Ordering};
8
9use bsql_driver_postgres::{Config, Connection};
10
11use crate::error::{BsqlError, ConnectError};
12use crate::pool::Pool;
13
14static TEST_COUNTER: AtomicU64 = AtomicU64::new(0);
15
16pub struct TestContext {
19 pub pool: Pool,
21 schema_name: String,
22 db_url: String,
23}
24
25impl std::fmt::Debug for TestContext {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("TestContext")
28 .field("schema", &self.schema_name)
29 .finish()
30 }
31}
32
33impl Drop for TestContext {
34 fn drop(&mut self) {
35 if let Ok(config) = Config::from_url(&self.db_url) {
38 if let Ok(mut conn) = Connection::connect(&config) {
39 let _ = conn.simple_query(&format!(
40 "DROP SCHEMA IF EXISTS \"{}\" CASCADE",
41 self.schema_name
42 ));
43 }
44 }
45 }
46}
47
48pub async fn setup_test_schema(fixtures_sql: &[&str]) -> Result<TestContext, BsqlError> {
54 let db_url = std::env::var("BSQL_DATABASE_URL")
55 .or_else(|_| std::env::var("DATABASE_URL"))
56 .map_err(|_| {
57 ConnectError::create("BSQL_DATABASE_URL or DATABASE_URL must be set for #[bsql::test]")
58 })?;
59
60 let schema_name = format!(
61 "__bsql_test_{}_{}",
62 std::process::id(),
63 TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
64 );
65
66 let config = Config::from_url(&db_url)
68 .map_err(|e| ConnectError::create(format!("invalid database URL: {e}")))?;
69 let mut conn = Connection::connect(&config)
70 .map_err(|e| ConnectError::create(format!("connection failed: {e}")))?;
71
72 conn.simple_query(&format!("CREATE SCHEMA \"{}\"", schema_name))
74 .map_err(|e| ConnectError::create(format!("failed to create test schema: {e}")))?;
75
76 conn.simple_query(&format!("SET search_path TO \"{}\", public", schema_name))
78 .map_err(|e| ConnectError::create(format!("failed to set search_path: {e}")))?;
79
80 for fixture_sql in fixtures_sql {
82 if !fixture_sql.trim().is_empty() {
83 conn.simple_query(fixture_sql)
84 .map_err(|e| ConnectError::create(format!("fixture failed: {e}")))?;
85 }
86 }
87
88 drop(conn); let pool = Pool::connect(&db_url).await?;
93
94 pool.raw_execute(&format!("SET search_path TO \"{}\", public", schema_name))
97 .await?;
98
99 let warmup_sql = format!("SET search_path TO \"{}\", public", schema_name);
103 let leaked: &'static str = Box::leak(warmup_sql.into_boxed_str());
107 pool.set_warmup_sqls(&[leaked]);
108
109 Ok(TestContext {
110 pool,
111 schema_name,
112 db_url,
113 })
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use std::collections::HashSet;
120
121 #[test]
126 fn schema_name_is_unique() {
127 let name1 = format!(
128 "__bsql_test_{}_{}",
129 std::process::id(),
130 TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
131 );
132 let name2 = format!(
133 "__bsql_test_{}_{}",
134 std::process::id(),
135 TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
136 );
137 assert_ne!(name1, name2);
138 }
139
140 #[test]
141 fn schema_name_contains_pid() {
142 let name = format!(
143 "__bsql_test_{}_{}",
144 std::process::id(),
145 TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
146 );
147 assert!(name.contains(&std::process::id().to_string()));
148 }
149
150 #[test]
151 fn schema_name_starts_with_prefix() {
152 let name = format!(
153 "__bsql_test_{}_{}",
154 std::process::id(),
155 TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
156 );
157 assert!(name.starts_with("__bsql_test_"));
158 }
159
160 #[test]
161 fn schema_names_never_collide_100_sequential() {
162 let mut names = HashSet::new();
163 for _ in 0..100 {
164 let name = format!(
165 "__bsql_test_{}_{}",
166 std::process::id(),
167 TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
168 );
169 assert!(names.insert(name.clone()), "duplicate schema name: {name}");
170 }
171 assert_eq!(names.len(), 100);
172 }
173
174 #[test]
175 fn schema_name_is_valid_sql_identifier() {
176 let name = format!(
177 "__bsql_test_{}_{}",
178 std::process::id(),
179 TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
180 );
181 assert!(
183 name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_'),
184 "schema name contains invalid chars: {name}"
185 );
186 assert!(
187 name.starts_with('_') || name.starts_with(|c: char| c.is_ascii_alphabetic()),
188 "schema name must start with letter or underscore: {name}"
189 );
190 }
191
192 #[test]
197 fn test_counter_is_monotonic() {
198 let a = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
199 let b = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
200 let c = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
201 assert!(a < b);
202 assert!(b < c);
203 }
204
205 #[test]
206 fn counter_increments_atomically_across_threads() {
207 use std::sync::Arc;
208 let results: Arc<std::sync::Mutex<Vec<u64>>> = Arc::new(std::sync::Mutex::new(Vec::new()));
209 let mut handles = Vec::new();
210 for _ in 0..10 {
211 let results = Arc::clone(&results);
212 handles.push(std::thread::spawn(move || {
213 for _ in 0..10 {
214 let val = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
215 results.lock().unwrap().push(val);
216 }
217 }));
218 }
219 for h in handles {
220 h.join().unwrap();
221 }
222 let mut vals = results.lock().unwrap().clone();
223 assert_eq!(vals.len(), 100, "expected 100 counter values");
224 let set: HashSet<u64> = vals.iter().copied().collect();
226 assert_eq!(
227 set.len(),
228 100,
229 "counter values must be unique across threads"
230 );
231 vals.sort();
233 for window in vals.windows(2) {
234 assert!(window[0] < window[1], "counter must be strictly increasing");
235 }
236 }
237
238 #[test]
243 fn multiple_schema_names_created_simultaneously_are_different() {
244 let names: Vec<String> = (0..50)
246 .map(|_| {
247 format!(
248 "__bsql_test_{}_{}",
249 std::process::id(),
250 TEST_COUNTER.fetch_add(1, Ordering::Relaxed),
251 )
252 })
253 .collect();
254 let set: HashSet<&String> = names.iter().collect();
255 assert_eq!(set.len(), names.len(), "all schema names must be unique");
256 }
257
258 #[tokio::test]
263 async fn missing_db_url_returns_clear_error() {
264 let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
266 let orig_db = std::env::var("DATABASE_URL").ok();
267 std::env::remove_var("BSQL_DATABASE_URL");
268 std::env::remove_var("DATABASE_URL");
269
270 let result = setup_test_schema(&[]).await;
271 assert!(result.is_err());
272 let err = result.unwrap_err();
273 let msg = err.to_string();
274 assert!(
275 msg.contains("BSQL_DATABASE_URL") && msg.contains("DATABASE_URL"),
276 "error should mention both env vars, got: {msg}"
277 );
278
279 if let Some(v) = orig_bsql {
281 std::env::set_var("BSQL_DATABASE_URL", v);
282 }
283 if let Some(v) = orig_db {
284 std::env::set_var("DATABASE_URL", v);
285 }
286 }
287
288 #[tokio::test]
289 async fn missing_bsql_database_url_falls_back_to_database_url() {
290 let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
291 let orig_db = std::env::var("DATABASE_URL").ok();
292 std::env::remove_var("BSQL_DATABASE_URL");
293 std::env::set_var("DATABASE_URL", "not-a-url");
295
296 let result = setup_test_schema(&[]).await;
297 assert!(result.is_err());
299 let msg = result.unwrap_err().to_string();
300 assert!(
301 msg.contains("invalid database URL"),
302 "should fail on URL parse after falling back to DATABASE_URL, got: {msg}"
303 );
304
305 std::env::remove_var("DATABASE_URL");
307 if let Some(v) = orig_bsql {
308 std::env::set_var("BSQL_DATABASE_URL", v);
309 }
310 if let Some(v) = orig_db {
311 std::env::set_var("DATABASE_URL", v);
312 }
313 }
314
315 #[tokio::test]
316 async fn invalid_db_url_returns_clear_error() {
317 let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
318 let orig_db = std::env::var("DATABASE_URL").ok();
319 std::env::set_var("BSQL_DATABASE_URL", "not-a-valid-url");
320 std::env::remove_var("DATABASE_URL");
321
322 let result = setup_test_schema(&[]).await;
323 assert!(result.is_err());
324 let err = result.unwrap_err();
325 let msg = err.to_string();
326 assert!(
327 msg.contains("invalid database URL"),
328 "error should mention invalid URL, got: {msg}"
329 );
330
331 std::env::remove_var("BSQL_DATABASE_URL");
333 if let Some(v) = orig_bsql {
334 std::env::set_var("BSQL_DATABASE_URL", v);
335 }
336 if let Some(v) = orig_db {
337 std::env::set_var("DATABASE_URL", v);
338 }
339 }
340
341 #[tokio::test]
342 async fn invalid_db_url_not_postgres_scheme() {
343 let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
344 let orig_db = std::env::var("DATABASE_URL").ok();
345 std::env::set_var("BSQL_DATABASE_URL", "mysql://user:pass@localhost/db");
346 std::env::remove_var("DATABASE_URL");
347
348 let result = setup_test_schema(&[]).await;
349 assert!(result.is_err());
350 let msg = result.unwrap_err().to_string();
351 assert!(
352 msg.contains("invalid database URL"),
353 "non-postgres scheme should fail with clear error, got: {msg}"
354 );
355
356 std::env::remove_var("BSQL_DATABASE_URL");
357 if let Some(v) = orig_bsql {
358 std::env::set_var("BSQL_DATABASE_URL", v);
359 }
360 if let Some(v) = orig_db {
361 std::env::set_var("DATABASE_URL", v);
362 }
363 }
364
365 #[test]
366 fn connection_refused_unreachable_host() {
367 let url = "postgres://user:pass@127.0.0.1:1/testdb";
370 let config = Config::from_url(url).expect("URL should parse");
371 let conn_result = Connection::connect(&config);
372 assert!(conn_result.is_err(), "connection to port 1 should fail");
373 let err = ConnectError::create(format!("connection failed: {}", conn_result.unwrap_err()));
376 let msg = err.to_string();
377 assert!(
378 msg.contains("connection failed"),
379 "unreachable host should produce 'connection failed' error, got: {msg}"
380 );
381 }
382
383 #[test]
388 fn test_context_has_debug_impl() {
389 fn assert_debug<T: std::fmt::Debug>() {}
391 assert_debug::<TestContext>();
392 }
393
394 #[test]
395 fn test_context_debug_shows_schema_name() {
396 let schema = "__bsql_test_12345_0";
400 let expected = format!("TestContext {{ schema: {:?} }}", schema);
401 assert!(expected.contains("TestContext"));
403 assert!(expected.contains("schema"));
404 assert!(expected.contains(schema));
405 }
406
407 #[test]
412 fn drop_code_path_with_invalid_url_does_not_panic() {
413 let db_url = "garbage-url";
417 let schema_name = "__bsql_test_fake_0";
418 if let Ok(config) = Config::from_url(db_url) {
420 if let Ok(mut conn) = Connection::connect(&config) {
422 let _ = conn.simple_query(&format!(
423 "DROP SCHEMA IF EXISTS \"{}\" CASCADE",
424 schema_name
425 ));
426 }
427 }
428 }
430
431 #[test]
432 fn drop_with_garbage_url_does_not_panic() {
433 let db_url = "not-a-postgres-url";
440 let config_result = Config::from_url(db_url);
441 assert!(config_result.is_err(), "garbage URL should not parse");
442 }
444
445 #[test]
446 fn drop_with_valid_url_but_unreachable_host_does_not_panic() {
447 let db_url = "postgres://user:pass@127.0.0.1:1/testdb";
450 let config = Config::from_url(db_url);
451 assert!(config.is_ok(), "URL should parse");
452 let conn_result = Connection::connect(&config.unwrap());
453 assert!(conn_result.is_err(), "connection to port 1 should fail");
454 }
456
457 #[test]
462 fn empty_fixture_string_is_skipped() {
463 let fixture = "";
466 assert!(fixture.trim().is_empty(), "empty string should be skipped");
467 }
468
469 #[test]
470 fn whitespace_only_fixture_is_skipped() {
471 let fixture = " \n\t \n ";
472 assert!(
473 fixture.trim().is_empty(),
474 "whitespace-only fixture should be skipped"
475 );
476 }
477
478 #[test]
479 fn fixture_with_only_comments_is_not_empty() {
480 let fixture = "-- just a comment\n/* block comment */";
483 assert!(
484 !fixture.trim().is_empty(),
485 "comment-only fixture should NOT be skipped (PG handles it)"
486 );
487 }
488
489 #[test]
490 fn fixture_with_multiple_statements_passes_trim_check() {
491 let fixture = "CREATE TABLE a (id INT);\nCREATE TABLE b (id INT);";
492 assert!(!fixture.trim().is_empty());
493 }
494
495 #[test]
500 fn missing_env_error_is_connect_variant() {
501 let err =
502 ConnectError::create("BSQL_DATABASE_URL or DATABASE_URL must be set for #[bsql::test]");
503 match err {
504 BsqlError::Connect(ref ce) => {
505 assert!(ce.message.contains("BSQL_DATABASE_URL"));
506 }
507 _ => panic!("expected Connect variant"),
508 }
509 }
510
511 #[test]
512 fn invalid_url_error_is_connect_variant() {
513 let err = ConnectError::create("invalid database URL: missing postgres:// prefix");
514 match err {
515 BsqlError::Connect(ref ce) => {
516 assert!(ce.message.contains("invalid database URL"));
517 }
518 _ => panic!("expected Connect variant"),
519 }
520 }
521
522 #[test]
523 fn connection_failed_error_is_connect_variant() {
524 let err = ConnectError::create("connection failed: Connection refused");
525 match err {
526 BsqlError::Connect(ref ce) => {
527 assert!(ce.message.contains("connection failed"));
528 }
529 _ => panic!("expected Connect variant"),
530 }
531 }
532
533 #[test]
534 fn fixture_failed_error_is_connect_variant() {
535 let err = ConnectError::create("fixture failed: syntax error at position 5");
536 match err {
537 BsqlError::Connect(ref ce) => {
538 assert!(ce.message.contains("fixture failed"));
539 }
540 _ => panic!("expected Connect variant"),
541 }
542 }
543
544 #[test]
545 fn schema_creation_failed_error_is_connect_variant() {
546 let err = ConnectError::create("failed to create test schema: permission denied");
547 match err {
548 BsqlError::Connect(ref ce) => {
549 assert!(ce.message.contains("failed to create test schema"));
550 }
551 _ => panic!("expected Connect variant"),
552 }
553 }
554
555 #[test]
560 fn schema_name_has_three_parts() {
561 let counter = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
562 let pid = std::process::id();
563 let name = format!("__bsql_test_{}_{}", pid, counter);
564 assert!(name.starts_with("__bsql_test_"));
566 let suffix = &name["__bsql_test_".len()..];
567 let parts: Vec<&str> = suffix.split('_').collect();
568 assert_eq!(parts.len(), 2, "expected PID_COUNTER suffix, got: {suffix}");
569 assert_eq!(parts[0], pid.to_string());
570 assert_eq!(parts[1], counter.to_string());
571 }
572
573 #[test]
574 fn schema_name_counter_part_increases() {
575 let c1 = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
576 let c2 = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
577 let pid = std::process::id();
578 let name1 = format!("__bsql_test_{}_{}", pid, c1);
579 let name2 = format!("__bsql_test_{}_{}", pid, c2);
580 let counter1: u64 = name1.rsplit('_').next().unwrap().parse().unwrap();
582 let counter2: u64 = name2.rsplit('_').next().unwrap().parse().unwrap();
583 assert!(counter2 > counter1);
584 }
585
586 #[tokio::test]
591 async fn bsql_database_url_takes_priority_over_database_url() {
592 let orig_bsql = std::env::var("BSQL_DATABASE_URL").ok();
593 let orig_db = std::env::var("DATABASE_URL").ok();
594
595 std::env::set_var("BSQL_DATABASE_URL", "not-postgres-bsql");
598 std::env::set_var("DATABASE_URL", "postgres://user:pass@127.0.0.1:1/realdb");
599
600 let result = setup_test_schema(&[]).await;
601 assert!(result.is_err());
602 let msg = result.unwrap_err().to_string();
603 assert!(
605 msg.contains("invalid database URL"),
606 "BSQL_DATABASE_URL should take priority, got: {msg}"
607 );
608
609 std::env::remove_var("BSQL_DATABASE_URL");
611 std::env::remove_var("DATABASE_URL");
612 if let Some(v) = orig_bsql {
613 std::env::set_var("BSQL_DATABASE_URL", v);
614 }
615 if let Some(v) = orig_db {
616 std::env::set_var("DATABASE_URL", v);
617 }
618 }
619}