use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
use tokio::sync::RwLock;
use crate::connection::client::DatabaseClient;
use crate::connection::config::ConnectionConfig;
use crate::error::{Result, SurqlError};
#[derive(Debug, Clone, Default)]
pub struct ConnectionRegistry {
inner: Arc<RegistryInner>,
}
#[derive(Debug, Default)]
struct RegistryInner {
connections: RwLock<HashMap<String, Arc<DatabaseClient>>>,
configs: RwLock<HashMap<String, ConnectionConfig>>,
default_name: RwLock<Option<String>>,
}
impl ConnectionRegistry {
pub fn new() -> Self {
Self::default()
}
pub async fn register(
&self,
name: impl Into<String>,
config: ConnectionConfig,
connect: bool,
set_default: bool,
) -> Result<Arc<DatabaseClient>> {
let name = name.into();
{
let conns = self.inner.connections.read().await;
if conns.contains_key(&name) {
return Err(SurqlError::Registry {
reason: format!("connection {name:?} already registered"),
});
}
}
let client = Arc::new(DatabaseClient::new(config.clone())?);
if connect {
client.connect().await?;
}
let mut conns = self.inner.connections.write().await;
if conns.contains_key(&name) {
return Err(SurqlError::Registry {
reason: format!("connection {name:?} already registered"),
});
}
conns.insert(name.clone(), client.clone());
let mut configs = self.inner.configs.write().await;
configs.insert(name.clone(), config);
let mut default = self.inner.default_name.write().await;
if set_default || default.is_none() {
*default = Some(name);
}
Ok(client)
}
pub async fn unregister(&self, name: &str, disconnect: bool) -> Result<()> {
let mut conns = self.inner.connections.write().await;
let Some(client) = conns.remove(name) else {
return Err(SurqlError::Registry {
reason: format!("connection {name:?} not found"),
});
};
if disconnect && client.is_connected() {
drop(conns);
let _ = client.disconnect().await;
conns = self.inner.connections.write().await;
}
let mut configs = self.inner.configs.write().await;
configs.remove(name);
let mut default = self.inner.default_name.write().await;
if default.as_deref() == Some(name) {
*default = conns.keys().next().cloned();
}
Ok(())
}
pub async fn get(&self, name: Option<&str>) -> Result<Arc<DatabaseClient>> {
let conns = self.inner.connections.read().await;
let lookup = match name {
Some(n) => n.to_owned(),
None => self
.inner
.default_name
.read()
.await
.clone()
.ok_or_else(|| SurqlError::Registry {
reason: "no default connection set".into(),
})?,
};
conns.get(&lookup).cloned().ok_or(SurqlError::Registry {
reason: format!("connection {lookup:?} not found"),
})
}
pub async fn get_config(&self, name: Option<&str>) -> Result<ConnectionConfig> {
let configs = self.inner.configs.read().await;
let lookup = match name {
Some(n) => n.to_owned(),
None => self
.inner
.default_name
.read()
.await
.clone()
.ok_or_else(|| SurqlError::Registry {
reason: "no default connection set".into(),
})?,
};
configs.get(&lookup).cloned().ok_or(SurqlError::Registry {
reason: format!("connection {lookup:?} not found"),
})
}
pub async fn set_default(&self, name: &str) -> Result<()> {
let conns = self.inner.connections.read().await;
if !conns.contains_key(name) {
return Err(SurqlError::Registry {
reason: format!("connection {name:?} not found"),
});
}
*self.inner.default_name.write().await = Some(name.to_owned());
Ok(())
}
pub async fn list(&self) -> Vec<String> {
self.inner
.connections
.read()
.await
.keys()
.cloned()
.collect()
}
pub async fn default_name(&self) -> Option<String> {
self.inner.default_name.read().await.clone()
}
pub async fn disconnect_all(&self) {
let snapshot: Vec<Arc<DatabaseClient>> = self
.inner
.connections
.read()
.await
.values()
.cloned()
.collect();
for client in snapshot {
if client.is_connected() {
let _ = client.disconnect().await;
}
}
}
pub async fn clear(&self) {
self.disconnect_all().await;
self.inner.connections.write().await.clear();
self.inner.configs.write().await.clear();
*self.inner.default_name.write().await = None;
}
}
pub fn get_registry() -> ConnectionRegistry {
global().clone()
}
pub fn set_registry(registry: ConnectionRegistry) -> Result<()> {
GLOBAL.set(registry).map_err(|_| SurqlError::Registry {
reason: "global registry is already initialised".into(),
})
}
static GLOBAL: OnceLock<ConnectionRegistry> = OnceLock::new();
fn global() -> &'static ConnectionRegistry {
GLOBAL.get_or_init(ConnectionRegistry::new)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config(db: &str) -> ConnectionConfig {
ConnectionConfig::builder()
.url("ws://localhost:8000")
.namespace("test")
.database(db)
.build()
.expect("valid config")
}
#[tokio::test]
async fn register_without_connect_stores_client() {
let r = ConnectionRegistry::new();
let client = r
.register("primary", make_config("a"), false, false)
.await
.expect("register");
assert!(!client.is_connected());
let fetched = r.get(Some("primary")).await.expect("fetch");
assert!(Arc::ptr_eq(&client, &fetched));
let default_fetched = r.get(None).await.expect("default fetch");
assert!(Arc::ptr_eq(&client, &default_fetched));
assert_eq!(r.default_name().await.as_deref(), Some("primary"));
}
#[tokio::test]
async fn duplicate_register_rejects() {
let r = ConnectionRegistry::new();
r.register("primary", make_config("a"), false, false)
.await
.expect("first");
let err = r
.register("primary", make_config("a"), false, false)
.await
.unwrap_err();
assert!(matches!(err, SurqlError::Registry { .. }));
}
#[tokio::test]
async fn unregister_rotates_default() {
let r = ConnectionRegistry::new();
r.register("a", make_config("a"), false, false)
.await
.unwrap();
r.register("b", make_config("b"), false, false)
.await
.unwrap();
assert_eq!(r.default_name().await.as_deref(), Some("a"));
r.unregister("a", false).await.unwrap();
assert_eq!(r.default_name().await.as_deref(), Some("b"));
r.unregister("b", false).await.unwrap();
assert!(r.default_name().await.is_none());
let err = r.get(None).await.unwrap_err();
assert!(matches!(err, SurqlError::Registry { .. }));
}
#[tokio::test]
async fn unregister_missing_errors() {
let r = ConnectionRegistry::new();
let err = r.unregister("ghost", false).await.unwrap_err();
assert!(matches!(err, SurqlError::Registry { .. }));
}
#[tokio::test]
async fn set_default_requires_known_name() {
let r = ConnectionRegistry::new();
let err = r.set_default("ghost").await.unwrap_err();
assert!(matches!(err, SurqlError::Registry { .. }));
}
#[tokio::test]
async fn clear_empties_state() {
let r = ConnectionRegistry::new();
r.register("a", make_config("a"), false, false)
.await
.unwrap();
r.register("b", make_config("b"), false, false)
.await
.unwrap();
r.clear().await;
assert!(r.list().await.is_empty());
assert!(r.default_name().await.is_none());
}
#[tokio::test]
async fn list_returns_every_registered_name() {
let r = ConnectionRegistry::new();
r.register("a", make_config("a"), false, false)
.await
.unwrap();
r.register("b", make_config("b"), false, false)
.await
.unwrap();
let mut names = r.list().await;
names.sort();
assert_eq!(names, vec!["a".to_owned(), "b".to_owned()]);
}
#[tokio::test]
async fn set_default_promotes_named_connection() {
let r = ConnectionRegistry::new();
r.register("a", make_config("a"), false, false)
.await
.unwrap();
r.register("b", make_config("b"), false, false)
.await
.unwrap();
r.set_default("b").await.unwrap();
assert_eq!(r.default_name().await.as_deref(), Some("b"));
}
}