#![cfg(native)]
use std::ops::Deref;
use async_trait::async_trait;
#[derive(Debug)]
pub struct TestTransaction<C> {
connection: Option<C>,
commit_on_drop: bool,
completed: bool,
}
impl<C> TestTransaction<C> {
pub fn new(connection: C) -> Self {
Self {
connection: Some(connection),
commit_on_drop: false,
completed: false,
}
}
pub fn commit_on_drop(mut self) -> Self {
self.commit_on_drop = true;
self
}
pub fn connection(&self) -> &C {
self.connection
.as_ref()
.expect("connection already consumed")
}
pub fn connection_mut(&mut self) -> &mut C {
self.connection
.as_mut()
.expect("connection already consumed")
}
pub fn into_inner(mut self) -> C {
self.completed = true;
let connection = self.connection.take().expect("connection already consumed");
std::mem::forget(self);
connection
}
pub fn mark_completed(&mut self) {
self.completed = true;
}
}
impl<C> Deref for TestTransaction<C> {
type Target = C;
fn deref(&self) -> &Self::Target {
self.connection
.as_ref()
.expect("connection already consumed")
}
}
#[async_trait]
pub trait TestConnectionExt: Sized {
type Error;
async fn begin_test_transaction(self) -> Result<TestTransaction<Self>, Self::Error>;
async fn commit_transaction(self) -> Result<(), Self::Error>;
async fn rollback_transaction(self) -> Result<(), Self::Error>;
}
#[derive(Debug)]
pub struct TestSavepoint {
pub name: String,
released: bool,
}
impl TestSavepoint {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
released: false,
}
}
pub fn generate() -> Self {
Self::new(format!("sp_{}", uuid::Uuid::now_v7().simple()))
}
pub fn mark_released(&mut self) {
self.released = true;
}
pub fn is_released(&self) -> bool {
self.released
}
}
pub mod utils {
use reinhardt_query::{
Alias, ColumnRef, Expr, Iden, PostgresQueryBuilder, Query, QueryStatementBuilder,
};
fn quote_ident(name: &str) -> String {
let alias = Alias::new(name);
let mut buf = String::new();
alias.quoted('"', &mut buf);
buf
}
pub fn truncate_tables_sql(tables: &[&str]) -> String {
if tables.is_empty() {
return String::new();
}
let quoted_tables: Vec<String> = tables
.iter()
.map(|t| {
let query = Query::select()
.column(ColumnRef::asterisk())
.from(Alias::new(*t))
.to_string(PostgresQueryBuilder);
query
.strip_prefix("SELECT * FROM ")
.unwrap_or(t)
.to_string()
})
.collect();
format!(
"TRUNCATE TABLE {} RESTART IDENTITY CASCADE",
quoted_tables.join(", ")
)
}
pub fn delete_from_sql(table: &str, where_clause: Option<&str>) -> String {
let mut query = Query::delete();
query.from_table(Alias::new(table));
if let Some(clause) = where_clause {
query.and_where(Expr::cust(clause.to_string()));
}
query.to_string(PostgresQueryBuilder)
}
pub fn insert_test_data_sql(table: &str, columns: &[&str], values: &[&str]) -> String {
let quoted_table = quote_ident(table);
let quoted_cols: Vec<String> = columns.iter().map(|c| quote_ident(c)).collect();
format!(
"INSERT INTO {} ({}) VALUES ({})",
quoted_table,
quoted_cols.join(", "),
values.join(", ")
)
}
}
#[derive(Debug, Clone)]
pub struct TestDatabaseConfig {
pub truncate_tables: Vec<String>,
pub use_transactions: bool,
pub max_connections: u32,
pub connection_timeout_secs: u64,
}
impl Default for TestDatabaseConfig {
fn default() -> Self {
Self {
truncate_tables: Vec::new(),
use_transactions: true,
max_connections: 5,
connection_timeout_secs: 30,
}
}
}
impl TestDatabaseConfig {
pub fn new() -> Self {
Self::default()
}
pub fn truncate(mut self, table: impl Into<String>) -> Self {
self.truncate_tables.push(table.into());
self
}
pub fn without_transactions(mut self) -> Self {
self.use_transactions = false;
self
}
pub fn max_connections(mut self, count: u32) -> Self {
self.max_connections = count;
self
}
pub fn connection_timeout(mut self, secs: u64) -> Self {
self.connection_timeout_secs = secs;
self
}
}
#[derive(Debug, Default)]
pub struct TestDataSeeder {
statements: Vec<String>,
}
impl TestDataSeeder {
pub fn new() -> Self {
Self::default()
}
pub fn sql(mut self, statement: impl Into<String>) -> Self {
self.statements.push(statement.into());
self
}
pub fn insert(self, table: &str, columns: &[&str], values: &[&str]) -> Self {
self.sql(utils::insert_test_data_sql(table, columns, values))
}
pub fn statements(&self) -> &[String] {
&self.statements
}
pub fn build(&self) -> String {
self.statements.join(";\n")
}
}
pub struct CleanupGuard<F: FnOnce()> {
cleanup: Option<F>,
}
impl<F: FnOnce()> CleanupGuard<F> {
pub fn new(cleanup: F) -> Self {
Self {
cleanup: Some(cleanup),
}
}
pub fn disarm(&mut self) {
self.cleanup = None;
}
}
impl<F: FnOnce()> Drop for CleanupGuard<F> {
fn drop(&mut self) {
if let Some(cleanup) = self.cleanup.take() {
cleanup();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_truncate_tables_sql() {
let sql = utils::truncate_tables_sql(&["users", "posts"]);
assert!(sql.contains("TRUNCATE TABLE"));
assert!(sql.contains("\"users\""));
assert!(sql.contains("\"posts\""));
assert!(sql.contains("CASCADE"));
}
#[test]
fn test_truncate_tables_sql_empty() {
let sql = utils::truncate_tables_sql(&[]);
assert!(sql.is_empty());
}
#[test]
fn test_delete_from_sql() {
let sql = utils::delete_from_sql("users", None);
assert_eq!(sql, "DELETE FROM \"users\"");
let sql_with_where = utils::delete_from_sql("users", Some("id = 1"));
assert_eq!(sql_with_where, "DELETE FROM \"users\" WHERE id = 1");
}
#[test]
fn test_insert_test_data_sql() {
let sql = utils::insert_test_data_sql(
"users",
&["name", "email"],
&["'Alice'", "'alice@example.com'"],
);
assert!(sql.contains("INSERT INTO \"users\""));
assert!(sql.contains("\"name\""));
assert!(sql.contains("\"email\""));
assert!(sql.contains("'Alice'"));
}
#[test]
fn test_database_config() {
let config = TestDatabaseConfig::new()
.truncate("users")
.truncate("posts")
.max_connections(10)
.connection_timeout(60);
assert_eq!(config.truncate_tables.len(), 2);
assert_eq!(config.max_connections, 10);
assert_eq!(config.connection_timeout_secs, 60);
}
#[test]
fn test_data_seeder() {
let seeder = TestDataSeeder::new()
.insert("users", &["name"], &["'Alice'"])
.insert("posts", &["title", "user_id"], &["'Hello'", "1"]);
assert_eq!(seeder.statements().len(), 2);
}
#[test]
fn test_savepoint() {
let sp = TestSavepoint::generate();
assert!(sp.name.starts_with("sp_"));
assert!(!sp.is_released());
let mut sp2 = TestSavepoint::new("my_savepoint");
sp2.mark_released();
assert!(sp2.is_released());
}
#[test]
fn test_cleanup_guard() {
use std::cell::RefCell;
use std::rc::Rc;
let cleaned = Rc::new(RefCell::new(false));
let cleaned_clone = cleaned.clone();
{
let _guard = CleanupGuard::new(move || {
*cleaned_clone.borrow_mut() = true;
});
}
assert!(*cleaned.borrow());
}
#[test]
fn test_cleanup_guard_disarm() {
use std::cell::RefCell;
use std::rc::Rc;
let cleaned = Rc::new(RefCell::new(false));
let cleaned_clone = cleaned.clone();
{
let mut guard = CleanupGuard::new(move || {
*cleaned_clone.borrow_mut() = true;
});
guard.disarm();
}
assert!(!*cleaned.borrow());
}
}