mongodb_lock/lib.rs
1//! Rusty distributed locking backed by Mongodb.
2//!
3//! All [`Mutex`]s can share the same collection (even with different `Key`s) so long as all the
4//! `Key`s in the collection are unique. I would recommend using different collections for different
5//! `Key`s and different collections for each type of operation.
6//!
7//! All [`RwLock`]s can share the same collection. I would recommend using the same collection.
8//!
9//! ## Similar works
10//!
11//! - <https://github.com/square/mongo-lock>
12//!
13//! ## Example
14//!
15//! ```ignore
16//! #[derive(Clone, Serialize, Deserialize)]
17//! struct MyDocument {
18//! _id: ObjectId,
19//! x: i32,
20//! }
21//! let db = client.database("basic");
22//! let docs = db.collection::<MyDocument>("docs");
23//! let lock = Arc::new(mongodb_lock::Mutex::new(&db, "locks").await.unwrap());
24//! let one = MyDocument { _id: ObjectId::new(), x: 1 };
25//! let two = MyDocument { _id: ObjectId::new(), x: 1 };
26//! let three = MyDocument { _id: ObjectId::new(), x: 1 };
27//! docs.insert_many(vec![one.clone(), two.clone(), three.clone()]).await.unwrap();
28//!
29//! let one_id = one._id;
30//! let two_id = two._id;
31//! let clock = lock.clone();
32//! let cdocs = docs.clone();
33//! let first = task::spawn(async move {
34//! let _guard = clock.lock_default([one_id, two_id]).await.unwrap();
35//! let a = cdocs.find_one(doc! { "_id": one_id }).await.unwrap().unwrap();
36//! let b = cdocs.find_one(doc! { "_id": two_id }).await.unwrap().unwrap();
37//! cdocs.update_many(
38//! doc! { "_id": { "$in": [one_id,two_id] }},
39//! doc! { "$set": { "x": a.x + b.x } }
40//! ).await.unwrap();
41//! });
42//!
43//! let two_id = two._id;
44//! let three_id = three._id;
45//! let clock = lock.clone();
46//! let cdocs = docs.clone();
47//! let second = task::spawn(async move {
48//! let _guard = lock.lock_default([two_id, three_id]).await.unwrap();
49//! let a = cdocs.find_one(doc! { "_id": two_id }).await.unwrap().unwrap();
50//! let b = cdocs.find_one(doc! { "_id": three_id }).await.unwrap().unwrap();
51//! cdocs.update_many(
52//! doc! { "_id": { "$in": [two_id,three_id] } },
53//! doc! { "$set": { "x": a.x + b.x } }
54//! ).await.unwrap();
55//! });
56//!
57//! first.await.unwrap();
58//! second.await.unwrap();
59//!
60//! let a = docs.find_one(doc! { "_id": one_id }).await.unwrap().unwrap().x;
61//! let b = docs.find_one(doc! { "_id": two_id }).await.unwrap().unwrap().x;
62//! let c = docs.find_one(doc! { "_id": three_id }).await.unwrap().unwrap().x;
63//! assert!((a == 2 && b == 3 && c == 3) || (a == 3 && b == 3 && c == 2));
64//! ```
65
66use bson::doc;
67use bson::oid::ObjectId;
68use bson::{Bson, Document};
69use displaydoc::Display;
70use mongodb::{
71 options::IndexOptions,
72 results::{DeleteResult, InsertOneResult},
73 Collection, IndexModel,
74};
75use serde::{Deserialize, Serialize};
76use std::iter::once;
77use std::time::{Duration, Instant};
78use thiserror::Error;
79use tokio::runtime::Handle;
80use tokio::task;
81use tokio::time::sleep;
82
83/// The default timeout used by [`Mutex::lock_default`], [`RwLock::read_default`] and
84/// [`RwLock::write_default`].
85pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);
86/// The default wait used by [`Mutex::lock_default`], [`RwLock::read_default`] and
87/// [`RwLock::write_default`].
88pub const DEFAULT_WAIT: Duration = Duration::from_millis(500);
89
90/// Error type for [`Mutex::new`].
91#[derive(Debug, Error, Display)]
92pub enum MutexLockError {
93 /// Failed to acquire lock due to timeout.
94 LockTimeout,
95 /// Failed to get [`ObjectId`] from [`InsertOneResult::inserted_id`].
96 ObjectId,
97 /// Failed attempt to acquire lock: {0}
98 Attempt(mongodb::error::Error),
99 /// Failed to create index: {0}
100 CreateIndex(mongodb::error::Error),
101 /// Failed to serialize to bson: {0}
102 ToBson(bson::ser::Error),
103}
104
105/// Error type for [`Mutex::release`].
106#[derive(Debug, Error, Display)]
107enum ReleaseError {
108 /// Failed to start deleting the lock: {0}
109 PreDelete(mongodb::error::Error),
110 /// Failed to finish deleting the lock.
111 PostDelete,
112}
113
114/// A distributed lock guard that acts like [`std::sync::MutexGuard`].
115#[derive(Debug)]
116pub struct MutexGuard<'a, Key: Clone + Send + Sync + Serialize + 'static> {
117 pub lock: &'a Mutex<Key>,
118 pub id: ObjectId,
119 pub rt: Handle,
120}
121
122/// The document used for backing [`Mutex`].
123#[derive(Debug, Serialize, Deserialize)]
124struct MutexDocument<Key> {
125 /// Lock id
126 pub _id: ObjectId,
127 /// Key used for locking.
128 pub key: Key,
129}
130
131/// A distributed lock that acts like [`std::sync::Mutex`].
132#[derive(Debug)]
133pub struct Mutex<Key: Clone + Send + Sync + Serialize + 'static>(Collection<MutexDocument<Key>>);
134
135impl<Key: Clone + Send + Sync + Serialize + 'static> Mutex<Key> {
136 /// Constructs a new [`Mutex`].
137 ///
138 /// # Errors
139 ///
140 /// When [`mongodb::Collection::create_index`] errors.
141 #[inline]
142 pub async fn new(
143 database: &mongodb::Database,
144 collection: &str,
145 ) -> Result<Self, mongodb::error::Error> {
146 let col = database.collection::<MutexDocument<Key>>(collection);
147 col.create_index(
148 IndexModel::builder()
149 .keys(once((String::from("key"), Bson::Int32(1))).collect::<Document>())
150 .options(IndexOptions::builder().unique(true).build())
151 .build(),
152 )
153 .await?;
154 Ok(Self(col))
155 }
156 /// Create [`Mutex`] without initializing the lock.
157 ///
158 /// This should be used when the lock is already initialized; possibly by another process.
159 #[inline]
160 pub async fn new_uninit(database: &mongodb::Database, collection: &str) -> Self {
161 let col = database.collection::<MutexDocument<Key>>(collection);
162 Self(col)
163 }
164 /// Calls [`Mutex::lock`] with [`DEFAULT_TIMEOUT`] and [`DEFAULT_WAIT`].
165 /// # Errors
166 ///
167 /// When [`Mutex::lock`] errors.
168 #[inline]
169 pub async fn lock_default(&self, key: Key) -> Result<MutexGuard<'_, Key>, MutexLockError> {
170 self.lock(DEFAULT_TIMEOUT, DEFAULT_WAIT, key).await
171 }
172 /// Attempts to lock the given `key` using the given lock `collection`.
173 ///
174 /// Since the Mongodb Rust driver doesn't fully support change streams see
175 /// <https://github.com/mongodb/mongo-rust-driver/issues/1230> a busy polling approach is used
176 /// where it will attempt to acquire the lock for `timeout` sleeping `wait` in between attempts.
177 ///
178 /// In this sense it is like:
179 /// ```
180 /// # use std::time::Duration;
181 /// # use std::time::Instant;
182 /// # fn main() -> Result<(),()> {
183 /// # let rt = tokio::runtime::Runtime::new().unwrap();
184 /// # rt.block_on(async {
185 /// let lock = tokio::sync::Mutex::new(());
186 /// let timeout = Duration::from_secs(1);
187 /// let sleep = Duration::from_millis(100);
188 /// let start = Instant::now();
189 /// let guard = loop {
190 /// match lock.try_lock() {
191 /// Ok(guard) => break guard,
192 /// Err(err) if start.elapsed() > timeout => return Err(()),
193 /// Err(_) => tokio::time::sleep(sleep).await,
194 /// }
195 /// };
196 /// // Do some work.
197 /// # Ok(())
198 /// # })
199 /// # }
200 /// ```
201 ///
202 /// # Errors
203 ///
204 /// When:
205 /// - Timing out.
206 /// - [`mongodb::Collection::insert_one`] errors.
207 #[inline]
208 pub async fn lock(
209 &self,
210 timeout: Duration,
211 wait: Duration,
212 key: Key,
213 ) -> Result<MutexGuard<'_, Key>, MutexLockError> {
214 let lock_id = ObjectId::new();
215 let lock_doc = MutexDocument {
216 _id: lock_id,
217 key: key.clone(),
218 };
219
220 let start = Instant::now();
221 loop {
222 if start.elapsed() > timeout {
223 return Err(MutexLockError::LockTimeout);
224 }
225 let insert = self.0.insert_one(&lock_doc).await;
226 match insert {
227 Ok(InsertOneResult { inserted_id, .. }) => {
228 let id = inserted_id.as_object_id().ok_or(MutexLockError::ObjectId)?;
229 debug_assert_eq!(id, lock_id, "Document id mismatch");
230 break Ok(MutexGuard {
231 lock: self,
232 id,
233 rt: Handle::current(),
234 });
235 }
236 // Wait to retry acquiring the lock.
237 Err(err) if is_duplicate_key_error(&err) => sleep(wait).await,
238 Err(err) => break Err(MutexLockError::Attempt(err)),
239 }
240 }
241 }
242 /// Release the lock.
243 async fn release(&self, lock: ObjectId) -> Result<(), ReleaseError> {
244 let delete = self
245 .0
246 .delete_one(doc! { "_id": lock })
247 .await
248 .map_err(ReleaseError::PreDelete)?;
249 if !matches!(
250 delete,
251 DeleteResult {
252 deleted_count: 1,
253 ..
254 }
255 ) {
256 return Err(ReleaseError::PostDelete);
257 }
258 Ok(())
259 }
260}
261
262// TODO Remove below `expect`.
263#[expect(
264 clippy::unwrap_used,
265 reason = "I do not know a way to propagate the error."
266)]
267impl<Key: Clone + Send + Sync + Serialize + 'static> Drop for MutexGuard<'_, Key> {
268 #[inline]
269 fn drop(&mut self) {
270 let rt = self.rt.clone();
271 let id = self.id;
272 let lock = Mutex(self.lock.0.clone());
273 task::spawn_blocking(move || {
274 rt.block_on(async { lock.release(id).await }).unwrap();
275 });
276 }
277}
278
279/// Check if the error is a duplicate key error.
280#[must_use]
281#[inline]
282pub fn is_duplicate_key_error(error: &mongodb::error::Error) -> bool {
283 if let mongodb::error::ErrorKind::Write(mongodb::error::WriteFailure::WriteError(write_error)) =
284 &*error.kind
285 {
286 write_error.code == 11000 && write_error.message.contains("duplicate key error")
287 } else {
288 false
289 }
290}
291
292/// Error type for [`RwLock::read`].
293#[derive(Debug, Error, Display)]
294pub enum RwLockReadError {
295 /// Failed to query lock: {0}
296 Query(mongodb::error::Error),
297 /// Failed to acquire lock due to timeout.
298 Timeout,
299}
300
301/// Error type for [`RwLock::release_read`].
302#[derive(Debug, Error, Display)]
303enum RwLockReleaseReadError {
304 /// Failed to query lock: {0}
305 Query(mongodb::error::Error),
306 /// Failed to find lock.
307 Find,
308}
309
310/// Error type for [`RwLock::write`].
311#[derive(Debug, Error, Display)]
312pub enum RwLockWriteError {
313 /// Failed to query lock: {0}
314 Query(mongodb::error::Error),
315 /// Failed to acquire lock due to timeout.
316 Timeout,
317}
318
319/// Error type for [`RwLock::release_write`].
320#[derive(Debug, Error, Display)]
321enum RwLockReleaseWriteError {
322 /// Failed to query lock: {0}
323 Query(mongodb::error::Error),
324 /// Failed to find lock.
325 Find,
326}
327
328/// A distributed lock that acts like [`std::sync::RwLock`].
329pub struct RwLock {
330 /// The id of the lock document within the collection.
331 id: ObjectId,
332 /// The collection within which the lock document is stored.
333 collection: Collection<RwLockDocument>,
334}
335impl RwLock {
336 /// Returns the [`ObjectId`] of the underlying lock document stored in the collection.
337 ///
338 /// Intended for usage with [`RwLock::new_uninit`].
339 pub fn id(&self) -> ObjectId {
340 self.id
341 }
342 /// Constructs a new [`RwLock`].
343 ///
344 /// # Errors
345 ///
346 /// When [`mongodb::Collection::insert_one`] errors.
347 #[inline]
348 pub async fn new(
349 database: &mongodb::Database,
350 collection: &str,
351 ) -> Result<Self, mongodb::error::Error> {
352 let col = database.collection(collection);
353 let id = ObjectId::new();
354 col.insert_one(RwLockDocument {
355 _id: id,
356 reads: 0,
357 write: false,
358 })
359 .await?;
360 Ok(Self {
361 id,
362 collection: col,
363 })
364 }
365 /// Create [`RwLock`] without initializing the lock.
366 ///
367 /// This should be used when the lock is already initialized; possibly by another process.
368 #[inline]
369 pub async fn new_uninit(database: &mongodb::Database, collection: &str, id: ObjectId) -> Self {
370 let col = database.collection(collection);
371 Self {
372 id,
373 collection: col,
374 }
375 }
376 /// Calls [`RwLock::read`] with [`DEFAULT_TIMEOUT`] and [`DEFAULT_WAIT`].
377 ///
378 /// # Errors
379 ///
380 /// When [`RwLock::read`] errors.
381 #[inline]
382 pub async fn read_default(&self) -> Result<RwLockReadGuard<'_>, RwLockReadError> {
383 self.read(DEFAULT_TIMEOUT, DEFAULT_WAIT).await
384 }
385 /// Locks for reading.
386 ///
387 /// # Errors
388 ///
389 /// When:
390 /// - Timing out.
391 /// - [`mongodb::Collection::find_one_and_update`] errors.
392 #[inline]
393 pub async fn read(
394 &self,
395 timeout: Duration,
396 wait: Duration,
397 ) -> Result<RwLockReadGuard<'_>, RwLockReadError> {
398 let now = Instant::now();
399 loop {
400 if now.elapsed() > timeout {
401 return Err(RwLockReadError::Timeout);
402 }
403 let result = self
404 .collection
405 .find_one_and_update(
406 doc! { "_id": self.id, "write": false },
407 doc! { "$inc": { "reads": 1i32 } },
408 )
409 .await
410 .map_err(RwLockReadError::Query)?;
411 if let Some(RwLockDocument { _id, write, .. }) = result {
412 debug_assert_eq!(write, false, "Write should be false.");
413 break Ok(RwLockReadGuard {
414 lock: self,
415 rt: Handle::current(),
416 });
417 }
418 sleep(wait).await;
419 }
420 }
421 /// Release a read lock.
422 async fn release_read(&self) -> Result<(), RwLockReleaseReadError> {
423 let delete = self
424 .collection
425 .find_one_and_update(doc! { "_id": self.id }, doc! { "$inc": {"reads": -1i32} })
426 .await
427 .map_err(RwLockReleaseReadError::Query)?
428 .ok_or(RwLockReleaseReadError::Find)?;
429 debug_assert!(delete.reads > 0i32, "Reads should be greater than 0");
430 debug_assert_eq!(delete.write, false, "Write lock should be false");
431 Ok(())
432 }
433 /// Calls [`RwLock::write`] with [`DEFAULT_TIMEOUT`] and [`DEFAULT_WAIT`].
434 ///
435 /// # Errors
436 ///
437 /// When [`RwLock::write`] errors.
438 #[inline]
439 pub async fn write_default(&self) -> Result<RwLockWriteGuard<'_>, RwLockWriteError> {
440 self.write(DEFAULT_TIMEOUT, DEFAULT_WAIT).await
441 }
442 /// Locks for writing.
443 ///
444 /// # Errors
445 ///
446 /// When:
447 /// - Timing out.
448 /// - [`mongodb::Collection::find_one_and_update`] errors.
449 #[inline]
450 pub async fn write(
451 &self,
452 timeout: Duration,
453 wait: Duration,
454 ) -> Result<RwLockWriteGuard<'_>, RwLockWriteError> {
455 let now = Instant::now();
456 loop {
457 if now.elapsed() > timeout {
458 return Err(RwLockWriteError::Timeout);
459 }
460 let result = self
461 .collection
462 .find_one_and_update(
463 doc! { "_id": self.id, "reads": 0i32, "write": false },
464 doc! { "$set": { "write": true } },
465 )
466 .await
467 .map_err(RwLockWriteError::Query)?;
468 if let Some(RwLockDocument { _id, reads, write }) = result {
469 debug_assert_eq!(reads, 0i32, "reads should be >0");
470 debug_assert_eq!(write, false, "write should be false");
471 break Ok(RwLockWriteGuard {
472 lock: self,
473 rt: Handle::current(),
474 });
475 }
476 sleep(wait).await;
477 }
478 }
479 /// Releases the write lock.
480 async fn release_write(&self) -> Result<(), RwLockReleaseWriteError> {
481 let delete = self
482 .collection
483 .find_one_and_update(
484 doc! { "_id": self.id, "write": true },
485 doc! { "$set": {"write": false} },
486 )
487 .await
488 .map_err(RwLockReleaseWriteError::Query)?
489 .ok_or(RwLockReleaseWriteError::Find)?;
490 debug_assert_eq!(delete.reads, 0i32, "Reads should be zero");
491 Ok(())
492 }
493}
494
495/// A distributed lock guard that acts like [`std::sync::RwLockReadGuard`].
496pub struct RwLockReadGuard<'a> {
497 /// Lock.
498 lock: &'a RwLock,
499 /// Tokio runtime handle.
500 rt: Handle,
501}
502
503// TODO Remove below `expect`.
504#[expect(
505 clippy::unwrap_used,
506 reason = "I do not know a way to propagate the error."
507)]
508impl Drop for RwLockReadGuard<'_> {
509 #[inline]
510 fn drop(&mut self) {
511 let rt = self.rt.clone();
512 let lock = RwLock {
513 collection: self.lock.collection.clone(),
514 id: self.lock.id,
515 };
516 task::spawn_blocking(move || {
517 rt.block_on(async { lock.release_read().await }).unwrap();
518 });
519 }
520}
521
522/// A distributed lock guard that acts like [`std::sync::RwLockWriteGuard`].
523pub struct RwLockWriteGuard<'a> {
524 /// Lock.
525 lock: &'a RwLock,
526 /// Tokio runtime handle.
527 rt: Handle,
528}
529
530// TODO Remove below `expect`.
531#[expect(
532 clippy::unwrap_used,
533 reason = "I do not know a way to propagate the error."
534)]
535impl Drop for RwLockWriteGuard<'_> {
536 #[inline]
537 fn drop(&mut self) {
538 let rt = self.rt.clone();
539 let lock = RwLock {
540 collection: self.lock.collection.clone(),
541 id: self.lock.id,
542 };
543 task::spawn_blocking(move || {
544 rt.block_on(async { lock.release_write().await }).unwrap();
545 });
546 }
547}
548
549/// The document used for backing [`RwLock`].
550#[derive(Debug, Serialize, Deserialize)]
551struct RwLockDocument {
552 /// Lock id
553 pub _id: ObjectId,
554 /// How many read locks are held.
555 pub reads: i32,
556 /// Is write lock held.
557 pub write: bool,
558}