use crate::client::APIClient;
use crate::resource::AsyncTestResource;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::RwLock;
#[derive(Debug, Error)]
pub enum TeardownError {
#[error("Failed to rollback transactions: {0}")]
TransactionRollbackFailed(String),
#[error("Failed to close database connection: {0}")]
ConnectionCloseFailed(String),
#[error("Failed to cleanup client state: {0}")]
ClientCleanupFailed(String),
}
#[cfg(feature = "testcontainers")]
#[derive(Debug, Clone)]
pub struct TransactionHandle {
id: String,
committed: bool,
}
#[cfg(feature = "testcontainers")]
impl TransactionHandle {
pub fn new() -> Self {
Self {
id: uuid::Uuid::now_v7().to_string(),
committed: false,
}
}
pub fn id(&self) -> &str {
&self.id
}
pub fn is_committed(&self) -> bool {
self.committed
}
pub fn mark_committed(&mut self) {
self.committed = true;
}
}
#[cfg(feature = "testcontainers")]
impl Default for TransactionHandle {
fn default() -> Self {
Self::new()
}
}
pub struct APITestCase {
client: Arc<RwLock<APIClient>>,
#[cfg(feature = "testcontainers")]
database_url: Arc<RwLock<Option<String>>>,
#[cfg(feature = "testcontainers")]
db_connection: Arc<RwLock<Option<sqlx::AnyPool>>>,
#[cfg(feature = "testcontainers")]
active_transactions: Arc<RwLock<Vec<TransactionHandle>>>,
}
impl APITestCase {
#[cfg(feature = "testcontainers")]
pub async fn database_url(&self) -> Option<String> {
self.database_url.read().await.clone()
}
pub async fn client(&self) -> tokio::sync::RwLockReadGuard<'_, APIClient> {
self.client.read().await
}
pub async fn client_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, APIClient> {
self.client.write().await
}
#[cfg(feature = "testcontainers")]
pub async fn set_database_url(&self, url: String) {
let mut db_url = self.database_url.write().await;
*db_url = Some(url);
}
#[cfg(feature = "testcontainers")]
pub async fn set_database_connection(&self, pool: sqlx::AnyPool) {
let mut conn = self.db_connection.write().await;
*conn = Some(pool);
}
#[cfg(feature = "testcontainers")]
pub async fn db_connection(&self) -> Option<sqlx::AnyPool> {
self.db_connection.read().await.clone()
}
#[cfg(feature = "testcontainers")]
pub async fn begin_transaction(&self) -> TransactionHandle {
let handle = TransactionHandle::new();
let mut transactions = self.active_transactions.write().await;
transactions.push(handle.clone());
handle
}
#[cfg(feature = "testcontainers")]
pub async fn commit_transaction(&self, transaction_id: &str) {
let mut transactions = self.active_transactions.write().await;
if let Some(pos) = transactions.iter().position(|t| t.id() == transaction_id) {
let mut handle = transactions.remove(pos);
handle.mark_committed();
}
}
#[cfg(feature = "testcontainers")]
pub async fn active_transaction_count(&self) -> usize {
self.active_transactions.read().await.len()
}
}
#[async_trait::async_trait]
impl AsyncTestResource for APITestCase {
async fn setup() -> Self {
Self {
client: Arc::new(RwLock::new(APIClient::new())),
#[cfg(feature = "testcontainers")]
database_url: Arc::new(RwLock::new(None)),
#[cfg(feature = "testcontainers")]
db_connection: Arc::new(RwLock::new(None)),
#[cfg(feature = "testcontainers")]
active_transactions: Arc::new(RwLock::new(Vec::new())),
}
}
async fn teardown(self) {
{
let client = self.client.write().await;
client.cleanup().await;
}
#[cfg(feature = "testcontainers")]
{
let transactions = self.active_transactions.read().await;
let uncommitted_count = transactions.iter().filter(|t| !t.is_committed()).count();
if uncommitted_count > 0 {
tracing::debug!(
"Rolling back {} uncommitted transaction(s) during teardown",
uncommitted_count
);
}
drop(transactions);
let mut pool_guard = self.db_connection.write().await;
if let Some(pool) = pool_guard.take() {
pool.close().await;
}
}
drop(self.client);
}
}
#[macro_export]
macro_rules! test_case {
(
async fn $name:ident($case:ident: &APITestCase) $body:block
) => {
#[rstest::rstest]
#[tokio::test]
async fn $name() {
use $crate::resource::AsyncTeardownGuard;
use $crate::testcase::APITestCase;
let guard = AsyncTeardownGuard::<APITestCase>::new().await;
let $case = &*guard;
$body
}
};
}
#[macro_export]
macro_rules! authenticated_test_case {
(
async fn $name:ident($case:ident: &APITestCase, $user:ident: serde_json::Value) $body:block
) => {
#[rstest::rstest]
#[tokio::test]
async fn $name() {
use $crate::resource::AsyncTeardownGuard;
use $crate::testcase::APITestCase;
let guard = AsyncTeardownGuard::<APITestCase>::new().await;
let $case = &*guard;
let $user = serde_json::json!({
"id": 1,
"username": "testuser",
});
$body
}
};
}
#[cfg(feature = "testcontainers")]
#[macro_export]
macro_rules! test_case_with_db {
(
postgres,
async fn $name:ident($case:ident: &APITestCase) $body:block
) => {
#[rstest::rstest]
#[tokio::test]
async fn $name() {
use $crate::containers::{with_postgres, PostgresContainer};
use $crate::resource::AsyncTeardownGuard;
use $crate::testcase::APITestCase;
with_postgres(|db| async move {
let guard = AsyncTeardownGuard::<APITestCase>::new().await;
let $case = &*guard;
$case.set_database_url(db.connection_url()).await;
$body
Ok(())
})
.await
.unwrap();
}
};
(
mysql,
async fn $name:ident($case:ident: &APITestCase) $body:block
) => {
#[rstest::rstest]
#[tokio::test]
async fn $name() {
use $crate::containers::{with_mysql, MySqlContainer};
use $crate::resource::AsyncTeardownGuard;
use $crate::testcase::APITestCase;
with_mysql(|db| async move {
let guard = AsyncTeardownGuard::<APITestCase>::new().await;
let $case = &*guard;
$case.set_database_url(db.connection_url()).await;
$body
Ok(())
})
.await
.unwrap();
}
};
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
fn test_teardown_error_transaction_rollback_display() {
let error = TeardownError::TransactionRollbackFailed("tx-123 failed".to_string());
let display = format!("{}", error);
assert_eq!(display, "Failed to rollback transactions: tx-123 failed");
}
#[rstest]
fn test_teardown_error_connection_close_display() {
let error = TeardownError::ConnectionCloseFailed("connection refused".to_string());
let display = format!("{}", error);
assert_eq!(
display,
"Failed to close database connection: connection refused"
);
}
#[rstest]
fn test_teardown_error_client_cleanup_display() {
let error = TeardownError::ClientCleanupFailed("timeout".to_string());
let display = format!("{}", error);
assert_eq!(display, "Failed to cleanup client state: timeout");
}
#[rstest]
fn test_teardown_error_debug() {
let error = TeardownError::TransactionRollbackFailed("debug test".to_string());
let debug = format!("{:?}", error);
assert!(
debug.contains("debug test"),
"Debug output should contain the message, got: {}",
debug
);
}
#[rstest]
#[tokio::test]
async fn test_api_test_case_setup_creates_client() {
let test_case = APITestCase::setup().await;
let client = test_case.client().await;
drop(client);
}
#[rstest]
#[tokio::test]
async fn test_api_test_case_client_read_access() {
let test_case = APITestCase::setup().await;
let client = test_case.client().await;
assert!(
std::mem::size_of_val(&*client) > 0,
"Client should have non-zero size"
);
}
#[rstest]
#[tokio::test]
async fn test_api_test_case_teardown_completes() {
let test_case = APITestCase::setup().await;
test_case.teardown().await;
}
#[rstest]
#[tokio::test]
async fn test_api_test_case_multiple_reads() {
let test_case = APITestCase::setup().await;
let client1 = test_case.client().await;
let client2 = test_case.client().await;
assert!(
std::mem::size_of_val(&*client1) > 0,
"First client read should succeed"
);
assert!(
std::mem::size_of_val(&*client2) > 0,
"Second client read should succeed"
);
}
#[cfg(feature = "testcontainers")]
mod testcontainers_tests {
use super::*;
use rstest::rstest;
#[rstest]
fn test_transaction_handle_new() {
let handle = TransactionHandle::new();
assert!(!handle.id().is_empty(), "ID should not be empty");
assert!(!handle.is_committed(), "New handle should not be committed");
}
#[rstest]
fn test_transaction_handle_mark_committed() {
let mut handle = TransactionHandle::new();
handle.mark_committed();
assert!(handle.is_committed());
}
#[rstest]
fn test_transaction_handle_default() {
let handle = TransactionHandle::default();
assert!(!handle.id().is_empty(), "Default ID should not be empty");
assert!(
!handle.is_committed(),
"Default handle should not be committed"
);
}
#[rstest]
fn test_transaction_handle_id_is_uuid() {
let handle = TransactionHandle::new();
let id = handle.id();
let parts: Vec<&str> = id.split('-').collect();
assert_eq!(
parts.len(),
5,
"UUID should have 5 parts separated by hyphens, got: {}",
id
);
assert_eq!(parts[0].len(), 8);
assert_eq!(parts[1].len(), 4);
assert_eq!(parts[2].len(), 4);
assert_eq!(parts[3].len(), 4);
assert_eq!(parts[4].len(), 12);
}
}
}