logic_lock/
lib.rs

1#![deny(warnings)]
2#![deny(missing_docs)]
3
4//! # logic-lock
5//!
6//! MySQL logic locks implemented over sea-orm
7
8use std::{future::Future, pin::Pin};
9
10use sea_orm::{
11    AccessMode, ConnectionTrait, DatabaseTransaction, DbBackend, DbErr, ExecResult, IsolationLevel,
12    QueryResult, Statement, StreamTrait, TransactionError, TransactionTrait, Value, Values,
13};
14
15use tracing::{error, instrument};
16
17/// Lock and Unlock error types
18pub mod error;
19
20/// Lock entity
21#[derive(Debug)]
22pub struct Lock<C>
23where
24    C: ConnectionTrait + std::fmt::Debug,
25{
26    key: String,
27    conn: Option<C>,
28}
29
30macro_rules! if_let_unreachable {
31    ($val:expr, $bind:pat => $e:expr) => {
32        if let Some($bind) = &$val {
33            $e
34        } else {
35            unreachable!()
36        }
37    };
38}
39
40#[async_trait::async_trait]
41impl<C> ConnectionTrait for Lock<C>
42where
43    C: ConnectionTrait + std::fmt::Debug + Send,
44{
45    fn get_database_backend(&self) -> DbBackend {
46        if_let_unreachable!(self.conn, conn => conn.get_database_backend())
47    }
48
49    async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
50        if_let_unreachable!(self.conn, conn => conn.execute(stmt).await)
51    }
52
53    async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
54        if_let_unreachable!(self.conn, conn => conn.execute_unprepared(sql).await)
55    }
56
57    async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
58        if_let_unreachable!(self.conn, conn => conn.query_one(stmt).await)
59    }
60
61    async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
62        if_let_unreachable!(self.conn, conn => conn.query_all(stmt).await)
63    }
64
65    fn support_returning(&self) -> bool {
66        if_let_unreachable!(self.conn, conn => conn.support_returning())
67    }
68
69    fn is_mock_connection(&self) -> bool {
70        if_let_unreachable!(self.conn, conn => conn.is_mock_connection())
71    }
72}
73
74impl<C> StreamTrait for Lock<C>
75where
76    C: ConnectionTrait + StreamTrait + std::fmt::Debug,
77{
78    type Stream<'a> = C::Stream<'a> where Self: 'a;
79
80    fn stream<'a>(
81        &'a self,
82        stmt: Statement,
83    ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
84        if_let_unreachable!(self.conn, conn => conn.stream(stmt))
85    }
86}
87
88#[async_trait::async_trait]
89impl<C> TransactionTrait for Lock<C>
90where
91    C: ConnectionTrait + TransactionTrait + std::fmt::Debug + Send,
92{
93    async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
94        if_let_unreachable!(self.conn, conn => conn.begin().await)
95    }
96
97    async fn begin_with_config(
98        &self,
99        isolation_level: Option<IsolationLevel>,
100        access_mode: Option<AccessMode>,
101    ) -> Result<DatabaseTransaction, DbErr> {
102        if_let_unreachable!(self.conn, conn => conn.begin_with_config(isolation_level, access_mode).await)
103    }
104
105    async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
106    where
107        F: for<'c> FnOnce(
108                &'c DatabaseTransaction,
109            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
110            + Send,
111        T: Send,
112        E: std::error::Error + Send,
113    {
114        if_let_unreachable!(self.conn, conn => conn.transaction(callback).await)
115    }
116
117    async fn transaction_with_config<F, T, E>(
118        &self,
119        callback: F,
120        isolation_level: Option<IsolationLevel>,
121        access_mode: Option<AccessMode>,
122    ) -> Result<T, TransactionError<E>>
123    where
124        F: for<'c> FnOnce(
125                &'c DatabaseTransaction,
126            ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
127            + Send,
128        T: Send,
129        E: std::error::Error + Send,
130    {
131        if_let_unreachable!(self.conn, conn => conn.transaction_with_config(callback, isolation_level, access_mode).await)
132    }
133}
134
135impl<C> Drop for Lock<C>
136where
137    C: ConnectionTrait + std::fmt::Debug,
138{
139    fn drop(&mut self) {
140        if self.conn.is_some() {
141            // panicing here could create a panic-while-panic situatiuon
142            error!("Dropping unreleased lock {}", self.key);
143        }
144    }
145}
146
147impl<C> Lock<C>
148where
149    C: ConnectionTrait + std::fmt::Debug,
150{
151    /// Lock builder
152    /// Takes anything can become a String as key, an owned connection (it can be a `sea_orm::DatabaseConnection`,
153    /// a `sea_orm::DatabaseTransaction or another `Lock` himself), and an optional timeout in seconds, defaulting to 1 second
154    #[instrument(level = "trace")]
155    pub async fn build<S>(key: S, conn: C, timeout: Option<u8>) -> Result<Lock<C>, error::Lock<C>>
156    where
157        S: Into<String> + std::fmt::Debug,
158    {
159        let key = key.into();
160        let mut stmt = Statement::from_string(
161            conn.get_database_backend(),
162            String::from("SELECT GET_LOCK(?, ?) AS res"),
163        );
164        stmt.values = Some(Values(vec![
165            Value::from(key.as_str()),
166            Value::from(timeout.unwrap_or(1)),
167        ]));
168        let res = match conn.query_one(stmt).await {
169            Ok(Some(res)) => res,
170            Ok(None) => return Err(error::Lock::DbErr(key, conn, None)),
171            Err(e) => return Err(error::Lock::DbErr(key, conn, Some(e))),
172        };
173        let lock = match res.try_get::<Option<bool>>("", "res") {
174            Ok(Some(res)) => res,
175            Ok(None) => return Err(error::Lock::DbErr(key, conn, None)),
176            Err(e) => return Err(error::Lock::DbErr(key, conn, Some(e))),
177        };
178
179        if lock {
180            Ok(Lock {
181                key,
182                conn: Some(conn),
183            })
184        } else {
185            Err(error::Lock::Failed(key, conn))
186        }
187    }
188
189    /// returns locked key
190    #[must_use]
191    pub fn get_key(&self) -> &str {
192        self.key.as_ref()
193    }
194
195    /// releases the lock, returning the owned connection on success
196    /// on error it will return the `Lock` himself alongside with the database error, if any
197    #[instrument(level = "trace")]
198    pub async fn release(mut self) -> Result<C, error::Unlock<C>> {
199        if_let_unreachable!(self.conn, conn => {
200            let mut stmt =
201                Statement::from_string(conn.get_database_backend(), String::from("SELECT RELEASE_LOCK(?) AS res"));
202            stmt.values = Some(Values(vec![Value::from(self.key.as_str())]));
203            let res = match conn.query_one(stmt).await {
204                Ok(Some(res)) => res,
205                Ok(None) => return Err(error::Unlock::DbErr(self, None)),
206                Err(e) => return Err(error::Unlock::DbErr(self, Some(e))),
207            };
208            let released = match res.try_get::<Option<bool>>("", "res") {
209                Ok(Some(res)) => res,
210                Ok(None) => return Err(error::Unlock::DbErr(self, None)),
211                Err(e) => return Err(error::Unlock::DbErr(self, Some(e))),
212            };
213
214            if released {
215                Ok(self.conn.take().unwrap())
216            }
217            else {
218                Err(error::Unlock::Failed(self))
219            }
220        })
221    }
222
223    /// forgets the lock and returns inner connection
224    /// WARNING: the lock will continue to live in the database session
225    #[must_use]
226    pub fn into_inner(mut self) -> C {
227        self.conn.take().unwrap()
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use sea_orm::{
234        ConnectionTrait, Database, DatabaseConnection, DbErr, Statement, StreamTrait,
235        TransactionTrait,
236    };
237
238    use tokio_stream::StreamExt;
239
240    fn metric_mysql(info: &sea_orm::metric::Info<'_>) {
241        tracing::debug!(
242            "mysql query{} took {}s: {}",
243            if info.failed { " failed" } else { "" },
244            info.elapsed.as_secs_f64(),
245            info.statement.sql
246        );
247    }
248
249    async fn get_conn() -> DatabaseConnection {
250        let url = std::env::var("DATABASE_URL");
251        let mut conn = Database::connect(url.as_deref().unwrap_or("mysql://root@127.0.0.1/test"))
252            .await
253            .unwrap();
254        conn.set_metric_callback(metric_mysql);
255        conn
256    }
257
258    async fn generic_method_who_needs_a_connection<C>(conn: &C) -> Result<bool, DbErr>
259    where
260        C: ConnectionTrait + std::fmt::Debug,
261    {
262        let stmt =
263            Statement::from_string(conn.get_database_backend(), String::from("SELECT 1 AS res"));
264        let res = conn
265            .query_one(stmt)
266            .await?
267            .ok_or_else(|| DbErr::RecordNotFound(String::from("1")))?;
268        res.try_get::<Option<bool>>("", "res")?
269            .ok_or_else(|| DbErr::Custom(String::from("Unknown error")))
270    }
271
272    async fn generic_method_who_creates_a_transaction<C>(conn: &C) -> Result<bool, DbErr>
273    where
274        C: ConnectionTrait + TransactionTrait + std::fmt::Debug,
275    {
276        let txn = conn.begin().await?;
277        let lock = super::Lock::build("barfoo", txn, None).await.unwrap();
278        let res = generic_method_who_needs_a_connection(&lock).await;
279        let txn = lock.release().await.unwrap();
280        txn.commit().await?;
281        res
282    }
283
284    async fn generic_method_who_makes_a_stream<C>(conn: &C) -> Result<bool, DbErr>
285    where
286        C: ConnectionTrait + StreamTrait + std::fmt::Debug,
287    {
288        let stmt =
289            Statement::from_string(conn.get_database_backend(), String::from("SELECT 1 AS res"));
290        let res = conn.stream(stmt).await?;
291        let row = Box::pin(res)
292            .next()
293            .await
294            .ok_or_else(|| DbErr::RecordNotFound(String::from("1")))??;
295        row.try_get::<Option<bool>>("", "res")?
296            .ok_or_else(|| DbErr::Custom(String::from("Unknown error")))
297    }
298
299    async fn generic_method_who_makes_a_stream_inside_a_transaction<C>(
300        conn: &C,
301    ) -> Result<bool, DbErr>
302    where
303        C: ConnectionTrait + TransactionTrait + std::fmt::Debug,
304    {
305        let txn = conn.begin().await?;
306        let lock = super::Lock::build("barfoo", txn, None).await.unwrap();
307        let res = generic_method_who_makes_a_stream(&lock).await;
308        let txn = lock.release().await.unwrap();
309        txn.commit().await?;
310        res
311    }
312
313    #[tokio::test]
314    async fn simple() {
315        tracing_subscriber::fmt::try_init().ok();
316
317        let conn = get_conn().await;
318
319        let lock = super::Lock::build("foobar", conn, None).await.unwrap();
320        let res = generic_method_who_needs_a_connection(&lock).await;
321        assert!(lock.release().await.is_ok());
322        res.unwrap();
323    }
324
325    #[tokio::test]
326    async fn transaction() {
327        tracing_subscriber::fmt::try_init().ok();
328
329        let conn = get_conn().await;
330
331        generic_method_who_creates_a_transaction(&conn)
332            .await
333            .unwrap();
334    }
335
336    #[tokio::test]
337    async fn stream() {
338        tracing_subscriber::fmt::try_init().ok();
339
340        let conn = get_conn().await;
341
342        let lock = super::Lock::build("foobar", conn, None).await.unwrap();
343        let res = generic_method_who_makes_a_stream(&lock).await;
344        assert!(lock.release().await.is_ok());
345        res.unwrap();
346    }
347
348    #[tokio::test]
349    async fn transaction_stream() {
350        tracing_subscriber::fmt::try_init().ok();
351
352        let conn = get_conn().await;
353
354        generic_method_who_makes_a_stream_inside_a_transaction(&conn)
355            .await
356            .unwrap();
357    }
358}