reinhardt_testkit/
testcase.rs1use crate::client::APIClient;
6use crate::resource::AsyncTestResource;
7use std::sync::Arc;
8use thiserror::Error;
9use tokio::sync::RwLock;
10
11#[derive(Debug, Error)]
13pub enum TeardownError {
14 #[error("Failed to rollback transactions: {0}")]
16 TransactionRollbackFailed(String),
17
18 #[error("Failed to close database connection: {0}")]
20 ConnectionCloseFailed(String),
21
22 #[error("Failed to cleanup client state: {0}")]
24 ClientCleanupFailed(String),
25}
26
27#[cfg(feature = "testcontainers")]
33#[derive(Debug, Clone)]
34pub struct TransactionHandle {
35 id: String,
37 committed: bool,
39}
40
41#[cfg(feature = "testcontainers")]
42impl TransactionHandle {
43 pub fn new() -> Self {
45 Self {
46 id: uuid::Uuid::now_v7().to_string(),
47 committed: false,
48 }
49 }
50
51 pub fn id(&self) -> &str {
53 &self.id
54 }
55
56 pub fn is_committed(&self) -> bool {
58 self.committed
59 }
60
61 pub fn mark_committed(&mut self) {
63 self.committed = true;
64 }
65}
66
67#[cfg(feature = "testcontainers")]
68impl Default for TransactionHandle {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74pub struct APITestCase {
105 client: Arc<RwLock<APIClient>>,
106 #[cfg(feature = "testcontainers")]
107 database_url: Arc<RwLock<Option<String>>>,
108 #[cfg(feature = "testcontainers")]
109 db_connection: Arc<RwLock<Option<sqlx::AnyPool>>>,
110 #[cfg(feature = "testcontainers")]
111 active_transactions: Arc<RwLock<Vec<TransactionHandle>>>,
112}
113
114impl APITestCase {
115 #[cfg(feature = "testcontainers")]
117 pub async fn database_url(&self) -> Option<String> {
118 self.database_url.read().await.clone()
119 }
120
121 pub async fn client(&self) -> tokio::sync::RwLockReadGuard<'_, APIClient> {
123 self.client.read().await
124 }
125
126 pub async fn client_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, APIClient> {
128 self.client.write().await
129 }
130
131 #[cfg(feature = "testcontainers")]
133 pub async fn set_database_url(&self, url: String) {
134 let mut db_url = self.database_url.write().await;
135 *db_url = Some(url);
136 }
137
138 #[cfg(feature = "testcontainers")]
151 pub async fn set_database_connection(&self, pool: sqlx::AnyPool) {
152 let mut conn = self.db_connection.write().await;
153 *conn = Some(pool);
154 }
155
156 #[cfg(feature = "testcontainers")]
158 pub async fn db_connection(&self) -> Option<sqlx::AnyPool> {
159 self.db_connection.read().await.clone()
160 }
161
162 #[cfg(feature = "testcontainers")]
179 pub async fn begin_transaction(&self) -> TransactionHandle {
180 let handle = TransactionHandle::new();
181 let mut transactions = self.active_transactions.write().await;
182 transactions.push(handle.clone());
183 handle
184 }
185
186 #[cfg(feature = "testcontainers")]
191 pub async fn commit_transaction(&self, transaction_id: &str) {
192 let mut transactions = self.active_transactions.write().await;
193 if let Some(pos) = transactions.iter().position(|t| t.id() == transaction_id) {
194 let mut handle = transactions.remove(pos);
195 handle.mark_committed();
196 }
197 }
198
199 #[cfg(feature = "testcontainers")]
201 pub async fn active_transaction_count(&self) -> usize {
202 self.active_transactions.read().await.len()
203 }
204}
205
206#[async_trait::async_trait]
207impl AsyncTestResource for APITestCase {
208 async fn setup() -> Self {
209 Self {
210 client: Arc::new(RwLock::new(APIClient::new())),
211 #[cfg(feature = "testcontainers")]
212 database_url: Arc::new(RwLock::new(None)),
213 #[cfg(feature = "testcontainers")]
214 db_connection: Arc::new(RwLock::new(None)),
215 #[cfg(feature = "testcontainers")]
216 active_transactions: Arc::new(RwLock::new(Vec::new())),
217 }
218 }
219
220 async fn teardown(self) {
221 {
223 let client = self.client.write().await;
224 client.cleanup().await;
225 }
226
227 #[cfg(feature = "testcontainers")]
229 {
230 let transactions = self.active_transactions.read().await;
232 let uncommitted_count = transactions.iter().filter(|t| !t.is_committed()).count();
233 if uncommitted_count > 0 {
234 tracing::debug!(
237 "Rolling back {} uncommitted transaction(s) during teardown",
238 uncommitted_count
239 );
240 }
241 drop(transactions);
242
243 let mut pool_guard = self.db_connection.write().await;
245 if let Some(pool) = pool_guard.take() {
246 pool.close().await;
248 }
249 }
250
251 drop(self.client);
253 }
254}
255
256#[macro_export]
273macro_rules! test_case {
274 (
275 async fn $name:ident($case:ident: &APITestCase) $body:block
276 ) => {
277 #[rstest::rstest]
278 #[tokio::test]
279 async fn $name() {
280 use $crate::resource::AsyncTeardownGuard;
281 use $crate::testcase::APITestCase;
282
283 let guard = AsyncTeardownGuard::<APITestCase>::new().await;
284 let $case = &*guard;
285
286 $body
288
289 }
291 };
292}
293
294#[macro_export]
296macro_rules! authenticated_test_case {
297 (
298 async fn $name:ident($case:ident: &APITestCase, $user:ident: serde_json::Value) $body:block
299 ) => {
300 #[rstest::rstest]
301 #[tokio::test]
302 async fn $name() {
303 use $crate::resource::AsyncTeardownGuard;
304 use $crate::testcase::APITestCase;
305
306 let guard = AsyncTeardownGuard::<APITestCase>::new().await;
307 let $case = &*guard;
308
309 let $user = serde_json::json!({
311 "id": 1,
312 "username": "testuser",
313 });
314 {
315 let client = $case.client().await;
316 #[allow(deprecated)]
317 client.force_authenticate(Some($user.clone())).await;
318 }
319
320 $body
322
323 }
325 };
326}
327
328#[cfg(feature = "testcontainers")]
373#[macro_export]
374macro_rules! test_case_with_db {
375 (
376 postgres,
377 async fn $name:ident($case:ident: &APITestCase) $body:block
378 ) => {
379 #[rstest::rstest]
380 #[tokio::test]
381 async fn $name() {
382 use $crate::containers::{with_postgres, PostgresContainer};
383 use $crate::resource::AsyncTeardownGuard;
384 use $crate::testcase::APITestCase;
385
386 with_postgres(|db| async move {
387 let guard = AsyncTeardownGuard::<APITestCase>::new().await;
388 let $case = &*guard;
389 $case.set_database_url(db.connection_url()).await;
390
391 $body
393
394 Ok(())
396 })
397 .await
398 .unwrap();
399 }
400 };
401 (
402 mysql,
403 async fn $name:ident($case:ident: &APITestCase) $body:block
404 ) => {
405 #[rstest::rstest]
406 #[tokio::test]
407 async fn $name() {
408 use $crate::containers::{with_mysql, MySqlContainer};
409 use $crate::resource::AsyncTeardownGuard;
410 use $crate::testcase::APITestCase;
411
412 with_mysql(|db| async move {
413 let guard = AsyncTeardownGuard::<APITestCase>::new().await;
414 let $case = &*guard;
415 $case.set_database_url(db.connection_url()).await;
416
417 $body
419
420 Ok(())
422 })
423 .await
424 .unwrap();
425 }
426 };
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use rstest::rstest;
433
434 #[rstest]
439 fn test_teardown_error_transaction_rollback_display() {
440 let error = TeardownError::TransactionRollbackFailed("tx-123 failed".to_string());
442
443 let display = format!("{}", error);
445
446 assert_eq!(display, "Failed to rollback transactions: tx-123 failed");
448 }
449
450 #[rstest]
451 fn test_teardown_error_connection_close_display() {
452 let error = TeardownError::ConnectionCloseFailed("connection refused".to_string());
454
455 let display = format!("{}", error);
457
458 assert_eq!(
460 display,
461 "Failed to close database connection: connection refused"
462 );
463 }
464
465 #[rstest]
466 fn test_teardown_error_client_cleanup_display() {
467 let error = TeardownError::ClientCleanupFailed("timeout".to_string());
469
470 let display = format!("{}", error);
472
473 assert_eq!(display, "Failed to cleanup client state: timeout");
475 }
476
477 #[rstest]
478 fn test_teardown_error_debug() {
479 let error = TeardownError::TransactionRollbackFailed("debug test".to_string());
481
482 let debug = format!("{:?}", error);
484
485 assert!(
487 debug.contains("debug test"),
488 "Debug output should contain the message, got: {}",
489 debug
490 );
491 }
492
493 #[rstest]
498 #[tokio::test]
499 async fn test_api_test_case_setup_creates_client() {
500 let test_case = APITestCase::setup().await;
502
503 let client = test_case.client().await;
505 drop(client);
507 }
508
509 #[rstest]
510 #[tokio::test]
511 async fn test_api_test_case_client_read_access() {
512 let test_case = APITestCase::setup().await;
514
515 let client = test_case.client().await;
517
518 assert!(
521 std::mem::size_of_val(&*client) > 0,
522 "Client should have non-zero size"
523 );
524 }
525
526 #[rstest]
527 #[tokio::test]
528 async fn test_api_test_case_teardown_completes() {
529 let test_case = APITestCase::setup().await;
531
532 test_case.teardown().await;
535 }
536
537 #[rstest]
538 #[tokio::test]
539 async fn test_api_test_case_multiple_reads() {
540 let test_case = APITestCase::setup().await;
542
543 let client1 = test_case.client().await;
545 let client2 = test_case.client().await;
546
547 assert!(
550 std::mem::size_of_val(&*client1) > 0,
551 "First client read should succeed"
552 );
553 assert!(
554 std::mem::size_of_val(&*client2) > 0,
555 "Second client read should succeed"
556 );
557 }
558
559 #[cfg(feature = "testcontainers")]
564 mod testcontainers_tests {
565 use super::*;
566 use rstest::rstest;
567
568 #[rstest]
569 fn test_transaction_handle_new() {
570 let handle = TransactionHandle::new();
572
573 assert!(!handle.id().is_empty(), "ID should not be empty");
575 assert!(!handle.is_committed(), "New handle should not be committed");
576 }
577
578 #[rstest]
579 fn test_transaction_handle_mark_committed() {
580 let mut handle = TransactionHandle::new();
582
583 handle.mark_committed();
585
586 assert!(handle.is_committed());
588 }
589
590 #[rstest]
591 fn test_transaction_handle_default() {
592 let handle = TransactionHandle::default();
594
595 assert!(!handle.id().is_empty(), "Default ID should not be empty");
597 assert!(
598 !handle.is_committed(),
599 "Default handle should not be committed"
600 );
601 }
602
603 #[rstest]
604 fn test_transaction_handle_id_is_uuid() {
605 let handle = TransactionHandle::new();
607
608 let id = handle.id();
610 let parts: Vec<&str> = id.split('-').collect();
612 assert_eq!(
613 parts.len(),
614 5,
615 "UUID should have 5 parts separated by hyphens, got: {}",
616 id
617 );
618 assert_eq!(parts[0].len(), 8);
619 assert_eq!(parts[1].len(), 4);
620 assert_eq!(parts[2].len(), 4);
621 assert_eq!(parts[3].len(), 4);
622 assert_eq!(parts[4].len(), 12);
623 }
624 }
625}