reinhardt_testkit/server_fn/
transaction.rs1#![cfg(native)]
22
23use std::ops::Deref;
24
25use async_trait::async_trait;
26
27#[derive(Debug)]
53pub struct TestTransaction<C> {
54 connection: Option<C>,
56 commit_on_drop: bool,
58 completed: bool,
60}
61
62impl<C> TestTransaction<C> {
63 pub fn new(connection: C) -> Self {
65 Self {
66 connection: Some(connection),
67 commit_on_drop: false,
68 completed: false,
69 }
70 }
71
72 pub fn commit_on_drop(mut self) -> Self {
76 self.commit_on_drop = true;
77 self
78 }
79
80 pub fn connection(&self) -> &C {
82 self.connection
83 .as_ref()
84 .expect("connection already consumed")
85 }
86
87 pub fn connection_mut(&mut self) -> &mut C {
89 self.connection
90 .as_mut()
91 .expect("connection already consumed")
92 }
93
94 pub fn into_inner(mut self) -> C {
98 self.completed = true;
99 let connection = self.connection.take().expect("connection already consumed");
100 std::mem::forget(self);
101 connection
102 }
103
104 pub fn mark_completed(&mut self) {
106 self.completed = true;
107 }
108}
109
110impl<C> Deref for TestTransaction<C> {
111 type Target = C;
112
113 fn deref(&self) -> &Self::Target {
114 self.connection
115 .as_ref()
116 .expect("connection already consumed")
117 }
118}
119
120#[async_trait]
122pub trait TestConnectionExt: Sized {
123 type Error;
125
126 async fn begin_test_transaction(self) -> Result<TestTransaction<Self>, Self::Error>;
128
129 async fn commit_transaction(self) -> Result<(), Self::Error>;
131
132 async fn rollback_transaction(self) -> Result<(), Self::Error>;
134}
135
136#[derive(Debug)]
141pub struct TestSavepoint {
142 pub name: String,
144 released: bool,
146}
147
148impl TestSavepoint {
149 pub fn new(name: impl Into<String>) -> Self {
151 Self {
152 name: name.into(),
153 released: false,
154 }
155 }
156
157 pub fn generate() -> Self {
159 Self::new(format!("sp_{}", uuid::Uuid::now_v7().simple()))
160 }
161
162 pub fn mark_released(&mut self) {
164 self.released = true;
165 }
166
167 pub fn is_released(&self) -> bool {
169 self.released
170 }
171}
172
173pub mod utils {
177 use reinhardt_query::{
178 Alias, ColumnRef, Expr, Iden, PostgresQueryBuilder, Query, QueryStatementBuilder,
179 };
180
181 fn quote_ident(name: &str) -> String {
183 let alias = Alias::new(name);
184 let mut buf = String::new();
185 alias.quoted('"', &mut buf);
186 buf
187 }
188
189 pub fn truncate_tables_sql(tables: &[&str]) -> String {
197 if tables.is_empty() {
198 return String::new();
199 }
200
201 let quoted_tables: Vec<String> = tables
202 .iter()
203 .map(|t| {
204 let query = Query::select()
206 .column(ColumnRef::asterisk())
207 .from(Alias::new(*t))
208 .to_string(PostgresQueryBuilder);
209 query
211 .strip_prefix("SELECT * FROM ")
212 .unwrap_or(t)
213 .to_string()
214 })
215 .collect();
216
217 format!(
218 "TRUNCATE TABLE {} RESTART IDENTITY CASCADE",
219 quoted_tables.join(", ")
220 )
221 }
222
223 pub fn delete_from_sql(table: &str, where_clause: Option<&str>) -> String {
225 let mut query = Query::delete();
226 query.from_table(Alias::new(table));
227
228 if let Some(clause) = where_clause {
229 query.and_where(Expr::cust(clause.to_string()));
230 }
231
232 query.to_string(PostgresQueryBuilder)
233 }
234
235 pub fn insert_test_data_sql(table: &str, columns: &[&str], values: &[&str]) -> String {
241 let quoted_table = quote_ident(table);
242 let quoted_cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
243 format!(
244 "INSERT INTO {} ({}) VALUES ({})",
245 quoted_table,
246 quoted_cols.join(", "),
247 values.join(", ")
248 )
249 }
250}
251
252#[derive(Debug, Clone)]
254pub struct TestDatabaseConfig {
255 pub truncate_tables: Vec<String>,
257 pub use_transactions: bool,
259 pub max_connections: u32,
261 pub connection_timeout_secs: u64,
263}
264
265impl Default for TestDatabaseConfig {
266 fn default() -> Self {
267 Self {
268 truncate_tables: Vec::new(),
269 use_transactions: true,
270 max_connections: 5,
271 connection_timeout_secs: 30,
272 }
273 }
274}
275
276impl TestDatabaseConfig {
277 pub fn new() -> Self {
279 Self::default()
280 }
281
282 pub fn truncate(mut self, table: impl Into<String>) -> Self {
284 self.truncate_tables.push(table.into());
285 self
286 }
287
288 pub fn without_transactions(mut self) -> Self {
290 self.use_transactions = false;
291 self
292 }
293
294 pub fn max_connections(mut self, count: u32) -> Self {
296 self.max_connections = count;
297 self
298 }
299
300 pub fn connection_timeout(mut self, secs: u64) -> Self {
302 self.connection_timeout_secs = secs;
303 self
304 }
305}
306
307#[derive(Debug, Default)]
312pub struct TestDataSeeder {
313 statements: Vec<String>,
315}
316
317impl TestDataSeeder {
318 pub fn new() -> Self {
320 Self::default()
321 }
322
323 pub fn sql(mut self, statement: impl Into<String>) -> Self {
325 self.statements.push(statement.into());
326 self
327 }
328
329 pub fn insert(self, table: &str, columns: &[&str], values: &[&str]) -> Self {
331 self.sql(utils::insert_test_data_sql(table, columns, values))
332 }
333
334 pub fn statements(&self) -> &[String] {
336 &self.statements
337 }
338
339 pub fn build(&self) -> String {
341 self.statements.join(";\n")
342 }
343}
344
345pub struct CleanupGuard<F: FnOnce()> {
350 cleanup: Option<F>,
351}
352
353impl<F: FnOnce()> CleanupGuard<F> {
354 pub fn new(cleanup: F) -> Self {
356 Self {
357 cleanup: Some(cleanup),
358 }
359 }
360
361 pub fn disarm(&mut self) {
363 self.cleanup = None;
364 }
365}
366
367impl<F: FnOnce()> Drop for CleanupGuard<F> {
368 fn drop(&mut self) {
369 if let Some(cleanup) = self.cleanup.take() {
370 cleanup();
371 }
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn test_truncate_tables_sql() {
381 let sql = utils::truncate_tables_sql(&["users", "posts"]);
382 assert!(sql.contains("TRUNCATE TABLE"));
383 assert!(sql.contains("\"users\""));
384 assert!(sql.contains("\"posts\""));
385 assert!(sql.contains("CASCADE"));
386 }
387
388 #[test]
389 fn test_truncate_tables_sql_empty() {
390 let sql = utils::truncate_tables_sql(&[]);
391 assert!(sql.is_empty());
392 }
393
394 #[test]
395 fn test_delete_from_sql() {
396 let sql = utils::delete_from_sql("users", None);
397 assert_eq!(sql, "DELETE FROM \"users\"");
398
399 let sql_with_where = utils::delete_from_sql("users", Some("id = 1"));
400 assert_eq!(sql_with_where, "DELETE FROM \"users\" WHERE id = 1");
401 }
402
403 #[test]
404 fn test_insert_test_data_sql() {
405 let sql = utils::insert_test_data_sql(
406 "users",
407 &["name", "email"],
408 &["'Alice'", "'alice@example.com'"],
409 );
410 assert!(sql.contains("INSERT INTO \"users\""));
411 assert!(sql.contains("\"name\""));
412 assert!(sql.contains("\"email\""));
413 assert!(sql.contains("'Alice'"));
414 }
415
416 #[test]
417 fn test_database_config() {
418 let config = TestDatabaseConfig::new()
419 .truncate("users")
420 .truncate("posts")
421 .max_connections(10)
422 .connection_timeout(60);
423
424 assert_eq!(config.truncate_tables.len(), 2);
425 assert_eq!(config.max_connections, 10);
426 assert_eq!(config.connection_timeout_secs, 60);
427 }
428
429 #[test]
430 fn test_data_seeder() {
431 let seeder = TestDataSeeder::new()
432 .insert("users", &["name"], &["'Alice'"])
433 .insert("posts", &["title", "user_id"], &["'Hello'", "1"]);
434
435 assert_eq!(seeder.statements().len(), 2);
436 }
437
438 #[test]
439 fn test_savepoint() {
440 let sp = TestSavepoint::generate();
441 assert!(sp.name.starts_with("sp_"));
442 assert!(!sp.is_released());
443
444 let mut sp2 = TestSavepoint::new("my_savepoint");
445 sp2.mark_released();
446 assert!(sp2.is_released());
447 }
448
449 #[test]
450 fn test_cleanup_guard() {
451 use std::cell::RefCell;
452 use std::rc::Rc;
453
454 let cleaned = Rc::new(RefCell::new(false));
455 let cleaned_clone = cleaned.clone();
456
457 {
458 let _guard = CleanupGuard::new(move || {
459 *cleaned_clone.borrow_mut() = true;
460 });
461 }
462
463 assert!(*cleaned.borrow());
464 }
465
466 #[test]
467 fn test_cleanup_guard_disarm() {
468 use std::cell::RefCell;
469 use std::rc::Rc;
470
471 let cleaned = Rc::new(RefCell::new(false));
472 let cleaned_clone = cleaned.clone();
473
474 {
475 let mut guard = CleanupGuard::new(move || {
476 *cleaned_clone.borrow_mut() = true;
477 });
478 guard.disarm();
479 }
480
481 assert!(!*cleaned.borrow());
482 }
483}