1#![doc(issue_tracker_base_url = "https://github.com/lazureykis/mongo-lock-async/issues")]
2
3mod error;
35mod util;
36
37pub use error::Error;
38use mongodb::bson::{doc, Document};
39use mongodb::error::{ErrorKind, WriteError, WriteFailure};
40use mongodb::options::IndexOptions;
41use mongodb::{Client, Collection, IndexModel};
42use std::time::Duration;
43
44const COLLECTION_NAME: &str = "locks";
45const DEFAULT_DB_NAME: &str = "mongo-lock";
46
47#[inline]
48fn collection(mongo: &Client) -> Collection<Document> {
49 mongo
50 .default_database()
51 .unwrap_or_else(|| mongo.database(DEFAULT_DB_NAME))
52 .collection(COLLECTION_NAME)
53}
54
55pub struct Lock {
57 mongo: Client,
58 id: String,
59}
60
61impl Lock {
62 pub async fn try_acquire(
64 mongo: &Client,
65 key: &str,
66 ttl: Duration,
67 ) -> Result<Option<Lock>, Error> {
68 let (now, expires_at) = util::now_and_expires_at(ttl);
69
70 let query = doc! {
72 "_id": key,
73 "expiresAt": {"$lte": now},
74 };
75
76 let update = doc! {
77 "$set": {
78 "expiresAt": expires_at,
79 },
80 "$setOnInsert": {
81 "_id": key,
82 },
83 };
84
85 match collection(mongo)
86 .update_one(query, update)
87 .upsert(true)
88 .await
89 {
90 Ok(result) => {
91 if result.upserted_id.is_some() || result.modified_count == 1 {
92 Ok(Some(Lock {
93 mongo: mongo.clone(),
94 id: key.to_string(),
95 }))
96 } else {
97 Ok(None)
98 }
99 }
100 Err(err) => {
101 if let ErrorKind::Write(WriteFailure::WriteError(WriteError {
102 code: 11000, ..
103 })) = *err.kind
104 {
105 Ok(None)
106 } else {
107 Err(err.into())
108 }
109 }
110 }
111 }
112
113 pub async fn try_acquire_with_timeout(
117 mongo: &Client,
118 key: &str,
119 key_ttl: Duration,
120 lock_wait_timeout: Duration,
121 lock_poll_interval: Duration,
122 ) -> Result<Option<Lock>, Error> {
123 let start = std::time::Instant::now();
124 loop {
125 match Self::try_acquire(mongo, key, key_ttl).await {
126 Ok(Some(lock)) => return Ok(Some(lock)),
127 Ok(None) => {
128 if start.elapsed() > lock_wait_timeout {
129 return Ok(None);
130 }
131 tokio::time::sleep(lock_poll_interval).await;
132 }
133 Err(err) => return Err(err),
134 }
135
136 if start.elapsed() > lock_wait_timeout {
137 return Err("Cannot acquire lock".into());
138 }
139 }
140 }
141
142 pub async fn release(&self) -> Result<bool, Error> {
144 let result = collection(&self.mongo)
145 .delete_one(doc! {"_id": &self.id})
146 .await?;
147
148 Ok(result.deleted_count == 1)
149 }
150}
151
152pub async fn prepare_database(mongo: &Client) -> Result<(), Error> {
159 let options = IndexOptions::builder()
160 .expire_after(Some(Duration::from_secs(0)))
161 .build();
162
163 let model = IndexModel::builder()
164 .keys(doc! {"expiresAt": 1})
165 .options(options)
166 .build();
167
168 collection(mongo).create_index(model).await?;
169
170 Ok(())
171}
172
173#[cfg(test)]
174mod tests {
175 use tokio::time::Instant;
176
177 use super::*;
178
179 fn gen_random_key() -> String {
180 use rand::{distributions::Alphanumeric, thread_rng, Rng};
181 thread_rng()
182 .sample_iter(&Alphanumeric)
183 .take(30)
184 .map(char::from)
185 .collect()
186 }
187
188 #[tokio::test]
189 async fn simple_locks() {
190 let mongo = mongodb::Client::with_uri_str("mongodb://localhost")
191 .await
192 .unwrap();
193
194 prepare_database(&mongo).await.unwrap();
195
196 let key1 = gen_random_key();
197 let key2 = gen_random_key();
198
199 let lock1 = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5))
200 .await
201 .unwrap();
202 assert!(lock1.is_some());
203
204 let lock1_dup = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5))
205 .await
206 .unwrap();
207 assert!(lock1_dup.is_none());
208
209 let released1 = lock1.unwrap().release().await.unwrap();
210 assert!(released1);
211
212 let lock1 = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5))
213 .await
214 .unwrap();
215 assert!(lock1.is_some());
216
217 let lock2 = Lock::try_acquire(&mongo, &key2, Duration::from_secs(5))
218 .await
219 .unwrap();
220 assert!(lock2.is_some());
221
222 lock1.unwrap().release().await.unwrap();
223 lock2.unwrap().release().await.unwrap();
224 }
225
226 #[tokio::test]
227 async fn with_ttl() {
228 let mongo = Client::with_uri_str("mongodb://localhost").await.unwrap();
229
230 prepare_database(&mongo).await.unwrap();
231
232 let key = gen_random_key();
233
234 assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
235 .await
236 .unwrap()
237 .is_some());
238
239 assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
240 .await
241 .unwrap()
242 .is_none());
243
244 tokio::time::sleep(Duration::from_secs(1)).await;
245
246 assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
247 .await
248 .unwrap()
249 .is_some());
250 }
251
252 #[tokio::test]
253 async fn wait_for_lock() {
254 let mongo = Client::with_uri_str("mongodb://localhost").await.unwrap();
255
256 prepare_database(&mongo).await.unwrap();
257
258 let key = gen_random_key();
259
260 assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(3))
261 .await
262 .unwrap()
263 .is_some());
264
265 let now = Instant::now();
266 assert!(Lock::try_acquire_with_timeout(
267 &mongo,
268 &key,
269 Duration::from_secs(3),
270 Duration::from_secs(5),
271 Duration::from_millis(100)
272 )
273 .await
274 .unwrap()
275 .is_some());
276
277 assert!(now.elapsed() > Duration::from_secs(2));
278
279 assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
280 .await
281 .unwrap()
282 .is_none());
283 }
284}