use bson::doc;
use bson::oid::ObjectId;
use mongodb::Client;
use mongodb_lock::Mutex;
use mongodb_lock::RwLock;
use serde::{Deserialize, Serialize};
use std::net::Ipv4Addr;
use std::process::{Command, Stdio};
use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicBool, AtomicU16};
use std::sync::Arc;
use std::time::Duration;
use tokio::task;
use tokio::time::sleep;
use tracing::info;
struct MongodbClient {
client: mongodb::Client,
_server: Arc<Mongodb>,
}
impl AsRef<Client> for MongodbClient {
fn as_ref(&self) -> &Client {
&self.client
}
}
pub struct Mongodb {
name: String,
port: u16,
}
impl Mongodb {
async fn new() -> Self {
static PORT: AtomicU16 = AtomicU16::new(27021);
let port = PORT.fetch_add(1, Ordering::SeqCst);
let name = uuid::Uuid::new_v4().to_string();
info!("port: {port}");
info!("name: {name}");
let _mongodb = Command::new("atlas")
.args([
"deployments",
"setup",
&name,
"--type",
"local",
"--port",
&port.to_string(),
"--force",
])
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.unwrap();
info!("setup deployment");
sleep(Duration::from_secs(3)).await;
Mongodb { name, port }
}
fn client(this: Arc<Self>) -> MongodbClient {
let host = mongodb::options::ServerAddress::Tcp {
host: Ipv4Addr::LOCALHOST.to_string(),
port: Some(this.port),
};
let opts = mongodb::options::ClientOptions::builder()
.hosts(vec![host])
.direct_connection(true)
.build();
let client = mongodb::Client::with_options(opts).unwrap();
MongodbClient {
client,
_server: this,
}
}
}
impl Drop for Mongodb {
fn drop(&mut self) {
info!("deleting deployment");
let _mongodb = Command::new("atlas")
.args(["deployments", "delete", &self.name, "--force"])
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.unwrap();
info!("deleted");
}
}
#[tokio::test]
async fn single() {
use mongodb_lock::*;
#[derive(Clone, Serialize, Deserialize)]
struct MyDocument {
_id: ObjectId,
x: i32,
}
let _guard = tracing::subscriber::set_default(
tracing_subscriber::fmt::Subscriber::builder()
.with_test_writer()
.finish(),
);
let mongodb = Arc::new(Mongodb::new().await);
let client = Mongodb::client(mongodb);
let db = client.as_ref().database("basic");
let docs = db.collection::<MyDocument>("docs");
let lock = Arc::new(Mutex::new(&db, "locks").await.unwrap());
let one = MyDocument {
_id: ObjectId::new(),
x: 1,
};
let two = MyDocument {
_id: ObjectId::new(),
x: 1,
};
let three = MyDocument {
_id: ObjectId::new(),
x: 1,
};
docs.insert_many(vec![one.clone(), two.clone(), three.clone()])
.await
.unwrap();
let one_id = one._id;
let two_id = two._id;
let clock = lock.clone();
let cdocs = docs.clone();
let first = task::spawn(async move {
let _guard = clock.lock_default([one_id, two_id]).await.unwrap();
let a = cdocs
.find_one(doc! { "_id": one_id })
.await
.unwrap()
.unwrap();
let b = cdocs
.find_one(doc! { "_id": two_id })
.await
.unwrap()
.unwrap();
cdocs
.update_many(
doc! { "_id": { "$in": [one_id,two_id] }},
doc! { "$set": { "x": a.x + b.x } },
)
.await
.unwrap();
});
let two_id = two._id;
let three_id = three._id;
let clock = lock.clone();
let cdocs = docs.clone();
let second = task::spawn(async move {
let _guard = clock.lock_default([two_id, three_id]).await.unwrap();
let a = cdocs
.find_one(doc! { "_id": two_id })
.await
.unwrap()
.unwrap();
let b = cdocs
.find_one(doc! { "_id": three_id })
.await
.unwrap()
.unwrap();
cdocs
.update_many(
doc! { "_id": { "$in": [two_id,three_id] } },
doc! { "$set": { "x": a.x + b.x } },
)
.await
.unwrap();
});
first.await.unwrap();
second.await.unwrap();
let a = docs
.find_one(doc! { "_id": one_id })
.await
.unwrap()
.unwrap()
.x;
let b = docs
.find_one(doc! { "_id": two_id })
.await
.unwrap()
.unwrap()
.x;
let c = docs
.find_one(doc! { "_id": three_id })
.await
.unwrap()
.unwrap()
.x;
assert!((a == 2 && b == 3 && c == 3) || (a == 3 && b == 3 && c == 2));
}
#[tokio::test]
async fn adder() {
#[derive(Debug, Serialize, Deserialize)]
struct Number {
_id: ObjectId,
x: i32,
}
const N: i32 = 10;
const A: i32 = 0;
const B: i32 = 10;
static CHECK_ONE: AtomicBool = AtomicBool::new(false);
static CHECK_TWO: AtomicBool = AtomicBool::new(false);
let _guard = tracing::subscriber::set_default(
tracing_subscriber::fmt::Subscriber::builder()
.with_test_writer()
.finish(),
);
let mongodb = Arc::new(Mongodb::new().await);
let client = Mongodb::client(mongodb);
let db = client.as_ref().database("adder");
let lock = Arc::new(Mutex::new(&db, "locks").await.unwrap());
let cola = db.collection::<Number>("first");
let colb = db.collection::<Number>("second");
let ida1 = ObjectId::new();
let ida2 = ObjectId::new();
cola.insert_one(Number { _id: ida1, x: A }).await.unwrap();
cola.insert_one(Number { _id: ida2, x: B }).await.unwrap();
let idb1 = ObjectId::new();
let idb2 = ObjectId::new();
colb.insert_one(Number { _id: idb1, x: A }).await.unwrap();
colb.insert_one(Number { _id: idb2, x: B }).await.unwrap();
let tasks = (0..N)
.flat_map(|_| {
let (clock, ccola, ccolb, cida1, cidb1) =
(lock.clone(), cola.clone(), colb.clone(), ida1, idb1);
let one = task::spawn(async move {
let guard = clock.lock_default([cida1, cidb1]).await.unwrap();
assert!(!CHECK_ONE.swap(true, Ordering::SeqCst));
let num = ccola
.find_one(doc! { "_id": cida1 })
.await
.unwrap()
.unwrap();
ccolb
.update_one(doc! { "_id": cidb1 }, doc! { "$set": { "x": num.x + 1 } })
.await
.unwrap();
let num = ccolb
.find_one(doc! { "_id": cidb1 })
.await
.unwrap()
.unwrap();
ccola
.update_one(doc! { "_id": cida1 }, doc! { "$set": { "x": num.x + 1 } })
.await
.unwrap();
assert!(CHECK_ONE.swap(false, Ordering::SeqCst));
drop(guard);
});
let (clock, ccola, ccolb, cida2, cidb2) =
(lock.clone(), cola.clone(), colb.clone(), ida2, idb2);
let two = task::spawn(async move {
let guard = clock.lock_default([cida2, cidb2]).await.unwrap();
assert!(!CHECK_TWO.swap(true, Ordering::SeqCst));
let num = ccola
.find_one(doc! { "_id": cida2 })
.await
.unwrap()
.unwrap();
ccolb
.update_one(doc! { "_id": cidb2 }, doc! { "$set": { "x": num.x + 1 } })
.await
.unwrap();
let num = ccolb
.find_one(doc! { "_id": cidb2 })
.await
.unwrap()
.unwrap();
ccola
.update_one(doc! { "_id": cida2 }, doc! { "$set": { "x": num.x + 1 } })
.await
.unwrap();
assert!(CHECK_TWO.swap(false, Ordering::SeqCst));
drop(guard);
});
[one, two]
})
.collect::<Vec<_>>();
for task in tasks {
task.await.unwrap();
}
const T: i32 = N * 2;
let numa1 = cola.find_one(doc! { "_id": ida1 }).await.unwrap().unwrap();
let numb1 = colb.find_one(doc! { "_id": idb1 }).await.unwrap().unwrap();
info!("numa1: {numa1:?}");
info!("numb1: {numb1:?}");
assert!(
(numa1.x == T - A + 1 && numb1.x == T + A) || (numb1.x == T + A - 1 && numa1.x == T + A)
);
let numa2 = cola.find_one(doc! { "_id": ida2 }).await.unwrap().unwrap();
let numb2 = colb.find_one(doc! { "_id": idb2 }).await.unwrap().unwrap();
info!("numa2: {numa2:?}");
info!("numb2: {numb2:?}");
assert!(
(numa2.x == T + B - 1 && numb2.x == T + B) || (numb2.x == T + B - 1 && numa2.x == T + B)
);
}
#[tokio::test]
async fn reader() {
#[derive(Debug, Serialize, Deserialize)]
struct Number {
_id: ObjectId,
x: i32,
}
const READS: usize = 10;
const WRITE: usize = 10;
let _guard = tracing::subscriber::set_default(
tracing_subscriber::fmt::Subscriber::builder()
.with_test_writer()
.finish(),
);
let mongodb = Arc::new(Mongodb::new().await);
let client = Mongodb::client(mongodb);
let db = client.as_ref().database("adder");
let lock = Arc::new(RwLock::new(&db, "locks").await.unwrap());
let col = db.collection::<Number>("first");
let id = ObjectId::new();
col.insert_one(Number { _id: id, x: 0 }).await.unwrap();
let reads = (0..READS)
.map(|_| {
let clock = lock.clone();
let ccol = col.clone();
let cid = id.clone();
task::spawn(async move {
let _guard = clock.read_default().await.unwrap();
let a = ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x;
assert_eq!(
ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x,
a
);
assert_eq!(
ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x,
a
);
assert_eq!(
ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x,
a
);
})
})
.collect::<Vec<_>>();
let writes = (0..WRITE)
.map(|_| {
let clock = lock.clone();
let ccol = col.clone();
let cid = id.clone();
task::spawn(async move {
let _guard = clock.write_default().await.unwrap();
let a = ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x;
ccol.update_one(doc! {"_id": cid}, doc! { "$inc": { "x": 1i32 } })
.await
.unwrap();
assert_eq!(
ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x,
a + 1
);
})
})
.collect::<Vec<_>>();
for read in reads {
read.await.unwrap();
}
for write in writes {
write.await.unwrap();
}
}