1#![deny(warnings)]
2#![deny(missing_docs)]
3
4use std::{future::Future, pin::Pin};
9
10use sea_orm::{
11 AccessMode, ConnectionTrait, DatabaseTransaction, DbBackend, DbErr, ExecResult, IsolationLevel,
12 QueryResult, Statement, StreamTrait, TransactionError, TransactionTrait, Value, Values,
13};
14
15use tracing::{error, instrument};
16
17pub mod error;
19
20#[derive(Debug)]
22pub struct Lock<C>
23where
24 C: ConnectionTrait + std::fmt::Debug,
25{
26 key: String,
27 conn: Option<C>,
28}
29
30macro_rules! if_let_unreachable {
31 ($val:expr, $bind:pat => $e:expr) => {
32 if let Some($bind) = &$val {
33 $e
34 } else {
35 unreachable!()
36 }
37 };
38}
39
40#[async_trait::async_trait]
41impl<C> ConnectionTrait for Lock<C>
42where
43 C: ConnectionTrait + std::fmt::Debug + Send,
44{
45 fn get_database_backend(&self) -> DbBackend {
46 if_let_unreachable!(self.conn, conn => conn.get_database_backend())
47 }
48
49 async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
50 if_let_unreachable!(self.conn, conn => conn.execute(stmt).await)
51 }
52
53 async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
54 if_let_unreachable!(self.conn, conn => conn.execute_unprepared(sql).await)
55 }
56
57 async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
58 if_let_unreachable!(self.conn, conn => conn.query_one(stmt).await)
59 }
60
61 async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
62 if_let_unreachable!(self.conn, conn => conn.query_all(stmt).await)
63 }
64
65 fn support_returning(&self) -> bool {
66 if_let_unreachable!(self.conn, conn => conn.support_returning())
67 }
68
69 fn is_mock_connection(&self) -> bool {
70 if_let_unreachable!(self.conn, conn => conn.is_mock_connection())
71 }
72}
73
74impl<C> StreamTrait for Lock<C>
75where
76 C: ConnectionTrait + StreamTrait + std::fmt::Debug,
77{
78 type Stream<'a> = C::Stream<'a> where Self: 'a;
79
80 fn stream<'a>(
81 &'a self,
82 stmt: Statement,
83 ) -> Pin<Box<dyn Future<Output = Result<Self::Stream<'a>, DbErr>> + 'a + Send>> {
84 if_let_unreachable!(self.conn, conn => conn.stream(stmt))
85 }
86}
87
88#[async_trait::async_trait]
89impl<C> TransactionTrait for Lock<C>
90where
91 C: ConnectionTrait + TransactionTrait + std::fmt::Debug + Send,
92{
93 async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
94 if_let_unreachable!(self.conn, conn => conn.begin().await)
95 }
96
97 async fn begin_with_config(
98 &self,
99 isolation_level: Option<IsolationLevel>,
100 access_mode: Option<AccessMode>,
101 ) -> Result<DatabaseTransaction, DbErr> {
102 if_let_unreachable!(self.conn, conn => conn.begin_with_config(isolation_level, access_mode).await)
103 }
104
105 async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
106 where
107 F: for<'c> FnOnce(
108 &'c DatabaseTransaction,
109 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
110 + Send,
111 T: Send,
112 E: std::error::Error + Send,
113 {
114 if_let_unreachable!(self.conn, conn => conn.transaction(callback).await)
115 }
116
117 async fn transaction_with_config<F, T, E>(
118 &self,
119 callback: F,
120 isolation_level: Option<IsolationLevel>,
121 access_mode: Option<AccessMode>,
122 ) -> Result<T, TransactionError<E>>
123 where
124 F: for<'c> FnOnce(
125 &'c DatabaseTransaction,
126 ) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>>
127 + Send,
128 T: Send,
129 E: std::error::Error + Send,
130 {
131 if_let_unreachable!(self.conn, conn => conn.transaction_with_config(callback, isolation_level, access_mode).await)
132 }
133}
134
135impl<C> Drop for Lock<C>
136where
137 C: ConnectionTrait + std::fmt::Debug,
138{
139 fn drop(&mut self) {
140 if self.conn.is_some() {
141 error!("Dropping unreleased lock {}", self.key);
143 }
144 }
145}
146
147impl<C> Lock<C>
148where
149 C: ConnectionTrait + std::fmt::Debug,
150{
151 #[instrument(level = "trace")]
155 pub async fn build<S>(key: S, conn: C, timeout: Option<u8>) -> Result<Lock<C>, error::Lock<C>>
156 where
157 S: Into<String> + std::fmt::Debug,
158 {
159 let key = key.into();
160 let mut stmt = Statement::from_string(
161 conn.get_database_backend(),
162 String::from("SELECT GET_LOCK(?, ?) AS res"),
163 );
164 stmt.values = Some(Values(vec![
165 Value::from(key.as_str()),
166 Value::from(timeout.unwrap_or(1)),
167 ]));
168 let res = match conn.query_one(stmt).await {
169 Ok(Some(res)) => res,
170 Ok(None) => return Err(error::Lock::DbErr(key, conn, None)),
171 Err(e) => return Err(error::Lock::DbErr(key, conn, Some(e))),
172 };
173 let lock = match res.try_get::<Option<bool>>("", "res") {
174 Ok(Some(res)) => res,
175 Ok(None) => return Err(error::Lock::DbErr(key, conn, None)),
176 Err(e) => return Err(error::Lock::DbErr(key, conn, Some(e))),
177 };
178
179 if lock {
180 Ok(Lock {
181 key,
182 conn: Some(conn),
183 })
184 } else {
185 Err(error::Lock::Failed(key, conn))
186 }
187 }
188
189 #[must_use]
191 pub fn get_key(&self) -> &str {
192 self.key.as_ref()
193 }
194
195 #[instrument(level = "trace")]
198 pub async fn release(mut self) -> Result<C, error::Unlock<C>> {
199 if_let_unreachable!(self.conn, conn => {
200 let mut stmt =
201 Statement::from_string(conn.get_database_backend(), String::from("SELECT RELEASE_LOCK(?) AS res"));
202 stmt.values = Some(Values(vec![Value::from(self.key.as_str())]));
203 let res = match conn.query_one(stmt).await {
204 Ok(Some(res)) => res,
205 Ok(None) => return Err(error::Unlock::DbErr(self, None)),
206 Err(e) => return Err(error::Unlock::DbErr(self, Some(e))),
207 };
208 let released = match res.try_get::<Option<bool>>("", "res") {
209 Ok(Some(res)) => res,
210 Ok(None) => return Err(error::Unlock::DbErr(self, None)),
211 Err(e) => return Err(error::Unlock::DbErr(self, Some(e))),
212 };
213
214 if released {
215 Ok(self.conn.take().unwrap())
216 }
217 else {
218 Err(error::Unlock::Failed(self))
219 }
220 })
221 }
222
223 #[must_use]
226 pub fn into_inner(mut self) -> C {
227 self.conn.take().unwrap()
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use sea_orm::{
234 ConnectionTrait, Database, DatabaseConnection, DbErr, Statement, StreamTrait,
235 TransactionTrait,
236 };
237
238 use tokio_stream::StreamExt;
239
240 fn metric_mysql(info: &sea_orm::metric::Info<'_>) {
241 tracing::debug!(
242 "mysql query{} took {}s: {}",
243 if info.failed { " failed" } else { "" },
244 info.elapsed.as_secs_f64(),
245 info.statement.sql
246 );
247 }
248
249 async fn get_conn() -> DatabaseConnection {
250 let url = std::env::var("DATABASE_URL");
251 let mut conn = Database::connect(url.as_deref().unwrap_or("mysql://root@127.0.0.1/test"))
252 .await
253 .unwrap();
254 conn.set_metric_callback(metric_mysql);
255 conn
256 }
257
258 async fn generic_method_who_needs_a_connection<C>(conn: &C) -> Result<bool, DbErr>
259 where
260 C: ConnectionTrait + std::fmt::Debug,
261 {
262 let stmt =
263 Statement::from_string(conn.get_database_backend(), String::from("SELECT 1 AS res"));
264 let res = conn
265 .query_one(stmt)
266 .await?
267 .ok_or_else(|| DbErr::RecordNotFound(String::from("1")))?;
268 res.try_get::<Option<bool>>("", "res")?
269 .ok_or_else(|| DbErr::Custom(String::from("Unknown error")))
270 }
271
272 async fn generic_method_who_creates_a_transaction<C>(conn: &C) -> Result<bool, DbErr>
273 where
274 C: ConnectionTrait + TransactionTrait + std::fmt::Debug,
275 {
276 let txn = conn.begin().await?;
277 let lock = super::Lock::build("barfoo", txn, None).await.unwrap();
278 let res = generic_method_who_needs_a_connection(&lock).await;
279 let txn = lock.release().await.unwrap();
280 txn.commit().await?;
281 res
282 }
283
284 async fn generic_method_who_makes_a_stream<C>(conn: &C) -> Result<bool, DbErr>
285 where
286 C: ConnectionTrait + StreamTrait + std::fmt::Debug,
287 {
288 let stmt =
289 Statement::from_string(conn.get_database_backend(), String::from("SELECT 1 AS res"));
290 let res = conn.stream(stmt).await?;
291 let row = Box::pin(res)
292 .next()
293 .await
294 .ok_or_else(|| DbErr::RecordNotFound(String::from("1")))??;
295 row.try_get::<Option<bool>>("", "res")?
296 .ok_or_else(|| DbErr::Custom(String::from("Unknown error")))
297 }
298
299 async fn generic_method_who_makes_a_stream_inside_a_transaction<C>(
300 conn: &C,
301 ) -> Result<bool, DbErr>
302 where
303 C: ConnectionTrait + TransactionTrait + std::fmt::Debug,
304 {
305 let txn = conn.begin().await?;
306 let lock = super::Lock::build("barfoo", txn, None).await.unwrap();
307 let res = generic_method_who_makes_a_stream(&lock).await;
308 let txn = lock.release().await.unwrap();
309 txn.commit().await?;
310 res
311 }
312
313 #[tokio::test]
314 async fn simple() {
315 tracing_subscriber::fmt::try_init().ok();
316
317 let conn = get_conn().await;
318
319 let lock = super::Lock::build("foobar", conn, None).await.unwrap();
320 let res = generic_method_who_needs_a_connection(&lock).await;
321 assert!(lock.release().await.is_ok());
322 res.unwrap();
323 }
324
325 #[tokio::test]
326 async fn transaction() {
327 tracing_subscriber::fmt::try_init().ok();
328
329 let conn = get_conn().await;
330
331 generic_method_who_creates_a_transaction(&conn)
332 .await
333 .unwrap();
334 }
335
336 #[tokio::test]
337 async fn stream() {
338 tracing_subscriber::fmt::try_init().ok();
339
340 let conn = get_conn().await;
341
342 let lock = super::Lock::build("foobar", conn, None).await.unwrap();
343 let res = generic_method_who_makes_a_stream(&lock).await;
344 assert!(lock.release().await.is_ok());
345 res.unwrap();
346 }
347
348 #[tokio::test]
349 async fn transaction_stream() {
350 tracing_subscriber::fmt::try_init().ok();
351
352 let conn = get_conn().await;
353
354 generic_method_who_makes_a_stream_inside_a_transaction(&conn)
355 .await
356 .unwrap();
357 }
358}