use super::{McpFunction, McpToolCall, McpToolResult};
use crate::config::{McpConnectionType, McpServerConfig};
use anyhow::Result;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::process::{Child, Command, Stdio};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant, SystemTime};
use tokio::time::sleep;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ServerHealth {
Running, Dead, Restarting, Failed, Unreachable, }
#[derive(Debug, Clone)]
pub struct ServerRestartInfo {
pub restart_count: u32,
pub last_restart_time: Option<SystemTime>,
pub health_status: ServerHealth,
pub consecutive_failures: u32,
pub last_health_check: Option<SystemTime>,
}
impl Default for ServerRestartInfo {
fn default() -> Self {
Self {
restart_count: 0,
last_restart_time: None,
health_status: ServerHealth::Running,
consecutive_failures: 0,
last_health_check: None,
}
}
}
lazy_static::lazy_static! {
pub static ref SERVER_RESTART_INFO: Arc<RwLock<HashMap<String, ServerRestartInfo>>> =
Arc::new(RwLock::new(HashMap::new()));
static ref SERVER_RESTART_MUTEXES: Arc<RwLock<HashMap<String, Arc<tokio::sync::Mutex<()>>>>> =
Arc::new(RwLock::new(HashMap::new()));
}
type InFlightHandle =
Arc<std::sync::Mutex<Option<tokio::task::JoinHandle<anyhow::Result<serde_json::Value>>>>>;
type StderrBuffer = Arc<std::sync::Mutex<Vec<String>>>;
lazy_static::lazy_static! {
pub static ref SERVER_PROCESSES: Arc<RwLock<HashMap<String, Arc<Mutex<ServerProcess>>>>> =
Arc::new(RwLock::new(HashMap::new()));
static ref SERVER_IN_FLIGHT: Arc<RwLock<HashMap<String, InFlightHandle>>> =
Arc::new(RwLock::new(HashMap::new()));
static ref SERVER_REF_COUNTS: Arc<RwLock<HashMap<String, usize>>> =
Arc::new(RwLock::new(HashMap::new()));
static ref SERVER_STDERR: Arc<RwLock<HashMap<String, StderrBuffer>>> =
Arc::new(RwLock::new(HashMap::new()));
static ref SERVER_CAPABILITIES: Arc<RwLock<HashMap<String, rmcp::model::InitializeResult>>> =
Arc::new(RwLock::new(HashMap::new()));
}
#[cfg(unix)]
lazy_static::lazy_static! {
static ref SERVER_PGIDS: Arc<RwLock<HashMap<String, libc::pid_t>>> =
Arc::new(RwLock::new(HashMap::new()));
}
lazy_static::lazy_static! {
static ref CLI_NOTIFICATION_SENDER: RwLock<Option<tokio::sync::mpsc::UnboundedSender<crate::websocket::ServerMessage>>> =
RwLock::new(None);
static ref CLI_PENDING_NOTIFICATIONS: RwLock<Vec<crate::websocket::ServerMessage>> =
RwLock::new(Vec::new());
}
lazy_static::lazy_static! {
static ref CLI_SESSION_CONTEXT: RwLock<(String, String, String)> = RwLock::new((String::new(), String::new(), String::new()));
}
pub fn derive_project_id() -> String {
use sha2::{Digest, Sha256};
let source = std::process::Command::new("git")
.args(["remote", "get-url", "origin"])
.output()
.ok()
.filter(|o| o.status.success())
.and_then(|o| String::from_utf8(o.stdout).ok())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.unwrap_or_else(|| {
std::env::current_dir()
.unwrap_or_default()
.to_string_lossy()
.into_owned()
});
let hash = Sha256::digest(source.as_bytes());
hex::encode(hash)[..16].to_string()
}
pub fn derive_project_id_from_path(path: &std::path::Path) -> String {
use sha2::{Digest, Sha256};
let source = std::process::Command::new("git")
.args(["remote", "get-url", "origin"])
.current_dir(path)
.output()
.ok()
.filter(|o| o.status.success())
.and_then(|o| String::from_utf8(o.stdout).ok())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.unwrap_or_else(|| path.to_string_lossy().into_owned());
let hash = Sha256::digest(source.as_bytes());
hex::encode(hash)[..16].to_string()
}
pub fn set_session_context(role: &str, project: &str, workdir: &str) {
if let Some(_session_id) = crate::session::context::current_session_id() {
}
*CLI_SESSION_CONTEXT.write().unwrap() =
(role.to_string(), project.to_string(), workdir.to_string());
}
pub fn get_session_context() -> (String, String, String, String, String) {
let (full_role, project, workdir) = {
if let Some(session_id) = crate::session::context::current_session_id() {
if let Some(role) = crate::session::context::get_session_role(&session_id) {
let project = crate::session::context::get_session_workdir_anchor(&session_id)
.map(|p| crate::mcp::process::derive_project_id_from_path(&p))
.unwrap_or_default();
let workdir = crate::session::context::get_session_workdir_anchor(&session_id)
.map(|p| p.to_string_lossy().into_owned())
.unwrap_or_default();
(role, project, workdir)
} else {
CLI_SESSION_CONTEXT.read().unwrap().clone()
}
} else {
CLI_SESSION_CONTEXT.read().unwrap().clone()
}
};
let session_id = crate::session::context::current_session_id().unwrap_or_default();
let (domain, spec) = match full_role.split_once(':') {
Some((d, s)) => (d.to_string(), s.to_string()),
None => (full_role, String::new()),
};
(domain, spec, project, session_id, workdir)
}
pub fn init_session_context(role: &str) {
let project = derive_project_id();
let workdir = std::env::current_dir()
.map(|p| p.to_string_lossy().into_owned())
.unwrap_or_default();
set_session_context(role, &project, &workdir);
}
pub fn set_notification_sender(
session_id: Option<String>,
tx: tokio::sync::mpsc::UnboundedSender<crate::websocket::ServerMessage>,
) {
match session_id {
Some(sid) => {
crate::session::context::register_notification_sender(sid, tx);
}
None => {
let pending = {
let mut guard = CLI_PENDING_NOTIFICATIONS.write().unwrap();
std::mem::take(&mut *guard)
};
for msg in pending {
let _ = tx.send(msg);
}
let mut guard = CLI_NOTIFICATION_SENDER.write().unwrap();
*guard = Some(tx);
}
}
}
pub fn clear_notification_sender(session_id: Option<String>) {
match session_id {
Some(sid) => {
crate::session::context::unregister_notification_sender(sid);
}
None => {
let mut guard = CLI_NOTIFICATION_SENDER.write().unwrap();
*guard = None;
}
}
}
pub fn send_notification_message(msg: crate::websocket::ServerMessage) {
if let Some(session_id) = crate::session::context::current_session_id() {
if let Some(sender) = crate::session::context::get_notification_sender_by_id(&session_id) {
let _ = sender.send(msg);
return;
}
}
let sender = CLI_NOTIFICATION_SENDER.read().unwrap();
if let Some(tx) = sender.as_ref() {
let _ = tx.send(msg);
}
}
fn emit_notification(
server_name: &str,
method: &str,
params: &serde_json::Value,
session_id: Option<&str>,
) {
let msg = crate::websocket::ServerMessage::McpNotification(
crate::websocket::McpNotificationPayload {
server: server_name.to_string(),
method: method.to_string(),
params: params.clone(),
},
);
let effective_session_id = session_id
.map(|s| s.to_string())
.or_else(crate::session::context::current_session_id);
if let Some(sid) = effective_session_id {
if let Some(sender) = crate::session::context::get_notification_sender_by_id(&sid) {
let _ = sender.send(msg);
return;
}
}
let sender = CLI_NOTIFICATION_SENDER.read().unwrap();
if let Some(tx) = sender.as_ref() {
let _ = tx.send(msg);
} else {
drop(sender); CLI_PENDING_NOTIFICATIONS.write().unwrap().push(msg);
}
}
pub enum ServerProcess {
Http(Child),
Stdin {
child: Child,
reader: BufReader<std::process::ChildStdout>,
writer: BufWriter<std::process::ChildStdin>,
next_id: Arc<AtomicU64>, is_shutdown: Arc<AtomicBool>, },
}
impl ServerProcess {
pub fn kill(&mut self) -> Result<()> {
match self {
ServerProcess::Http(child) => {
child
.kill()
.map_err(|e| anyhow::anyhow!("Failed to kill HTTP process: {}", e))?;
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_secs(5);
while start.elapsed() < timeout {
match child.try_wait() {
Ok(Some(_)) => return Ok(()), Ok(None) => std::thread::sleep(std::time::Duration::from_millis(100)),
Err(e) => {
return Err(anyhow::anyhow!("Error waiting for HTTP process: {}", e))
}
}
}
crate::log_debug!("HTTP process did not terminate within timeout, may be zombie");
Ok(())
}
ServerProcess::Stdin {
child,
is_shutdown,
writer,
..
} => {
is_shutdown.store(true, Ordering::SeqCst);
if let Err(e) = writer.flush() {
crate::log_debug!("Failed to flush stdin before shutdown: {}", e);
}
std::thread::sleep(std::time::Duration::from_millis(100));
match child.try_wait() {
Ok(Some(_)) => {
crate::log_debug!("Process terminated gracefully after stdin close");
return Ok(());
}
Ok(None) => {
crate::log_debug!(
"Process didn't terminate after stdin close, sending SIGTERM"
);
}
Err(e) => {
crate::log_debug!("Error checking process status: {}", e);
}
}
#[cfg(unix)]
{
let pid = child.id();
let pgid = pid as libc::pid_t;
unsafe {
libc::kill(-pgid, libc::SIGTERM);
}
crate::log_debug!(
"Sent SIGTERM to process group {} for graceful shutdown",
pgid
);
std::thread::sleep(std::time::Duration::from_millis(200));
match child.try_wait() {
Ok(Some(_)) => {
crate::log_debug!("Process terminated after SIGTERM");
return Ok(());
}
_ => {
crate::log_debug!("Process still alive after SIGTERM, sending SIGKILL");
}
}
unsafe {
libc::kill(-pgid, libc::SIGKILL);
}
}
#[cfg(not(unix))]
{
let _ = child.kill();
}
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_secs(5);
while start.elapsed() < timeout {
match child.try_wait() {
Ok(Some(_)) => return Ok(()), Ok(None) => std::thread::sleep(std::time::Duration::from_millis(100)),
Err(e) => {
return Err(anyhow::anyhow!("Error waiting for stdin process: {}", e))
}
}
}
crate::log_debug!("Stdin process did not terminate within timeout, may be zombie");
Ok(())
}
}
}
pub fn try_wait(&mut self) -> Result<Option<std::process::ExitStatus>> {
match self {
ServerProcess::Http(child) => child
.try_wait()
.map_err(|e| anyhow::anyhow!("Failed to check HTTP process: {}", e)),
ServerProcess::Stdin { child, .. } => child
.try_wait()
.map_err(|e| anyhow::anyhow!("Failed to check stdin process: {}", e)),
}
}
}
fn get_server_restart_mutex(server_id: &str) -> Arc<tokio::sync::Mutex<()>> {
let mutexes = SERVER_RESTART_MUTEXES.read().unwrap();
if let Some(mutex) = mutexes.get(server_id) {
return mutex.clone();
}
drop(mutexes);
let mut mutexes = SERVER_RESTART_MUTEXES.write().unwrap();
if let Some(mutex) = mutexes.get(server_id) {
return mutex.clone();
}
let new_mutex = Arc::new(tokio::sync::Mutex::new(()));
mutexes.insert(server_id.to_string(), new_mutex.clone());
new_mutex
}
fn cleanup_server_restart_mutex(server_id: &str) {
let mut mutexes = SERVER_RESTART_MUTEXES.write().unwrap();
mutexes.remove(server_id);
}
pub async fn ensure_server_running(server: &McpServerConfig) -> Result<String> {
let server_id = server.name();
let restart_mutex = get_server_restart_mutex(server_id);
let _guard = restart_mutex.lock().await;
crate::log_debug!("Checking server '{}' status for potential start", server_id);
let result = start_server_once_if_needed(server).await;
if result.is_ok() {
let mut counts = SERVER_REF_COUNTS.write().unwrap();
*counts.entry(server_id.to_string()).or_insert(0) += 1;
crate::log_debug!("Server '{}' ref count: {}", server_id, counts[server_id]);
}
crate::log_debug!("Completed server '{}' check", server_id);
result
}
async fn start_server_once_if_needed(server: &McpServerConfig) -> Result<String> {
let server_id = server.name();
{
let processes = SERVER_PROCESSES.read().unwrap();
if let Some(process_arc) = processes.get(server_id) {
match process_arc.try_lock() {
Ok(mut process) => {
let is_alive = match &mut *process {
ServerProcess::Http(child) => child
.try_wait()
.map(|status| status.is_none())
.unwrap_or(false),
ServerProcess::Stdin {
child, is_shutdown, ..
} => {
let process_alive = child
.try_wait()
.map(|status| status.is_none())
.unwrap_or(false);
let not_marked_shutdown = !is_shutdown.load(Ordering::SeqCst);
process_alive && not_marked_shutdown
}
};
if is_alive {
{
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
let info = restart_info_guard.entry(server_id.to_string()).or_default();
info.health_status = ServerHealth::Running;
info.last_health_check = Some(SystemTime::now());
}
crate::log_debug!("Server '{}' is already running and healthy", server_id);
match server.connection_type() {
McpConnectionType::Http => return get_server_url(server),
McpConnectionType::Stdin => {
return Ok("stdin://".to_string() + server_id)
}
McpConnectionType::Builtin => {
unreachable!("Builtin servers should not use this function")
}
}
} else {
crate::log_info!(
"Server '{}' process is dead - cleaning up before restart",
server_id
);
if let Err(e) = process.kill() {
crate::log_debug!(
"Failed to kill dead server process '{}': {}",
server_id,
e
);
}
{
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
let info = restart_info_guard.entry(server_id.to_string()).or_default();
info.health_status = ServerHealth::Dead;
}
}
}
Err(_) => {
{
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
let info = restart_info_guard.entry(server_id.to_string()).or_default();
info.health_status = ServerHealth::Running;
info.last_health_check = Some(SystemTime::now());
}
crate::log_debug!(
"Server '{}' is busy (in-flight request) — treating as healthy",
server_id
);
match server.connection_type() {
McpConnectionType::Http => return get_server_url(server),
McpConnectionType::Stdin => return Ok("stdin://".to_string() + server_id),
McpConnectionType::Builtin => {
unreachable!("Builtin servers should not use this function")
}
}
}
}
} else {
crate::log_debug!(
"Server '{}' not found in registry - needs initial start",
server_id
);
}
}
{
let mut processes = SERVER_PROCESSES.write().unwrap();
processes.remove(server_id);
}
{
let mut in_flight_map = SERVER_IN_FLIGHT.write().unwrap();
in_flight_map.remove(server_id);
}
crate::log_info!("Starting MCP server: {}", server_id);
match start_server_process(server).await {
Ok(url) => {
{
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
let info = restart_info_guard.entry(server_id.to_string()).or_default();
info.health_status = ServerHealth::Running;
info.restart_count += 1; info.last_restart_time = Some(SystemTime::now());
info.last_health_check = Some(SystemTime::now());
info.consecutive_failures = 0;
}
crate::log_info!("Successfully started server '{}'", server_id);
Ok(url)
}
Err(e) => {
{
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
let info = restart_info_guard.entry(server_id.to_string()).or_default();
info.health_status = ServerHealth::Failed;
info.consecutive_failures += 1;
}
crate::log_error!("Failed to start server '{}': {}", server_id, e);
Err(anyhow::anyhow!(
"Failed to start server '{}': {}",
server_id,
e
))
}
}
}
async fn start_server_process(server: &McpServerConfig) -> Result<String> {
let (command, args) = match server {
McpServerConfig::Stdin { command, args, .. } => (command.as_str(), args.as_slice()),
McpServerConfig::Http { url, .. } => {
return Err(anyhow::anyhow!(
"HTTP server '{}' should not be started as a process (URL: {}) - use Stdin type for local processes",
server.name(),
url
));
}
McpServerConfig::Builtin { .. } => {
return Err(anyhow::anyhow!(
"Builtin server '{}' should not be started as external process",
server.name()
));
}
};
let mut cmd = Command::new(command);
if !args.is_empty() {
cmd.args(args);
}
#[cfg(unix)]
{
use std::os::unix::process::CommandExt;
cmd.process_group(0); }
#[cfg(windows)]
{
use std::os::windows::process::CommandExt;
cmd.creation_flags(0x00000200); }
match server.connection_type() {
McpConnectionType::Http => {
cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
crate::log_debug!(
"🚀 Starting MCP server (HTTP mode, signal-isolated): {}",
server.name()
);
let child = cmd.spawn().map_err(|e| {
anyhow::anyhow!("Failed to start MCP server '{}': {}", server.name(), e)
})?;
#[cfg(unix)]
{
let mut pgids = SERVER_PGIDS.write().unwrap();
pgids.insert(server.name().to_string(), child.id() as libc::pid_t);
}
{
let mut processes = SERVER_PROCESSES.write().unwrap();
processes.insert(
server.name().to_string(),
Arc::new(Mutex::new(ServerProcess::Http(child))),
);
}
crate::mcp::server::clear_function_cache_for_server(server.name());
let start_time = Instant::now();
let max_wait = Duration::from_secs(10);
let server_url = get_server_url(server)?;
loop {
if start_time.elapsed() > max_wait {
return Err(anyhow::anyhow!(
"Timed out waiting for MCP server to start: {}",
server.name()
));
}
if can_connect(&server_url).await {
crate::log_debug!("✅ MCP server started: {} at {}", server.name(), server_url);
return Ok(server_url);
}
sleep(Duration::from_millis(500)).await;
}
}
McpConnectionType::Stdin => {
cmd.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
crate::log_debug!(
"🚀 Starting MCP server (stdin mode, signal-isolated): {}",
server.name()
);
let mut child = cmd.spawn().map_err(|e| {
anyhow::anyhow!("Failed to start MCP server '{}': {}", server.name(), e)
})?;
#[cfg(unix)]
{
let mut pgids = SERVER_PGIDS.write().unwrap();
pgids.insert(server.name().to_string(), child.id() as libc::pid_t);
}
let child_stdin = child.stdin.take().ok_or_else(|| {
anyhow::anyhow!("Failed to open stdin for MCP server: {}", server.name())
})?;
let child_stdout = child.stdout.take().ok_or_else(|| {
anyhow::anyhow!("Failed to open stdout for MCP server: {}", server.name())
})?;
let stderr_buf: Arc<std::sync::Mutex<Vec<String>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
{
let mut map = SERVER_STDERR.write().unwrap();
map.insert(server.name().to_string(), stderr_buf.clone());
}
if let Some(child_stderr) = child.stderr.take() {
let buf = stderr_buf;
let sname = server.name().to_string();
std::thread::spawn(move || {
let reader = BufReader::new(child_stderr);
for line in reader.lines() {
match line {
Ok(l) => {
let trimmed = l.trim().to_string();
if !trimmed.is_empty() {
crate::log_debug!("MCP '{}' stderr: {}", sname, trimmed);
if let Ok(mut b) = buf.lock() {
b.push(trimmed);
if b.len() > 50 {
let drain_count = b.len() - 50;
b.drain(..drain_count);
}
}
}
}
Err(_) => break,
}
}
});
}
let writer = BufWriter::new(child_stdin);
let reader = BufReader::new(child_stdout);
let server_process = ServerProcess::Stdin {
child,
reader,
writer,
next_id: Arc::new(AtomicU64::new(1)),
is_shutdown: Arc::new(AtomicBool::new(false)),
};
{
let mut in_flight_map = SERVER_IN_FLIGHT.write().unwrap();
in_flight_map.insert(
server.name().to_string(),
Arc::new(std::sync::Mutex::new(None)),
);
}
{
let mut processes = SERVER_PROCESSES.write().unwrap();
processes.insert(
server.name().to_string(),
Arc::new(Mutex::new(server_process)),
);
}
crate::mcp::server::clear_function_cache_for_server(server.name());
let _process_arc = {
let processes = SERVER_PROCESSES.read().unwrap();
processes.get(server.name()).cloned().ok_or_else(|| {
anyhow::anyhow!("Server not found right after creation: {}", server.name())
})?
};
let init_result = initialize_stdin_server(server.name()).await;
if let Err(e) = &init_result {
let stderr_lines = {
let map = SERVER_STDERR.read().unwrap();
map.get(server.name())
.and_then(|buf| buf.lock().ok().map(|b| b.clone()))
.unwrap_or_default()
};
let stderr_detail = if stderr_lines.is_empty() {
String::new()
} else {
format!("\nServer stderr:\n {}", stderr_lines.join("\n "))
};
crate::log_error!(
"Failed to initialize stdin MCP server '{}': {}{}",
server.name(),
e,
stderr_detail
);
if let Err(cleanup_err) = cleanup_server_process(server.name()) {
crate::log_debug!(
"Failed to cleanup server '{}' after init failure: {}",
server.name(),
cleanup_err
);
}
return Err(anyhow::anyhow!(
"Failed to initialize stdin MCP server '{}': {}{}",
server.name(),
e,
stderr_detail
));
}
let stdin_url = format!("stdin://{}", server.name());
Ok(stdin_url)
}
McpConnectionType::Builtin => Err(anyhow::anyhow!(
"Builtin servers should not use process management"
)),
}
}
async fn initialize_stdin_server(server_name: &str) -> Result<()> {
let (role, spec, project, session_id, workdir) = get_session_context();
let session_obj = serde_json::json!({
"role": role,
"spec": spec,
"project": project,
"session_id": session_id,
"workdir": workdir,
});
let init_message = json!({
"jsonrpc": "2.0",
"id": 1, "method": "initialize",
"params": {
"clientInfo": {
"name": "octomind",
"version": env!("CARGO_PKG_VERSION")
},
"protocolVersion": "2025-03-26",
"capabilities": {
"experimental": {
"session": session_obj
}
}
}
});
let response = communicate_with_stdin_server(server_name, &init_message, 1, None).await?;
if let Some(error) = response.get("error") {
return Err(anyhow::anyhow!(
"Server returned error during initialization: {}",
error
));
}
if let Some(result_value) = response.get("result").cloned() {
match serde_json::from_value::<rmcp::model::InitializeResult>(result_value) {
Ok(init_info) => {
crate::log_debug!(
"Stdin server '{}': {} v{}, protocol {}",
server_name,
init_info.server_info.name,
init_info.server_info.version,
init_info.protocol_version
);
if let Some(ref instructions) = init_info.instructions {
crate::log_debug!("Server '{}' instructions: {}", server_name, instructions);
}
store_server_capabilities(server_name, init_info);
}
Err(e) => {
crate::log_debug!(
"Failed to parse InitializeResult for '{}': {}",
server_name,
e
);
}
}
} else {
return Err(anyhow::anyhow!(
"Server did not return a valid result during initialization"
));
}
let initialized_message = json!({
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {}
});
if let Err(e) = send_stdin_notification(server_name, &initialized_message).await {
crate::log_error!(
"Warning: Error sending initialized notification to MCP server: {}",
e
);
}
Ok(())
}
pub fn store_server_capabilities(server_name: &str, init_result: rmcp::model::InitializeResult) {
let mut caps = SERVER_CAPABILITIES.write().unwrap();
caps.insert(server_name.to_string(), init_result);
}
pub fn get_server_capabilities(server_name: &str) -> Option<rmcp::model::InitializeResult> {
let caps = SERVER_CAPABILITIES.read().unwrap();
caps.get(server_name).cloned()
}
pub fn get_server_instructions(server_name: &str) -> Option<String> {
let caps = SERVER_CAPABILITIES.read().unwrap();
caps.get(server_name).and_then(|c| c.instructions.clone())
}
async fn can_connect(url: &str) -> bool {
if url.starts_with("stdin://") {
return true;
}
match reqwest::Client::new().get(url).send().await {
Ok(response) => response.status().is_success(),
Err(_) => false,
}
}
fn get_server_url(server: &McpServerConfig) -> Result<String> {
if let Some(url) = server.url() {
return Ok(url.to_string());
}
if let McpConnectionType::Stdin = server.connection_type() {
return Ok(format!("stdin://{}", server.name()));
}
Ok("http://localhost:8008".to_string())
}
pub async fn communicate_with_stdin_server(
server_name: &str,
message: &Value,
override_id: u64,
cancellation_token: Option<tokio::sync::watch::Receiver<bool>>,
) -> Result<Value> {
communicate_with_stdin_server_extended_timeout(
server_name,
message,
override_id,
15,
cancellation_token,
)
.await
}
pub async fn communicate_with_stdin_server_extended_timeout(
server_name: &str,
message: &Value,
override_id: u64,
timeout_seconds: u64,
cancellation_token: Option<tokio::sync::watch::Receiver<bool>>,
) -> Result<Value> {
if let Some(ref token) = cancellation_token {
if *token.borrow() {
return Err(anyhow::anyhow!("Operation cancelled before communication"));
}
}
let server_process = {
let processes = SERVER_PROCESSES
.read()
.map_err(|_| anyhow::anyhow!("Failed to acquire read lock on server processes"))?;
processes
.get(server_name)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Server not found: {}", server_name))?
};
let in_flight_arc = {
let in_flight_map = SERVER_IN_FLIGHT.read().unwrap();
in_flight_map
.get(server_name)
.cloned()
.ok_or_else(|| anyhow::anyhow!("No in-flight slot for server: {}", server_name))?
};
let previous_handle = in_flight_arc.lock().unwrap().take();
if let Some(handle) = previous_handle {
let wait_secs = std::time::Duration::from_secs(5);
if tokio::time::timeout(wait_secs, handle).await.is_err() {
crate::log_debug!(
"Previous in-flight task for server '{}' did not finish in time — marking dead for restart",
server_name
);
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
let info = restart_info_guard
.entry(server_name.to_string())
.or_default();
info.health_status = ServerHealth::Dead;
return Err(anyhow::anyhow!(
"Server '{}' previous operation timed out — will restart on next call",
server_name
));
}
}
let (final_message, request_id, child_pid) = {
let mut process_guard = server_process
.lock()
.map_err(|_| anyhow::anyhow!("Failed to acquire lock on server process"))?;
match &mut *process_guard {
ServerProcess::Stdin {
next_id,
is_shutdown,
child,
..
} => {
if is_shutdown.load(Ordering::SeqCst) {
return Err(anyhow::anyhow!("Server {} is shut down", server_name));
}
let actual_id = if override_id > 0 {
override_id
} else {
next_id.fetch_add(1, Ordering::SeqCst)
};
let mut final_msg = message.clone();
if let Some(obj) = final_msg.as_object_mut() {
obj.insert("id".to_string(), json!(actual_id));
if !obj.contains_key("jsonrpc") {
obj.insert("jsonrpc".to_string(), json!("2.0"));
}
}
let pid = child.id();
(final_msg, actual_id, pid)
}
_ => {
return Err(anyhow::anyhow!(
"Server {} is not a stdin-based server",
server_name
))
}
}
};
let server_name_for_error = server_name.to_string();
let server_name_for_closure = server_name.to_string();
let final_message_clone = final_message.clone();
let request_id_clone = request_id;
let session_id_for_closure = crate::session::context::current_session_id();
let cancel_flag = Arc::new(AtomicBool::new(false));
let cancel_flag_for_blocking = cancel_flag.clone();
let blocking_handle = tokio::task::spawn_blocking(move || {
let mut process = server_process
.lock()
.map_err(|_| anyhow::anyhow!("Failed to acquire lock on server process"))?;
match &mut *process {
ServerProcess::Stdin {
writer,
reader,
is_shutdown,
..
} => {
if is_shutdown.load(Ordering::SeqCst) {
return Err(anyhow::anyhow!(
"Server {} is shut down",
server_name_for_closure
));
}
let mut message_str = serde_json::to_string(&final_message_clone)?
.trim_end()
.to_string();
message_str.push('\n');
match writer.write_all(message_str.as_bytes()) {
Ok(_) => {}
Err(e) => {
if e.kind() == std::io::ErrorKind::BrokenPipe {
{
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
let info = restart_info_guard
.entry(server_name_for_closure.clone())
.or_default();
info.health_status = ServerHealth::Dead;
}
crate::log_debug!("Broken pipe detected on write for server '{}', marking for cleanup", server_name_for_closure);
return Err(anyhow::anyhow!(
"Server '{}' appears to have died (broken pipe on write). Will attempt restart on next call.",
server_name_for_closure
));
}
return Err(anyhow::anyhow!("Failed to write to stdin: {}", e));
}
}
match writer.flush() {
Ok(_) => {}
Err(e) => {
if e.kind() == std::io::ErrorKind::BrokenPipe {
{
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
let info = restart_info_guard
.entry(server_name_for_closure.clone())
.or_default();
info.health_status = ServerHealth::Dead;
}
crate::log_debug!("Broken pipe detected on flush for server '{}', marking for cleanup", server_name_for_closure);
return Err(anyhow::anyhow!(
"Server '{}' appears to have died (broken pipe on flush). Will attempt restart on next call.",
server_name_for_closure
));
}
return Err(anyhow::anyhow!("Failed to flush stdin: {}", e));
}
}
#[cfg(unix)]
{
use std::os::unix::io::AsRawFd;
let fd = reader.get_ref().as_raw_fd();
unsafe {
let flags = libc::fcntl(fd, libc::F_GETFL);
if flags != -1 {
libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK);
}
}
}
let mut response_str = String::new();
let response = loop {
if cancel_flag_for_blocking.load(Ordering::Relaxed) {
return Err(anyhow::anyhow!(
"Operation cancelled while waiting for server response"
));
}
let len_before = response_str.len();
let read_result = match reader.read_line(&mut response_str) {
Ok(n) => n,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::sleep(std::time::Duration::from_millis(50));
continue;
}
Err(e) => {
return Err(anyhow::anyhow!("Failed to read from stdout: {}", e));
}
};
let ended_with_newline = response_str.ends_with('\n');
if !ended_with_newline {
if read_result == 0 && response_str.len() == len_before {
let stderr_hint = {
let map = SERVER_STDERR.read().unwrap();
map.get(&server_name_for_closure)
.and_then(|buf| {
buf.lock().ok().and_then(|b| {
let last: Vec<_> =
b.iter().rev().take(10).cloned().collect();
if last.is_empty() {
None
} else {
let mut lines = last;
lines.reverse();
Some(format!(
"\nServer stderr:\n {}",
lines.join("\n ")
))
}
})
})
.unwrap_or_default()
};
return Err(anyhow::anyhow!(
"Server closed connection while reading response{}",
stderr_hint
));
}
continue;
}
let line = std::mem::take(&mut response_str);
let msg: Value = match serde_json::from_str(&line) {
Ok(v) => v,
Err(_) => {
let trimmed = line.trim();
if !trimmed.is_empty() {
eprintln!(
"⚠️ MCP '{}' prints: {}",
server_name_for_closure, trimmed
);
}
continue;
}
};
if msg.get("method").is_some() && msg.get("id").is_none() {
let method = msg
.get("method")
.and_then(|m| m.as_str())
.unwrap_or("unknown");
let params = msg
.get("params")
.cloned()
.unwrap_or(serde_json::Value::Null);
emit_notification(
&server_name_for_closure,
method,
¶ms,
session_id_for_closure.as_deref(),
);
continue;
}
break msg;
};
let response_id = response.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
if response_id != request_id_clone && override_id > 0 {
return Err(anyhow::anyhow!(
"Response ID {} does not match request ID {}",
response_id,
request_id_clone
));
}
Ok(response)
}
ServerProcess::Http(_) => Err(anyhow::anyhow!(
"Server {} is not a stdin-based server",
server_name_for_closure
)),
}
});
let cancellation_token_clone = cancellation_token.clone();
let cancellation_future = async move {
if let Some(mut token) = cancellation_token_clone {
while !*token.borrow() {
if token.changed().await.is_err() {
break;
}
}
} else {
std::future::pending::<()>().await;
}
};
let mut handle_opt = Some(blocking_handle);
tokio::select! {
result = tokio::time::timeout(
std::time::Duration::from_secs(timeout_seconds),
handle_opt.take().unwrap(),
) => {
*in_flight_arc.lock().unwrap() = None;
match result {
Ok(task_result) => task_result?,
Err(_) => Err(anyhow::anyhow!("Timeout ({} seconds) communicating with stdin server: {}", timeout_seconds, server_name_for_error))
}
},
_ = cancellation_future => {
cancel_flag.store(true, Ordering::Relaxed);
*in_flight_arc.lock().unwrap() = handle_opt.take();
#[cfg(unix)]
{
let pgid = child_pid as libc::pid_t;
unsafe {
libc::kill(-pgid, libc::SIGTERM);
}
crate::log_debug!(
"Sent SIGTERM to process group {} (server '{}') on cancellation",
pgid,
server_name_for_error
);
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
unsafe {
libc::kill(-pgid, libc::SIGKILL);
}
});
}
{
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
let info = restart_info_guard
.entry(server_name_for_error.clone())
.or_default();
info.health_status = ServerHealth::Dead;
}
Err(anyhow::anyhow!("Operation cancelled while communicating with server: {}", server_name_for_error))
}
}
}
pub async fn get_stdin_server_functions(server: &McpServerConfig) -> Result<Vec<McpFunction>> {
let mut all_functions = Vec::new();
let mut cursor: Option<String> = None;
const MAX_PAGES: usize = 20;
for page in 0..MAX_PAGES {
let mut params = json!({});
if let Some(ref c) = cursor {
params["cursor"] = json!(c);
}
let message = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": params
});
crate::log_debug!(
"tools/list request to '{}' (page {}, cursor: {:?})",
server.name(),
page + 1,
cursor
);
let response = communicate_with_stdin_server(server.name(), &message, 1, None).await?;
if let Some(error) = response.get("error") {
crate::log_error!(
"Warning: Server returned error during tools/list: {}",
error
);
return Ok(all_functions);
}
if let Some(result_value) = response.get("result").cloned() {
match serde_json::from_value::<rmcp::model::ListToolsResult>(result_value) {
Ok(list_result) => {
let next = list_result.next_cursor.clone();
let functions =
crate::mcp::server::parse_tools_from_list_result(&list_result, server);
all_functions.extend(functions);
match next {
Some(c) if !c.is_empty() => cursor = Some(c),
_ => break,
}
}
Err(e) => {
crate::log_debug!("Failed to deserialize ListToolsResult: {}", e);
break;
}
}
} else {
crate::log_debug!("Invalid response format from tools/list: {}", response);
break;
}
}
Ok(all_functions)
}
pub async fn execute_stdin_tool_call(
call: &McpToolCall,
server: &McpServerConfig,
cancellation_token: Option<tokio::sync::watch::Receiver<bool>>,
) -> Result<McpToolResult> {
let message = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call", "params": {
"name": call.tool_name,
"arguments": call.parameters
}
});
let response = match communicate_with_stdin_server_extended_timeout(
server.name(),
&message,
1,
server.timeout_seconds(),
cancellation_token,
)
.await
{
Ok(resp) => resp,
Err(e) => {
crate::log_error!("Error executing tool call '{}': {}", call.tool_name, e);
return Ok(McpToolResult::error(
call.tool_name.clone(),
call.tool_id.clone(),
format!("Error executing tool: {}", e),
));
}
};
if let Some(error) = response.get("error") {
let error_message = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown error");
let error_code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
let _output = json!({
"error": true,
"success": false,
"message": error_message,
"code": error_code
});
return Ok(McpToolResult::error(
call.tool_name.clone(),
call.tool_id.clone(),
format!("{} (code: {})", error_message, error_code),
));
}
let call_tool_result = response
.get("result")
.cloned()
.and_then(|v| serde_json::from_value::<rmcp::model::CallToolResult>(v).ok())
.unwrap_or_else(|| {
rmcp::model::CallToolResult::success(vec![rmcp::model::Content::text("No result")])
});
let tool_result = McpToolResult {
tool_name: call.tool_name.clone(),
tool_id: call.tool_id.clone(),
result: call_tool_result,
};
Ok(tool_result)
}
pub fn stop_all_servers() -> Result<()> {
let mut processes = SERVER_PROCESSES.write().unwrap();
for (name, process_arc) in processes.iter() {
crate::log_debug!("Stopping MCP server: {}", name);
match process_arc.try_lock() {
Ok(mut process) => {
if let Err(e) = process.kill() {
crate::log_error!("Failed to kill MCP server '{}': {}", name, e);
}
}
Err(_) => {
crate::log_debug!(
"Could not acquire lock for server '{}', using PGID for SIGKILL",
name
);
#[cfg(unix)]
{
let pgids = SERVER_PGIDS.read().unwrap();
if let Some(pgid) = pgids.get(name) {
crate::log_debug!(
"Sending SIGKILL to process group {} for server '{}'",
pgid,
name
);
unsafe {
libc::kill(-*pgid, libc::SIGKILL);
}
} else {
crate::log_debug!("No PGID found for server '{}', process may leak", name);
}
}
#[cfg(not(unix))]
{
crate::log_debug!("No PGID fallback on non-Unix for server '{}'", name);
}
}
}
}
processes.clear();
#[cfg(unix)]
{
let mut pgids = SERVER_PGIDS.write().unwrap();
pgids.clear();
}
{
let mut in_flight_map = SERVER_IN_FLIGHT.write().unwrap();
in_flight_map.clear();
}
crate::mcp::server::clear_all_function_cache();
{
let mut mutexes = SERVER_RESTART_MUTEXES.write().unwrap();
mutexes.clear();
crate::log_debug!("Cleared all server restart mutexes");
}
{
let mut counts = SERVER_REF_COUNTS.write().unwrap();
counts.clear();
}
{
let mut stderr_map = SERVER_STDERR.write().unwrap();
stderr_map.clear();
}
{
let mut caps = SERVER_CAPABILITIES.write().unwrap();
caps.clear();
}
Ok(())
}
pub fn cleanup_server_process(server_name: &str) -> Result<()> {
{
let counts = SERVER_REF_COUNTS.read().unwrap();
let refs = counts.get(server_name).copied().unwrap_or(0);
if refs > 0 {
crate::log_debug!(
"Skipping cleanup of server '{}': {} session(s) still using it",
server_name,
refs
);
return Ok(());
}
}
let mut processes = SERVER_PROCESSES.write().unwrap();
if let Some(process_arc) = processes.remove(server_name) {
match process_arc.try_lock() {
Ok(mut process) => {
crate::log_debug!("Cleaning up server process '{}'", server_name);
if let Err(e) = process.kill() {
crate::log_debug!("Failed to kill server process '{}': {}", server_name, e);
}
}
Err(_) => {
crate::log_debug!(
"Could not acquire lock for server '{}' during cleanup, using PGID for SIGKILL",
server_name
);
#[cfg(unix)]
{
let pgids = SERVER_PGIDS.read().unwrap();
if let Some(pgid) = pgids.get(server_name) {
crate::log_debug!(
"Sending SIGKILL to process group {} for server '{}' during cleanup",
pgid,
server_name
);
unsafe {
libc::kill(-*pgid, libc::SIGKILL);
}
} else {
crate::log_debug!(
"No PGID found for server '{}' during cleanup, process may leak",
server_name
);
}
}
#[cfg(not(unix))]
{
crate::log_debug!(
"No PGID fallback on non-Unix for server '{}' during cleanup",
server_name
);
}
}
}
crate::mcp::server::clear_function_cache_for_server(server_name);
#[cfg(unix)]
{
let mut pgids = SERVER_PGIDS.write().unwrap();
pgids.remove(server_name);
}
{
let mut in_flight_map = SERVER_IN_FLIGHT.write().unwrap();
in_flight_map.remove(server_name);
}
{
let mut stderr_map = SERVER_STDERR.write().unwrap();
stderr_map.remove(server_name);
}
{
let mut caps = SERVER_CAPABILITIES.write().unwrap();
caps.remove(server_name);
}
cleanup_server_restart_mutex(server_name);
crate::log_debug!("Server '{}' removed from registry", server_name);
Ok(())
} else {
Err(anyhow::anyhow!(
"Server '{}' not found in registry",
server_name
))
}
}
pub fn release_server(server_name: &str) {
let should_cleanup = {
let mut counts = SERVER_REF_COUNTS.write().unwrap();
if let Some(count) = counts.get_mut(server_name) {
if *count > 0 {
*count -= 1;
}
let remaining = *count;
crate::log_debug!(
"Server '{}' ref count after release: {}",
server_name,
remaining
);
if remaining == 0 {
counts.remove(server_name);
true
} else {
false
}
} else {
false
}
};
if should_cleanup {
if let Err(e) = cleanup_server_process(server_name) {
crate::log_debug!(
"Failed to cleanup server '{}' after release: {}",
server_name,
e
);
}
}
}
pub fn is_server_running(server_name: &str) -> bool {
let processes = SERVER_PROCESSES.read().unwrap();
if let Some(process_arc) = processes.get(server_name) {
let is_alive = match process_arc.try_lock() {
Ok(mut process) => process
.try_wait()
.map(|status| status.is_none())
.unwrap_or(false),
Err(_) => {
true
}
};
{
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
let info = restart_info_guard
.entry(server_name.to_string())
.or_default();
info.health_status = if is_alive {
ServerHealth::Running
} else {
ServerHealth::Dead
};
info.last_health_check = Some(SystemTime::now());
}
is_alive
} else {
{
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
let info = restart_info_guard
.entry(server_name.to_string())
.or_default();
info.last_health_check = Some(SystemTime::now());
}
false }
}
pub fn get_server_health(server_name: &str) -> ServerHealth {
let restart_info_guard = SERVER_RESTART_INFO.read().unwrap();
restart_info_guard
.get(server_name)
.map(|info| info.health_status)
.unwrap_or(ServerHealth::Dead)
}
pub fn get_server_restart_info(server_name: &str) -> ServerRestartInfo {
let restart_info_guard = SERVER_RESTART_INFO.read().unwrap();
restart_info_guard
.get(server_name)
.cloned()
.unwrap_or_default()
}
pub fn reset_server_failure_state(server_name: &str) -> Result<()> {
let mut restart_info_guard = SERVER_RESTART_INFO.write().unwrap();
if let Some(info) = restart_info_guard.get_mut(server_name) {
info.restart_count = 0;
info.consecutive_failures = 0;
info.health_status = ServerHealth::Dead; crate::log_debug!("Reset failure state for server '{}'", server_name);
Ok(())
} else {
Err(anyhow::anyhow!(
"Server '{}' not found in restart tracking",
server_name
))
}
}
pub async fn perform_health_check_all_servers() -> HashMap<String, ServerHealth> {
let mut health_status = HashMap::new();
let server_names: Vec<String> = {
let processes = SERVER_PROCESSES.read().unwrap();
processes.keys().cloned().collect()
};
for server_name in server_names {
let is_running = is_server_running(&server_name);
let health = if is_running {
ServerHealth::Running
} else {
ServerHealth::Dead
};
health_status.insert(server_name.clone(), health);
crate::log_debug!("Health check: Server '{}' is {:?}", server_name, health);
}
health_status
}
pub fn get_server_status_report() -> HashMap<String, (ServerHealth, ServerRestartInfo)> {
let mut report = HashMap::new();
let restart_info_guard = SERVER_RESTART_INFO.read().unwrap();
for (server_name, info) in restart_info_guard.iter() {
let current_health = get_server_health(server_name);
report.insert(server_name.clone(), (current_health, info.clone()));
}
report
}
async fn send_stdin_notification(server_name: &str, message: &Value) -> Result<()> {
let server_process = {
let processes = SERVER_PROCESSES
.read()
.map_err(|_| anyhow::anyhow!("Failed to acquire read lock on server processes"))?;
processes
.get(server_name)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Server not found: {}", server_name))?
};
let server_name_owned = server_name.to_string();
let message_clone = message.clone();
tokio::task::spawn_blocking(move || {
let mut process = server_process
.lock()
.map_err(|_| anyhow::anyhow!("Failed to acquire lock on server process"))?;
match &mut *process {
ServerProcess::Stdin {
writer,
is_shutdown,
..
} => {
if is_shutdown.load(Ordering::SeqCst) {
return Err(anyhow::anyhow!("Server {} is shut down", server_name_owned));
}
let mut message_str = serde_json::to_string(&message_clone)?
.trim_end()
.to_string();
message_str.push('\n');
writer
.write_all(message_str.as_bytes())
.map_err(|e| anyhow::anyhow!("Failed to write notification: {}", e))?;
writer
.flush()
.map_err(|e| anyhow::anyhow!("Failed to flush notification: {}", e))?;
Ok(())
}
_ => Err(anyhow::anyhow!(
"Server {} is not a stdin-based server",
server_name_owned
)),
}
})
.await
.map_err(|e| anyhow::anyhow!("Blocking task failed: {}", e))?
}