1#![doc(issue_tracker_base_url = "https://github.com/lazureykis/mongo-lock/issues")]
2
3mod error;
33mod util;
34
35pub use error::Error;
36
37const COLLECTION_NAME: &str = "locks";
38const DEFAULT_DB_NAME: &str = "mongo-lock";
39
40use mongodb::bson::{doc, Document};
41use mongodb::error::{ErrorKind, WriteError, WriteFailure};
42use mongodb::options::{IndexOptions, UpdateOptions};
43use mongodb::sync::{Client, Collection};
44use mongodb::IndexModel;
45use std::time::Duration;
46
47#[inline]
48fn collection(mongo: &Client) -> Collection<Document> {
49 mongo
50 .default_database()
51 .unwrap_or_else(|| mongo.database(DEFAULT_DB_NAME))
52 .collection(COLLECTION_NAME)
53}
54
55pub fn prepare_database(mongo: &Client) -> Result<(), Error> {
62 let options = IndexOptions::builder()
63 .expire_after(Some(Duration::from_secs(0)))
64 .build();
65
66 let model = IndexModel::builder()
67 .keys(doc! {"expiresAt": 1})
68 .options(options)
69 .build();
70
71 collection(mongo).create_index(model, None)?;
72
73 Ok(())
74}
75
76pub struct Lock {
78 mongo: Client,
79 id: String,
80 acquired: bool,
81}
82
83impl Lock {
84 pub fn try_acquire(mongo: &Client, key: &str, ttl: Duration) -> Result<Option<Lock>, Error> {
86 let (now, expires_at) = util::now_and_expires_at(ttl);
87
88 let query = doc! {
90 "_id": key,
91 "expiresAt": {"$lte": now},
92 };
93
94 let update = doc! {
95 "$set": {
96 "expiresAt": expires_at,
97 },
98 "$setOnInsert": {
99 "_id": key,
100 },
101 };
102
103 let options = UpdateOptions::builder().upsert(true).build();
104
105 match collection(mongo).update_one(query, update, options) {
106 Ok(result) => {
107 if result.upserted_id.is_some() || result.modified_count == 1 {
108 Ok(Some(Lock {
109 mongo: mongo.clone(),
110 id: key.to_string(),
111 acquired: true,
112 }))
113 } else {
114 Ok(None)
115 }
116 }
117 Err(err) => {
118 if let ErrorKind::Write(WriteFailure::WriteError(WriteError {
119 code: 11000, ..
120 })) = *err.kind
121 {
122 Ok(None)
123 } else {
124 Err(err.into())
125 }
126 }
127 }
128 }
129
130 pub fn try_acquire_with_timeout(
134 mongo: &Client,
135 key: &str,
136 key_ttl: Duration,
137 lock_wait_timeout: Duration,
138 lock_poll_interval: Duration,
139 ) -> Result<Option<Lock>, Error> {
140 let start = std::time::Instant::now();
141 loop {
142 match Lock::try_acquire(mongo, key, key_ttl)? {
143 Some(lock) => return Ok(Some(lock)),
144 None => {
145 if start.elapsed() > lock_wait_timeout {
146 return Ok(None);
147 }
148 std::thread::sleep(lock_poll_interval);
149 }
150 }
151 }
152 }
153
154 fn release(&mut self) -> Result<bool, mongodb::error::Error> {
155 if self.acquired {
156 let result = collection(&self.mongo).delete_one(doc! {"_id": &self.id}, None)?;
157
158 self.acquired = false;
159
160 Ok(result.deleted_count == 1)
161 } else {
162 Ok(false)
163 }
164 }
165}
166
167impl Drop for Lock {
168 fn drop(&mut self) {
169 self.release().ok();
170 }
171}
172
173#[cfg(test)]
174mod tests {
175
176 use super::*;
177
178 fn gen_random_key() -> String {
179 use rand::{distributions::Alphanumeric, thread_rng, Rng};
180
181 thread_rng()
182 .sample_iter(&Alphanumeric)
183 .take(30)
184 .map(char::from)
185 .collect()
186 }
187
188 #[test]
189 fn simple_locks() {
190 let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
191
192 prepare_database(&mongo).unwrap();
193
194 let key1 = gen_random_key();
195 let key2 = gen_random_key();
196
197 let lock1 = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5)).unwrap();
198 assert!(lock1.is_some());
199
200 let lock1_dup = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5)).unwrap();
201 assert!(lock1_dup.is_none());
202
203 let released1 = lock1.unwrap().release().unwrap();
204 assert!(released1);
205
206 let lock1 = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5)).unwrap();
207 assert!(lock1.is_some());
208
209 let lock2 = Lock::try_acquire(&mongo, &key2, Duration::from_secs(5)).unwrap();
210 assert!(lock2.is_some());
211
212 lock1.unwrap().release().unwrap();
213 lock2.unwrap().release().unwrap();
214 }
215
216 #[test]
217 fn with_ttl() {
218 let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
219
220 prepare_database(&mongo).unwrap();
221
222 let key = gen_random_key();
223
224 let lock = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
225 assert!(lock.is_some());
226
227 assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
228 .unwrap()
229 .is_none());
230
231 std::thread::sleep(Duration::from_secs(1));
232
233 assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
234 .unwrap()
235 .is_some());
236 }
237
238 #[test]
239 fn with_ttl_and_retry() {
240 let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
241
242 prepare_database(&mongo).unwrap();
243
244 let key = gen_random_key();
245
246 let lock = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
247 assert!(lock.is_some());
248
249 let time = std::time::Instant::now();
250
251 let lock2 = Lock::try_acquire_with_timeout(
252 &mongo,
253 &key,
254 Duration::from_secs(1),
255 Duration::from_secs(3),
256 Duration::from_millis(100),
257 )
258 .unwrap();
259
260 assert!(lock2.is_some());
261
262 assert!(time.elapsed() > Duration::from_secs(1));
263 }
264
265 #[test]
266 fn dropped_locks() {
267 let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
268
269 prepare_database(&mongo).unwrap();
270
271 let key = gen_random_key();
272
273 {
274 assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
275 .unwrap()
276 .is_some());
277 }
278
279 {
280 assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
281 .unwrap()
282 .is_some());
283 }
284
285 let lock1 = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
286 let lock2 = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
287
288 assert!(lock1.is_some());
289 assert!(lock2.is_none());
290
291 drop(lock1);
292
293 let lock3 = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
294 assert!(lock3.is_some());
295 }
296}