use std::sync::Arc;
use tokio::sync::Mutex;
use crate::middleware::{CustomDbRow, ResultSet, RowValues, SqlMiddlewareDbError};
use super::{
config::MssqlClient,
query::{build_result_set, query_map_optional},
};
#[derive(Clone)]
pub struct MssqlNonTxPreparedStatement {
client: Arc<Mutex<MssqlClient>>,
sql: Arc<String>,
}
impl std::fmt::Debug for MssqlNonTxPreparedStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MssqlNonTxPreparedStatement")
.field("client", &"<MssqlClient>")
.field("sql", &self.sql)
.finish()
}
}
impl MssqlNonTxPreparedStatement {
pub fn prepare(client: MssqlClient, sql: &str) -> Self {
Self {
client: Arc::new(Mutex::new(client)),
sql: Arc::new(sql.to_owned()),
}
}
#[must_use]
pub fn select(&self) -> MssqlPreparedSelect<'_, '_> {
MssqlPreparedSelect {
statement: self,
params: &[],
}
}
#[must_use]
pub fn execute(&self) -> MssqlPreparedExecute<'_, '_> {
MssqlPreparedExecute {
statement: self,
params: &[],
}
}
pub(crate) async fn query(
&self,
params: &[RowValues],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let mut client = self.client.lock().await;
build_result_set(&mut client, &self.sql, params).await
}
pub(crate) async fn query_optional(
&self,
params: &[RowValues],
) -> Result<Option<CustomDbRow>, SqlMiddlewareDbError> {
self.query(params).await.map(ResultSet::into_optional)
}
pub(crate) async fn query_one(
&self,
params: &[RowValues],
) -> Result<CustomDbRow, SqlMiddlewareDbError> {
self.query(params).await?.into_one()
}
pub(crate) async fn query_map_one<T, F>(
&self,
params: &[RowValues],
mapper: F,
) -> Result<T, SqlMiddlewareDbError>
where
F: FnOnce(&tiberius::Row) -> Result<T, SqlMiddlewareDbError>,
{
self.query_map_optional(params, mapper)
.await?
.ok_or_else(|| SqlMiddlewareDbError::ExecutionError("query returned no rows".into()))
}
pub(crate) async fn query_map_optional<T, F>(
&self,
params: &[RowValues],
mapper: F,
) -> Result<Option<T>, SqlMiddlewareDbError>
where
F: FnOnce(&tiberius::Row) -> Result<T, SqlMiddlewareDbError>,
{
let mut client = self.client.lock().await;
query_map_optional(&mut client, &self.sql, params, mapper).await
}
pub(crate) async fn execute_values(
&self,
params: &[RowValues],
) -> Result<usize, SqlMiddlewareDbError> {
let mut client = self.client.lock().await;
let query_builder = super::query::bind_query_params(&self.sql, params);
let exec_result = query_builder.execute(&mut *client).await.map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("MSSQL prepared execute error: {e}"))
})?;
let rows_affected: u64 = exec_result.rows_affected().iter().sum();
usize::try_from(rows_affected).map_err(|e| {
SqlMiddlewareDbError::ExecutionError(format!("Invalid rows affected count: {e}"))
})
}
#[must_use]
pub fn sql(&self) -> &str {
self.sql.as_str()
}
}
pub struct MssqlPreparedExecute<'stmt, 'params> {
statement: &'stmt MssqlNonTxPreparedStatement,
params: &'params [RowValues],
}
impl<'stmt, 'params> MssqlPreparedExecute<'stmt, 'params> {
#[must_use]
pub fn params<'next>(self, params: &'next [RowValues]) -> MssqlPreparedExecute<'stmt, 'next> {
MssqlPreparedExecute {
statement: self.statement,
params,
}
}
pub async fn run(self) -> Result<usize, SqlMiddlewareDbError> {
self.statement.execute_values(self.params).await
}
}
pub struct MssqlPreparedSelect<'stmt, 'params> {
statement: &'stmt MssqlNonTxPreparedStatement,
params: &'params [RowValues],
}
impl<'stmt, 'params> MssqlPreparedSelect<'stmt, 'params> {
#[must_use]
pub fn params<'next>(self, params: &'next [RowValues]) -> MssqlPreparedSelect<'stmt, 'next> {
MssqlPreparedSelect {
statement: self.statement,
params,
}
}
pub async fn all(self) -> Result<ResultSet, SqlMiddlewareDbError> {
self.statement.query(self.params).await
}
pub async fn optional(self) -> Result<Option<CustomDbRow>, SqlMiddlewareDbError> {
self.statement.query_optional(self.params).await
}
pub async fn one(self) -> Result<CustomDbRow, SqlMiddlewareDbError> {
self.statement.query_one(self.params).await
}
pub async fn map_one<T, F>(self, mapper: F) -> Result<T, SqlMiddlewareDbError>
where
F: FnOnce(&tiberius::Row) -> Result<T, SqlMiddlewareDbError>,
{
self.statement.query_map_one(self.params, mapper).await
}
pub async fn map_optional<T, F>(self, mapper: F) -> Result<Option<T>, SqlMiddlewareDbError>
where
F: FnOnce(&tiberius::Row) -> Result<T, SqlMiddlewareDbError>,
{
self.statement.query_map_optional(self.params, mapper).await
}
}