#![deny(warnings)]
#![deny(missing_docs)]
use std::{future::Future, pin::Pin};
use sea_orm::{
AccessMode, ConnectionTrait, DatabaseTransaction, DbBackend, DbErr, ExecResult, IsolationLevel,
QueryResult, Statement, StreamTrait, TransactionError, TransactionTrait, Value, Values,
};
use tracing::{error, instrument};
pub mod error;
#[derive(Debug)]
pub struct Lock<C>
where
C: ConnectionTrait + std::fmt::Debug,
{
key: String,
conn: Option<C>,
}
macro_rules! if_let_unreachable {
($val:expr, $bind:pat => $e:expr) => {
if let Some($bind) = &$val {
$e
} else {
unreachable!()
}
};
}
#[async_trait::async_trait]
impl<C> ConnectionTrait for Lock<C>
where
C: ConnectionTrait + std::fmt::Debug + Send,
{
fn get_database_backend(&self) -> DbBackend {
if_let_unreachable!(self.conn, conn => conn.get_database_backend())
}
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
if_let_unreachable!(self.conn, conn => conn.execute(stmt).await)
}
async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
if_let_unreachable!(self.conn, conn => conn.execute_unprepared(sql).await)
}
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
if_let_unreachable!(self.conn, conn => conn.query_one(stmt).await)
}
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
if_let_unreachable!(self.conn, conn => conn.query_all(stmt).await)
}
fn support_returning(&self) -> bool {
if_let_unreachable!(self.conn, conn => conn.support_returning())
}
fn is_mock_connection(&self) -> bool {
if_let_unreachable!(self.conn, conn => conn.is_mock_connection())
}
}
impl<C> StreamTrait for Lock<C>
where
C: ConnectionTrait + StreamTrait + std::fmt::Debug,
{
type Stream<'a> = C::Stream<'a> where Self: 'a;
fn stream<'a>(
&'a self,
stmt: Statement,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
if_let_unreachable!(self.conn, conn => conn.stream(stmt))
}
}
#[async_trait::async_trait]
impl<C> TransactionTrait for Lock<C>
where
C: ConnectionTrait + TransactionTrait + std::fmt::Debug + Send,
{
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
if_let_unreachable!(self.conn, conn => conn.begin().await)
}
async fn begin_with_config(
&self,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
) -> Result<DatabaseTransaction, DbErr> {
if_let_unreachable!(self.conn, conn => conn.begin_with_config(isolation_level, access_mode).await)
}
async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
T: Send,
E: std::error::Error + Send,
{
if_let_unreachable!(self.conn, conn => conn.transaction(callback).await)
}
async fn transaction_with_config<F, T, E>(
&self,
callback: F,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(
&'c DatabaseTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
+ Send,
T: Send,
E: std::error::Error + Send,
{
if_let_unreachable!(self.conn, conn => conn.transaction_with_config(callback, isolation_level, access_mode).await)
}
}
impl<C> Drop for Lock<C>
where
C: ConnectionTrait + std::fmt::Debug,
{
fn drop(&mut self) {
if self.conn.is_some() {
error!("Dropping unreleased lock {}", self.key);
}
}
}
impl<C> Lock<C>
where
C: ConnectionTrait + std::fmt::Debug,
{
#[instrument(level = "trace")]
pub async fn build<S>(key: S, conn: C, timeout: Option<u8>) -> Result<Lock<C>, error::Lock<C>>
where
S: Into<String> + std::fmt::Debug,
{
let key = key.into();
let mut stmt = Statement::from_string(
conn.get_database_backend(),
String::from("SELECT GET_LOCK(?, ?) AS res"),
);
stmt.values = Some(Values(vec![
Value::from(key.as_str()),
Value::from(timeout.unwrap_or(1)),
]));
let res = match conn.query_one(stmt).await {
Ok(Some(res)) => res,
Ok(None) => return Err(error::Lock::DbErr(key, conn, None)),
Err(e) => return Err(error::Lock::DbErr(key, conn, Some(e))),
};
let lock = match res.try_get::<Option<bool>>("", "res") {
Ok(Some(res)) => res,
Ok(None) => return Err(error::Lock::DbErr(key, conn, None)),
Err(e) => return Err(error::Lock::DbErr(key, conn, Some(e))),
};
if lock {
Ok(Lock {
key,
conn: Some(conn),
})
} else {
Err(error::Lock::Failed(key, conn))
}
}
#[must_use]
pub fn get_key(&self) -> &str {
self.key.as_ref()
}
#[instrument(level = "trace")]
pub async fn release(mut self) -> Result<C, error::Unlock<C>> {
if_let_unreachable!(self.conn, conn => {
let mut stmt =
Statement::from_string(conn.get_database_backend(), String::from("SELECT RELEASE_LOCK(?) AS res"));
stmt.values = Some(Values(vec![Value::from(self.key.as_str())]));
let res = match conn.query_one(stmt).await {
Ok(Some(res)) => res,
Ok(None) => return Err(error::Unlock::DbErr(self, None)),
Err(e) => return Err(error::Unlock::DbErr(self, Some(e))),
};
let released = match res.try_get::<Option<bool>>("", "res") {
Ok(Some(res)) => res,
Ok(None) => return Err(error::Unlock::DbErr(self, None)),
Err(e) => return Err(error::Unlock::DbErr(self, Some(e))),
};
if released {
Ok(self.conn.take().unwrap())
}
else {
Err(error::Unlock::Failed(self))
}
})
}
#[must_use]
pub fn into_inner(mut self) -> C {
self.conn.take().unwrap()
}
}
#[cfg(test)]
mod tests {
use sea_orm::{
ConnectionTrait, Database, DatabaseConnection, DbErr, Statement, StreamTrait,
TransactionTrait,
};
use tokio_stream::StreamExt;
fn metric_mysql(info: &sea_orm::metric::Info<'_>) {
tracing::debug!(
"mysql query{} took {}s: {}",
if info.failed { " failed" } else { "" },
info.elapsed.as_secs_f64(),
info.statement.sql
);
}
async fn get_conn() -> DatabaseConnection {
let url = std::env::var("DATABASE_URL");
let mut conn = Database::connect(url.as_deref().unwrap_or("mysql://root@127.0.0.1/test"))
.await
.unwrap();
conn.set_metric_callback(metric_mysql);
conn
}
async fn generic_method_who_needs_a_connection<C>(conn: &C) -> Result<bool, DbErr>
where
C: ConnectionTrait + std::fmt::Debug,
{
let stmt =
Statement::from_string(conn.get_database_backend(), String::from("SELECT 1 AS res"));
let res = conn
.query_one(stmt)
.await?
.ok_or_else(|| DbErr::RecordNotFound(String::from("1")))?;
res.try_get::<Option<bool>>("", "res")?
.ok_or_else(|| DbErr::Custom(String::from("Unknown error")))
}
async fn generic_method_who_creates_a_transaction<C>(conn: &C) -> Result<bool, DbErr>
where
C: ConnectionTrait + TransactionTrait + std::fmt::Debug,
{
let txn = conn.begin().await?;
let lock = super::Lock::build("barfoo", txn, None).await.unwrap();
let res = generic_method_who_needs_a_connection(&lock).await;
let txn = lock.release().await.unwrap();
txn.commit().await?;
res
}
async fn generic_method_who_makes_a_stream<C>(conn: &C) -> Result<bool, DbErr>
where
C: ConnectionTrait + StreamTrait + std::fmt::Debug,
{
let stmt =
Statement::from_string(conn.get_database_backend(), String::from("SELECT 1 AS res"));
let res = conn.stream(stmt).await?;
let row = Box::pin(res)
.next()
.await
.ok_or_else(|| DbErr::RecordNotFound(String::from("1")))??;
row.try_get::<Option<bool>>("", "res")?
.ok_or_else(|| DbErr::Custom(String::from("Unknown error")))
}
async fn generic_method_who_makes_a_stream_inside_a_transaction<C>(
conn: &C,
) -> Result<bool, DbErr>
where
C: ConnectionTrait + TransactionTrait + std::fmt::Debug,
{
let txn = conn.begin().await?;
let lock = super::Lock::build("barfoo", txn, None).await.unwrap();
let res = generic_method_who_makes_a_stream(&lock).await;
let txn = lock.release().await.unwrap();
txn.commit().await?;
res
}
#[tokio::test]
async fn simple() {
tracing_subscriber::fmt::try_init().ok();
let conn = get_conn().await;
let lock = super::Lock::build("foobar", conn, None).await.unwrap();
let res = generic_method_who_needs_a_connection(&lock).await;
assert!(lock.release().await.is_ok());
res.unwrap();
}
#[tokio::test]
async fn transaction() {
tracing_subscriber::fmt::try_init().ok();
let conn = get_conn().await;
generic_method_who_creates_a_transaction(&conn)
.await
.unwrap();
}
#[tokio::test]
async fn stream() {
tracing_subscriber::fmt::try_init().ok();
let conn = get_conn().await;
let lock = super::Lock::build("foobar", conn, None).await.unwrap();
let res = generic_method_who_makes_a_stream(&lock).await;
assert!(lock.release().await.is_ok());
res.unwrap();
}
#[tokio::test]
async fn transaction_stream() {
tracing_subscriber::fmt::try_init().ok();
let conn = get_conn().await;
generic_method_who_makes_a_stream_inside_a_transaction(&conn)
.await
.unwrap();
}
}