use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServiceLocation {
Local,
Remote,
}
#[derive(Default, Debug)]
pub struct ServiceLocationMap {
entries: RwLock<HashMap<String, ServiceLocation>>,
}
impl ServiceLocationMap {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn mark_local(&self, plugin_id: &str) {
let mut guard = self
.entries
.write()
.expect("service location map lock poisoned");
guard.insert(plugin_id.to_string(), ServiceLocation::Local);
}
pub fn mark_remote(&self, plugin_id: &str) {
let mut guard = self
.entries
.write()
.expect("service location map lock poisoned");
guard
.entry(plugin_id.to_string())
.or_insert(ServiceLocation::Remote);
}
#[must_use]
pub fn get(&self, plugin_id: &str) -> Option<ServiceLocation> {
let guard = self
.entries
.read()
.expect("service location map lock poisoned");
guard.get(plugin_id).copied()
}
pub fn clear(&self) {
let mut guard = self
.entries
.write()
.expect("service location map lock poisoned");
guard.clear();
}
#[must_use]
pub fn len(&self) -> usize {
self.entries
.read()
.expect("service location map lock poisoned")
.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[must_use]
pub fn global_service_locations() -> Arc<ServiceLocationMap> {
static GLOBAL: OnceLock<Arc<ServiceLocationMap>> = OnceLock::new();
GLOBAL
.get_or_init(|| Arc::new(ServiceLocationMap::new()))
.clone()
}
#[cfg(test)]
mod tests {
use super::{ServiceLocation, ServiceLocationMap, global_service_locations};
use std::sync::Arc;
#[test]
fn mark_local_records_local_location() {
let map = ServiceLocationMap::new();
map.mark_local("bmux.contexts");
assert_eq!(map.get("bmux.contexts"), Some(ServiceLocation::Local));
}
#[test]
fn mark_remote_records_remote_location_when_absent() {
let map = ServiceLocationMap::new();
map.mark_remote("bmux.contexts");
assert_eq!(map.get("bmux.contexts"), Some(ServiceLocation::Remote));
}
#[test]
fn mark_remote_does_not_override_local() {
let map = ServiceLocationMap::new();
map.mark_local("bmux.contexts");
map.mark_remote("bmux.contexts");
assert_eq!(map.get("bmux.contexts"), Some(ServiceLocation::Local));
}
#[test]
fn mark_local_overrides_remote() {
let map = ServiceLocationMap::new();
map.mark_remote("bmux.contexts");
map.mark_local("bmux.contexts");
assert_eq!(map.get("bmux.contexts"), Some(ServiceLocation::Local));
}
#[test]
fn get_missing_plugin_returns_none() {
let map = ServiceLocationMap::new();
assert_eq!(map.get("bmux.unknown"), None);
}
#[test]
fn concurrent_mark_and_get_is_safe() {
use std::thread;
let map = Arc::new(ServiceLocationMap::new());
let mut handles = Vec::new();
for i in 0..8 {
let map = Arc::clone(&map);
handles.push(thread::spawn(move || {
for j in 0..500 {
let id = format!("plugin-{}-{}", i, j % 16);
if j % 2 == 0 {
map.mark_local(&id);
} else {
map.mark_remote(&id);
}
let _ = map.get(&id);
}
}));
}
for h in handles {
h.join().unwrap();
}
assert!(!map.is_empty());
}
#[test]
fn clear_resets_all_entries() {
let map = ServiceLocationMap::new();
map.mark_local("a");
map.mark_remote("b");
assert_eq!(map.len(), 2);
map.clear();
assert!(map.is_empty());
}
#[test]
fn global_service_locations_returns_same_instance() {
let a = global_service_locations();
let b = global_service_locations();
assert!(Arc::ptr_eq(&a, &b));
}
}