reinhardt_testkit/
testcase.rs1use crate::client::APIClient;
6use crate::resource::AsyncTestResource;
7use std::sync::Arc;
8use thiserror::Error;
9use tokio::sync::RwLock;
10
11#[derive(Debug, Error)]
13pub enum TeardownError {
14 #[error("Failed to rollback transactions: {0}")]
16 TransactionRollbackFailed(String),
17
18 #[error("Failed to close database connection: {0}")]
20 ConnectionCloseFailed(String),
21
22 #[error("Failed to cleanup client state: {0}")]
24 ClientCleanupFailed(String),
25}
26
27#[cfg(feature = "testcontainers")]
33#[derive(Debug, Clone)]
34pub struct TransactionHandle {
35 id: String,
37 committed: bool,
39}
40
41#[cfg(feature = "testcontainers")]
42impl TransactionHandle {
43 pub fn new() -> Self {
45 Self {
46 id: uuid::Uuid::new_v4().to_string(),
47 committed: false,
48 }
49 }
50
51 pub fn id(&self) -> &str {
53 &self.id
54 }
55
56 pub fn is_committed(&self) -> bool {
58 self.committed
59 }
60
61 pub fn mark_committed(&mut self) {
63 self.committed = true;
64 }
65}
66
67#[cfg(feature = "testcontainers")]
68impl Default for TransactionHandle {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74pub struct APITestCase {
105 client: Arc<RwLock<APIClient>>,
106 #[cfg(feature = "testcontainers")]
107 database_url: Arc<RwLock<Option<String>>>,
108 #[cfg(feature = "testcontainers")]
109 db_connection: Arc<RwLock<Option<sqlx::AnyPool>>>,
110 #[cfg(feature = "testcontainers")]
111 active_transactions: Arc<RwLock<Vec<TransactionHandle>>>,
112}
113
114impl APITestCase {
115 #[cfg(feature = "testcontainers")]
117 pub async fn database_url(&self) -> Option<String> {
118 self.database_url.read().await.clone()
119 }
120
121 pub async fn client(&self) -> tokio::sync::RwLockReadGuard<'_, APIClient> {
123 self.client.read().await
124 }
125
126 pub async fn client_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, APIClient> {
128 self.client.write().await
129 }
130
131 #[cfg(feature = "testcontainers")]
133 pub async fn set_database_url(&self, url: String) {
134 let mut db_url = self.database_url.write().await;
135 *db_url = Some(url);
136 }
137
138 #[cfg(feature = "testcontainers")]
151 pub async fn set_database_connection(&self, pool: sqlx::AnyPool) {
152 let mut conn = self.db_connection.write().await;
153 *conn = Some(pool);
154 }
155
156 #[cfg(feature = "testcontainers")]
158 pub async fn db_connection(&self) -> Option<sqlx::AnyPool> {
159 self.db_connection.read().await.clone()
160 }
161
162 #[cfg(feature = "testcontainers")]
179 pub async fn begin_transaction(&self) -> TransactionHandle {
180 let handle = TransactionHandle::new();
181 let mut transactions = self.active_transactions.write().await;
182 transactions.push(handle.clone());
183 handle
184 }
185
186 #[cfg(feature = "testcontainers")]
191 pub async fn commit_transaction(&self, transaction_id: &str) {
192 let mut transactions = self.active_transactions.write().await;
193 if let Some(pos) = transactions.iter().position(|t| t.id() == transaction_id) {
194 let mut handle = transactions.remove(pos);
195 handle.mark_committed();
196 }
197 }
198
199 #[cfg(feature = "testcontainers")]
201 pub async fn active_transaction_count(&self) -> usize {
202 self.active_transactions.read().await.len()
203 }
204}
205
206#[async_trait::async_trait]
207impl AsyncTestResource for APITestCase {
208 async fn setup() -> Self {
209 Self {
210 client: Arc::new(RwLock::new(APIClient::new())),
211 #[cfg(feature = "testcontainers")]
212 database_url: Arc::new(RwLock::new(None)),
213 #[cfg(feature = "testcontainers")]
214 db_connection: Arc::new(RwLock::new(None)),
215 #[cfg(feature = "testcontainers")]
216 active_transactions: Arc::new(RwLock::new(Vec::new())),
217 }
218 }
219
220 async fn teardown(self) {
221 {
223 let client = self.client.write().await;
224 client.cleanup().await;
225 }
226
227 #[cfg(feature = "testcontainers")]
229 {
230 let transactions = self.active_transactions.read().await;
232 let uncommitted_count = transactions.iter().filter(|t| !t.is_committed()).count();
233 if uncommitted_count > 0 {
234 tracing::debug!(
237 "Rolling back {} uncommitted transaction(s) during teardown",
238 uncommitted_count
239 );
240 }
241 drop(transactions);
242
243 let mut pool_guard = self.db_connection.write().await;
245 if let Some(pool) = pool_guard.take() {
246 pool.close().await;
248 }
249 }
250
251 drop(self.client);
253 }
254}
255
256#[macro_export]
273macro_rules! test_case {
274 (
275 async fn $name:ident($case:ident: &APITestCase) $body:block
276 ) => {
277 #[rstest::rstest]
278 #[tokio::test]
279 async fn $name() {
280 use $crate::resource::AsyncTeardownGuard;
281 use $crate::testcase::APITestCase;
282
283 let guard = AsyncTeardownGuard::<APITestCase>::new().await;
284 let $case = &*guard;
285
286 $body
288
289 }
291 };
292}
293
294#[macro_export]
296macro_rules! authenticated_test_case {
297 (
298 async fn $name:ident($case:ident: &APITestCase, $user:ident: serde_json::Value) $body:block
299 ) => {
300 #[rstest::rstest]
301 #[tokio::test]
302 async fn $name() {
303 use $crate::resource::AsyncTeardownGuard;
304 use $crate::testcase::APITestCase;
305
306 let guard = AsyncTeardownGuard::<APITestCase>::new().await;
307 let $case = &*guard;
308
309 let $user = serde_json::json!({
311 "id": 1,
312 "username": "testuser",
313 });
314 {
315 let client = $case.client().await;
316 client.force_authenticate(Some($user.clone())).await;
317 }
318
319 $body
321
322 }
324 };
325}
326
327#[cfg(feature = "testcontainers")]
372#[macro_export]
373macro_rules! test_case_with_db {
374 (
375 postgres,
376 async fn $name:ident($case:ident: &APITestCase) $body:block
377 ) => {
378 #[rstest::rstest]
379 #[tokio::test]
380 async fn $name() {
381 use $crate::containers::{with_postgres, PostgresContainer};
382 use $crate::resource::AsyncTeardownGuard;
383 use $crate::testcase::APITestCase;
384
385 with_postgres(|db| async move {
386 let guard = AsyncTeardownGuard::<APITestCase>::new().await;
387 let $case = &*guard;
388 $case.set_database_url(db.connection_url()).await;
389
390 $body
392
393 Ok(())
395 })
396 .await
397 .unwrap();
398 }
399 };
400 (
401 mysql,
402 async fn $name:ident($case:ident: &APITestCase) $body:block
403 ) => {
404 #[rstest::rstest]
405 #[tokio::test]
406 async fn $name() {
407 use $crate::containers::{with_mysql, MySqlContainer};
408 use $crate::resource::AsyncTeardownGuard;
409 use $crate::testcase::APITestCase;
410
411 with_mysql(|db| async move {
412 let guard = AsyncTeardownGuard::<APITestCase>::new().await;
413 let $case = &*guard;
414 $case.set_database_url(db.connection_url()).await;
415
416 $body
418
419 Ok(())
421 })
422 .await
423 .unwrap();
424 }
425 };
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use rstest::rstest;
432
433 #[rstest]
438 fn test_teardown_error_transaction_rollback_display() {
439 let error = TeardownError::TransactionRollbackFailed("tx-123 failed".to_string());
441
442 let display = format!("{}", error);
444
445 assert_eq!(display, "Failed to rollback transactions: tx-123 failed");
447 }
448
449 #[rstest]
450 fn test_teardown_error_connection_close_display() {
451 let error = TeardownError::ConnectionCloseFailed("connection refused".to_string());
453
454 let display = format!("{}", error);
456
457 assert_eq!(
459 display,
460 "Failed to close database connection: connection refused"
461 );
462 }
463
464 #[rstest]
465 fn test_teardown_error_client_cleanup_display() {
466 let error = TeardownError::ClientCleanupFailed("timeout".to_string());
468
469 let display = format!("{}", error);
471
472 assert_eq!(display, "Failed to cleanup client state: timeout");
474 }
475
476 #[rstest]
477 fn test_teardown_error_debug() {
478 let error = TeardownError::TransactionRollbackFailed("debug test".to_string());
480
481 let debug = format!("{:?}", error);
483
484 assert!(
486 debug.contains("debug test"),
487 "Debug output should contain the message, got: {}",
488 debug
489 );
490 }
491
492 #[rstest]
497 #[tokio::test]
498 async fn test_api_test_case_setup_creates_client() {
499 let test_case = APITestCase::setup().await;
501
502 let client = test_case.client().await;
504 drop(client);
506 }
507
508 #[rstest]
509 #[tokio::test]
510 async fn test_api_test_case_client_read_access() {
511 let test_case = APITestCase::setup().await;
513
514 let client = test_case.client().await;
516
517 assert!(
520 std::mem::size_of_val(&*client) > 0,
521 "Client should have non-zero size"
522 );
523 }
524
525 #[rstest]
526 #[tokio::test]
527 async fn test_api_test_case_teardown_completes() {
528 let test_case = APITestCase::setup().await;
530
531 test_case.teardown().await;
534 }
535
536 #[rstest]
537 #[tokio::test]
538 async fn test_api_test_case_multiple_reads() {
539 let test_case = APITestCase::setup().await;
541
542 let client1 = test_case.client().await;
544 let client2 = test_case.client().await;
545
546 assert!(
549 std::mem::size_of_val(&*client1) > 0,
550 "First client read should succeed"
551 );
552 assert!(
553 std::mem::size_of_val(&*client2) > 0,
554 "Second client read should succeed"
555 );
556 }
557
558 #[cfg(feature = "testcontainers")]
563 mod testcontainers_tests {
564 use super::*;
565 use rstest::rstest;
566
567 #[rstest]
568 fn test_transaction_handle_new() {
569 let handle = TransactionHandle::new();
571
572 assert!(!handle.id().is_empty(), "ID should not be empty");
574 assert!(!handle.is_committed(), "New handle should not be committed");
575 }
576
577 #[rstest]
578 fn test_transaction_handle_mark_committed() {
579 let mut handle = TransactionHandle::new();
581
582 handle.mark_committed();
584
585 assert!(handle.is_committed());
587 }
588
589 #[rstest]
590 fn test_transaction_handle_default() {
591 let handle = TransactionHandle::default();
593
594 assert!(!handle.id().is_empty(), "Default ID should not be empty");
596 assert!(
597 !handle.is_committed(),
598 "Default handle should not be committed"
599 );
600 }
601
602 #[rstest]
603 fn test_transaction_handle_id_is_uuid() {
604 let handle = TransactionHandle::new();
606
607 let id = handle.id();
609 let parts: Vec<&str> = id.split('-').collect();
611 assert_eq!(
612 parts.len(),
613 5,
614 "UUID should have 5 parts separated by hyphens, got: {}",
615 id
616 );
617 assert_eq!(parts[0].len(), 8);
618 assert_eq!(parts[1].len(), 4);
619 assert_eq!(parts[2].len(), 4);
620 assert_eq!(parts[3].len(), 4);
621 assert_eq!(parts[4].len(), 12);
622 }
623 }
624}