mongodb_lock/
lib.rs

1//! Rusty distributed locking backed by Mongodb.
2//!
3//! All [`Mutex`]s can share the same collection (even with different `Key`s) so long as all the
4//! `Key`s in the collection are unique. I would recommend using different collections for different
5//! `Key`s and different collections for each type of operation.
6//!
7//! All [`RwLock`]s can share the same collection. I would recommend using the same collection.
8//!     
9//! ## Similar works
10//!
11//! - <https://github.com/square/mongo-lock>
12//!
13//! ## Example
14//!
15//! ```ignore
16//! #[derive(Clone, Serialize, Deserialize)]
17//! struct MyDocument {
18//!     _id: ObjectId,
19//!     x: i32,
20//! }
21//! let db = client.database("basic");
22//! let docs = db.collection::<MyDocument>("docs");
23//! let lock = Arc::new(mongodb_lock::Mutex::new(&db, "locks").await.unwrap());
24//! let one = MyDocument { _id: ObjectId::new(), x: 1 };
25//! let two = MyDocument { _id: ObjectId::new(), x: 1 };
26//! let three = MyDocument { _id: ObjectId::new(), x: 1 };
27//! docs.insert_many(vec![one.clone(), two.clone(), three.clone()]).await.unwrap();
28//!
29//! let one_id = one._id;
30//! let two_id = two._id;
31//! let clock = lock.clone();
32//! let cdocs = docs.clone();
33//! let first = task::spawn(async move {
34//!     let _guard = clock.lock_default([one_id, two_id]).await.unwrap();
35//!     let a = cdocs.find_one(doc! { "_id": one_id }).await.unwrap().unwrap();
36//!     let b = cdocs.find_one(doc! { "_id": two_id }).await.unwrap().unwrap();
37//!     cdocs.update_many(
38//!         doc! { "_id": { "$in": [one_id,two_id] }},
39//!         doc! { "$set": { "x": a.x + b.x } }
40//!     ).await.unwrap();
41//! });
42//!
43//! let two_id = two._id;
44//! let three_id = three._id;
45//! let clock = lock.clone();
46//! let cdocs = docs.clone();
47//! let second = task::spawn(async move {
48//!     let _guard = lock.lock_default([two_id, three_id]).await.unwrap();
49//!     let a = cdocs.find_one(doc! { "_id": two_id }).await.unwrap().unwrap();
50//!     let b = cdocs.find_one(doc! { "_id": three_id }).await.unwrap().unwrap();
51//!     cdocs.update_many(
52//!         doc! { "_id": { "$in": [two_id,three_id] } },
53//!         doc! { "$set": { "x": a.x + b.x } }
54//!     ).await.unwrap();
55//! });
56//!
57//! first.await.unwrap();
58//! second.await.unwrap();
59//!
60//! let a = docs.find_one(doc! { "_id": one_id }).await.unwrap().unwrap().x;
61//! let b = docs.find_one(doc! { "_id": two_id }).await.unwrap().unwrap().x;
62//! let c = docs.find_one(doc! { "_id": three_id }).await.unwrap().unwrap().x;
63//! assert!((a == 2 && b == 3 && c == 3) || (a == 3 && b == 3 && c == 2));
64//! ```
65
66use bson::doc;
67use bson::oid::ObjectId;
68use bson::{Bson, Document};
69use displaydoc::Display;
70use mongodb::{
71    options::IndexOptions,
72    results::{DeleteResult, InsertOneResult},
73    Collection, IndexModel,
74};
75use serde::{Deserialize, Serialize};
76use std::iter::once;
77use std::time::{Duration, Instant};
78use thiserror::Error;
79use tokio::runtime::Handle;
80use tokio::task;
81use tokio::time::sleep;
82
83/// The default timeout used by [`Mutex::lock_default`], [`RwLock::read_default`] and
84/// [`RwLock::write_default`].
85pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
86/// The default wait used by [`Mutex::lock_default`], [`RwLock::read_default`] and
87/// [`RwLock::write_default`].
88pub const DEFAULT_WAIT: Duration = Duration::from_millis(500);
89
90/// Error type for [`Mutex::new`].
91#[derive(Debug, Error, Display)]
92pub enum MutexLockError {
93    /// Failed to acquire lock due to timeout.
94    LockTimeout,
95    /// Failed to get [`ObjectId`] from [`InsertOneResult::inserted_id`].
96    ObjectId,
97    /// Failed attempt to acquire lock: {0}
98    Attempt(mongodb::error::Error),
99    /// Failed to create index: {0}
100    CreateIndex(mongodb::error::Error),
101    /// Failed to serialize to bson: {0}
102    ToBson(bson::ser::Error),
103}
104
105/// Error type for [`Mutex::release`].
106#[derive(Debug, Error, Display)]
107enum ReleaseError {
108    /// Failed to start deleting the lock: {0}
109    PreDelete(mongodb::error::Error),
110    /// Failed to finish deleting the lock.
111    PostDelete,
112}
113
114/// A distributed lock guard that acts like [`std::sync::MutexGuard`].
115#[derive(Debug)]
116pub struct MutexGuard<'a, Key: Clone + Send + Sync + Serialize + 'static> {
117    pub lock: &'a Mutex<Key>,
118    pub id: ObjectId,
119    pub rt: Handle,
120}
121
122/// The document used for backing [`Mutex`].
123#[derive(Debug, Serialize, Deserialize)]
124struct MutexDocument<Key> {
125    /// Lock id
126    pub _id: ObjectId,
127    /// Key used for locking.
128    pub key: Key,
129}
130
131/// A distributed lock that acts like [`std::sync::Mutex`].
132#[derive(Debug)]
133pub struct Mutex<Key: Clone + Send + Sync + Serialize + 'static>(Collection<MutexDocument<Key>>);
134
135impl<Key: Clone + Send + Sync + Serialize + 'static> Mutex<Key> {
136    /// Constructs a new [`Mutex`].
137    ///
138    /// # Errors
139    ///
140    /// When [`mongodb::Collection::create_index`] errors.
141    #[inline]
142    pub async fn new(
143        database: &mongodb::Database,
144        collection: &str,
145    ) -> Result<Self, mongodb::error::Error> {
146        let col = database.collection::<MutexDocument<Key>>(collection);
147        col.create_index(
148            IndexModel::builder()
149                .keys(once((String::from("key"), Bson::Int32(1))).collect::<Document>())
150                .options(IndexOptions::builder().unique(true).build())
151                .build(),
152        )
153        .await?;
154        Ok(Self(col))
155    }
156    /// Create [`Mutex`] without initializing the lock.
157    ///
158    /// This should be used when the lock is already initialized; possibly by another process.
159    #[inline]
160    pub async fn new_uninit(database: &mongodb::Database, collection: &str) -> Self {
161        let col = database.collection::<MutexDocument<Key>>(collection);
162        Self(col)
163    }
164    /// Calls [`Mutex::lock`] with [`DEFAULT_TIMEOUT`] and [`DEFAULT_WAIT`].
165    /// # Errors
166    ///
167    /// When [`Mutex::lock`] errors.
168    #[inline]
169    pub async fn lock_default(&self, key: Key) -> Result<MutexGuard<'_, Key>, MutexLockError> {
170        self.lock(DEFAULT_TIMEOUT, DEFAULT_WAIT, key).await
171    }
172    /// Attempts to lock the given `key` using the given lock `collection`.
173    ///
174    /// Since the Mongodb Rust driver doesn't fully support change streams see
175    /// <https://github.com/mongodb/mongo-rust-driver/issues/1230> a busy polling approach is used
176    /// where it will attempt to acquire the lock for `timeout` sleeping `wait` in between attempts.
177    ///
178    /// In this sense it is like:
179    /// ```
180    /// # use std::time::Duration;
181    /// # use std::time::Instant;
182    /// # fn main() -> Result<(),()> {
183    /// # let rt = tokio::runtime::Runtime::new().unwrap();
184    /// # rt.block_on(async {
185    /// let lock = tokio::sync::Mutex::new(());
186    /// let timeout = Duration::from_secs(1);
187    /// let sleep = Duration::from_millis(100);
188    /// let start = Instant::now();
189    /// let guard = loop {
190    ///     match lock.try_lock() {
191    ///         Ok(guard) => break guard,
192    ///         Err(err) if start.elapsed() > timeout => return Err(()),
193    ///         Err(_) => tokio::time::sleep(sleep).await,
194    ///     }
195    /// };
196    /// // Do some work.
197    /// # Ok(())
198    /// # })
199    /// # }
200    /// ```
201    ///
202    /// # Errors
203    ///
204    /// When:
205    /// - Timing out.
206    /// - [`mongodb::Collection::insert_one`] errors.
207    #[inline]
208    pub async fn lock(
209        &self,
210        timeout: Duration,
211        wait: Duration,
212        key: Key,
213    ) -> Result<MutexGuard<'_, Key>, MutexLockError> {
214        let lock_id = ObjectId::new();
215        let lock_doc = MutexDocument {
216            _id: lock_id,
217            key: key.clone(),
218        };
219
220        let start = Instant::now();
221        loop {
222            if start.elapsed() > timeout {
223                return Err(MutexLockError::LockTimeout);
224            }
225            let insert = self.0.insert_one(&lock_doc).await;
226            match insert {
227                Ok(InsertOneResult { inserted_id, .. }) => {
228                    let id = inserted_id.as_object_id().ok_or(MutexLockError::ObjectId)?;
229                    debug_assert_eq!(id, lock_id, "Document id mismatch");
230                    break Ok(MutexGuard {
231                        lock: self,
232                        id,
233                        rt: Handle::current(),
234                    });
235                }
236                // Wait to retry acquiring the lock.
237                Err(err) if is_duplicate_key_error(&err) => sleep(wait).await,
238                Err(err) => break Err(MutexLockError::Attempt(err)),
239            }
240        }
241    }
242    /// Release the lock.
243    async fn release(&self, lock: ObjectId) -> Result<(), ReleaseError> {
244        let delete = self
245            .0
246            .delete_one(doc! { "_id": lock })
247            .await
248            .map_err(ReleaseError::PreDelete)?;
249        if !matches!(
250            delete,
251            DeleteResult {
252                deleted_count: 1,
253                ..
254            }
255        ) {
256            return Err(ReleaseError::PostDelete);
257        }
258        Ok(())
259    }
260}
261
262// TODO Remove below `expect`.
263#[expect(
264    clippy::unwrap_used,
265    reason = "I do not know a way to propagate the error."
266)]
267impl<Key: Clone + Send + Sync + Serialize + 'static> Drop for MutexGuard<'_, Key> {
268    #[inline]
269    fn drop(&mut self) {
270        let rt = self.rt.clone();
271        let id = self.id;
272        let lock = Mutex(self.lock.0.clone());
273        task::spawn_blocking(move || {
274            rt.block_on(async { lock.release(id).await }).unwrap();
275        });
276    }
277}
278
279/// Check if the error is a duplicate key error.
280#[must_use]
281#[inline]
282pub fn is_duplicate_key_error(error: &mongodb::error::Error) -> bool {
283    if let mongodb::error::ErrorKind::Write(mongodb::error::WriteFailure::WriteError(write_error)) =
284        &*error.kind
285    {
286        write_error.code == 11000 && write_error.message.contains("duplicate key error")
287    } else {
288        false
289    }
290}
291
292/// Error type for [`RwLock::read`].
293#[derive(Debug, Error, Display)]
294pub enum RwLockReadError {
295    /// Failed to query lock: {0}
296    Query(mongodb::error::Error),
297    /// Failed to acquire lock due to timeout.
298    Timeout,
299}
300
301/// Error type for [`RwLock::release_read`].
302#[derive(Debug, Error, Display)]
303enum RwLockReleaseReadError {
304    /// Failed to query lock: {0}
305    Query(mongodb::error::Error),
306    /// Failed to find lock.
307    Find,
308}
309
310/// Error type for [`RwLock::write`].
311#[derive(Debug, Error, Display)]
312pub enum RwLockWriteError {
313    /// Failed to query lock: {0}
314    Query(mongodb::error::Error),
315    /// Failed to acquire lock due to timeout.
316    Timeout,
317}
318
319/// Error type for [`RwLock::release_write`].
320#[derive(Debug, Error, Display)]
321enum RwLockReleaseWriteError {
322    /// Failed to query lock: {0}
323    Query(mongodb::error::Error),
324    /// Failed to find lock.
325    Find,
326}
327
328/// A distributed lock that acts like [`std::sync::RwLock`].
329pub struct RwLock {
330    /// The id of the lock document within the collection.
331    id: ObjectId,
332    /// The collection within which the lock document is stored.
333    collection: Collection<RwLockDocument>,
334}
335impl RwLock {
336    /// Returns the [`ObjectId`] of the underlying lock document stored in the collection.
337    ///
338    /// Intended for usage with [`RwLock::new_uninit`].
339    pub fn id(&self) -> ObjectId {
340        self.id
341    }
342    /// Constructs a new [`RwLock`].
343    ///
344    /// # Errors
345    ///
346    /// When [`mongodb::Collection::insert_one`] errors.
347    #[inline]
348    pub async fn new(
349        database: &mongodb::Database,
350        collection: &str,
351    ) -> Result<Self, mongodb::error::Error> {
352        let col = database.collection(collection);
353        let id = ObjectId::new();
354        col.insert_one(RwLockDocument {
355            _id: id,
356            reads: 0,
357            write: false,
358        })
359        .await?;
360        Ok(Self {
361            id,
362            collection: col,
363        })
364    }
365    /// Create [`RwLock`] without initializing the lock.
366    ///
367    /// This should be used when the lock is already initialized; possibly by another process.
368    #[inline]
369    pub async fn new_uninit(database: &mongodb::Database, collection: &str, id: ObjectId) -> Self {
370        let col = database.collection(collection);
371        Self {
372            id,
373            collection: col,
374        }
375    }
376    /// Calls [`RwLock::read`] with [`DEFAULT_TIMEOUT`] and [`DEFAULT_WAIT`].
377    ///
378    /// # Errors
379    ///
380    /// When [`RwLock::read`] errors.
381    #[inline]
382    pub async fn read_default(&self) -> Result<RwLockReadGuard<'_>, RwLockReadError> {
383        self.read(DEFAULT_TIMEOUT, DEFAULT_WAIT).await
384    }
385    /// Locks for reading.
386    ///
387    /// # Errors
388    ///
389    /// When:
390    /// - Timing out.
391    /// - [`mongodb::Collection::find_one_and_update`] errors.
392    #[inline]
393    pub async fn read(
394        &self,
395        timeout: Duration,
396        wait: Duration,
397    ) -> Result<RwLockReadGuard<'_>, RwLockReadError> {
398        let now = Instant::now();
399        loop {
400            if now.elapsed() > timeout {
401                return Err(RwLockReadError::Timeout);
402            }
403            let result = self
404                .collection
405                .find_one_and_update(
406                    doc! { "_id": self.id, "write": false },
407                    doc! { "$inc": { "reads": 1i32 } },
408                )
409                .await
410                .map_err(RwLockReadError::Query)?;
411            if let Some(RwLockDocument { _id, write, .. }) = result {
412                debug_assert_eq!(write, false, "Write should be false.");
413                break Ok(RwLockReadGuard {
414                    lock: self,
415                    rt: Handle::current(),
416                });
417            }
418            sleep(wait).await;
419        }
420    }
421    /// Release a read lock.
422    async fn release_read(&self) -> Result<(), RwLockReleaseReadError> {
423        let delete = self
424            .collection
425            .find_one_and_update(doc! { "_id": self.id }, doc! { "$inc": {"reads": -1i32} })
426            .await
427            .map_err(RwLockReleaseReadError::Query)?
428            .ok_or(RwLockReleaseReadError::Find)?;
429        debug_assert!(delete.reads > 0i32, "Reads should be greater than 0");
430        debug_assert_eq!(delete.write, false, "Write lock should be false");
431        Ok(())
432    }
433    /// Calls [`RwLock::write`] with [`DEFAULT_TIMEOUT`] and [`DEFAULT_WAIT`].
434    ///
435    /// # Errors
436    ///
437    /// When [`RwLock::write`] errors.
438    #[inline]
439    pub async fn write_default(&self) -> Result<RwLockWriteGuard<'_>, RwLockWriteError> {
440        self.write(DEFAULT_TIMEOUT, DEFAULT_WAIT).await
441    }
442    /// Locks for writing.
443    ///
444    /// # Errors
445    ///
446    /// When:
447    /// - Timing out.
448    /// - [`mongodb::Collection::find_one_and_update`] errors.
449    #[inline]
450    pub async fn write(
451        &self,
452        timeout: Duration,
453        wait: Duration,
454    ) -> Result<RwLockWriteGuard<'_>, RwLockWriteError> {
455        let now = Instant::now();
456        loop {
457            if now.elapsed() > timeout {
458                return Err(RwLockWriteError::Timeout);
459            }
460            let result = self
461                .collection
462                .find_one_and_update(
463                    doc! { "_id": self.id, "reads": 0i32, "write": false },
464                    doc! { "$set": { "write": true } },
465                )
466                .await
467                .map_err(RwLockWriteError::Query)?;
468            if let Some(RwLockDocument { _id, reads, write }) = result {
469                debug_assert_eq!(reads, 0i32, "reads should be >0");
470                debug_assert_eq!(write, false, "write should be false");
471                break Ok(RwLockWriteGuard {
472                    lock: self,
473                    rt: Handle::current(),
474                });
475            }
476            sleep(wait).await;
477        }
478    }
479    /// Releases the write lock.
480    async fn release_write(&self) -> Result<(), RwLockReleaseWriteError> {
481        let delete = self
482            .collection
483            .find_one_and_update(
484                doc! { "_id": self.id, "write": true },
485                doc! { "$set": {"write": false} },
486            )
487            .await
488            .map_err(RwLockReleaseWriteError::Query)?
489            .ok_or(RwLockReleaseWriteError::Find)?;
490        debug_assert_eq!(delete.reads, 0i32, "Reads should be zero");
491        Ok(())
492    }
493}
494
495/// A distributed lock guard that acts like [`std::sync::RwLockReadGuard`].
496pub struct RwLockReadGuard<'a> {
497    /// Lock.
498    lock: &'a RwLock,
499    /// Tokio runtime handle.
500    rt: Handle,
501}
502
503// TODO Remove below `expect`.
504#[expect(
505    clippy::unwrap_used,
506    reason = "I do not know a way to propagate the error."
507)]
508impl Drop for RwLockReadGuard<'_> {
509    #[inline]
510    fn drop(&mut self) {
511        let rt = self.rt.clone();
512        let lock = RwLock {
513            collection: self.lock.collection.clone(),
514            id: self.lock.id,
515        };
516        task::spawn_blocking(move || {
517            rt.block_on(async { lock.release_read().await }).unwrap();
518        });
519    }
520}
521
522/// A distributed lock guard that acts like [`std::sync::RwLockWriteGuard`].
523pub struct RwLockWriteGuard<'a> {
524    /// Lock.
525    lock: &'a RwLock,
526    /// Tokio runtime handle.
527    rt: Handle,
528}
529
530// TODO Remove below `expect`.
531#[expect(
532    clippy::unwrap_used,
533    reason = "I do not know a way to propagate the error."
534)]
535impl Drop for RwLockWriteGuard<'_> {
536    #[inline]
537    fn drop(&mut self) {
538        let rt = self.rt.clone();
539        let lock = RwLock {
540            collection: self.lock.collection.clone(),
541            id: self.lock.id,
542        };
543        task::spawn_blocking(move || {
544            rt.block_on(async { lock.release_write().await }).unwrap();
545        });
546    }
547}
548
549/// The document used for backing [`RwLock`].
550#[derive(Debug, Serialize, Deserialize)]
551struct RwLockDocument {
552    /// Lock id
553    pub _id: ObjectId,
554    /// How many read locks are held.
555    pub reads: i32,
556    /// Is write lock held.
557    pub write: bool,
558}