Skip to main content

reinhardt_testkit/server_fn/
transaction.rs

1//! Transaction management utilities for server function testing.
2//!
3//! This module provides utilities for managing database transactions in tests,
4//! including automatic rollback to ensure test isolation.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use reinhardt_testkit::server_fn::transaction::TestTransaction;
10//!
11//! #[rstest]
12//! #[tokio::test]
13//! async fn test_with_rollback(postgres_suite: SuiteGuard<PostgresSuiteResource>) {
14//!     let tx = TestTransaction::begin(postgres_suite.pool()).await.unwrap();
15//!
16//!     // Perform database operations using tx.connection()
17//!     // All changes will be rolled back when tx is dropped
18//! }
19//! ```
20
21#![cfg(native)]
22
23use std::ops::Deref;
24
25use async_trait::async_trait;
26
27/// A test transaction that automatically rolls back on drop.
28///
29/// This provides a way to run database tests in isolation without
30/// affecting other tests or leaving test data in the database.
31///
32/// # Behavior
33///
34/// When `TestTransaction` is dropped, it will automatically rollback
35/// all changes made during the test. This ensures that:
36/// - Tests don't affect each other
37/// - No test data is left in the database
38/// - Tests can be run in parallel without conflicts
39///
40/// # Example
41///
42/// ```rust,ignore
43/// let tx = TestTransaction::begin(&pool).await?;
44///
45/// // All operations here happen within the transaction
46/// sqlx::query("INSERT INTO users (name) VALUES ('test')")
47///     .execute(tx.connection())
48///     .await?;
49///
50/// // When tx goes out of scope, all changes are rolled back
51/// ```
52#[derive(Debug)]
53pub struct TestTransaction<C> {
54	/// The underlying connection or transaction handle.
55	connection: Option<C>,
56	/// Whether to commit instead of rollback.
57	commit_on_drop: bool,
58	/// Whether the transaction has been explicitly completed.
59	completed: bool,
60}
61
62impl<C> TestTransaction<C> {
63	/// Create a new test transaction wrapper.
64	pub fn new(connection: C) -> Self {
65		Self {
66			connection: Some(connection),
67			commit_on_drop: false,
68			completed: false,
69		}
70	}
71
72	/// Configure the transaction to commit on drop instead of rollback.
73	///
74	/// Use this with caution - it defeats the purpose of test isolation.
75	pub fn commit_on_drop(mut self) -> Self {
76		self.commit_on_drop = true;
77		self
78	}
79
80	/// Get a reference to the underlying connection.
81	pub fn connection(&self) -> &C {
82		self.connection
83			.as_ref()
84			.expect("connection already consumed")
85	}
86
87	/// Get a mutable reference to the underlying connection.
88	pub fn connection_mut(&mut self) -> &mut C {
89		self.connection
90			.as_mut()
91			.expect("connection already consumed")
92	}
93
94	/// Consume the transaction and return the underlying connection.
95	///
96	/// Note: This prevents automatic rollback/commit on drop.
97	pub fn into_inner(mut self) -> C {
98		self.completed = true;
99		let connection = self.connection.take().expect("connection already consumed");
100		std::mem::forget(self);
101		connection
102	}
103
104	/// Mark the transaction as completed (no rollback on drop).
105	pub fn mark_completed(&mut self) {
106		self.completed = true;
107	}
108}
109
110impl<C> Deref for TestTransaction<C> {
111	type Target = C;
112
113	fn deref(&self) -> &Self::Target {
114		self.connection
115			.as_ref()
116			.expect("connection already consumed")
117	}
118}
119
120/// Trait for types that can be used as test transaction connections.
121#[async_trait]
122pub trait TestConnectionExt: Sized {
123	/// The error type.
124	type Error;
125
126	/// Begin a new transaction that will rollback on drop.
127	async fn begin_test_transaction(self) -> Result<TestTransaction<Self>, Self::Error>;
128
129	/// Commit the transaction explicitly.
130	async fn commit_transaction(self) -> Result<(), Self::Error>;
131
132	/// Rollback the transaction explicitly.
133	async fn rollback_transaction(self) -> Result<(), Self::Error>;
134}
135
136/// Wrapper for managing savepoints within a test.
137///
138/// Savepoints allow you to create checkpoints within a transaction
139/// and rollback to them without rolling back the entire transaction.
140#[derive(Debug)]
141pub struct TestSavepoint {
142	/// The savepoint name.
143	pub name: String,
144	/// Whether the savepoint has been released.
145	released: bool,
146}
147
148impl TestSavepoint {
149	/// Create a new savepoint with the given name.
150	pub fn new(name: impl Into<String>) -> Self {
151		Self {
152			name: name.into(),
153			released: false,
154		}
155	}
156
157	/// Generate a unique savepoint name.
158	pub fn generate() -> Self {
159		Self::new(format!("sp_{}", uuid::Uuid::now_v7().simple()))
160	}
161
162	/// Mark the savepoint as released.
163	pub fn mark_released(&mut self) {
164		self.released = true;
165	}
166
167	/// Check if the savepoint has been released.
168	pub fn is_released(&self) -> bool {
169		self.released
170	}
171}
172
173// Fixes #872: Migrate SQL utility functions to use reinhardt-query instead of
174// raw string interpolation to prevent SQL injection.
175/// Test database utilities for common operations.
176pub mod utils {
177	use reinhardt_query::{
178		Alias, ColumnRef, Expr, Iden, PostgresQueryBuilder, Query, QueryStatementBuilder,
179	};
180
181	/// Quote an identifier for PostgreSQL using `reinhardt_query::Iden`.
182	fn quote_ident(name: &str) -> String {
183		let alias = Alias::new(name);
184		let mut buf = String::new();
185		alias.quoted('"', &mut buf);
186		buf
187	}
188
189	/// Truncate all tables in the given list.
190	///
191	/// This is useful for cleaning up between tests when not using
192	/// transaction rollback.
193	///
194	/// Note: reinhardt-query does not natively support TRUNCATE, so this uses
195	/// properly quoted identifiers via `Alias`.
196	pub fn truncate_tables_sql(tables: &[&str]) -> String {
197		if tables.is_empty() {
198			return String::new();
199		}
200
201		let quoted_tables: Vec<String> = tables
202			.iter()
203			.map(|t| {
204				// Use a SELECT query to get the properly quoted identifier
205				let query = Query::select()
206					.column(ColumnRef::asterisk())
207					.from(Alias::new(*t))
208					.to_string(PostgresQueryBuilder);
209				// Extract quoted table name from "SELECT * FROM <table>"
210				query
211					.strip_prefix("SELECT * FROM ")
212					.unwrap_or(t)
213					.to_string()
214			})
215			.collect();
216
217		format!(
218			"TRUNCATE TABLE {} RESTART IDENTITY CASCADE",
219			quoted_tables.join(", ")
220		)
221	}
222
223	/// Generate a DELETE statement for cleaning up a table.
224	pub fn delete_from_sql(table: &str, where_clause: Option<&str>) -> String {
225		let mut query = Query::delete();
226		query.from_table(Alias::new(table));
227
228		if let Some(clause) = where_clause {
229			query.and_where(Expr::cust(clause.to_string()));
230		}
231
232		query.to_string(PostgresQueryBuilder)
233	}
234
235	/// Generate an INSERT statement for test data.
236	///
237	/// Values are inserted as raw SQL expressions (e.g., `'Alice'`, `NOW()`),
238	/// so they are NOT parameterised. Table and column names are properly quoted
239	/// via `reinhardt_query::Iden`.
240	pub fn insert_test_data_sql(table: &str, columns: &[&str], values: &[&str]) -> String {
241		let quoted_table = quote_ident(table);
242		let quoted_cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
243		format!(
244			"INSERT INTO {} ({}) VALUES ({})",
245			quoted_table,
246			quoted_cols.join(", "),
247			values.join(", ")
248		)
249	}
250}
251
252/// Configuration for test database behavior.
253#[derive(Debug, Clone)]
254pub struct TestDatabaseConfig {
255	/// Tables to truncate before each test (if not using transactions).
256	pub truncate_tables: Vec<String>,
257	/// Whether to use transactions for test isolation.
258	pub use_transactions: bool,
259	/// Maximum number of connections for the test pool.
260	pub max_connections: u32,
261	/// Connection timeout in seconds.
262	pub connection_timeout_secs: u64,
263}
264
265impl Default for TestDatabaseConfig {
266	fn default() -> Self {
267		Self {
268			truncate_tables: Vec::new(),
269			use_transactions: true,
270			max_connections: 5,
271			connection_timeout_secs: 30,
272		}
273	}
274}
275
276impl TestDatabaseConfig {
277	/// Create a new configuration.
278	pub fn new() -> Self {
279		Self::default()
280	}
281
282	/// Add a table to truncate.
283	pub fn truncate(mut self, table: impl Into<String>) -> Self {
284		self.truncate_tables.push(table.into());
285		self
286	}
287
288	/// Disable transaction-based isolation.
289	pub fn without_transactions(mut self) -> Self {
290		self.use_transactions = false;
291		self
292	}
293
294	/// Set the maximum number of connections.
295	pub fn max_connections(mut self, count: u32) -> Self {
296		self.max_connections = count;
297		self
298	}
299
300	/// Set the connection timeout.
301	pub fn connection_timeout(mut self, secs: u64) -> Self {
302		self.connection_timeout_secs = secs;
303		self
304	}
305}
306
307/// Helper for seeding test data.
308///
309/// This provides a fluent interface for inserting test data
310/// that will be available during the test.
311#[derive(Debug, Default)]
312pub struct TestDataSeeder {
313	/// SQL statements to execute for seeding.
314	statements: Vec<String>,
315}
316
317impl TestDataSeeder {
318	/// Create a new seeder.
319	pub fn new() -> Self {
320		Self::default()
321	}
322
323	/// Add a raw SQL statement.
324	pub fn sql(mut self, statement: impl Into<String>) -> Self {
325		self.statements.push(statement.into());
326		self
327	}
328
329	/// Add an INSERT statement.
330	pub fn insert(self, table: &str, columns: &[&str], values: &[&str]) -> Self {
331		self.sql(utils::insert_test_data_sql(table, columns, values))
332	}
333
334	/// Get all statements to execute.
335	pub fn statements(&self) -> &[String] {
336		&self.statements
337	}
338
339	/// Build the combined SQL for all seed operations.
340	pub fn build(&self) -> String {
341		self.statements.join(";\n")
342	}
343}
344
345/// Guard that ensures cleanup runs at the end of a test.
346///
347/// This is useful for tests that need to clean up resources
348/// even if the test panics.
349pub struct CleanupGuard<F: FnOnce()> {
350	cleanup: Option<F>,
351}
352
353impl<F: FnOnce()> CleanupGuard<F> {
354	/// Create a new cleanup guard.
355	pub fn new(cleanup: F) -> Self {
356		Self {
357			cleanup: Some(cleanup),
358		}
359	}
360
361	/// Disarm the guard (don't run cleanup on drop).
362	pub fn disarm(&mut self) {
363		self.cleanup = None;
364	}
365}
366
367impl<F: FnOnce()> Drop for CleanupGuard<F> {
368	fn drop(&mut self) {
369		if let Some(cleanup) = self.cleanup.take() {
370			cleanup();
371		}
372	}
373}
374
375#[cfg(test)]
376mod tests {
377	use super::*;
378
379	#[test]
380	fn test_truncate_tables_sql() {
381		let sql = utils::truncate_tables_sql(&["users", "posts"]);
382		assert!(sql.contains("TRUNCATE TABLE"));
383		assert!(sql.contains("\"users\""));
384		assert!(sql.contains("\"posts\""));
385		assert!(sql.contains("CASCADE"));
386	}
387
388	#[test]
389	fn test_truncate_tables_sql_empty() {
390		let sql = utils::truncate_tables_sql(&[]);
391		assert!(sql.is_empty());
392	}
393
394	#[test]
395	fn test_delete_from_sql() {
396		let sql = utils::delete_from_sql("users", None);
397		assert_eq!(sql, "DELETE FROM \"users\"");
398
399		let sql_with_where = utils::delete_from_sql("users", Some("id = 1"));
400		assert_eq!(sql_with_where, "DELETE FROM \"users\" WHERE id = 1");
401	}
402
403	#[test]
404	fn test_insert_test_data_sql() {
405		let sql = utils::insert_test_data_sql(
406			"users",
407			&["name", "email"],
408			&["'Alice'", "'alice@example.com'"],
409		);
410		assert!(sql.contains("INSERT INTO \"users\""));
411		assert!(sql.contains("\"name\""));
412		assert!(sql.contains("\"email\""));
413		assert!(sql.contains("'Alice'"));
414	}
415
416	#[test]
417	fn test_database_config() {
418		let config = TestDatabaseConfig::new()
419			.truncate("users")
420			.truncate("posts")
421			.max_connections(10)
422			.connection_timeout(60);
423
424		assert_eq!(config.truncate_tables.len(), 2);
425		assert_eq!(config.max_connections, 10);
426		assert_eq!(config.connection_timeout_secs, 60);
427	}
428
429	#[test]
430	fn test_data_seeder() {
431		let seeder = TestDataSeeder::new()
432			.insert("users", &["name"], &["'Alice'"])
433			.insert("posts", &["title", "user_id"], &["'Hello'", "1"]);
434
435		assert_eq!(seeder.statements().len(), 2);
436	}
437
438	#[test]
439	fn test_savepoint() {
440		let sp = TestSavepoint::generate();
441		assert!(sp.name.starts_with("sp_"));
442		assert!(!sp.is_released());
443
444		let mut sp2 = TestSavepoint::new("my_savepoint");
445		sp2.mark_released();
446		assert!(sp2.is_released());
447	}
448
449	#[test]
450	fn test_cleanup_guard() {
451		use std::cell::RefCell;
452		use std::rc::Rc;
453
454		let cleaned = Rc::new(RefCell::new(false));
455		let cleaned_clone = cleaned.clone();
456
457		{
458			let _guard = CleanupGuard::new(move || {
459				*cleaned_clone.borrow_mut() = true;
460			});
461		}
462
463		assert!(*cleaned.borrow());
464	}
465
466	#[test]
467	fn test_cleanup_guard_disarm() {
468		use std::cell::RefCell;
469		use std::rc::Rc;
470
471		let cleaned = Rc::new(RefCell::new(false));
472		let cleaned_clone = cleaned.clone();
473
474		{
475			let mut guard = CleanupGuard::new(move || {
476				*cleaned_clone.borrow_mut() = true;
477			});
478			guard.disarm();
479		}
480
481		assert!(!*cleaned.borrow());
482	}
483}