#![doc(issue_tracker_base_url = "https://github.com/lazureykis/mongo-lock/issues")]
mod error;
mod util;
pub use error::Error;
const COLLECTION_NAME: &str = "locks";
const DEFAULT_DB_NAME: &str = "mongo-lock";
use mongodb::bson::{doc, Document};
use mongodb::error::{ErrorKind, WriteError, WriteFailure};
use mongodb::options::{IndexOptions, UpdateOptions};
use mongodb::sync::{Client, Collection};
use mongodb::IndexModel;
use std::time::Duration;
#[inline]
fn collection(mongo: &Client) -> Collection<Document> {
mongo
.default_database()
.unwrap_or_else(|| mongo.database(DEFAULT_DB_NAME))
.collection(COLLECTION_NAME)
}
pub fn prepare_database(mongo: &Client) -> Result<(), Error> {
let options = IndexOptions::builder()
.expire_after(Some(Duration::from_secs(0)))
.build();
let model = IndexModel::builder()
.keys(doc! {"expiresAt": 1})
.options(options)
.build();
collection(mongo).create_index(model, None)?;
Ok(())
}
pub struct Lock {
mongo: Client,
id: String,
acquired: bool,
}
impl Lock {
pub fn try_acquire(mongo: &Client, key: &str, ttl: Duration) -> Result<Option<Lock>, Error> {
let (now, expires_at) = util::now_and_expires_at(ttl);
let query = doc! {
"_id": key,
"expiresAt": {"$lte": now},
};
let update = doc! {
"$set": {
"expiresAt": expires_at,
},
"$setOnInsert": {
"_id": key,
},
};
let options = UpdateOptions::builder().upsert(true).build();
match collection(mongo).update_one(query, update, options) {
Ok(result) => {
if result.upserted_id.is_some() || result.modified_count == 1 {
Ok(Some(Lock {
mongo: mongo.clone(),
id: key.to_string(),
acquired: true,
}))
} else {
Ok(None)
}
}
Err(err) => {
if let ErrorKind::Write(WriteFailure::WriteError(WriteError {
code: 11000, ..
})) = *err.kind
{
Ok(None)
} else {
Err(err.into())
}
}
}
}
pub fn try_acquire_with_timeout(
mongo: &Client,
key: &str,
key_ttl: Duration,
lock_wait_timeout: Duration,
lock_poll_interval: Duration,
) -> Result<Option<Lock>, Error> {
let start = std::time::Instant::now();
loop {
match Lock::try_acquire(mongo, key, key_ttl)? {
Some(lock) => return Ok(Some(lock)),
None => {
if start.elapsed() > lock_wait_timeout {
return Ok(None);
}
std::thread::sleep(lock_poll_interval);
}
}
}
}
fn release(&mut self) -> Result<bool, mongodb::error::Error> {
if self.acquired {
let result = collection(&self.mongo).delete_one(doc! {"_id": &self.id}, None)?;
self.acquired = false;
Ok(result.deleted_count == 1)
} else {
Ok(false)
}
}
}
impl Drop for Lock {
fn drop(&mut self) {
self.release().ok();
}
}
#[cfg(test)]
mod tests {
use super::*;
fn gen_random_key() -> String {
use rand::{distributions::Alphanumeric, thread_rng, Rng};
thread_rng()
.sample_iter(&Alphanumeric)
.take(30)
.map(char::from)
.collect()
}
#[test]
fn simple_locks() {
let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
prepare_database(&mongo).unwrap();
let key1 = gen_random_key();
let key2 = gen_random_key();
let lock1 = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5)).unwrap();
assert!(lock1.is_some());
let lock1_dup = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5)).unwrap();
assert!(lock1_dup.is_none());
let released1 = lock1.unwrap().release().unwrap();
assert!(released1);
let lock1 = Lock::try_acquire(&mongo, &key1, Duration::from_secs(5)).unwrap();
assert!(lock1.is_some());
let lock2 = Lock::try_acquire(&mongo, &key2, Duration::from_secs(5)).unwrap();
assert!(lock2.is_some());
lock1.unwrap().release().unwrap();
lock2.unwrap().release().unwrap();
}
#[test]
fn with_ttl() {
let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
prepare_database(&mongo).unwrap();
let key = gen_random_key();
let lock = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
assert!(lock.is_some());
assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
.unwrap()
.is_none());
std::thread::sleep(Duration::from_secs(1));
assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
.unwrap()
.is_some());
}
#[test]
fn with_ttl_and_retry() {
let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
prepare_database(&mongo).unwrap();
let key = gen_random_key();
let lock = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
assert!(lock.is_some());
let time = std::time::Instant::now();
let lock2 = Lock::try_acquire_with_timeout(
&mongo,
&key,
Duration::from_secs(1),
Duration::from_secs(3),
Duration::from_millis(100),
)
.unwrap();
assert!(lock2.is_some());
assert!(time.elapsed() > Duration::from_secs(1));
}
#[test]
fn dropped_locks() {
let mongo = Client::with_uri_str("mongodb://localhost").unwrap();
prepare_database(&mongo).unwrap();
let key = gen_random_key();
{
assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
.unwrap()
.is_some());
}
{
assert!(Lock::try_acquire(&mongo, &key, Duration::from_secs(1))
.unwrap()
.is_some());
}
let lock1 = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
let lock2 = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
assert!(lock1.is_some());
assert!(lock2.is_none());
drop(lock1);
let lock3 = Lock::try_acquire(&mongo, &key, Duration::from_secs(1)).unwrap();
assert!(lock3.is_some());
}
}