use tiberius::Query;
use crate::middleware::{CustomDbRow, ResultSet, RowValues, SqlMiddlewareDbError};
use crate::tx_outcome::TxOutcome;
use super::config::MssqlClient;
use super::query::{build_result_set, convert_affected_rows, query_map_optional};
pub struct Tx<'a> {
client: &'a mut MssqlClient,
open: bool,
}
pub struct Prepared {
sql: String,
}
pub async fn begin_transaction(client: &mut MssqlClient) -> Result<Tx<'_>, SqlMiddlewareDbError> {
client
.simple_query("BEGIN TRANSACTION")
.await
.map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("MSSQL begin transaction error: {e}"))
})?;
Ok(Tx { client, open: true })
}
impl<'conn> Tx<'conn> {
pub fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
Ok(Prepared {
sql: sql.to_string(),
})
}
#[must_use]
pub fn select<'tx, 'prepared>(
&'tx mut self,
prepared: &'prepared Prepared,
) -> PreparedSelect<'tx, 'prepared, 'static, 'conn> {
PreparedSelect {
tx: self,
prepared,
params: &[],
}
}
#[must_use]
pub fn execute<'tx, 'prepared>(
&'tx mut self,
prepared: &'prepared Prepared,
) -> PreparedExecute<'tx, 'prepared, 'static, 'conn> {
PreparedExecute {
tx: self,
prepared,
params: &[],
}
}
pub async fn execute_batch(&mut self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
Query::new(sql).execute(self.client).await.map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("MSSQL tx execute_batch error: {e}"))
})?;
Ok(())
}
pub async fn execute_dml(
&mut self,
query: &str,
params: &[RowValues],
) -> Result<usize, SqlMiddlewareDbError> {
let query_builder = super::query::bind_query_params(query, params);
let exec_result = query_builder.execute(self.client).await.map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("MSSQL tx execute error: {e}"))
})?;
let rows_affected: u64 = exec_result.rows_affected().iter().sum();
convert_affected_rows(rows_affected)
}
pub(crate) async fn execute_prepared(
&mut self,
prepared: &Prepared,
params: &[RowValues],
) -> Result<usize, SqlMiddlewareDbError> {
let query_builder = super::query::bind_query_params(&prepared.sql, params);
let exec_result = query_builder.execute(self.client).await.map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("MSSQL tx execute error: {e}"))
})?;
let rows_affected: u64 = exec_result.rows_affected().iter().sum();
convert_affected_rows(rows_affected)
}
pub(crate) async fn query_prepared(
&mut self,
prepared: &Prepared,
params: &[RowValues],
) -> Result<ResultSet, SqlMiddlewareDbError> {
build_result_set(self.client, &prepared.sql, params).await
}
pub(crate) async fn query_prepared_optional(
&mut self,
prepared: &Prepared,
params: &[RowValues],
) -> Result<Option<CustomDbRow>, SqlMiddlewareDbError> {
self.query_prepared(prepared, params)
.await
.map(ResultSet::into_optional)
}
pub(crate) async fn query_prepared_one(
&mut self,
prepared: &Prepared,
params: &[RowValues],
) -> Result<CustomDbRow, SqlMiddlewareDbError> {
self.query_prepared(prepared, params).await?.into_one()
}
pub(crate) async fn query_prepared_map_one<T, F>(
&mut self,
prepared: &Prepared,
params: &[RowValues],
mapper: F,
) -> Result<T, SqlMiddlewareDbError>
where
F: FnOnce(&tiberius::Row) -> Result<T, SqlMiddlewareDbError>,
{
self.query_prepared_map_optional(prepared, params, mapper)
.await?
.ok_or_else(|| SqlMiddlewareDbError::ExecutionError("query returned no rows".into()))
}
pub(crate) async fn query_prepared_map_optional<T, F>(
&mut self,
prepared: &Prepared,
params: &[RowValues],
mapper: F,
) -> Result<Option<T>, SqlMiddlewareDbError>
where
F: FnOnce(&tiberius::Row) -> Result<T, SqlMiddlewareDbError>,
{
query_map_optional(self.client, &prepared.sql, params, mapper).await
}
pub async fn query(
&mut self,
query: &str,
params: &[RowValues],
) -> Result<ResultSet, SqlMiddlewareDbError> {
build_result_set(self.client, query, params).await
}
pub async fn commit(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
if self.open {
self.client
.simple_query("COMMIT TRANSACTION")
.await
.map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("MSSQL commit error: {e}"))
})?;
self.open = false;
}
Ok(TxOutcome::without_restored_connection())
}
pub async fn rollback(mut self) -> Result<TxOutcome, SqlMiddlewareDbError> {
if self.open {
self.client
.simple_query("ROLLBACK TRANSACTION")
.await
.map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("MSSQL rollback error: {e}"))
})?;
self.open = false;
}
Ok(TxOutcome::without_restored_connection())
}
}
pub struct PreparedExecute<'tx, 'prepared, 'params, 'conn> {
tx: &'tx mut Tx<'conn>,
prepared: &'prepared Prepared,
params: &'params [RowValues],
}
impl<'tx, 'prepared, 'params, 'conn> PreparedExecute<'tx, 'prepared, 'params, 'conn> {
#[must_use]
pub fn params<'next>(
self,
params: &'next [RowValues],
) -> PreparedExecute<'tx, 'prepared, 'next, 'conn> {
PreparedExecute {
tx: self.tx,
prepared: self.prepared,
params,
}
}
pub async fn run(self) -> Result<usize, SqlMiddlewareDbError> {
self.tx.execute_prepared(self.prepared, self.params).await
}
}
pub struct PreparedSelect<'tx, 'prepared, 'params, 'conn> {
tx: &'tx mut Tx<'conn>,
prepared: &'prepared Prepared,
params: &'params [RowValues],
}
impl<'tx, 'prepared, 'params, 'conn> PreparedSelect<'tx, 'prepared, 'params, 'conn> {
#[must_use]
pub fn params<'next>(
self,
params: &'next [RowValues],
) -> PreparedSelect<'tx, 'prepared, 'next, 'conn> {
PreparedSelect {
tx: self.tx,
prepared: self.prepared,
params,
}
}
pub async fn all(self) -> Result<ResultSet, SqlMiddlewareDbError> {
self.tx.query_prepared(self.prepared, self.params).await
}
pub async fn optional(self) -> Result<Option<CustomDbRow>, SqlMiddlewareDbError> {
self.tx
.query_prepared_optional(self.prepared, self.params)
.await
}
pub async fn one(self) -> Result<CustomDbRow, SqlMiddlewareDbError> {
self.tx.query_prepared_one(self.prepared, self.params).await
}
pub async fn map_one<T, F>(self, mapper: F) -> Result<T, SqlMiddlewareDbError>
where
F: FnOnce(&tiberius::Row) -> Result<T, SqlMiddlewareDbError>,
{
self.tx
.query_prepared_map_one(self.prepared, self.params, mapper)
.await
}
pub async fn map_optional<T, F>(self, mapper: F) -> Result<Option<T>, SqlMiddlewareDbError>
where
F: FnOnce(&tiberius::Row) -> Result<T, SqlMiddlewareDbError>,
{
self.tx
.query_prepared_map_optional(self.prepared, self.params, mapper)
.await
}
}