use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::sleep;
use synwire_core::agents::error::AgentError;
use synwire_core::mcp::traits::{McpConnectionState, McpServerStatus, McpTransport};
struct ManagedServer {
transport: Box<dyn McpTransport>,
enabled: bool,
reconnect_delay: Duration,
}
impl std::fmt::Debug for ManagedServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ManagedServer")
.field("enabled", &self.enabled)
.field("reconnect_delay", &self.reconnect_delay)
.finish_non_exhaustive()
}
}
#[derive(Debug, Default)]
pub struct McpLifecycleManager {
servers: Arc<RwLock<HashMap<String, ManagedServer>>>,
}
impl McpLifecycleManager {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub async fn register(
&self,
name: impl Into<String>,
transport: impl McpTransport + 'static,
reconnect_delay: Duration,
) {
let _ = self.servers.write().await.insert(
name.into(),
ManagedServer {
transport: Box::new(transport),
enabled: true,
reconnect_delay,
},
);
}
pub async fn start_all(&self) -> Result<(), AgentError> {
let names: Vec<String> = self
.servers
.read()
.await
.iter()
.filter(|(_, server)| server.enabled)
.map(|(name, _)| name.clone())
.collect();
for name in names {
let guard = self.servers.read().await;
if let Some(server) = guard.get(&name) {
tracing::info!(%name, "Connecting MCP server");
server.transport.connect().await?;
}
}
Ok(())
}
pub async fn stop_all(&self) -> Result<(), AgentError> {
let names: Vec<String> = self.servers.read().await.keys().cloned().collect();
for name in names {
let guard = self.servers.read().await;
if let Some(server) = guard.get(&name) {
tracing::info!(%name, "Disconnecting MCP server");
let _ = server.transport.disconnect().await;
}
}
Ok(())
}
pub async fn enable(&self, name: &str) -> Result<(), AgentError> {
let guard = self.servers.read().await;
if let Some(server) = guard.get(name)
&& !server.enabled
{
drop(guard);
let _ = self
.servers
.write()
.await
.get_mut(name)
.map(|s| s.enabled = true);
let guard = self.servers.read().await;
if let Some(server) = guard.get(name) {
server.transport.connect().await?;
}
}
Ok(())
}
pub async fn disable(&self, name: &str) -> Result<(), AgentError> {
let found = {
let mut guard = self.servers.write().await;
if let Some(server) = guard.get_mut(name) {
server.enabled = false;
true
} else {
false
}
};
if found {
let guard = self.servers.read().await;
if let Some(server) = guard.get(name) {
server.transport.disconnect().await?;
}
}
Ok(())
}
pub async fn all_status(&self) -> Vec<McpServerStatus> {
let names: Vec<String> = self.servers.read().await.keys().cloned().collect();
let mut statuses = Vec::new();
for name in names {
let guard = self.servers.read().await;
if let Some(server) = guard.get(&name) {
statuses.push(server.transport.status().await);
}
}
statuses
}
#[allow(clippy::significant_drop_tightening)]
pub async fn list_tools(
&self,
server_name: &str,
) -> Result<Vec<synwire_core::mcp::traits::McpToolDescriptor>, AgentError> {
let guard = self.servers.read().await;
let server = guard
.get(server_name)
.ok_or_else(|| AgentError::Vfs(format!("Unknown MCP server: {server_name}")))?;
server.transport.list_tools().await
}
#[allow(clippy::significant_drop_tightening)]
pub async fn call_tool(
&self,
server_name: &str,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<serde_json::Value, AgentError> {
let (enabled, needs_reconnect) = {
let guard = self.servers.read().await;
let server = guard
.get(server_name)
.ok_or_else(|| AgentError::Vfs(format!("Unknown MCP server: {server_name}")))?;
let status = server.transport.status().await;
(
server.enabled,
status.state != McpConnectionState::Connected,
)
};
if !enabled {
return Err(AgentError::Vfs(format!(
"MCP server {server_name} is disabled"
)));
}
if needs_reconnect {
tracing::warn!(%server_name, "MCP server not connected — attempting reconnect");
let guard = self.servers.read().await;
if let Some(server) = guard.get(server_name) {
server.transport.reconnect().await?;
}
}
let guard = self.servers.read().await;
let server = guard
.get(server_name)
.ok_or_else(|| AgentError::Vfs(format!("Unknown MCP server: {server_name}")))?;
server.transport.call_tool(tool_name, arguments).await
}
#[allow(clippy::significant_drop_tightening)]
pub fn spawn_health_monitor(self: Arc<Self>, interval: Duration) {
drop(tokio::spawn(async move {
loop {
sleep(interval).await;
let disconnected: Option<(String, Duration)> = {
let guard = self.servers.read().await;
let mut found = None;
for (name, server) in guard.iter() {
if !server.enabled {
continue;
}
let status = server.transport.status().await;
if status.state == McpConnectionState::Disconnected {
tracing::warn!(%name, "MCP server disconnected — scheduling reconnect");
found = Some((name.clone(), server.reconnect_delay));
break;
}
}
found
};
if let Some((name, delay)) = disconnected {
sleep(delay).await;
let guard = self.servers.read().await;
if let Some(server) = guard.get(&name)
&& let Err(e) = server.transport.reconnect().await
{
tracing::error!(%name, %e, "MCP reconnect failed");
}
}
}
}));
}
}