use std::collections::HashMap;
use std::process::Stdio;
use std::sync::Arc;
use rmcp::service::{Peer, RoleClient, RunningService, serve_client};
use tokio::process::{ChildStderr, Command};
use tokio::sync::{Mutex, RwLock};
use super::config::McpServerConfig;
pub struct SharedConnection {
#[allow(dead_code)]
pub server_name: String,
peer: RwLock<Peer<RoleClient>>,
running_service: Mutex<Option<RunningService<RoleClient, ()>>>,
}
impl SharedConnection {
pub(crate) fn new(
server_name: String,
peer: Peer<RoleClient>,
rs: RunningService<RoleClient, ()>,
) -> Self {
Self {
server_name,
peer: RwLock::new(peer),
running_service: Mutex::new(Some(rs)),
}
}
pub async fn current_peer(&self) -> Peer<RoleClient> {
self.peer.read().await.clone()
}
pub async fn replace(
&self,
new_peer: Peer<RoleClient>,
new_rs: RunningService<RoleClient, ()>,
) {
let _old = {
let mut rs_guard = self.running_service.lock().await;
(*rs_guard).replace(new_rs)
};
*self.peer.write().await = new_peer;
}
pub async fn shutdown(&self) {
let mut rs = self.running_service.lock().await;
rs.take(); }
}
pub async fn connect(
server_name: String,
config: &McpServerConfig,
) -> anyhow::Result<Arc<SharedConnection>> {
let init_timeout = crate::timeout::Timeouts::get().mcp_init;
let inner = connect_inner(server_name.clone(), config);
match tokio::time::timeout(init_timeout, inner).await {
Ok(result) => result,
Err(_) => Err(anyhow::anyhow!(
"MCP server {server_name:?} did not initialize within {}s — skipping",
init_timeout.as_secs(),
)),
}
}
async fn connect_inner(
server_name: String,
config: &McpServerConfig,
) -> anyhow::Result<Arc<SharedConnection>> {
let (peer, rs) = raw_connect(&server_name, config).await?;
Ok(Arc::new(SharedConnection::new(server_name, peer, rs)))
}
pub async fn raw_connect(
server_name: &str,
config: &McpServerConfig,
) -> anyhow::Result<(Peer<RoleClient>, RunningService<RoleClient, ()>)> {
match config {
McpServerConfig::Command {
command,
args,
env,
allow_external_paths: _,
} => {
let mut cmd = Command::new(command);
cmd.args(args);
for (k, v) in env {
cmd.env(k, v);
}
let (transport, stderr) =
rmcp::transport::child_process::TokioChildProcess::builder(cmd)
.stderr(Stdio::piped())
.spawn()?;
if let Some(child_stderr) = stderr {
spawn_stderr_forwarder(server_name.to_string(), child_stderr);
}
let rs = serve_client((), transport)
.await
.map_err(|e| anyhow::anyhow!("MCP connection failed for '{server_name}': {e}"))?;
let peer = rs.peer().clone();
Ok((peer, rs))
}
McpServerConfig::Url {
url,
headers,
allow_external_paths: _,
} => {
let custom_headers = parse_headers(headers)?;
let cfg = rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig::with_uri(url.as_str())
.custom_headers(custom_headers);
type HttpClient = rmcp::transport::StreamableHttpClientTransport<reqwest::Client>;
let transport = HttpClient::from_config(cfg);
let rs = serve_client((), transport).await.map_err(|e| {
anyhow::anyhow!("MCP HTTP connection failed for '{server_name}': {e}")
})?;
let peer = rs.peer().clone();
Ok((peer, rs))
}
}
}
pub async fn list_tools(
conn: &SharedConnection,
) -> Result<Vec<rmcp::model::Tool>, rmcp::ServiceError> {
let peer = conn.current_peer().await;
peer.list_all_tools().await
}
const MAX_STDERR_LINE_BYTES: usize = 16 * 1024;
struct StderrLineSplitter {
buf: Vec<u8>,
draining: bool,
max_line_bytes: usize,
}
impl StderrLineSplitter {
fn new(max_line_bytes: usize) -> Self {
Self {
buf: Vec::with_capacity(1024),
draining: false,
max_line_bytes,
}
}
fn push(&mut self, b: u8) -> Option<Vec<u8>> {
if b == b'\n' {
let line = if self.draining {
None
} else {
Some(std::mem::take(&mut self.buf))
};
self.buf.clear();
self.draining = false;
return line;
}
if self.draining {
return None; }
if self.buf.len() >= self.max_line_bytes {
self.buf.extend_from_slice(b" ...[truncated]");
self.draining = true;
return Some(std::mem::take(&mut self.buf));
}
if self.buf.is_empty() && b == b'\r' {
return None; }
self.buf.push(b);
None
}
fn finish(self) -> Option<Vec<u8>> {
if self.buf.is_empty() {
None
} else {
Some(self.buf)
}
}
}
fn spawn_stderr_forwarder(server_name: String, stderr: ChildStderr) {
tokio::spawn(async move {
use tokio::io::AsyncReadExt;
let mut reader = tokio::io::BufReader::new(stderr);
let mut splitter = StderrLineSplitter::new(MAX_STDERR_LINE_BYTES);
let mut byte_buf = [0u8; 4096];
loop {
let n = match reader.read(&mut byte_buf).await {
Ok(0) => break, Ok(n) => n,
Err(_) => break,
};
for &b in &byte_buf[..n] {
if let Some(line) = splitter.push(b) {
emit_mcp_line(&server_name, &line);
}
}
}
if let Some(line) = splitter.finish() {
emit_mcp_line(&server_name, &line);
}
});
}
fn emit_mcp_line(server_name: &str, raw: &[u8]) {
let s = String::from_utf8_lossy(raw);
let sanitized = crate::ui::ansi::strip_controls(&s, crate::ui::ansi::StripPolicy::STRICT);
if sanitized.is_empty() {
return;
}
crate::ui::notifications::notify_mcp_log(server_name, &sanitized);
}
fn parse_headers(
headers: &HashMap<String, String>,
) -> anyhow::Result<HashMap<http::HeaderName, http::HeaderValue>> {
let mut result = HashMap::new();
for (name, value) in headers {
let h_name: http::HeaderName = name
.parse()
.map_err(|e| anyhow::anyhow!("Invalid header name '{name}': {e}"))?;
let h_value: http::HeaderValue = value
.parse()
.map_err(|e| anyhow::anyhow!("Invalid header value for '{name}': {e}"))?;
result.insert(h_name, h_value);
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
fn split(data: &[u8], cap: usize) -> Vec<String> {
let mut s = StderrLineSplitter::new(cap);
let mut out = Vec::new();
for &b in data {
if let Some(line) = s.push(b) {
out.push(String::from_utf8_lossy(&line).into_owned());
}
}
if let Some(line) = s.finish() {
out.push(String::from_utf8_lossy(&line).into_owned());
}
out
}
#[test]
fn normal_lines_split_on_newline() {
assert_eq!(split(b"alpha\nbeta\n", 16), vec!["alpha", "beta"]);
}
#[test]
fn trailing_partial_line_flushed_on_eof() {
assert_eq!(split(b"alpha\npartial", 16), vec!["alpha", "partial"]);
}
#[test]
fn leading_cr_stripped() {
assert_eq!(split(b"\rhi\n", 16), vec!["hi"]);
}
#[test]
fn overlong_line_truncates_once_then_drains() {
let mut data = vec![b'x'; 100]; data.push(b'\n');
let lines = split(&data, 8);
assert_eq!(lines.len(), 1, "exactly one emitted line, got {lines:?}");
assert!(lines[0].starts_with("xxxxxxxx"));
assert!(lines[0].ends_with("...[truncated]"));
assert_eq!(lines[0].matches("[truncated]").count(), 1);
}
#[test]
fn line_after_overlong_is_clean() {
let mut data = vec![b'x'; 50];
data.push(b'\n');
data.extend_from_slice(b"ok\n"); let lines = split(&data, 8);
assert_eq!(lines.len(), 2, "got {lines:?}");
assert!(lines[0].ends_with("...[truncated]"));
assert_eq!(lines[1], "ok");
}
#[test]
fn overlong_line_at_eof_emits_once() {
let data = vec![b'x'; 100];
let lines = split(&data, 8);
assert_eq!(lines.len(), 1, "got {lines:?}");
assert!(lines[0].ends_with("...[truncated]"));
}
}