use crate::error::{Result, SofosError};
use crate::mcp::client::{
MCP_REQUEST_TIMEOUT, create_call_tool_request, create_init_request, parse_call_tool_response,
parse_list_tools_response,
};
use crate::mcp::config::McpServerConfig;
use crate::mcp::protocol::*;
use serde_json::Value;
use std::io::{BufRead, BufReader, Write};
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command, Stdio};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
fn take_child_pipes(
mut process: Child,
server_name: &str,
) -> Result<(Child, ChildStdin, ChildStdout, ChildStderr)> {
if let (Some(stdin), Some(stdout), Some(stderr)) = (
process.stdin.take(),
process.stdout.take(),
process.stderr.take(),
) {
return Ok((process, stdin, stdout, stderr));
}
let _ = process.kill();
let _ = process.wait();
Err(SofosError::McpError(format!(
"Failed to acquire stdin/stdout/stderr for MCP server '{}'",
server_name
)))
}
fn spawn_stderr_reader(server_name: String, stderr: ChildStderr) {
tokio::task::spawn_blocking(move || {
let reader = BufReader::new(stderr);
for line in reader.lines() {
match line {
Ok(text) => {
let clean = strip_ansi_escapes(&text);
tracing::debug!(server = %server_name, "mcp stderr: {}", clean);
}
Err(e) => {
tracing::warn!(
server = %server_name,
"mcp stderr read failed: {}",
e
);
break;
}
}
}
});
}
fn strip_ansi_escapes(s: &str) -> String {
let mut out = String::with_capacity(s.len());
let mut chars = s.chars();
while let Some(c) = chars.next() {
if c != '\x1b' {
out.push(c);
continue;
}
if let Some('[') = chars.next() {
for cc in chars.by_ref() {
if matches!(cc, '\x40'..='\x7e') {
break;
}
}
}
}
out
}
fn stdio_write_blocking(
server_name: &str,
stdin: &Arc<Mutex<ChildStdin>>,
payload: &str,
) -> Result<()> {
let mut stdin_guard = stdin
.lock()
.map_err(|e| SofosError::McpError(format!("Failed to lock stdin: {}", e)))?;
writeln!(stdin_guard, "{}", payload).map_err(|e| {
SofosError::McpError(format!(
"Failed to write to MCP server '{}': {}",
server_name, e
))
})?;
stdin_guard.flush().map_err(|e| {
SofosError::McpError(format!(
"Failed to flush stdin for MCP server '{}': {}",
server_name, e
))
})?;
Ok(())
}
fn stdio_request_blocking(
server_name: &str,
request_lock: &Arc<Mutex<()>>,
stdin: &Arc<Mutex<ChildStdin>>,
stdout: &Arc<Mutex<BufReader<ChildStdout>>>,
request_json: &str,
) -> Result<JsonRpcResponse> {
let _request_guard = request_lock
.lock()
.map_err(|e| SofosError::McpError(format!("Failed to lock MCP request mutex: {}", e)))?;
stdio_write_blocking(server_name, stdin, request_json)?;
let mut stdout_guard = stdout
.lock()
.map_err(|e| SofosError::McpError(format!("Failed to lock stdout: {}", e)))?;
let mut response_line = String::new();
let bytes_read = stdout_guard.read_line(&mut response_line).map_err(|e| {
SofosError::McpError(format!(
"Failed to read from MCP server '{}': {}",
server_name, e
))
})?;
if bytes_read == 0 {
return Err(SofosError::McpError(format!(
"MCP server '{}' closed stdout unexpectedly (server crashed or exited?)",
server_name
)));
}
serde_json::from_str(&response_line).map_err(|e| {
SofosError::McpError(format!(
"Failed to parse response from MCP server '{}': {}",
server_name, e
))
})
}
pub struct StdioClient {
server_name: String,
process: Arc<Mutex<Child>>,
stdin: Arc<Mutex<ChildStdin>>,
stdout: Arc<Mutex<BufReader<ChildStdout>>>,
request_lock: Arc<Mutex<()>>,
next_id: Arc<AtomicU64>,
}
impl Drop for StdioClient {
fn drop(&mut self) {
if let Ok(mut child) = self.process.lock() {
let _ = child.kill();
let _ = child.wait();
}
}
}
impl StdioClient {
pub async fn new(server_name: String, config: McpServerConfig) -> Result<Self> {
let command = config
.command
.ok_or_else(|| SofosError::McpError("Missing command for stdio server".to_string()))?;
let args = config.args.unwrap_or_default();
let env_vars = config.env.unwrap_or_default();
let mut cmd = Command::new(&command);
cmd.args(&args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
for (key, value) in env_vars {
cmd.env(key, value);
}
let process = cmd.spawn().map_err(|e| {
SofosError::McpError(format!(
"Failed to start MCP server '{}': {}",
server_name, e
))
})?;
let (process, stdin, stdout, stderr) = take_child_pipes(process, &server_name)?;
spawn_stderr_reader(server_name.clone(), stderr);
let client = Self {
server_name: server_name.clone(),
process: Arc::new(Mutex::new(process)),
stdin: Arc::new(Mutex::new(stdin)),
stdout: Arc::new(Mutex::new(BufReader::new(stdout))),
request_lock: Arc::new(Mutex::new(())),
next_id: Arc::new(AtomicU64::new(1)),
};
client.initialize().await?;
Ok(client)
}
async fn run_with_timeout<T, F>(&self, label: &str, blocking: F) -> Result<T>
where
F: FnOnce() -> Result<T> + Send + 'static,
T: Send + 'static,
{
let task = tokio::task::spawn_blocking(blocking);
match tokio::time::timeout(MCP_REQUEST_TIMEOUT, task).await {
Ok(Ok(Ok(value))) => Ok(value),
Ok(Ok(Err(e))) => Err(e),
Ok(Err(join_err)) => Err(SofosError::McpError(format!(
"MCP worker panicked for server '{}' during {}: {}",
self.server_name, label, join_err
))),
Err(_) => {
self.kill_child_detached();
Err(SofosError::McpError(format!(
"MCP server '{}' {} timed out after {}s",
self.server_name,
label,
MCP_REQUEST_TIMEOUT.as_secs()
)))
}
}
}
fn kill_child_detached(&self) {
let process = Arc::clone(&self.process);
tokio::task::spawn_blocking(move || {
if let Ok(mut child) = process.lock() {
let _ = child.kill();
let _ = child.wait();
}
});
}
async fn initialize(&self) -> Result<()> {
let response = self
.send_request(
"initialize",
Some(serde_json::to_value(create_init_request())?),
)
.await?;
let _init_result: InitializeResult = serde_json::from_value(response)?;
self.send_notification("notifications/initialized", None)
.await?;
Ok(())
}
async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value> {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let request = JsonRpcRequest::new(id, method.to_string(), params);
let request_json = serde_json::to_string(&request)?;
let server_name = self.server_name.clone();
let request_lock = Arc::clone(&self.request_lock);
let stdin = Arc::clone(&self.stdin);
let stdout = Arc::clone(&self.stdout);
let response = self
.run_with_timeout("request", move || {
stdio_request_blocking(&server_name, &request_lock, &stdin, &stdout, &request_json)
})
.await?;
if let Some(error) = response.error {
return Err(SofosError::McpError(format!(
"MCP server '{}' returned error: {}",
self.server_name, error.message
)));
}
response.result.ok_or_else(|| {
SofosError::McpError(format!(
"MCP server '{}' returned no result",
self.server_name
))
})
}
async fn send_notification(&self, method: &str, _params: Option<Value>) -> Result<()> {
let notification = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
});
let notification_json = serde_json::to_string(¬ification)?;
let server_name = self.server_name.clone();
let stdin = Arc::clone(&self.stdin);
self.run_with_timeout("notification", move || {
stdio_write_blocking(&server_name, &stdin, ¬ification_json)
})
.await
}
pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
let result = self.send_request("tools/list", None).await?;
parse_list_tools_response(result)
}
pub async fn call_tool(&self, name: &str, arguments: Option<Value>) -> Result<CallToolResult> {
let result = self
.send_request(
"tools/call",
Some(serde_json::to_value(create_call_tool_request(
name, arguments,
))?),
)
.await?;
parse_call_tool_response(result)
}
}
#[cfg(test)]
mod tests {
use super::strip_ansi_escapes;
#[test]
fn strips_csi_color_run() {
let input = "\x1b[2m2026-05-15T22:34:54.965614Z\x1b[0m \x1b[32m INFO\x1b[0m start";
assert_eq!(
strip_ansi_escapes(input),
"2026-05-15T22:34:54.965614Z INFO start"
);
}
#[test]
fn passes_plain_text_through() {
assert_eq!(strip_ansi_escapes("no escapes here"), "no escapes here");
}
#[test]
fn drops_bare_escape() {
assert_eq!(strip_ansi_escapes("a\x1bXb"), "ab");
}
}