mongo_lock/
lib.rs

1#![doc(issue_tracker_base_url = "https://github.com/lazureykis/mongo-lock/issues")]
2
3//! Distributed mutually exclusive locks in MongoDB.
4//!
5//! This crate contains only sync implementation.
6//! If you need a async version, use [`mongo-lock-async`](https://crates.io/crates/mongo-lock-async) crate.
7//!
8//! This implementation relies on system time. Ensure that NTP clients on your servers are configured properly.
9//!
10//! Usage:
11//! ```rust
12//! fn main() {
13//!     let mongo = mongodb::sync::Client::with_uri_str("mongodb://localhost").unwrap();
14//!
15//!     // We need to ensure that mongodb collection has a proper index.
16//!     mongo_lock::prepare_database(&mongo).unwrap();
17//!
18//!     if let Ok(Some(lock)) =
19//!         mongo_lock::Lock::try_acquire(
20//!             &mongo,
21//!             "my-key",
22//!             std::time::Duration::from_secs(30)
23//!         )
24//!     {
25//!         println!("Lock acquired.");
26//!
27//!         // The lock will be released automatically after leaving the scope.
28//!     }
29//! }
30//! ```
31
32mod error;
33mod util;
34
35pub use error::Error;
36
37const COLLECTION_NAME: &str = "locks";
38const DEFAULT_DB_NAME: &str = "mongo-lock";
39
40use mongodb::bson::{doc, Document};
41use mongodb::error::{ErrorKind, WriteError, WriteFailure};
42use mongodb::options::{IndexOptions, UpdateOptions};
43use mongodb::sync::{Client, Collection};
44use mongodb::IndexModel;
45use std::time::Duration;
46
47#[inline]
48fn collection(mongo: &Client) -> Collection<Document> {
49    mongo
50        .default_database()
51        .unwrap_or_else(|| mongo.database(DEFAULT_DB_NAME))
52        .collection(COLLECTION_NAME)
53}
54
55/// Prepares MongoDB collection to store locks.
56///
57/// Creates TTL index to remove old records after they expire.
58///
59/// The [Lock] itself does not relies on this index,
60/// because MongoDB can remove documents with some significant delay.
61pub fn prepare_database(mongo: &Client) -> Result<(), Error> {
62    let options = IndexOptions::builder()
63        .expire_after(Some(Duration::from_secs(0)))
64        .build();
65
66    let model = IndexModel::builder()
67        .keys(doc! {"expiresAt": 1})
68        .options(options)
69        .build();
70
71    collection(mongo).create_index(model, None)?;
72
73    Ok(())
74}
75
76/// Distributed mutex lock.
77pub struct Lock {
78    mongo: Client,
79    id: String,
80    acquired: bool,
81}
82
83impl Lock {
84    /// Tries to acquire the lock with the given key.
85    pub fn try_acquire(mongo: &Client, key: &str, ttl: Duration) -> Result<Option<Lock>, Error> {
86        let (now, expires_at) = util::now_and_expires_at(ttl);
87
88        // Update expired locks if mongodb didn't clean it yet.
89        let query = doc! {
90            "_id": key,
91            "expiresAt": {"$lte": now},
92        };
93
94        let update = doc! {
95            "$set": {
96                "expiresAt": expires_at,
97            },
98            "$setOnInsert": {
99                "_id": key,
100            },
101        };
102
103        let options = UpdateOptions::builder().upsert(true).build();
104
105        match collection(mongo).update_one(query, update, options) {
106            Ok(result) => {
107                if result.upserted_id.is_some() || result.modified_count == 1 {
108                    Ok(Some(Lock {
109                        mongo: mongo.clone(),
110                        id: key.to_string(),
111                        acquired: true,
112                    }))
113                } else {
114                    Ok(None)
115                }
116            }
117            Err(err) => {
118                if let ErrorKind::Write(WriteFailure::WriteError(WriteError {
119                    code: 11000, ..
120                })) = *err.kind
121                {
122                    Ok(None)
123                } else {
124                    Err(err.into())
125                }
126            }
127        }
128    }
129
130    /// Tries to acquire the lock with the given key.
131    /// If the lock is already acquired, waits for it to be released
132    /// up to `lock_wait_timeout` time checking every `lock_poll_interval`.
133    pub fn try_acquire_with_timeout(
134        mongo: &Client,
135        key: &str,
136        key_ttl: Duration,
137        lock_wait_timeout: Duration,
138        lock_poll_interval: Duration,
139    ) -> Result<Option<Lock>, Error> {
140        let start = std::time::Instant::now();
141        loop {
142            match Lock::try_acquire(mongo, key, key_ttl)? {
143                Some(lock) => return Ok(Some(lock)),
144                None => {
145                    if start.elapsed() > lock_wait_timeout {
146                        return Ok(None);
147                    }
148                    std::thread::sleep(lock_poll_interval);
149                }
150            }
151        }
152    }
153
154    fn release(&mut self) -> Result<bool, mongodb::error::Error> {
155        if self.acquired {
156            let result = collection(&self.mongo).delete_one(doc! {"_id": &self.id}, None)?;
157
158            self.acquired = false;
159
160            Ok(result.deleted_count == 1)
161        } else {
162            Ok(false)
163        }
164    }
165}
166
167impl Drop for Lock {
168    fn drop(&mut self) {
169        self.release().ok();
170    }
171}
172
173#[cfg(test)]
174mod tests {
175
176    use super::*;
177
178    fn gen_random_key() -> String {
179        use rand::{distributions::Alphanumeric, thread_rng, Rng};
180
181        thread_rng()
182            .sample_iter(&Alphanumeric)
183            .take(30)
184            .map(char::from)
185            .collect()
186    }
187
188    #[test]
189    fn simple_locks() {
190        let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
191
192        prepare_database(&mongo).unwrap();
193
194        let key1 = gen_random_key();
195        let key2 = gen_random_key();
196
197        let lock1 = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5)).unwrap();
198        assert!(lock1.is_some());
199
200        let lock1_dup = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5)).unwrap();
201        assert!(lock1_dup.is_none());
202
203        let released1 = lock1.unwrap().release().unwrap();
204        assert!(released1);
205
206        let lock1 = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5)).unwrap();
207        assert!(lock1.is_some());
208
209        let lock2 = Lock::try_acquire(&mongo, &key2, Duration::from_secs(5)).unwrap();
210        assert!(lock2.is_some());
211
212        lock1.unwrap().release().unwrap();
213        lock2.unwrap().release().unwrap();
214    }
215
216    #[test]
217    fn with_ttl() {
218        let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
219
220        prepare_database(&mongo).unwrap();
221
222        let key = gen_random_key();
223
224        let lock = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
225        assert!(lock.is_some());
226
227        assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
228            .unwrap()
229            .is_none());
230
231        std::thread::sleep(Duration::from_secs(1));
232
233        assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
234            .unwrap()
235            .is_some());
236    }
237
238    #[test]
239    fn with_ttl_and_retry() {
240        let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
241
242        prepare_database(&mongo).unwrap();
243
244        let key = gen_random_key();
245
246        let lock = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
247        assert!(lock.is_some());
248
249        let time = std::time::Instant::now();
250
251        let lock2 = Lock::try_acquire_with_timeout(
252            &mongo,
253            &key,
254            Duration::from_secs(1),
255            Duration::from_secs(3),
256            Duration::from_millis(100),
257        )
258        .unwrap();
259
260        assert!(lock2.is_some());
261
262        assert!(time.elapsed() > Duration::from_secs(1));
263    }
264
265    #[test]
266    fn dropped_locks() {
267        let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
268
269        prepare_database(&mongo).unwrap();
270
271        let key = gen_random_key();
272
273        {
274            assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
275                .unwrap()
276                .is_some());
277        }
278
279        {
280            assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
281                .unwrap()
282                .is_some());
283        }
284
285        let lock1 = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
286        let lock2 = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
287
288        assert!(lock1.is_some());
289        assert!(lock2.is_none());
290
291        drop(lock1);
292
293        let lock3 = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
294        assert!(lock3.is_some());
295    }
296}