Skip to main content

reinhardt_testkit/
testcase.rs

1//! Base test case with common setup and assertions
2//!
3//! Similar to DRF's APITestCase
4
5use crate::client::APIClient;
6use crate::resource::AsyncTestResource;
7use std::sync::Arc;
8use thiserror::Error;
9use tokio::sync::RwLock;
10
11/// Error types that can occur during test teardown
12#[derive(Debug, Error)]
13pub enum TeardownError {
14	/// Failed to rollback one or more active transactions
15	#[error("Failed to rollback transactions: {0}")]
16	TransactionRollbackFailed(String),
17
18	/// Failed to close database connection
19	#[error("Failed to close database connection: {0}")]
20	ConnectionCloseFailed(String),
21
22	/// Failed to cleanup client state
23	#[error("Failed to cleanup client state: {0}")]
24	ClientCleanupFailed(String),
25}
26
27/// Handle for tracking active test transactions
28///
29/// This struct tracks transaction state for monitoring and cleanup purposes.
30/// Actual transaction management is handled by sqlx's Transaction type,
31/// which automatically rolls back uncommitted transactions when dropped.
32#[cfg(feature = "testcontainers")]
33#[derive(Debug, Clone)]
34pub struct TransactionHandle {
35	/// Unique identifier for the transaction
36	id: String,
37	/// Whether the transaction has been committed
38	committed: bool,
39}
40
41#[cfg(feature = "testcontainers")]
42impl TransactionHandle {
43	/// Create a new transaction handle with a unique ID
44	pub fn new() -> Self {
45		Self {
46			id: uuid::Uuid::now_v7().to_string(),
47			committed: false,
48		}
49	}
50
51	/// Get the transaction ID
52	pub fn id(&self) -> &str {
53		&self.id
54	}
55
56	/// Check if the transaction has been committed
57	pub fn is_committed(&self) -> bool {
58		self.committed
59	}
60
61	/// Mark the transaction as committed
62	pub fn mark_committed(&mut self) {
63		self.committed = true;
64	}
65}
66
67#[cfg(feature = "testcontainers")]
68impl Default for TransactionHandle {
69	fn default() -> Self {
70		Self::new()
71	}
72}
73
74/// Base test case for API testing
75///
76/// Provides:
77/// - Pre-configured APIClient
78/// - Automatic setup/teardown via AsyncTestResource
79/// - Assertion helpers
80/// - Optional TestContainer database integration
81///
82/// # Example
83/// ```rust,no_run
84/// # #[tokio::main]
85/// # async fn main() {
86/// use reinhardt_testkit::testcase::APITestCase;
87/// use reinhardt_testkit::resource::AsyncTeardownGuard;
88/// use rstest::*;
89///
90/// #[fixture]
91/// async fn api_test() -> AsyncTeardownGuard<APITestCase> {
92///     AsyncTeardownGuard::new().await
93/// }
94///
95/// #[rstest]
96/// #[tokio::test]
97/// async fn test_list_users(#[future] api_test: AsyncTeardownGuard<APITestCase>) {
98///     let case = api_test.await;
99///     let response = case.client().await.get("/api/users/").await.unwrap();
100///     response.assert_ok();
101/// }
102/// # }
103/// ```
104pub struct APITestCase {
105	client: Arc<RwLock<APIClient>>,
106	#[cfg(feature = "testcontainers")]
107	database_url: Arc<RwLock<Option<String>>>,
108	#[cfg(feature = "testcontainers")]
109	db_connection: Arc<RwLock<Option<sqlx::AnyPool>>>,
110	#[cfg(feature = "testcontainers")]
111	active_transactions: Arc<RwLock<Vec<TransactionHandle>>>,
112}
113
114impl APITestCase {
115	/// Get the database connection URL (if configured)
116	#[cfg(feature = "testcontainers")]
117	pub async fn database_url(&self) -> Option<String> {
118		self.database_url.read().await.clone()
119	}
120
121	/// Get the test client
122	pub async fn client(&self) -> tokio::sync::RwLockReadGuard<'_, APIClient> {
123		self.client.read().await
124	}
125
126	/// Get mutable access to the test client
127	pub async fn client_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, APIClient> {
128		self.client.write().await
129	}
130
131	/// Set the database URL (useful for TestContainers integration)
132	#[cfg(feature = "testcontainers")]
133	pub async fn set_database_url(&self, url: String) {
134		let mut db_url = self.database_url.write().await;
135		*db_url = Some(url);
136	}
137
138	/// Set the database connection pool
139	///
140	/// This method allows setting a pre-configured database connection pool
141	/// for use in tests. The pool will be properly closed during teardown.
142	///
143	/// # Example
144	/// ```rust,ignore
145	/// use sqlx::AnyPool;
146	///
147	/// let pool = AnyPool::connect("postgres://localhost/test").await?;
148	/// test_case.set_database_connection(pool).await;
149	/// ```
150	#[cfg(feature = "testcontainers")]
151	pub async fn set_database_connection(&self, pool: sqlx::AnyPool) {
152		let mut conn = self.db_connection.write().await;
153		*conn = Some(pool);
154	}
155
156	/// Get the database connection pool (if configured)
157	#[cfg(feature = "testcontainers")]
158	pub async fn db_connection(&self) -> Option<sqlx::AnyPool> {
159		self.db_connection.read().await.clone()
160	}
161
162	/// Begin a new tracked transaction
163	///
164	/// This method registers a new transaction handle for tracking purposes.
165	/// The actual sqlx::Transaction should be obtained from the pool directly.
166	/// The handle is used to track whether transactions are properly committed
167	/// or rolled back during teardown.
168	///
169	/// # Returns
170	/// A TransactionHandle that can be used to track the transaction state.
171	///
172	/// # Example
173	/// ```rust,ignore
174	/// let handle = test_case.begin_transaction().await;
175	/// // ... perform database operations with sqlx::Transaction ...
176	/// handle.mark_committed(); // Mark as committed if successful
177	/// ```
178	#[cfg(feature = "testcontainers")]
179	pub async fn begin_transaction(&self) -> TransactionHandle {
180		let handle = TransactionHandle::new();
181		let mut transactions = self.active_transactions.write().await;
182		transactions.push(handle.clone());
183		handle
184	}
185
186	/// Mark a transaction as committed by its ID
187	///
188	/// This removes the transaction from the active list, indicating
189	/// it was successfully committed and doesn't need rollback.
190	#[cfg(feature = "testcontainers")]
191	pub async fn commit_transaction(&self, transaction_id: &str) {
192		let mut transactions = self.active_transactions.write().await;
193		if let Some(pos) = transactions.iter().position(|t| t.id() == transaction_id) {
194			let mut handle = transactions.remove(pos);
195			handle.mark_committed();
196		}
197	}
198
199	/// Get the count of active (uncommitted) transactions
200	#[cfg(feature = "testcontainers")]
201	pub async fn active_transaction_count(&self) -> usize {
202		self.active_transactions.read().await.len()
203	}
204}
205
206#[async_trait::async_trait]
207impl AsyncTestResource for APITestCase {
208	async fn setup() -> Self {
209		Self {
210			client: Arc::new(RwLock::new(APIClient::new())),
211			#[cfg(feature = "testcontainers")]
212			database_url: Arc::new(RwLock::new(None)),
213			#[cfg(feature = "testcontainers")]
214			db_connection: Arc::new(RwLock::new(None)),
215			#[cfg(feature = "testcontainers")]
216			active_transactions: Arc::new(RwLock::new(Vec::new())),
217		}
218	}
219
220	async fn teardown(self) {
221		// Step 1: Clean up HTTP client state
222		{
223			let client = self.client.write().await;
224			client.cleanup().await;
225		}
226
227		// Step 2: Handle database cleanup (testcontainers feature only)
228		#[cfg(feature = "testcontainers")]
229		{
230			// Log any uncommitted transactions (they will be rolled back when pool closes)
231			let transactions = self.active_transactions.read().await;
232			let uncommitted_count = transactions.iter().filter(|t| !t.is_committed()).count();
233			if uncommitted_count > 0 {
234				// Uncommitted transactions will be automatically rolled back by sqlx
235				// when the pool is closed
236				tracing::debug!(
237					"Rolling back {} uncommitted transaction(s) during teardown",
238					uncommitted_count
239				);
240			}
241			drop(transactions);
242
243			// Close the database connection pool
244			let mut pool_guard = self.db_connection.write().await;
245			if let Some(pool) = pool_guard.take() {
246				// Close the pool gracefully - this will rollback any uncommitted transactions
247				pool.close().await;
248			}
249		}
250
251		// Step 3: Drop the client
252		drop(self.client);
253	}
254}
255
256/// Helper macro for defining test cases with automatic setup/teardown
257///
258/// # Example
259/// ```rust,no_run
260/// # #[tokio::main]
261/// # async fn main() {
262/// # use reinhardt_testkit::test_case;
263/// test_case! {
264///     async fn test_get_users(case: &APITestCase) {
265///         let client = case.client().await;
266///         let response = client.get("/api/users/").await.unwrap();
267///         response.assert_ok();
268///     }
269/// }
270/// # }
271/// ```
272#[macro_export]
273macro_rules! test_case {
274	(
275        async fn $name:ident($case:ident: &APITestCase) $body:block
276    ) => {
277		#[rstest::rstest]
278		#[tokio::test]
279		async fn $name() {
280			use $crate::resource::AsyncTeardownGuard;
281			use $crate::testcase::APITestCase;
282
283			let guard = AsyncTeardownGuard::<APITestCase>::new().await;
284			let $case = &*guard;
285
286			// Run test
287			$body
288
289			// guard is dropped here, teardown() is automatically called
290		}
291	};
292}
293
294/// Helper macro for defining authenticated test cases
295#[macro_export]
296macro_rules! authenticated_test_case {
297    (
298        async fn $name:ident($case:ident: &APITestCase, $user:ident: serde_json::Value) $body:block
299    ) => {
300        #[rstest::rstest]
301        #[tokio::test]
302        async fn $name() {
303            use $crate::resource::AsyncTeardownGuard;
304            use $crate::testcase::APITestCase;
305
306            let guard = AsyncTeardownGuard::<APITestCase>::new().await;
307            let $case = &*guard;
308
309            // Setup authentication
310            let $user = serde_json::json!({
311                "id": 1,
312                "username": "testuser",
313            });
314            {
315                let client = $case.client().await;
316                #[allow(deprecated)]
317                client.force_authenticate(Some($user.clone())).await;
318            }
319
320            // Run test
321            $body
322
323            // guard is dropped here, teardown() is automatically called
324        }
325    };
326}
327
328/// Helper macro for defining test cases with database containers
329///
330/// Requires `testcontainers` feature to be enabled.
331///
332/// This macro automatically sets up a PostgreSQL or MySQL container via TestContainers,
333/// initializes an `APITestCase` with the database URL, and ensures proper cleanup.
334///
335/// # Examples
336///
337/// ## PostgreSQL Example
338///
339/// ```rust,ignore
340/// use reinhardt_testkit::test_case_with_db;
341/// use reinhardt_testkit::testcase::APITestCase;
342///
343/// test_case_with_db! {
344///     postgres,
345///     async fn test_users_with_postgres(case: &APITestCase) {
346///         let db_url = case.database_url().await.unwrap();
347///         // Database URL is automatically set
348///         assert!(db_url.starts_with("postgres://"));
349///
350///         // Perform database operations...
351///     }
352/// }
353/// ```
354///
355/// ## MySQL Example
356///
357/// ```rust,ignore
358/// use reinhardt_testkit::test_case_with_db;
359/// use reinhardt_testkit::testcase::APITestCase;
360///
361/// test_case_with_db! {
362///     mysql,
363///     async fn test_users_with_mysql(case: &APITestCase) {
364///         let db_url = case.database_url().await.unwrap();
365///         // Database URL is automatically set
366///         assert!(db_url.starts_with("mysql://"));
367///
368///         // Perform database operations...
369///     }
370/// }
371/// ```
372#[cfg(feature = "testcontainers")]
373#[macro_export]
374macro_rules! test_case_with_db {
375    (
376        postgres,
377        async fn $name:ident($case:ident: &APITestCase) $body:block
378    ) => {
379        #[rstest::rstest]
380        #[tokio::test]
381        async fn $name() {
382            use $crate::containers::{with_postgres, PostgresContainer};
383            use $crate::resource::AsyncTeardownGuard;
384            use $crate::testcase::APITestCase;
385
386            with_postgres(|db| async move {
387                let guard = AsyncTeardownGuard::<APITestCase>::new().await;
388                let $case = &*guard;
389                $case.set_database_url(db.connection_url()).await;
390
391                // Run test
392                $body
393
394                // guard is dropped here, teardown() is automatically called
395                Ok(())
396            })
397            .await
398            .unwrap();
399        }
400    };
401    (
402        mysql,
403        async fn $name:ident($case:ident: &APITestCase) $body:block
404    ) => {
405        #[rstest::rstest]
406        #[tokio::test]
407        async fn $name() {
408            use $crate::containers::{with_mysql, MySqlContainer};
409            use $crate::resource::AsyncTeardownGuard;
410            use $crate::testcase::APITestCase;
411
412            with_mysql(|db| async move {
413                let guard = AsyncTeardownGuard::<APITestCase>::new().await;
414                let $case = &*guard;
415                $case.set_database_url(db.connection_url()).await;
416
417                // Run test
418                $body
419
420                // guard is dropped here, teardown() is automatically called
421                Ok(())
422            })
423            .await
424            .unwrap();
425        }
426    };
427}
428
429#[cfg(test)]
430mod tests {
431	use super::*;
432	use rstest::rstest;
433
434	// ========================================================================
435	// TeardownError Display tests
436	// ========================================================================
437
438	#[rstest]
439	fn test_teardown_error_transaction_rollback_display() {
440		// Arrange
441		let error = TeardownError::TransactionRollbackFailed("tx-123 failed".to_string());
442
443		// Act
444		let display = format!("{}", error);
445
446		// Assert
447		assert_eq!(display, "Failed to rollback transactions: tx-123 failed");
448	}
449
450	#[rstest]
451	fn test_teardown_error_connection_close_display() {
452		// Arrange
453		let error = TeardownError::ConnectionCloseFailed("connection refused".to_string());
454
455		// Act
456		let display = format!("{}", error);
457
458		// Assert
459		assert_eq!(
460			display,
461			"Failed to close database connection: connection refused"
462		);
463	}
464
465	#[rstest]
466	fn test_teardown_error_client_cleanup_display() {
467		// Arrange
468		let error = TeardownError::ClientCleanupFailed("timeout".to_string());
469
470		// Act
471		let display = format!("{}", error);
472
473		// Assert
474		assert_eq!(display, "Failed to cleanup client state: timeout");
475	}
476
477	#[rstest]
478	fn test_teardown_error_debug() {
479		// Arrange
480		let error = TeardownError::TransactionRollbackFailed("debug test".to_string());
481
482		// Act
483		let debug = format!("{:?}", error);
484
485		// Assert
486		assert!(
487			debug.contains("debug test"),
488			"Debug output should contain the message, got: {}",
489			debug
490		);
491	}
492
493	// ========================================================================
494	// APITestCase tests
495	// ========================================================================
496
497	#[rstest]
498	#[tokio::test]
499	async fn test_api_test_case_setup_creates_client() {
500		// Arrange & Act
501		let test_case = APITestCase::setup().await;
502
503		// Assert
504		let client = test_case.client().await;
505		// Verify we can access the client (read guard obtained successfully)
506		drop(client);
507	}
508
509	#[rstest]
510	#[tokio::test]
511	async fn test_api_test_case_client_read_access() {
512		// Arrange
513		let test_case = APITestCase::setup().await;
514
515		// Act
516		let client = test_case.client().await;
517
518		// Assert
519		// Successfully obtained read guard - the client is accessible
520		assert!(
521			std::mem::size_of_val(&*client) > 0,
522			"Client should have non-zero size"
523		);
524	}
525
526	#[rstest]
527	#[tokio::test]
528	async fn test_api_test_case_teardown_completes() {
529		// Arrange
530		let test_case = APITestCase::setup().await;
531
532		// Act & Assert
533		// teardown should complete without panicking
534		test_case.teardown().await;
535	}
536
537	#[rstest]
538	#[tokio::test]
539	async fn test_api_test_case_multiple_reads() {
540		// Arrange
541		let test_case = APITestCase::setup().await;
542
543		// Act
544		let client1 = test_case.client().await;
545		let client2 = test_case.client().await;
546
547		// Assert
548		// Both read guards should be held concurrently without deadlock
549		assert!(
550			std::mem::size_of_val(&*client1) > 0,
551			"First client read should succeed"
552		);
553		assert!(
554			std::mem::size_of_val(&*client2) > 0,
555			"Second client read should succeed"
556		);
557	}
558
559	// ========================================================================
560	// TransactionHandle tests (testcontainers feature)
561	// ========================================================================
562
563	#[cfg(feature = "testcontainers")]
564	mod testcontainers_tests {
565		use super::*;
566		use rstest::rstest;
567
568		#[rstest]
569		fn test_transaction_handle_new() {
570			// Arrange & Act
571			let handle = TransactionHandle::new();
572
573			// Assert
574			assert!(!handle.id().is_empty(), "ID should not be empty");
575			assert!(!handle.is_committed(), "New handle should not be committed");
576		}
577
578		#[rstest]
579		fn test_transaction_handle_mark_committed() {
580			// Arrange
581			let mut handle = TransactionHandle::new();
582
583			// Act
584			handle.mark_committed();
585
586			// Assert
587			assert!(handle.is_committed());
588		}
589
590		#[rstest]
591		fn test_transaction_handle_default() {
592			// Arrange & Act
593			let handle = TransactionHandle::default();
594
595			// Assert
596			assert!(!handle.id().is_empty(), "Default ID should not be empty");
597			assert!(
598				!handle.is_committed(),
599				"Default handle should not be committed"
600			);
601		}
602
603		#[rstest]
604		fn test_transaction_handle_id_is_uuid() {
605			// Arrange & Act
606			let handle = TransactionHandle::new();
607
608			// Assert
609			let id = handle.id();
610			// UUID v4 format: 8-4-4-4-12 hex characters
611			let parts: Vec<&str> = id.split('-').collect();
612			assert_eq!(
613				parts.len(),
614				5,
615				"UUID should have 5 parts separated by hyphens, got: {}",
616				id
617			);
618			assert_eq!(parts[0].len(), 8);
619			assert_eq!(parts[1].len(), 4);
620			assert_eq!(parts[2].len(), 4);
621			assert_eq!(parts[3].len(), 4);
622			assert_eq!(parts[4].len(), 12);
623		}
624	}
625}