#![cfg(any(feature = "sqlite", feature = "postgres", feature = "turso"))]
use std::env;
use sql_middleware::middleware::{
ConfigAndPool, MiddlewarePoolConnection, PgConfig, PlaceholderStyle, RowValues,
SqlMiddlewareDbError, TxOutcome, translate_placeholders,
};
use tokio::runtime::Runtime;
#[cfg(feature = "postgres")]
use sql_middleware::postgres::{
PostgresOptions, Prepared as PostgresPrepared, Tx as PostgresTx,
begin_transaction as begin_postgres_tx,
};
#[cfg(feature = "sqlite")]
use sql_middleware::sqlite::{
Prepared as SqlitePrepared, Tx as SqliteTx, begin_transaction as begin_sqlite_tx,
};
#[cfg(feature = "turso")]
use sql_middleware::turso::{
Prepared as TursoPrepared, Tx as TursoTx, begin_transaction as begin_turso_tx,
};
#[cfg(feature = "postgres")]
use sql_middleware::typed_postgres::{Idle as PgIdle, PgConnection, PgManager};
#[cfg(feature = "postgres")]
fn postgres_config() -> PgConfig {
let mut cfg = PgConfig::new();
cfg.dbname = Some("testing".to_string());
cfg.host = Some("10.3.0.201".to_string());
cfg.port = Some(5432);
cfg.user = Some("testuser".to_string());
cfg.password = Some(env::var("TESTING_PG_PASSWORD").unwrap_or_default());
cfg
}
enum BackendTx<'conn> {
#[cfg(feature = "turso")]
Turso(TursoTx<'conn>),
#[cfg(feature = "postgres")]
Postgres(PostgresTx<'conn>),
#[cfg(feature = "sqlite")]
Sqlite(SqliteTx<'conn>),
}
enum PreparedStmt {
#[cfg(feature = "turso")]
Turso(TursoPrepared),
#[cfg(feature = "postgres")]
Postgres(PostgresPrepared),
#[cfg(feature = "sqlite")]
Sqlite(SqlitePrepared),
}
impl BackendTx<'_> {
async fn commit(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
match self {
#[cfg(feature = "turso")]
BackendTx::Turso(tx) => tx.commit().await,
#[cfg(feature = "postgres")]
BackendTx::Postgres(tx) => tx.commit().await,
#[cfg(feature = "sqlite")]
BackendTx::Sqlite(tx) => tx.commit().await,
}
}
async fn rollback(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
match self {
#[cfg(feature = "turso")]
BackendTx::Turso(tx) => tx.rollback().await,
#[cfg(feature = "postgres")]
BackendTx::Postgres(tx) => tx.rollback().await,
#[cfg(feature = "sqlite")]
BackendTx::Sqlite(tx) => tx.rollback().await,
}
}
}
impl PreparedStmt {
async fn execute_prepared(
&mut self,
tx: &mut BackendTx<'_>,
params: &[RowValues],
) -> Result<usize, SqlMiddlewareDbError> {
match (tx, self) {
#[cfg(feature = "turso")]
(BackendTx::Turso(tx), PreparedStmt::Turso(stmt)) => {
tx.execute_prepared(stmt, params).await
}
#[cfg(feature = "postgres")]
(BackendTx::Postgres(tx), PreparedStmt::Postgres(stmt)) => {
tx.execute_prepared(stmt, params).await
}
#[cfg(feature = "sqlite")]
(BackendTx::Sqlite(tx), PreparedStmt::Sqlite(stmt)) => {
tx.execute_prepared(stmt, params).await
}
_ => unreachable!("transaction and prepared variants should align"),
}
}
}
async fn run_execute_with_finalize(
mut tx: BackendTx<'_>,
mut stmt: PreparedStmt,
params: Vec<RowValues>,
) -> Result<usize, SqlMiddlewareDbError> {
let result = stmt.execute_prepared(&mut tx, ¶ms).await;
match result {
Ok(rows) => {
tx.commit().await?;
Ok(rows)
}
Err(e) => {
let _ = tx.rollback().await;
Err(e)
}
}
}
async fn prepare_backend_tx_and_stmt<'conn>(
conn: &'conn mut MiddlewarePoolConnection,
base_query: &str,
) -> Result<(BackendTx<'conn>, PreparedStmt), SqlMiddlewareDbError> {
match conn {
#[cfg(feature = "turso")]
MiddlewarePoolConnection::Turso { conn, .. } => {
let tx = begin_turso_tx(conn).await?;
let q = translate_placeholders(base_query, PlaceholderStyle::Sqlite, true);
let stmt = tx.prepare(q.as_ref()).await?;
Ok((BackendTx::Turso(tx), PreparedStmt::Turso(stmt)))
}
#[cfg(feature = "postgres")]
MiddlewarePoolConnection::Postgres { client, .. } => {
let tx = begin_postgres_tx(client).await?;
let stmt = tx.prepare(base_query).await?;
Ok((BackendTx::Postgres(tx), PreparedStmt::Postgres(stmt)))
}
#[cfg(feature = "sqlite")]
MiddlewarePoolConnection::Sqlite {
translate_placeholders: translate_default,
..
} => {
let translate_default = *translate_default;
let tx = begin_sqlite_tx(conn).await?;
let q = translate_placeholders(base_query, PlaceholderStyle::Sqlite, translate_default);
let stmt = tx.prepare(q.as_ref())?;
Ok((BackendTx::Sqlite(tx), PreparedStmt::Sqlite(stmt)))
}
_ => Err(SqlMiddlewareDbError::Unimplemented(
"expected Turso, Postgres, or SQLite connection".to_string(),
)),
}
}
async fn run_roundtrip(conn: &mut MiddlewarePoolConnection) -> Result<(), SqlMiddlewareDbError> {
let insert_query = "INSERT INTO custom_logic_txn (id, note) VALUES ($1, $2)";
{
let (tx, stmt) = prepare_backend_tx_and_stmt(conn, insert_query).await?;
run_execute_with_finalize(
tx,
stmt,
vec![RowValues::Int(1), RowValues::Text("ok".into())],
)
.await?;
}
{
let (tx, stmt) = prepare_backend_tx_and_stmt(conn, insert_query).await?;
let res = run_execute_with_finalize(
tx,
stmt,
vec![RowValues::Int(1), RowValues::Text("dup".into())],
)
.await;
assert!(res.is_err(), "expected duplicate key to fail");
}
let rs = conn
.query("SELECT COUNT(*) AS cnt FROM custom_logic_txn")
.select()
.await?;
let count = *rs.results[0].get("cnt").unwrap().as_int().unwrap();
assert_eq!(count, 1);
Ok(())
}
#[cfg(all(feature = "postgres", feature = "postgres"))]
async fn run_typed_pg_roundtrip(
mut conn: PgConnection<PgIdle>,
) -> Result<(), SqlMiddlewareDbError> {
conn.execute_batch(
"DROP TABLE IF EXISTS custom_logic_txn;
CREATE TABLE custom_logic_txn (id BIGINT PRIMARY KEY, note TEXT);",
)
.await?;
{
let mut tx = conn.begin().await?;
let rows = tx
.dml(
"INSERT INTO custom_logic_txn (id, note) VALUES ($1, $2)",
&[RowValues::Int(1), RowValues::Text("ok".into())],
)
.await?;
assert_eq!(rows, 1);
conn = tx.commit().await?;
}
{
let mut tx = conn.begin().await?;
let res = tx
.dml(
"INSERT INTO custom_logic_txn (id, note) VALUES ($1, $2)",
&[RowValues::Int(1), RowValues::Text("dup".into())],
)
.await;
assert!(res.is_err(), "expected duplicate key to fail");
conn = tx.rollback().await?;
}
let rs = conn
.select("SELECT COUNT(*) AS cnt FROM custom_logic_txn", &[])
.await?;
let count = *rs.results[0].get("cnt").unwrap().as_int().unwrap();
assert_eq!(count, 1);
conn.execute_batch("DROP TABLE IF EXISTS custom_logic_txn;")
.await?;
Ok(())
}
#[test]
fn custom_logic_between_transactions_across_backends() -> Result<(), Box<dyn std::error::Error>> {
let rt = Runtime::new()?;
rt.block_on(async {
#[cfg(feature = "sqlite")]
{
let cap = ConfigAndPool::sqlite_builder("file::memory:?cache=shared".to_string())
.build()
.await?;
let mut conn = cap.get_connection().await?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS custom_logic_txn (id INTEGER PRIMARY KEY, note TEXT);",
)
.await?;
run_roundtrip(&mut conn).await?;
println!("sqlite backend run successful");
}
#[cfg(feature = "postgres")]
{
let cfg = postgres_config();
let cap = ConfigAndPool::new_postgres(PostgresOptions::new(cfg)).await?;
let mut conn = cap.get_connection().await?;
conn.execute_batch(
"DROP TABLE IF EXISTS custom_logic_txn;
CREATE TABLE custom_logic_txn (id BIGINT PRIMARY KEY, note TEXT);",
)
.await?;
run_roundtrip(&mut conn).await?;
conn.execute_batch("DROP TABLE IF EXISTS custom_logic_txn;")
.await?;
println!("postgres backend run successful");
}
#[cfg(all(feature = "postgres", feature = "postgres"))]
{
let cfg = postgres_config();
let pool = PgManager::new(cfg.to_tokio_config()).build_pool().await?;
let typed_conn: PgConnection<PgIdle> = PgConnection::from_pool(&pool).await?;
run_typed_pg_roundtrip(typed_conn).await?;
println!("typed-postgres backend run successful");
}
#[cfg(feature = "turso")]
{
let cap = ConfigAndPool::turso_builder(":memory:".to_string())
.build()
.await?;
let mut conn = cap.get_connection().await?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS custom_logic_txn (id INTEGER PRIMARY KEY, note TEXT);",
)
.await?;
run_roundtrip(&mut conn).await?;
println!("turso backend run successful");
}
Ok::<(), SqlMiddlewareDbError>(())
})?;
Ok(())
}