Skip to main content

reinhardt_testkit/
testcase.rs

1//! Base test case with common setup and assertions
2//!
3//! Similar to DRF's APITestCase
4
5use crate::client::APIClient;
6use crate::resource::AsyncTestResource;
7use std::sync::Arc;
8use thiserror::Error;
9use tokio::sync::RwLock;
10
11/// Error types that can occur during test teardown
12#[derive(Debug, Error)]
13pub enum TeardownError {
14	/// Failed to rollback one or more active transactions
15	#[error("Failed to rollback transactions: {0}")]
16	TransactionRollbackFailed(String),
17
18	/// Failed to close database connection
19	#[error("Failed to close database connection: {0}")]
20	ConnectionCloseFailed(String),
21
22	/// Failed to cleanup client state
23	#[error("Failed to cleanup client state: {0}")]
24	ClientCleanupFailed(String),
25}
26
27/// Handle for tracking active test transactions
28///
29/// This struct tracks transaction state for monitoring and cleanup purposes.
30/// Actual transaction management is handled by sqlx's Transaction type,
31/// which automatically rolls back uncommitted transactions when dropped.
32#[cfg(feature = "testcontainers")]
33#[derive(Debug, Clone)]
34pub struct TransactionHandle {
35	/// Unique identifier for the transaction
36	id: String,
37	/// Whether the transaction has been committed
38	committed: bool,
39}
40
41#[cfg(feature = "testcontainers")]
42impl TransactionHandle {
43	/// Create a new transaction handle with a unique ID
44	pub fn new() -> Self {
45		Self {
46			id: uuid::Uuid::new_v4().to_string(),
47			committed: false,
48		}
49	}
50
51	/// Get the transaction ID
52	pub fn id(&self) -> &str {
53		&self.id
54	}
55
56	/// Check if the transaction has been committed
57	pub fn is_committed(&self) -> bool {
58		self.committed
59	}
60
61	/// Mark the transaction as committed
62	pub fn mark_committed(&mut self) {
63		self.committed = true;
64	}
65}
66
67#[cfg(feature = "testcontainers")]
68impl Default for TransactionHandle {
69	fn default() -> Self {
70		Self::new()
71	}
72}
73
74/// Base test case for API testing
75///
76/// Provides:
77/// - Pre-configured APIClient
78/// - Automatic setup/teardown via AsyncTestResource
79/// - Assertion helpers
80/// - Optional TestContainer database integration
81///
82/// # Example
83/// ```rust,no_run
84/// # #[tokio::main]
85/// # async fn main() {
86/// use reinhardt_testkit::testcase::APITestCase;
87/// use reinhardt_testkit::resource::AsyncTeardownGuard;
88/// use rstest::*;
89///
90/// #[fixture]
91/// async fn api_test() -> AsyncTeardownGuard<APITestCase> {
92///     AsyncTeardownGuard::new().await
93/// }
94///
95/// #[rstest]
96/// #[tokio::test]
97/// async fn test_list_users(#[future] api_test: AsyncTeardownGuard<APITestCase>) {
98///     let case = api_test.await;
99///     let response = case.client().await.get("/api/users/").await.unwrap();
100///     response.assert_ok();
101/// }
102/// # }
103/// ```
104pub struct APITestCase {
105	client: Arc<RwLock<APIClient>>,
106	#[cfg(feature = "testcontainers")]
107	database_url: Arc<RwLock<Option<String>>>,
108	#[cfg(feature = "testcontainers")]
109	db_connection: Arc<RwLock<Option<sqlx::AnyPool>>>,
110	#[cfg(feature = "testcontainers")]
111	active_transactions: Arc<RwLock<Vec<TransactionHandle>>>,
112}
113
114impl APITestCase {
115	/// Get the database connection URL (if configured)
116	#[cfg(feature = "testcontainers")]
117	pub async fn database_url(&self) -> Option<String> {
118		self.database_url.read().await.clone()
119	}
120
121	/// Get the test client
122	pub async fn client(&self) -> tokio::sync::RwLockReadGuard<'_, APIClient> {
123		self.client.read().await
124	}
125
126	/// Get mutable access to the test client
127	pub async fn client_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, APIClient> {
128		self.client.write().await
129	}
130
131	/// Set the database URL (useful for TestContainers integration)
132	#[cfg(feature = "testcontainers")]
133	pub async fn set_database_url(&self, url: String) {
134		let mut db_url = self.database_url.write().await;
135		*db_url = Some(url);
136	}
137
138	/// Set the database connection pool
139	///
140	/// This method allows setting a pre-configured database connection pool
141	/// for use in tests. The pool will be properly closed during teardown.
142	///
143	/// # Example
144	/// ```rust,ignore
145	/// use sqlx::AnyPool;
146	///
147	/// let pool = AnyPool::connect("postgres://localhost/test").await?;
148	/// test_case.set_database_connection(pool).await;
149	/// ```
150	#[cfg(feature = "testcontainers")]
151	pub async fn set_database_connection(&self, pool: sqlx::AnyPool) {
152		let mut conn = self.db_connection.write().await;
153		*conn = Some(pool);
154	}
155
156	/// Get the database connection pool (if configured)
157	#[cfg(feature = "testcontainers")]
158	pub async fn db_connection(&self) -> Option<sqlx::AnyPool> {
159		self.db_connection.read().await.clone()
160	}
161
162	/// Begin a new tracked transaction
163	///
164	/// This method registers a new transaction handle for tracking purposes.
165	/// The actual sqlx::Transaction should be obtained from the pool directly.
166	/// The handle is used to track whether transactions are properly committed
167	/// or rolled back during teardown.
168	///
169	/// # Returns
170	/// A TransactionHandle that can be used to track the transaction state.
171	///
172	/// # Example
173	/// ```rust,ignore
174	/// let handle = test_case.begin_transaction().await;
175	/// // ... perform database operations with sqlx::Transaction ...
176	/// handle.mark_committed(); // Mark as committed if successful
177	/// ```
178	#[cfg(feature = "testcontainers")]
179	pub async fn begin_transaction(&self) -> TransactionHandle {
180		let handle = TransactionHandle::new();
181		let mut transactions = self.active_transactions.write().await;
182		transactions.push(handle.clone());
183		handle
184	}
185
186	/// Mark a transaction as committed by its ID
187	///
188	/// This removes the transaction from the active list, indicating
189	/// it was successfully committed and doesn't need rollback.
190	#[cfg(feature = "testcontainers")]
191	pub async fn commit_transaction(&self, transaction_id: &str) {
192		let mut transactions = self.active_transactions.write().await;
193		if let Some(pos) = transactions.iter().position(|t| t.id() == transaction_id) {
194			let mut handle = transactions.remove(pos);
195			handle.mark_committed();
196		}
197	}
198
199	/// Get the count of active (uncommitted) transactions
200	#[cfg(feature = "testcontainers")]
201	pub async fn active_transaction_count(&self) -> usize {
202		self.active_transactions.read().await.len()
203	}
204}
205
206#[async_trait::async_trait]
207impl AsyncTestResource for APITestCase {
208	async fn setup() -> Self {
209		Self {
210			client: Arc::new(RwLock::new(APIClient::new())),
211			#[cfg(feature = "testcontainers")]
212			database_url: Arc::new(RwLock::new(None)),
213			#[cfg(feature = "testcontainers")]
214			db_connection: Arc::new(RwLock::new(None)),
215			#[cfg(feature = "testcontainers")]
216			active_transactions: Arc::new(RwLock::new(Vec::new())),
217		}
218	}
219
220	async fn teardown(self) {
221		// Step 1: Clean up HTTP client state
222		{
223			let client = self.client.write().await;
224			client.cleanup().await;
225		}
226
227		// Step 2: Handle database cleanup (testcontainers feature only)
228		#[cfg(feature = "testcontainers")]
229		{
230			// Log any uncommitted transactions (they will be rolled back when pool closes)
231			let transactions = self.active_transactions.read().await;
232			let uncommitted_count = transactions.iter().filter(|t| !t.is_committed()).count();
233			if uncommitted_count > 0 {
234				// Uncommitted transactions will be automatically rolled back by sqlx
235				// when the pool is closed
236				tracing::debug!(
237					"Rolling back {} uncommitted transaction(s) during teardown",
238					uncommitted_count
239				);
240			}
241			drop(transactions);
242
243			// Close the database connection pool
244			let mut pool_guard = self.db_connection.write().await;
245			if let Some(pool) = pool_guard.take() {
246				// Close the pool gracefully - this will rollback any uncommitted transactions
247				pool.close().await;
248			}
249		}
250
251		// Step 3: Drop the client
252		drop(self.client);
253	}
254}
255
256/// Helper macro for defining test cases with automatic setup/teardown
257///
258/// # Example
259/// ```rust,no_run
260/// # #[tokio::main]
261/// # async fn main() {
262/// # use reinhardt_testkit::test_case;
263/// test_case! {
264///     async fn test_get_users(case: &APITestCase) {
265///         let client = case.client().await;
266///         let response = client.get("/api/users/").await.unwrap();
267///         response.assert_ok();
268///     }
269/// }
270/// # }
271/// ```
272#[macro_export]
273macro_rules! test_case {
274	(
275        async fn $name:ident($case:ident: &APITestCase) $body:block
276    ) => {
277		#[rstest::rstest]
278		#[tokio::test]
279		async fn $name() {
280			use $crate::resource::AsyncTeardownGuard;
281			use $crate::testcase::APITestCase;
282
283			let guard = AsyncTeardownGuard::<APITestCase>::new().await;
284			let $case = &*guard;
285
286			// Run test
287			$body
288
289			// guard is dropped here, teardown() is automatically called
290		}
291	};
292}
293
294/// Helper macro for defining authenticated test cases
295#[macro_export]
296macro_rules! authenticated_test_case {
297    (
298        async fn $name:ident($case:ident: &APITestCase, $user:ident: serde_json::Value) $body:block
299    ) => {
300        #[rstest::rstest]
301        #[tokio::test]
302        async fn $name() {
303            use $crate::resource::AsyncTeardownGuard;
304            use $crate::testcase::APITestCase;
305
306            let guard = AsyncTeardownGuard::<APITestCase>::new().await;
307            let $case = &*guard;
308
309            // Setup authentication
310            let $user = serde_json::json!({
311                "id": 1,
312                "username": "testuser",
313            });
314            {
315                let client = $case.client().await;
316                client.force_authenticate(Some($user.clone())).await;
317            }
318
319            // Run test
320            $body
321
322            // guard is dropped here, teardown() is automatically called
323        }
324    };
325}
326
327/// Helper macro for defining test cases with database containers
328///
329/// Requires `testcontainers` feature to be enabled.
330///
331/// This macro automatically sets up a PostgreSQL or MySQL container via TestContainers,
332/// initializes an `APITestCase` with the database URL, and ensures proper cleanup.
333///
334/// # Examples
335///
336/// ## PostgreSQL Example
337///
338/// ```rust,ignore
339/// use reinhardt_testkit::test_case_with_db;
340/// use reinhardt_testkit::testcase::APITestCase;
341///
342/// test_case_with_db! {
343///     postgres,
344///     async fn test_users_with_postgres(case: &APITestCase) {
345///         let db_url = case.database_url().await.unwrap();
346///         // Database URL is automatically set
347///         assert!(db_url.starts_with("postgres://"));
348///
349///         // Perform database operations...
350///     }
351/// }
352/// ```
353///
354/// ## MySQL Example
355///
356/// ```rust,ignore
357/// use reinhardt_testkit::test_case_with_db;
358/// use reinhardt_testkit::testcase::APITestCase;
359///
360/// test_case_with_db! {
361///     mysql,
362///     async fn test_users_with_mysql(case: &APITestCase) {
363///         let db_url = case.database_url().await.unwrap();
364///         // Database URL is automatically set
365///         assert!(db_url.starts_with("mysql://"));
366///
367///         // Perform database operations...
368///     }
369/// }
370/// ```
371#[cfg(feature = "testcontainers")]
372#[macro_export]
373macro_rules! test_case_with_db {
374    (
375        postgres,
376        async fn $name:ident($case:ident: &APITestCase) $body:block
377    ) => {
378        #[rstest::rstest]
379        #[tokio::test]
380        async fn $name() {
381            use $crate::containers::{with_postgres, PostgresContainer};
382            use $crate::resource::AsyncTeardownGuard;
383            use $crate::testcase::APITestCase;
384
385            with_postgres(|db| async move {
386                let guard = AsyncTeardownGuard::<APITestCase>::new().await;
387                let $case = &*guard;
388                $case.set_database_url(db.connection_url()).await;
389
390                // Run test
391                $body
392
393                // guard is dropped here, teardown() is automatically called
394                Ok(())
395            })
396            .await
397            .unwrap();
398        }
399    };
400    (
401        mysql,
402        async fn $name:ident($case:ident: &APITestCase) $body:block
403    ) => {
404        #[rstest::rstest]
405        #[tokio::test]
406        async fn $name() {
407            use $crate::containers::{with_mysql, MySqlContainer};
408            use $crate::resource::AsyncTeardownGuard;
409            use $crate::testcase::APITestCase;
410
411            with_mysql(|db| async move {
412                let guard = AsyncTeardownGuard::<APITestCase>::new().await;
413                let $case = &*guard;
414                $case.set_database_url(db.connection_url()).await;
415
416                // Run test
417                $body
418
419                // guard is dropped here, teardown() is automatically called
420                Ok(())
421            })
422            .await
423            .unwrap();
424        }
425    };
426}
427
428#[cfg(test)]
429mod tests {
430	use super::*;
431	use rstest::rstest;
432
433	// ========================================================================
434	// TeardownError Display tests
435	// ========================================================================
436
437	#[rstest]
438	fn test_teardown_error_transaction_rollback_display() {
439		// Arrange
440		let error = TeardownError::TransactionRollbackFailed("tx-123 failed".to_string());
441
442		// Act
443		let display = format!("{}", error);
444
445		// Assert
446		assert_eq!(display, "Failed to rollback transactions: tx-123 failed");
447	}
448
449	#[rstest]
450	fn test_teardown_error_connection_close_display() {
451		// Arrange
452		let error = TeardownError::ConnectionCloseFailed("connection refused".to_string());
453
454		// Act
455		let display = format!("{}", error);
456
457		// Assert
458		assert_eq!(
459			display,
460			"Failed to close database connection: connection refused"
461		);
462	}
463
464	#[rstest]
465	fn test_teardown_error_client_cleanup_display() {
466		// Arrange
467		let error = TeardownError::ClientCleanupFailed("timeout".to_string());
468
469		// Act
470		let display = format!("{}", error);
471
472		// Assert
473		assert_eq!(display, "Failed to cleanup client state: timeout");
474	}
475
476	#[rstest]
477	fn test_teardown_error_debug() {
478		// Arrange
479		let error = TeardownError::TransactionRollbackFailed("debug test".to_string());
480
481		// Act
482		let debug = format!("{:?}", error);
483
484		// Assert
485		assert!(
486			debug.contains("debug test"),
487			"Debug output should contain the message, got: {}",
488			debug
489		);
490	}
491
492	// ========================================================================
493	// APITestCase tests
494	// ========================================================================
495
496	#[rstest]
497	#[tokio::test]
498	async fn test_api_test_case_setup_creates_client() {
499		// Arrange & Act
500		let test_case = APITestCase::setup().await;
501
502		// Assert
503		let client = test_case.client().await;
504		// Verify we can access the client (read guard obtained successfully)
505		drop(client);
506	}
507
508	#[rstest]
509	#[tokio::test]
510	async fn test_api_test_case_client_read_access() {
511		// Arrange
512		let test_case = APITestCase::setup().await;
513
514		// Act
515		let client = test_case.client().await;
516
517		// Assert
518		// Successfully obtained read guard - the client is accessible
519		assert!(
520			std::mem::size_of_val(&*client) > 0,
521			"Client should have non-zero size"
522		);
523	}
524
525	#[rstest]
526	#[tokio::test]
527	async fn test_api_test_case_teardown_completes() {
528		// Arrange
529		let test_case = APITestCase::setup().await;
530
531		// Act & Assert
532		// teardown should complete without panicking
533		test_case.teardown().await;
534	}
535
536	#[rstest]
537	#[tokio::test]
538	async fn test_api_test_case_multiple_reads() {
539		// Arrange
540		let test_case = APITestCase::setup().await;
541
542		// Act
543		let client1 = test_case.client().await;
544		let client2 = test_case.client().await;
545
546		// Assert
547		// Both read guards should be held concurrently without deadlock
548		assert!(
549			std::mem::size_of_val(&*client1) > 0,
550			"First client read should succeed"
551		);
552		assert!(
553			std::mem::size_of_val(&*client2) > 0,
554			"Second client read should succeed"
555		);
556	}
557
558	// ========================================================================
559	// TransactionHandle tests (testcontainers feature)
560	// ========================================================================
561
562	#[cfg(feature = "testcontainers")]
563	mod testcontainers_tests {
564		use super::*;
565		use rstest::rstest;
566
567		#[rstest]
568		fn test_transaction_handle_new() {
569			// Arrange & Act
570			let handle = TransactionHandle::new();
571
572			// Assert
573			assert!(!handle.id().is_empty(), "ID should not be empty");
574			assert!(!handle.is_committed(), "New handle should not be committed");
575		}
576
577		#[rstest]
578		fn test_transaction_handle_mark_committed() {
579			// Arrange
580			let mut handle = TransactionHandle::new();
581
582			// Act
583			handle.mark_committed();
584
585			// Assert
586			assert!(handle.is_committed());
587		}
588
589		#[rstest]
590		fn test_transaction_handle_default() {
591			// Arrange & Act
592			let handle = TransactionHandle::default();
593
594			// Assert
595			assert!(!handle.id().is_empty(), "Default ID should not be empty");
596			assert!(
597				!handle.is_committed(),
598				"Default handle should not be committed"
599			);
600		}
601
602		#[rstest]
603		fn test_transaction_handle_id_is_uuid() {
604			// Arrange & Act
605			let handle = TransactionHandle::new();
606
607			// Assert
608			let id = handle.id();
609			// UUID v4 format: 8-4-4-4-12 hex characters
610			let parts: Vec<&str> = id.split('-').collect();
611			assert_eq!(
612				parts.len(),
613				5,
614				"UUID should have 5 parts separated by hyphens, got: {}",
615				id
616			);
617			assert_eq!(parts[0].len(), 8);
618			assert_eq!(parts[1].len(), 4);
619			assert_eq!(parts[2].len(), 4);
620			assert_eq!(parts[3].len(), 4);
621			assert_eq!(parts[4].len(), 12);
622		}
623	}
624}