use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use parking_lot::RwLock;
use uuid::Uuid;
use crate::database::Database;
use crate::document::value::Value;
use crate::error::{GrumpyError, Result};
use crate::server::GrumpyServer;
#[derive(Clone)]
pub struct SharedDatabase {
inner: Arc<RwLock<Database>>,
}
impl SharedDatabase {
pub fn new(db: Database) -> Self {
Self {
inner: Arc::new(RwLock::new(db)),
}
}
pub fn open(path: &Path) -> Result<Self> {
let db = Database::open(path)?;
Ok(Self::new(db))
}
pub fn create_collection(&self, name: &str) -> Result<()> {
self.inner.write().create_collection(name)
}
pub fn drop_collection(&self, name: &str) -> Result<()> {
self.inner.write().drop_collection(name)
}
pub fn list_collections(&self) -> Vec<String> {
let db = self.inner.read();
db.list_collections().into_iter().map(|s| s.to_string()).collect()
}
pub fn insert(&self, collection: &str, key: Uuid, value: Value) -> Result<()> {
self.inner.write().insert(collection, key, value)
}
pub fn get(&self, collection: &str, key: &Uuid) -> Result<Option<Value>> {
self.inner.write().get(collection, key)
}
pub fn update(&self, collection: &str, key: &Uuid, value: Value) -> Result<()> {
self.inner.write().update(collection, key, value)
}
pub fn delete(&self, collection: &str, key: &Uuid) -> Result<()> {
self.inner.write().delete(collection, key)
}
pub fn scan(
&self,
collection: &str,
range: impl std::ops::RangeBounds<Uuid>,
) -> Result<Vec<(Uuid, Value)>> {
self.inner.write().scan(collection, range)
}
pub fn create_index(
&self,
collection: &str,
index_name: &str,
field_path: &str,
) -> Result<()> {
self.inner.write().create_index(collection, index_name, field_path)
}
pub fn drop_index(&self, collection: &str, index_name: &str) -> Result<()> {
self.inner.write().drop_index(collection, index_name)
}
pub fn query(
&self,
collection: &str,
index_name: &str,
value: &Value,
) -> Result<Vec<(Uuid, Value)>> {
self.inner.write().query(collection, index_name, value)
}
pub fn query_range(
&self,
collection: &str,
index_name: &str,
start: &Value,
end: &Value,
) -> Result<Vec<(Uuid, Value)>> {
self.inner.write().query_range(collection, index_name, start, end)
}
pub fn resolve_ref(&self, collection: &str, id: &Uuid) -> Result<Option<Value>> {
self.inner.write().resolve_ref(collection, id)
}
pub fn resolve_deep(&self, value: &Value, max_depth: usize) -> Result<Value> {
self.inner.write().resolve_deep(value, max_depth)
}
pub fn document_count(&self, collection: &str) -> Result<u64> {
self.inner.write().document_count(collection)
}
pub fn flush(&self) -> Result<()> {
self.inner.write().flush()
}
pub fn compact(&self, collection: &str) -> Result<u64> {
self.inner.write().compact(collection)
}
pub fn close(self) -> Result<()> {
match Arc::try_unwrap(self.inner) {
Ok(lock) => lock.into_inner().close(),
Err(_) => Ok(()),
}
}
}
pub struct SharedServer {
server: Arc<RwLock<GrumpyServer>>,
databases: Arc<RwLock<HashMap<String, SharedDatabase>>>,
}
impl Clone for SharedServer {
fn clone(&self) -> Self {
Self {
server: Arc::clone(&self.server),
databases: Arc::clone(&self.databases),
}
}
}
impl SharedServer {
pub fn open(path: &Path) -> Result<Self> {
let server = GrumpyServer::open(path)?;
Ok(Self {
server: Arc::new(RwLock::new(server)),
databases: Arc::new(RwLock::new(HashMap::new())),
})
}
pub fn create_client(&self, name: &str) -> Result<()> {
self.server.write().create_client(name)
}
pub fn drop_client(&self, name: &str) -> Result<()> {
{
let mut dbs = self.databases.write();
let prefix = format!("{name}/");
dbs.retain(|k, _| !k.starts_with(&prefix));
}
self.server.write().drop_client(name)
}
pub fn list_clients(&self) -> Vec<String> {
let server = self.server.read();
server.list_clients().into_iter().map(|s| s.to_string()).collect()
}
pub fn create_database(&self, client: &str, db_name: &str) -> Result<()> {
self.server.write().client(client)?.create_database(db_name)
}
pub fn drop_database(&self, client: &str, db_name: &str) -> Result<()> {
let key = format!("{client}/{db_name}");
self.databases.write().remove(&key);
self.server.write().client(client)?.drop_database(db_name)
}
pub fn list_databases(&self, client: &str) -> Result<Vec<String>> {
let mut server = self.server.write();
let c = server.client(client)?;
Ok(c.list_databases().into_iter().map(|s| s.to_string()).collect())
}
pub fn database(&self, client: &str, db_name: &str) -> Result<SharedDatabase> {
let key = format!("{client}/{db_name}");
{
let dbs = self.databases.read();
if let Some(db) = dbs.get(&key) {
return Ok(db.clone());
}
}
let db_path = {
let server = self.server.read();
server.path().join(client).join(db_name)
};
if !db_path.exists() {
return Err(GrumpyError::DatabaseNotFound(db_name.into()));
}
let shared_db = SharedDatabase::open(&db_path)?;
self.databases.write().insert(key, shared_db.clone());
Ok(shared_db)
}
pub fn close(self) -> Result<()> {
drop(self.databases);
match Arc::try_unwrap(self.server) {
Ok(lock) => lock.into_inner().close(),
Err(_) => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Barrier;
use tempfile::TempDir;
fn setup_shared_db() -> (TempDir, SharedDatabase) {
let dir = TempDir::new().unwrap();
let db = SharedDatabase::open(dir.path().join("testdb").as_path()).unwrap();
(dir, db)
}
#[test]
fn test_shared_database_crud() {
let (_dir, db) = setup_shared_db();
db.create_collection("items").unwrap();
let key = Uuid::from_u128(1);
db.insert("items", key, Value::Integer(42)).unwrap();
assert_eq!(db.get("items", &key).unwrap(), Some(Value::Integer(42)));
db.update("items", &key, Value::Integer(99)).unwrap();
assert_eq!(db.get("items", &key).unwrap(), Some(Value::Integer(99)));
db.delete("items", &key).unwrap();
assert_eq!(db.get("items", &key).unwrap(), None);
}
#[test]
fn test_shared_database_concurrent_reads() {
let (_dir, db) = setup_shared_db();
db.create_collection("nums").unwrap();
for i in 0u128..50 {
db.insert("nums", Uuid::from_u128(i), Value::Integer(i as i64))
.unwrap();
}
let barrier = Arc::new(Barrier::new(8));
let mut handles = Vec::new();
for _ in 0..8 {
let db = db.clone();
let barrier = barrier.clone();
handles.push(std::thread::spawn(move || {
barrier.wait();
for i in 0u128..50 {
let val = db.get("nums", &Uuid::from_u128(i)).unwrap();
assert_eq!(val, Some(Value::Integer(i as i64)));
}
}));
}
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_shared_database_writer_and_readers() {
let (_dir, db) = setup_shared_db();
db.create_collection("data").unwrap();
for i in 0u128..100 {
db.insert("data", Uuid::from_u128(i), Value::Integer(i as i64))
.unwrap();
}
let barrier = Arc::new(Barrier::new(5));
let db_w = db.clone();
let b_w = barrier.clone();
let writer = std::thread::spawn(move || {
b_w.wait();
for i in 100u128..200 {
db_w.insert("data", Uuid::from_u128(i), Value::Integer(i as i64))
.unwrap();
}
});
let mut readers = Vec::new();
for _ in 0..4 {
let db = db.clone();
let b = barrier.clone();
readers.push(std::thread::spawn(move || {
b.wait();
for i in 0u128..100 {
let val = db.get("data", &Uuid::from_u128(i)).unwrap();
assert_eq!(val, Some(Value::Integer(i as i64)));
}
}));
}
writer.join().unwrap();
for r in readers {
r.join().unwrap();
}
assert_eq!(db.document_count("data").unwrap(), 200);
}
#[test]
fn test_shared_database_collections_and_indexes() {
let (_dir, db) = setup_shared_db();
db.create_collection("users").unwrap();
db.create_index("users", "by_age", "age").unwrap();
let key = Uuid::from_u128(1);
let val = Value::Object(std::collections::BTreeMap::from([
("name".into(), Value::String("Alice".into())),
("age".into(), Value::Integer(30)),
]));
db.insert("users", key, val).unwrap();
let results = db.query("users", "by_age", &Value::Integer(30)).unwrap();
assert_eq!(results.len(), 1);
}
fn setup_shared_server() -> (TempDir, SharedServer) {
let dir = TempDir::new().unwrap();
let server = SharedServer::open(dir.path().join("root").as_path()).unwrap();
(dir, server)
}
#[test]
fn test_shared_server_client_management() {
let (_dir, server) = setup_shared_server();
server.create_client("alice").unwrap();
server.create_client("bob").unwrap();
let clients = server.list_clients();
assert_eq!(clients, vec!["alice", "bob"]);
server.drop_client("bob").unwrap();
assert_eq!(server.list_clients(), vec!["alice"]);
}
#[test]
fn test_shared_server_database_access() {
let (_dir, server) = setup_shared_server();
server.create_client("alice").unwrap();
server.create_database("alice", "mydb").unwrap();
let db = server.database("alice", "mydb").unwrap();
db.create_collection("items").unwrap();
db.insert("items", Uuid::from_u128(1), Value::Integer(42))
.unwrap();
assert_eq!(
db.get("items", &Uuid::from_u128(1)).unwrap(),
Some(Value::Integer(42))
);
}
#[test]
fn test_shared_server_concurrent_different_databases() {
let (_dir, server) = setup_shared_server();
server.create_client("alice").unwrap();
for i in 0..4 {
server
.create_database("alice", &format!("db{i}"))
.unwrap();
let db = server.database("alice", &format!("db{i}")).unwrap();
db.create_collection("items").unwrap();
}
let barrier = Arc::new(Barrier::new(4));
let mut handles = Vec::new();
for t in 0..4u128 {
let server = server.clone();
let barrier = barrier.clone();
handles.push(std::thread::spawn(move || {
let db = server.database("alice", &format!("db{t}")).unwrap();
barrier.wait();
for i in 0..50 {
db.insert(
"items",
Uuid::from_u128(t * 1000 + i),
Value::Integer(i as i64),
)
.unwrap();
}
}));
}
for h in handles {
h.join().unwrap();
}
for i in 0..4 {
let db = server.database("alice", &format!("db{i}")).unwrap();
assert_eq!(db.document_count("items").unwrap(), 50);
}
}
#[test]
fn test_shared_server_8_threads_4_databases() {
let (_dir, server) = setup_shared_server();
server.create_client("test").unwrap();
for i in 0..4 {
server
.create_database("test", &format!("db{i}"))
.unwrap();
let db = server.database("test", &format!("db{i}")).unwrap();
db.create_collection("data").unwrap();
}
let barrier = Arc::new(Barrier::new(8));
let mut handles = Vec::new();
for t in 0..8u128 {
let server = server.clone();
let barrier = barrier.clone();
handles.push(std::thread::spawn(move || {
let db_idx = t % 4;
let db = server.database("test", &format!("db{db_idx}")).unwrap();
barrier.wait();
for i in 0..25 {
let key = Uuid::from_u128(t * 1000 + i);
db.insert("data", key, Value::Integer((t * 1000 + i) as i64))
.unwrap();
}
}));
}
for h in handles {
h.join().unwrap();
}
for i in 0..4 {
let db = server.database("test", &format!("db{i}")).unwrap();
assert_eq!(db.document_count("data").unwrap(), 50);
}
}
#[test]
fn test_shared_server_writer_and_readers_per_db() {
let (_dir, server) = setup_shared_server();
server.create_client("c").unwrap();
server.create_database("c", "mydb").unwrap();
let db = server.database("c", "mydb").unwrap();
db.create_collection("nums").unwrap();
for i in 0u128..100 {
db.insert("nums", Uuid::from_u128(i), Value::Integer(i as i64))
.unwrap();
}
let barrier = Arc::new(Barrier::new(5));
let db_w = db.clone();
let b_w = barrier.clone();
let writer = std::thread::spawn(move || {
b_w.wait();
for i in 100u128..200 {
db_w.insert("nums", Uuid::from_u128(i), Value::Integer(i as i64))
.unwrap();
}
});
let mut readers = Vec::new();
for _ in 0..4 {
let db = db.clone();
let b = barrier.clone();
readers.push(std::thread::spawn(move || {
b.wait();
for i in 0u128..100 {
let val = db.get("nums", &Uuid::from_u128(i)).unwrap();
assert_eq!(val, Some(Value::Integer(i as i64)));
}
}));
}
writer.join().unwrap();
for r in readers {
r.join().unwrap();
}
assert_eq!(db.document_count("nums").unwrap(), 200);
}
#[test]
fn test_shared_server_cross_database_independence() {
let (_dir, server) = setup_shared_server();
server.create_client("c").unwrap();
server.create_database("c", "fast").unwrap();
server.create_database("c", "slow").unwrap();
let db_fast = server.database("c", "fast").unwrap();
let db_slow = server.database("c", "slow").unwrap();
db_fast.create_collection("items").unwrap();
db_slow.create_collection("items").unwrap();
let barrier = Arc::new(Barrier::new(2));
let b1 = barrier.clone();
let b2 = barrier.clone();
let h1 = std::thread::spawn(move || {
b1.wait();
for i in 0u128..100 {
db_fast
.insert("items", Uuid::from_u128(i), Value::Integer(i as i64))
.unwrap();
}
});
let h2 = std::thread::spawn(move || {
b2.wait();
for i in 0u128..100 {
db_slow
.insert(
"items",
Uuid::from_u128(1000 + i),
Value::Integer(i as i64),
)
.unwrap();
}
});
h1.join().unwrap();
h2.join().unwrap();
}
}