use std::any::Any;
use std::collections::HashMap;
use std::sync::{OnceLock, RwLock};
static CONNECTIONS: OnceLock<RwLock<HashMap<String, Box<dyn Any + Send + Sync>>>> = OnceLock::new();
fn connections() -> &'static RwLock<HashMap<String, Box<dyn Any + Send + Sync>>> {
CONNECTIONS.get_or_init(|| RwLock::new(HashMap::new()))
}
pub fn register_connection(name: impl Into<String>, client: impl Any + Send + Sync + 'static) {
connections()
.write()
.expect("connections lock poisoned")
.insert(name.into(), Box::new(client));
}
pub fn get_connection(name: &str) -> Result<Box<dyn Any + Send + Sync>, String> {
let map = connections().read().expect("connections lock poisoned");
if map.contains_key(name) {
drop(map);
Err(format!(
"Use get_connection_ref() for zero-copy access to connection '{name}'"
))
} else {
Err(format!(
"Connection \"{name}\" is not registered. Use register_connection() first."
))
}
}
pub fn has_connection(name: &str) -> bool {
connections()
.read()
.expect("connections lock poisoned")
.contains_key(name)
}
pub fn with_connection<T: Any + Send + Sync, R>(
name: &str,
f: impl FnOnce(&T) -> R,
) -> Result<R, String> {
let map = connections().read().expect("connections lock poisoned");
let boxed = map.get(name).ok_or_else(|| {
format!("Connection \"{name}\" is not registered. Use register_connection() first.")
})?;
let typed = boxed
.downcast_ref::<T>()
.ok_or_else(|| format!("Connection \"{name}\" exists but is not the expected type"))?;
Ok(f(typed))
}
pub fn clear_connections() {
if let Some(m) = CONNECTIONS.get() {
m.write().expect("connections lock poisoned").clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[derive(Debug)]
struct FakeClient {
endpoint: String,
}
#[test]
#[serial]
fn test_register_and_check() {
clear_connections();
assert!(!has_connection("test"));
register_connection(
"test",
FakeClient {
endpoint: "https://example.com".into(),
},
);
assert!(has_connection("test"));
}
#[test]
#[serial]
fn test_with_connection_success() {
clear_connections();
register_connection(
"my-client",
FakeClient {
endpoint: "https://api.example.com".into(),
},
);
let endpoint = with_connection::<FakeClient, _>("my-client", |c| c.endpoint.clone());
assert_eq!(endpoint.unwrap(), "https://api.example.com");
}
#[test]
#[serial]
fn test_with_connection_missing() {
clear_connections();
let result = with_connection::<FakeClient, _>("missing", |_| ());
assert!(result.is_err());
assert!(result.unwrap_err().contains("not registered"));
}
#[test]
#[serial]
fn test_with_connection_wrong_type() {
clear_connections();
register_connection("typed", 42_u32);
let result = with_connection::<FakeClient, _>("typed", |_| ());
assert!(result.is_err());
assert!(result.unwrap_err().contains("not the expected type"));
}
#[test]
#[serial]
fn test_clear_connections() {
register_connection("temp", 100_u32);
assert!(has_connection("temp"));
clear_connections();
assert!(!has_connection("temp"));
}
#[test]
#[serial]
fn test_overwrite_connection() {
clear_connections();
register_connection(
"overwrite",
FakeClient {
endpoint: "old".into(),
},
);
register_connection(
"overwrite",
FakeClient {
endpoint: "new".into(),
},
);
let endpoint = with_connection::<FakeClient, _>("overwrite", |c| c.endpoint.clone());
assert_eq!(endpoint.unwrap(), "new");
}
}