use std::collections::BTreeMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use crate::mcp::{connect_mcp_server_from_json, VmMcpClientHandle};
use crate::value::VmError;
#[derive(Clone, Debug)]
pub struct RegisteredMcpServer {
pub name: String,
pub spec: serde_json::Value,
pub lazy: bool,
pub card: Option<String>,
pub keep_alive: Option<Duration>,
}
struct ActiveConnection {
handle: VmMcpClientHandle,
ref_count: usize,
last_released_at: Option<Instant>,
}
struct RegistryInner {
servers: BTreeMap<String, RegisteredMcpServer>,
active: BTreeMap<String, ActiveConnection>,
}
impl RegistryInner {
const fn new() -> Self {
Self {
servers: BTreeMap::new(),
active: BTreeMap::new(),
}
}
}
static REGISTRY: Mutex<RegistryInner> = Mutex::new(RegistryInner::new());
pub fn register_servers(servers: Vec<RegisteredMcpServer>) {
let mut guard = REGISTRY.lock().expect("mcp registry poisoned");
for server in servers {
guard.servers.insert(server.name.clone(), server);
}
}
pub fn is_registered(name: &str) -> bool {
REGISTRY
.lock()
.expect("mcp registry poisoned")
.servers
.contains_key(name)
}
pub fn get_registration(name: &str) -> Option<RegisteredMcpServer> {
REGISTRY
.lock()
.expect("mcp registry poisoned")
.servers
.get(name)
.cloned()
}
pub fn reset() {
let mut guard = REGISTRY.lock().expect("mcp registry poisoned");
guard.servers.clear();
guard.active.clear();
}
pub fn install_active(name: &str, handle: VmMcpClientHandle) {
let mut guard = REGISTRY.lock().expect("mcp registry poisoned");
guard.active.insert(
name.to_string(),
ActiveConnection {
handle,
ref_count: usize::MAX / 2,
last_released_at: None,
},
);
}
pub fn active_handle(name: &str) -> Option<VmMcpClientHandle> {
REGISTRY
.lock()
.expect("mcp registry poisoned")
.active
.get(name)
.map(|a| a.handle.clone())
}
pub async fn ensure_active(name: &str) -> Result<VmMcpClientHandle, VmError> {
{
let mut guard = REGISTRY.lock().expect("mcp registry poisoned");
if let Some(active) = guard.active.get_mut(name) {
if active.ref_count != usize::MAX / 2 {
active.ref_count = active.ref_count.saturating_add(1);
}
active.last_released_at = None;
return Ok(active.handle.clone());
}
}
let spec = {
let guard = REGISTRY.lock().expect("mcp registry poisoned");
guard.servers.get(name).cloned()
};
let registration = spec.ok_or_else(|| {
VmError::Runtime(format!(
"mcp: no server named '{name}' is registered (check harn.toml)"
))
})?;
let handle = connect_mcp_server_from_json(®istration.spec).await?;
let mut guard = REGISTRY.lock().expect("mcp registry poisoned");
match guard.active.get_mut(name) {
Some(existing) => {
if existing.ref_count != usize::MAX / 2 {
existing.ref_count = existing.ref_count.saturating_add(1);
}
existing.last_released_at = None;
Ok(existing.handle.clone())
}
None => {
guard.active.insert(
name.to_string(),
ActiveConnection {
handle: handle.clone(),
ref_count: 1,
last_released_at: None,
},
);
Ok(handle)
}
}
}
pub fn release(name: &str) {
let mut guard = REGISTRY.lock().expect("mcp registry poisoned");
let keep_alive = guard
.servers
.get(name)
.and_then(|s| s.keep_alive)
.unwrap_or(Duration::ZERO);
let to_drop = match guard.active.get_mut(name) {
Some(active) => {
if active.ref_count == usize::MAX / 2 {
return;
}
if active.ref_count > 1 {
active.ref_count -= 1;
None
} else {
active.ref_count = 0;
active.last_released_at = Some(Instant::now());
if keep_alive.is_zero() {
Some(active.handle.clone())
} else {
None
}
}
}
None => None,
};
if to_drop.is_some() {
guard.active.remove(name);
}
}
pub fn sweep_expired() {
let mut guard = REGISTRY.lock().expect("mcp registry poisoned");
let now = Instant::now();
let mut expired: Vec<String> = Vec::new();
for (name, active) in guard.active.iter() {
if active.ref_count != 0 {
continue;
}
let Some(last) = active.last_released_at else {
continue;
};
let ka = guard
.servers
.get(name)
.and_then(|s| s.keep_alive)
.unwrap_or(Duration::ZERO);
if now.duration_since(last) >= ka {
expired.push(name.clone());
}
}
for name in expired {
guard.active.remove(&name);
}
}
#[derive(Clone, Debug)]
pub struct RegistryStatus {
pub name: String,
pub lazy: bool,
pub active: bool,
pub ref_count: usize,
pub card: Option<String>,
}
pub fn snapshot_status() -> Vec<RegistryStatus> {
let guard = REGISTRY.lock().expect("mcp registry poisoned");
let mut out = Vec::new();
for (name, server) in guard.servers.iter() {
let active = guard.active.get(name);
out.push(RegistryStatus {
name: name.clone(),
lazy: server.lazy,
active: active.is_some(),
ref_count: active.map(|a| a.ref_count).unwrap_or(0),
card: server.card.clone(),
});
}
out
}
#[cfg(test)]
mod tests {
use super::*;
static TEST_LOCK: Mutex<()> = Mutex::new(());
fn make_spec(name: &str) -> serde_json::Value {
serde_json::json!({
"name": name,
"transport": "stdio",
"command": "/bin/true",
"args": [],
})
}
#[test]
fn register_and_snapshot() {
let _g = TEST_LOCK.lock().unwrap();
reset();
register_servers(vec![RegisteredMcpServer {
name: "x".into(),
spec: make_spec("x"),
lazy: true,
card: Some("card.json".into()),
keep_alive: None,
}]);
let snap = snapshot_status();
assert_eq!(snap.len(), 1);
assert_eq!(snap[0].name, "x");
assert!(snap[0].lazy);
assert!(!snap[0].active);
assert_eq!(snap[0].card.as_deref(), Some("card.json"));
}
#[test]
fn release_on_unknown_is_noop() {
let _g = TEST_LOCK.lock().unwrap();
reset();
release("doesnt-exist");
}
#[test]
fn is_registered_reflects_state() {
let _g = TEST_LOCK.lock().unwrap();
reset();
assert!(!is_registered("a"));
register_servers(vec![RegisteredMcpServer {
name: "a".into(),
spec: make_spec("a"),
lazy: false,
card: None,
keep_alive: None,
}]);
assert!(is_registered("a"));
}
}