use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use vibesql_storage::Database;
pub type SharedDatabase = Arc<RwLock<Database>>;
#[derive(Clone)]
pub struct DatabaseRegistry {
databases: Arc<RwLock<HashMap<String, SharedDatabase>>>,
}
impl DatabaseRegistry {
pub fn new() -> Self {
Self { databases: Arc::new(RwLock::new(HashMap::new())) }
}
pub async fn get_or_create(&self, name: &str) -> SharedDatabase {
{
let databases = self.databases.read().await;
if let Some(db) = databases.get(name) {
return Arc::clone(db);
}
}
let mut databases = self.databases.write().await;
if let Some(db) = databases.get(name) {
return Arc::clone(db);
}
let db = Arc::new(RwLock::new(Database::new()));
databases.insert(name.to_string(), Arc::clone(&db));
db
}
#[allow(dead_code)]
pub async fn get(&self, name: &str) -> Option<SharedDatabase> {
let databases = self.databases.read().await;
databases.get(name).cloned()
}
#[allow(dead_code)]
pub async fn list_databases(&self) -> Vec<String> {
let databases = self.databases.read().await;
databases.keys().cloned().collect()
}
#[allow(dead_code)]
pub async fn database_count(&self) -> usize {
let databases = self.databases.read().await;
databases.len()
}
pub async fn register_database(&self, name: &str, db: Database) {
let mut databases = self.databases.write().await;
databases.insert(name.to_string(), Arc::new(RwLock::new(db)));
}
}
impl Default for DatabaseRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_get_or_create_new_database() {
let registry = DatabaseRegistry::new();
let db1 = registry.get_or_create("testdb").await;
assert_eq!(registry.database_count().await, 1);
let db2 = registry.get_or_create("testdb").await;
assert_eq!(registry.database_count().await, 1);
assert!(Arc::ptr_eq(&db1, &db2));
}
#[tokio::test]
async fn test_different_databases() {
let registry = DatabaseRegistry::new();
let db1 = registry.get_or_create("db1").await;
let db2 = registry.get_or_create("db2").await;
assert_eq!(registry.database_count().await, 2);
assert!(!Arc::ptr_eq(&db1, &db2));
}
#[tokio::test]
async fn test_shared_data_across_connections() {
let registry = DatabaseRegistry::new();
let db1 = registry.get_or_create("shared").await;
let db2 = registry.get_or_create("shared").await;
{
let mut db = db1.write().await;
let schema = vibesql_catalog::TableSchema::new(
"users".to_string(),
vec![vibesql_catalog::ColumnSchema::new(
"id".to_string(),
vibesql_types::DataType::Integer,
true,
)],
);
db.create_table(schema).unwrap();
}
{
let db = db2.read().await;
assert!(db.get_table("users").is_some());
}
}
#[tokio::test]
async fn test_list_databases() {
let registry = DatabaseRegistry::new();
registry.get_or_create("alpha").await;
registry.get_or_create("beta").await;
registry.get_or_create("gamma").await;
let names = registry.list_databases().await;
assert_eq!(names.len(), 3);
assert!(names.contains(&"alpha".to_string()));
assert!(names.contains(&"beta".to_string()));
assert!(names.contains(&"gamma".to_string()));
}
}