use std::sync::Arc;
use std::time::Duration;
use crossbeam_channel::Sender;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use crate::entity::SqlEntity;
use crate::error::SqlError;
use crate::params::{OdbcParam, PkValue};
use crate::pool::connection::OdbcConn;
use crate::pool::PooledConn;
pub struct Transaction<'pool> {
conn: Option<OdbcConn>,
return_tx: Sender<OdbcConn>,
notify: Arc<Notify>,
committed: bool,
_phantom: std::marker::PhantomData<&'pool ()>,
}
impl<'pool> Transaction<'pool> {
pub(crate) fn begin(guard: PooledConn) -> Result<Self, SqlError> {
let (mut conn, return_tx, notify) = guard.take();
conn.execute_non_query_sync("SET IMPLICIT_TRANSACTIONS OFF; BEGIN TRANSACTION", &[])
.map_err(|e| SqlError::odbc(0, format!("BEGIN TRANSACTION failed: {e}")))?;
debug!("Transaction started");
Ok(Self {
conn: Some(conn),
return_tx,
notify,
committed: false,
_phantom: std::marker::PhantomData,
})
}
fn conn_mut(&mut self) -> Result<&mut OdbcConn, SqlError> {
self.conn.as_mut().ok_or(SqlError::InvalidTransactionState)
}
pub async fn insert<T: SqlEntity>(
&mut self,
entity: &T,
_token: &CancellationToken,
) -> Result<i64, SqlError> {
let params = entity.to_params();
let conn = self.conn_mut()?;
if T::PK_IS_IDENTITY {
conn.execute_insert_sync(T::INSERT_SQL, ¶ms)
} else {
conn.execute_non_query_sync(T::INSERT_SQL, ¶ms)?;
Ok(match entity.pk_value() {
PkValue::I32(v) => v as i64,
PkValue::I64(v) => v,
PkValue::Str(_) => 0,
PkValue::Guid(_) => 0,
})
}
}
pub async fn update<T: SqlEntity>(
&mut self,
entity: &T,
_token: &CancellationToken,
) -> Result<(), SqlError> {
let params = entity.to_params();
let conn = self.conn_mut()?;
conn.execute_non_query_sync(T::UPDATE_SQL, ¶ms)?;
Ok(())
}
pub async fn delete<T: SqlEntity>(
&mut self,
id: impl Into<PkValue>,
_token: &CancellationToken,
) -> Result<(), SqlError> {
let pk = id.into();
let params = [OdbcParam::new(T::PK_COLUMN, pk.as_param())];
let conn = self.conn_mut()?;
conn.execute_non_query_sync(T::DELETE_SQL, ¶ms)?;
Ok(())
}
pub async fn execute_raw(
&mut self,
sql: &str,
params: &[OdbcParam],
_token: &CancellationToken,
) -> Result<usize, SqlError> {
let conn = self.conn_mut()?;
conn.execute_non_query_sync(sql, params)
}
pub async fn commit(mut self) -> Result<(), SqlError> {
let conn = self.conn.as_mut().ok_or(SqlError::InvalidTransactionState)?;
conn.commit_sync()?;
self.committed = true;
debug!("Transaction committed");
Ok(())
}
pub fn rollback(&mut self) {
if let Some(conn) = self.conn.as_mut() {
conn.rollback_sync();
warn!("Transaction rolled back");
}
}
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if !self.committed {
if let Some(conn) = self.conn.as_mut() {
conn.rollback_sync();
debug!("Transaction auto-rolled back on drop");
}
}
if let Some(mut conn) = self.conn.take() {
conn.needs_reset = true;
let _ = self.return_tx.send(conn);
self.notify.notify_one();
}
}
}
pub async fn with_retry<F, Fut, T>(
pool: &crate::pool::Pool,
token: &CancellationToken,
max_retries: u8,
mut f: F,
) -> Result<T, SqlError>
where
F: FnMut(&mut Transaction<'_>) -> Fut,
Fut: std::future::Future<Output = Result<T, SqlError>>,
{
let mut attempts = 0u8;
loop {
let conn = pool.checkout(token).await?;
let mut tx = Transaction::begin(conn)?;
match f(&mut tx).await {
Ok(value) => {
tx.commit().await?;
return Ok(value);
}
Err(e) if e.is_deadlock() && attempts < max_retries => {
attempts += 1;
pool.record_deadlock();
warn!(attempt = attempts, "Deadlock detected, retrying");
let backoff = Duration::from_millis(50 * (1u64 << attempts));
tokio::time::sleep(backoff).await;
}
Err(e) if e.is_deadlock() => {
pool.record_deadlock();
return Err(SqlError::DeadlockRetryExhausted { attempts });
}
Err(e) => return Err(e),
}
}
}