use std::sync::Arc;
use std::time::Duration;
use futures::future::BoxFuture;
use nautilus_core::Value;
use tokio::sync::Mutex;
use nautilus_dialect::Sql;
use crate::error::{ConnectorError as Error, Result};
use crate::row_stream::RowStream;
use crate::single_row::{fetch_single_row, SingleRowExpectation};
use crate::{Executor, Row};
#[derive(Debug, Clone)]
pub struct TransactionOptions {
pub timeout: Duration,
pub isolation_level: Option<IsolationLevel>,
}
impl Default for TransactionOptions {
fn default() -> Self {
Self {
timeout: Duration::from_secs(5),
isolation_level: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IsolationLevel {
ReadUncommitted,
ReadCommitted,
RepeatableRead,
Serializable,
}
impl IsolationLevel {
pub fn as_sql(&self) -> &'static str {
match self {
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
IsolationLevel::ReadCommitted => "READ COMMITTED",
IsolationLevel::RepeatableRead => "REPEATABLE READ",
IsolationLevel::Serializable => "SERIALIZABLE",
}
}
}
enum TransactionInner {
Postgres(Arc<Mutex<Option<sqlx::Transaction<'static, sqlx::Postgres>>>>),
Mysql(Arc<Mutex<Option<sqlx::Transaction<'static, sqlx::MySql>>>>),
Sqlite(Arc<Mutex<Option<sqlx::Transaction<'static, sqlx::Sqlite>>>>),
}
type TxHandle<DB> = Arc<Mutex<Option<sqlx::Transaction<'static, DB>>>>;
pub struct TransactionExecutor {
inner: TransactionInner,
}
impl TransactionExecutor {
pub fn postgres(tx: sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
Self {
inner: TransactionInner::Postgres(Arc::new(Mutex::new(Some(tx)))),
}
}
pub fn mysql(tx: sqlx::Transaction<'static, sqlx::MySql>) -> Self {
Self {
inner: TransactionInner::Mysql(Arc::new(Mutex::new(Some(tx)))),
}
}
pub fn sqlite(tx: sqlx::Transaction<'static, sqlx::Sqlite>) -> Self {
Self {
inner: TransactionInner::Sqlite(Arc::new(Mutex::new(Some(tx)))),
}
}
async fn take_transaction<DB>(tx_arc: &TxHandle<DB>) -> Result<sqlx::Transaction<'static, DB>>
where
DB: sqlx::Database,
{
tx_arc
.lock()
.await
.take()
.ok_or_else(|| Error::database_msg("Transaction already closed"))
}
async fn transaction_is_open<DB>(tx_arc: &TxHandle<DB>) -> bool
where
DB: sqlx::Database,
{
tx_arc.lock().await.is_some()
}
fn bind_query<'q, DB, Bind>(
sql_text: &'q str,
params: &'q [Value],
persistent: bool,
bind: Bind,
) -> Result<sqlx::query::Query<'q, DB, <DB as sqlx::Database>::Arguments<'q>>>
where
DB: sqlx::Database + sqlx::database::HasStatementCache,
for<'q2> <DB as sqlx::Database>::Arguments<'q2>: sqlx::IntoArguments<'q2, DB>,
Bind: Fn(
sqlx::query::Query<'q, DB, <DB as sqlx::Database>::Arguments<'q>>,
&'q Value,
)
-> Result<sqlx::query::Query<'q, DB, <DB as sqlx::Database>::Arguments<'q>>>,
{
let mut query = sqlx::query(sql_text).persistent(persistent);
for param in params {
query = bind(query, param)?;
}
Ok(query)
}
fn execute_affected_on<DB, Bind, RowsAffected>(
tx_arc: TxHandle<DB>,
sql_text: String,
params: Vec<Value>,
persistent: bool,
bind: Bind,
rows_affected: RowsAffected,
) -> BoxFuture<'static, Result<usize>>
where
DB: sqlx::Database + sqlx::database::HasStatementCache + Send + 'static,
for<'c> &'c mut <DB as sqlx::Database>::Connection: sqlx::Executor<'c, Database = DB>,
for<'q> <DB as sqlx::Database>::Arguments<'q>: sqlx::IntoArguments<'q, DB>,
for<'q> Bind: Fn(
sqlx::query::Query<'q, DB, <DB as sqlx::Database>::Arguments<'q>>,
&'q Value,
)
-> Result<sqlx::query::Query<'q, DB, <DB as sqlx::Database>::Arguments<'q>>>
+ Copy
+ Send
+ 'static,
RowsAffected: Fn(<DB as sqlx::Database>::QueryResult) -> u64 + Copy + Send + 'static,
{
Box::pin(async move {
let query = Self::bind_query::<DB, Bind>(&sql_text, ¶ms, persistent, bind)?;
let mut guard = tx_arc.lock().await;
let tx = guard
.as_mut()
.ok_or_else(|| Error::database_msg("Transaction already closed"))?;
use sqlx::Executor as _;
let result = (&mut **tx)
.execute(query)
.await
.map_err(|e| Error::database(e, "Mutation failed"))?;
Ok(rows_affected(result) as usize)
})
}
fn execute_collect_on<DB, Bind, Decode>(
tx_arc: TxHandle<DB>,
sql_text: String,
params: Vec<Value>,
persistent: bool,
bind: Bind,
decode: Decode,
query_context: &'static str,
) -> BoxFuture<'static, Result<Vec<Row>>>
where
DB: sqlx::Database + sqlx::database::HasStatementCache + Send + 'static,
for<'c> &'c mut <DB as sqlx::Database>::Connection: sqlx::Executor<'c, Database = DB>,
for<'q> <DB as sqlx::Database>::Arguments<'q>: sqlx::IntoArguments<'q, DB>,
for<'q> Bind: Fn(
sqlx::query::Query<'q, DB, <DB as sqlx::Database>::Arguments<'q>>,
&'q Value,
)
-> Result<sqlx::query::Query<'q, DB, <DB as sqlx::Database>::Arguments<'q>>>
+ Copy
+ Send
+ 'static,
Decode: Fn(&[<DB as sqlx::Database>::Row]) -> Result<Vec<Row>> + Send + 'static,
{
Box::pin(async move {
let query = Self::bind_query::<DB, Bind>(&sql_text, ¶ms, persistent, bind)?;
let mut guard = tx_arc.lock().await;
let tx = guard
.as_mut()
.ok_or_else(|| Error::database_msg("Transaction already closed"))?;
use sqlx::Executor as _;
let rows = (&mut **tx)
.fetch_all(query)
.await
.map_err(|e| Error::database(e, query_context))?;
drop(guard);
decode(&rows)
})
}
fn execute_and_fetch_collect_on<DB, Bind, Decode>(
tx_arc: TxHandle<DB>,
mutation_text: String,
mutation_params: Vec<Value>,
fetch_text: String,
fetch_params: Vec<Value>,
bind: Bind,
decode: Decode,
) -> BoxFuture<'static, Result<Vec<Row>>>
where
DB: sqlx::Database + sqlx::database::HasStatementCache + Send + 'static,
for<'c> &'c mut <DB as sqlx::Database>::Connection: sqlx::Executor<'c, Database = DB>,
for<'q> <DB as sqlx::Database>::Arguments<'q>: sqlx::IntoArguments<'q, DB>,
for<'q> Bind: Fn(
sqlx::query::Query<'q, DB, <DB as sqlx::Database>::Arguments<'q>>,
&'q Value,
)
-> Result<sqlx::query::Query<'q, DB, <DB as sqlx::Database>::Arguments<'q>>>
+ Copy
+ Send
+ 'static,
Decode: Fn(&[<DB as sqlx::Database>::Row]) -> Result<Vec<Row>> + Send + 'static,
{
Box::pin(async move {
let mutation_query =
Self::bind_query::<DB, Bind>(&mutation_text, &mutation_params, true, bind)?;
let fetch_query = Self::bind_query::<DB, Bind>(&fetch_text, &fetch_params, true, bind)?;
let mut guard = tx_arc.lock().await;
let tx = guard
.as_mut()
.ok_or_else(|| Error::database_msg("Transaction already closed"))?;
use sqlx::Executor as _;
(&mut **tx)
.execute(mutation_query)
.await
.map_err(|e| Error::database(e, "Mutation failed"))?;
let rows = (&mut **tx)
.fetch_all(fetch_query)
.await
.map_err(|e| Error::database(e, "Fetch failed"))?;
drop(guard);
decode(&rows)
})
}
fn execute_single_on<DB, Bind, Decode>(
tx_arc: TxHandle<DB>,
sql_text: String,
params: Vec<Value>,
bind: Bind,
decode: Decode,
query_context: &'static str,
expectation: SingleRowExpectation,
) -> BoxFuture<'static, Result<Option<Row>>>
where
DB: sqlx::Database + Send + 'static,
for<'c> &'c mut <DB as sqlx::Database>::Connection: sqlx::Executor<'c, Database = DB>,
for<'q> <DB as sqlx::Database>::Arguments<'q>: sqlx::IntoArguments<'q, DB>,
for<'q> Bind: Fn(
sqlx::query::Query<'q, DB, <DB as sqlx::Database>::Arguments<'q>>,
&'q Value,
)
-> Result<sqlx::query::Query<'q, DB, <DB as sqlx::Database>::Arguments<'q>>>
+ Copy
+ Send
+ 'static,
Decode: Fn(<DB as sqlx::Database>::Row) -> Result<Row> + Copy + Send + 'static,
{
Box::pin(async move {
let mut guard = tx_arc.lock().await;
let tx = guard
.as_mut()
.ok_or_else(|| Error::database_msg("Transaction already closed"))?;
fetch_single_row::<DB, _, _, _>(
&mut **tx,
&sql_text,
¶ms,
bind,
decode,
query_context,
expectation,
)
.await
})
}
pub async fn commit(&self) -> Result<()> {
match &self.inner {
TransactionInner::Postgres(mx) => {
let tx = Self::take_transaction(mx).await?;
tx.commit()
.await
.map_err(|e| Error::database(e, "Commit failed"))
}
TransactionInner::Mysql(mx) => {
let tx = Self::take_transaction(mx).await?;
tx.commit()
.await
.map_err(|e| Error::database(e, "Commit failed"))
}
TransactionInner::Sqlite(mx) => {
let tx = Self::take_transaction(mx).await?;
tx.commit()
.await
.map_err(|e| Error::database(e, "Commit failed"))
}
}
}
pub async fn rollback(&self) -> Result<()> {
match &self.inner {
TransactionInner::Postgres(mx) => {
let tx = Self::take_transaction(mx).await?;
tx.rollback()
.await
.map_err(|e| Error::database(e, "Rollback failed"))
}
TransactionInner::Mysql(mx) => {
let tx = Self::take_transaction(mx).await?;
tx.rollback()
.await
.map_err(|e| Error::database(e, "Rollback failed"))
}
TransactionInner::Sqlite(mx) => {
let tx = Self::take_transaction(mx).await?;
tx.rollback()
.await
.map_err(|e| Error::database(e, "Rollback failed"))
}
}
}
pub async fn is_open(&self) -> bool {
match &self.inner {
TransactionInner::Postgres(mx) => Self::transaction_is_open(mx).await,
TransactionInner::Mysql(mx) => Self::transaction_is_open(mx).await,
TransactionInner::Sqlite(mx) => Self::transaction_is_open(mx).await,
}
}
pub async fn execute_affected(&self, sql: &Sql) -> Result<usize> {
match &self.inner {
TransactionInner::Postgres(tx_arc) => {
Self::execute_affected_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
true,
crate::postgres::bind_value,
|result: sqlx::postgres::PgQueryResult| result.rows_affected(),
)
.await
}
TransactionInner::Mysql(tx_arc) => {
Self::execute_affected_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
true,
crate::mysql::bind_value,
|result: sqlx::mysql::MySqlQueryResult| result.rows_affected(),
)
.await
}
TransactionInner::Sqlite(tx_arc) => {
Self::execute_affected_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
true,
crate::sqlite::bind_value,
|result: sqlx::sqlite::SqliteQueryResult| result.rows_affected(),
)
.await
}
}
}
pub async fn execute_collect_unprepared(&self, sql: &Sql) -> Result<Vec<Row>> {
match &self.inner {
TransactionInner::Postgres(tx_arc) => {
Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
false,
crate::postgres::bind_value,
crate::postgres_stream::decode_rows,
"Query failed",
)
.await
}
TransactionInner::Mysql(tx_arc) => {
Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
false,
crate::mysql::bind_value,
crate::mysql_stream::decode_rows,
"Query failed",
)
.await
}
TransactionInner::Sqlite(tx_arc) => {
Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
false,
crate::sqlite::bind_value,
crate::sqlite_stream::decode_rows,
"Query failed",
)
.await
}
}
}
}
impl Executor for TransactionExecutor {
type Row<'conn>
= Row
where
Self: 'conn;
type RowStream<'conn>
= RowStream<'conn>
where
Self: 'conn;
fn execute<'conn>(&'conn self, sql: &'conn Sql) -> Self::RowStream<'conn> {
match &self.inner {
TransactionInner::Postgres(tx_arc) => {
RowStream::from_rows_future(Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
true,
crate::postgres::bind_value,
crate::postgres_stream::decode_rows,
"Query failed",
))
}
TransactionInner::Mysql(tx_arc) => {
RowStream::from_rows_future(Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
true,
crate::mysql::bind_value,
crate::mysql_stream::decode_rows,
"Query failed",
))
}
TransactionInner::Sqlite(tx_arc) => {
RowStream::from_rows_future(Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
true,
crate::sqlite::bind_value,
crate::sqlite_stream::decode_rows,
"Query failed",
))
}
}
}
fn execute_owned(&self, sql: Sql) -> RowStream<'static> {
match &self.inner {
TransactionInner::Postgres(tx_arc) => {
RowStream::from_rows_future(Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text,
sql.params,
true,
crate::postgres::bind_value,
crate::postgres_stream::decode_rows,
"Query failed",
))
}
TransactionInner::Mysql(tx_arc) => {
RowStream::from_rows_future(Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text,
sql.params,
true,
crate::mysql::bind_value,
crate::mysql_stream::decode_rows,
"Query failed",
))
}
TransactionInner::Sqlite(tx_arc) => {
RowStream::from_rows_future(Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text,
sql.params,
true,
crate::sqlite::bind_value,
crate::sqlite_stream::decode_rows,
"Query failed",
))
}
}
}
fn execute_and_fetch<'conn>(
&'conn self,
mutation: &'conn Sql,
fetch: &'conn Sql,
) -> Self::RowStream<'conn> {
match &self.inner {
TransactionInner::Postgres(tx_arc) => {
RowStream::from_rows_future(Self::execute_and_fetch_collect_on(
Arc::clone(tx_arc),
mutation.text.clone(),
mutation.params.clone(),
fetch.text.clone(),
fetch.params.clone(),
crate::postgres::bind_value,
crate::postgres_stream::decode_rows,
))
}
TransactionInner::Mysql(tx_arc) => {
RowStream::from_rows_future(Self::execute_and_fetch_collect_on(
Arc::clone(tx_arc),
mutation.text.clone(),
mutation.params.clone(),
fetch.text.clone(),
fetch.params.clone(),
crate::mysql::bind_value,
crate::mysql_stream::decode_rows,
))
}
TransactionInner::Sqlite(tx_arc) => {
RowStream::from_rows_future(Self::execute_and_fetch_collect_on(
Arc::clone(tx_arc),
mutation.text.clone(),
mutation.params.clone(),
fetch.text.clone(),
fetch.params.clone(),
crate::sqlite::bind_value,
crate::sqlite_stream::decode_rows,
))
}
}
}
fn execute_collect<'conn>(
&'conn self,
sql: &'conn Sql,
) -> BoxFuture<'conn, Result<Vec<Self::Row<'conn>>>>
where
Self: 'conn,
{
match &self.inner {
TransactionInner::Postgres(tx_arc) => Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
true,
crate::postgres::bind_value,
crate::postgres_stream::decode_rows,
"Query failed",
),
TransactionInner::Mysql(tx_arc) => Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
true,
crate::mysql::bind_value,
crate::mysql_stream::decode_rows,
"Query failed",
),
TransactionInner::Sqlite(tx_arc) => Self::execute_collect_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
true,
crate::sqlite::bind_value,
crate::sqlite_stream::decode_rows,
"Query failed",
),
}
}
fn execute_one<'conn>(
&'conn self,
sql: &'conn Sql,
) -> BoxFuture<'conn, Result<Self::Row<'conn>>>
where
Self: 'conn,
{
Box::pin(async move {
let row = match &self.inner {
TransactionInner::Postgres(tx_arc) => {
Self::execute_single_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
crate::postgres::bind_value,
crate::postgres_stream::decode_row_internal,
"Query failed",
SingleRowExpectation::ExactlyOne,
)
.await?
}
TransactionInner::Mysql(tx_arc) => {
Self::execute_single_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
crate::mysql::bind_value,
crate::mysql_stream::decode_row_internal,
"Query failed",
SingleRowExpectation::ExactlyOne,
)
.await?
}
TransactionInner::Sqlite(tx_arc) => {
Self::execute_single_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
crate::sqlite::bind_value,
crate::sqlite_stream::decode_row_internal,
"Query failed",
SingleRowExpectation::ExactlyOne,
)
.await?
}
};
row.ok_or_else(|| Error::database_msg("Expected exactly one row, got 0"))
})
}
fn execute_optional<'conn>(
&'conn self,
sql: &'conn Sql,
) -> BoxFuture<'conn, Result<Option<Self::Row<'conn>>>>
where
Self: 'conn,
{
match &self.inner {
TransactionInner::Postgres(tx_arc) => Self::execute_single_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
crate::postgres::bind_value,
crate::postgres_stream::decode_row_internal,
"Query failed",
SingleRowExpectation::ZeroOrOne,
),
TransactionInner::Mysql(tx_arc) => Self::execute_single_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
crate::mysql::bind_value,
crate::mysql_stream::decode_row_internal,
"Query failed",
SingleRowExpectation::ZeroOrOne,
),
TransactionInner::Sqlite(tx_arc) => Self::execute_single_on(
Arc::clone(tx_arc),
sql.text.clone(),
sql.params.clone(),
crate::sqlite::bind_value,
crate::sqlite_stream::decode_row_internal,
"Query failed",
SingleRowExpectation::ZeroOrOne,
),
}
}
}