use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use coralstack_cmd_ipc::{
ChannelError, CommandChannel, CommandDef, ExecuteResult, Message, MessageId,
};
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::channel::oneshot;
use futures::future::BoxFuture;
use futures::lock::Mutex as AsyncMutex;
use futures::StreamExt;
use rmcp::handler::server::ServerHandler;
use rmcp::model::{
CallToolRequestParams, CallToolResult, Implementation, ListToolsResult, PaginatedRequestParams,
ServerCapabilities, ServerInfo,
};
use rmcp::service::RequestContext;
use rmcp::transport::IntoTransport;
use rmcp::{ErrorData as McpError, RoleServer, ServiceExt};
use serde_json::Value;
use crate::translate::{
command_to_tool, execute_error_to_call_result, is_tool_not_found, mcp_error_for_unknown_tool,
success_to_call_result,
};
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, thiserror::Error)]
pub enum McpServerError {
#[error("MCP transport error: {0}")]
Transport(String),
#[error("MCP protocol error: {0}")]
Protocol(String),
}
pub struct McpServerChannel {
id: String,
impl_name: Mutex<String>,
impl_version: Mutex<String>,
instructions: Mutex<Option<String>>,
timeout: Mutex<Duration>,
include: Mutex<Option<HashSet<String>>>,
exclude: Mutex<HashSet<String>>,
tx: UnboundedSender<Message>,
rx: AsyncMutex<Option<UnboundedReceiver<Message>>>,
pending_lists: Mutex<HashMap<MessageId, oneshot::Sender<Vec<CommandDef>>>>,
pending_calls: Mutex<HashMap<MessageId, oneshot::Sender<ExecuteResult>>>,
closed: AtomicBool,
}
impl McpServerChannel {
pub fn new(id: impl Into<String>) -> Self {
let (tx, rx) = unbounded();
Self {
id: id.into(),
impl_name: Mutex::new("cmd-ipc-mcp".into()),
impl_version: Mutex::new(env!("CARGO_PKG_VERSION").into()),
instructions: Mutex::new(None),
timeout: Mutex::new(DEFAULT_TIMEOUT),
include: Mutex::new(None),
exclude: Mutex::new(HashSet::new()),
tx,
rx: AsyncMutex::new(Some(rx)),
pending_lists: Mutex::new(HashMap::new()),
pending_calls: Mutex::new(HashMap::new()),
closed: AtomicBool::new(false),
}
}
pub fn with_implementation(self, name: impl Into<String>, version: impl Into<String>) -> Self {
*self.impl_name.lock().unwrap() = name.into();
*self.impl_version.lock().unwrap() = version.into();
self
}
pub fn with_instructions(self, instructions: impl Into<String>) -> Self {
*self.instructions.lock().unwrap() = Some(instructions.into());
self
}
pub fn with_timeout(self, timeout: Duration) -> Self {
*self.timeout.lock().unwrap() = timeout;
self
}
pub fn with_include<I, S>(self, ids: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
*self.include.lock().unwrap() = Some(ids.into_iter().map(Into::into).collect());
self
}
pub fn with_exclude<I, S>(self, ids: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
*self.exclude.lock().unwrap() = ids.into_iter().map(Into::into).collect();
self
}
fn is_exposed(&self, command_id: &str) -> bool {
if command_id.starts_with('_') {
return false;
}
if self.exclude.lock().unwrap().contains(command_id) {
return false;
}
if let Some(ref allow) = *self.include.lock().unwrap() {
if !allow.contains(command_id) {
return false;
}
}
true
}
pub async fn serve<T, E, A>(self: Arc<Self>, transport: T) -> Result<(), McpServerError>
where
T: IntoTransport<RoleServer, E, A>,
E: std::error::Error + Send + Sync + 'static,
{
let handler = McpHandler { channel: self };
let service = handler
.serve(transport)
.await
.map_err(|e| McpServerError::Transport(e.to_string()))?;
service
.waiting()
.await
.map_err(|e| McpServerError::Protocol(e.to_string()))?;
Ok(())
}
pub async fn serve_stdio(self: Arc<Self>) -> Result<(), McpServerError> {
self.serve(rmcp::transport::io::stdio()).await
}
pub fn into_handler(self: Arc<Self>) -> impl ServerHandler + Clone {
McpHandler { channel: self }
}
fn server_info(&self) -> ServerInfo {
let capabilities = ServerCapabilities::builder().enable_tools().build();
let implementation = Implementation::new(
self.impl_name.lock().unwrap().clone(),
self.impl_version.lock().unwrap().clone(),
);
let mut info = ServerInfo::new(capabilities).with_server_info(implementation);
if let Some(ref s) = *self.instructions.lock().unwrap() {
info = info.with_instructions(s.clone());
}
info
}
fn timeout_duration(&self) -> Duration {
*self.timeout.lock().unwrap()
}
}
impl CommandChannel for McpServerChannel {
fn id(&self) -> &str {
&self.id
}
fn start(&self) -> BoxFuture<'_, Result<(), ChannelError>> {
Box::pin(async { Ok(()) })
}
fn close(&self) -> BoxFuture<'_, ()> {
Box::pin(async move {
self.closed.store(true, Ordering::SeqCst);
self.tx.close_channel();
self.pending_lists.lock().unwrap().clear();
self.pending_calls.lock().unwrap().clear();
})
}
fn send(&self, msg: Message) -> Result<(), ChannelError> {
if self.closed.load(Ordering::SeqCst) {
return Err(ChannelError::Closed);
}
match msg {
Message::ListCommandsResponse { thid, commands, .. } => {
if let Some(tx) = self.pending_lists.lock().unwrap().remove(&thid) {
let _ = tx.send(commands);
}
}
Message::ExecuteCommandResponse { thid, response, .. } => {
if let Some(tx) = self.pending_calls.lock().unwrap().remove(&thid) {
let _ = tx.send(response);
}
}
_ => {}
}
Ok(())
}
fn recv(&self) -> BoxFuture<'_, Option<Message>> {
Box::pin(async move {
let mut guard = self.rx.lock().await;
let rx = guard.as_mut()?;
rx.next().await
})
}
}
#[derive(Clone)]
struct McpHandler {
channel: Arc<McpServerChannel>,
}
impl McpHandler {
async fn round_trip<T, F>(
&self,
build_request: impl FnOnce(MessageId) -> Message,
register_pending: F,
) -> Result<T, McpError>
where
F: FnOnce(MessageId, oneshot::Sender<T>, &McpServerChannel),
{
let id = MessageId::new_v4();
let (sender, receiver) = oneshot::channel();
register_pending(id, sender, &self.channel);
if let Err(e) = self.channel.tx.unbounded_send(build_request(id)) {
return Err(McpError::internal_error(
format!("cmd-ipc channel closed: {e}"),
None,
));
}
match tokio::time::timeout(self.channel.timeout_duration(), receiver).await {
Ok(Ok(value)) => Ok(value),
Ok(Err(_)) => Err(McpError::internal_error(
"cmd-ipc channel closed before response".to_string(),
None,
)),
Err(_) => Err(McpError::internal_error(
"timed out waiting for cmd-ipc response".to_string(),
None,
)),
}
}
}
impl ServerHandler for McpHandler {
fn get_info(&self) -> ServerInfo {
self.channel.server_info()
}
async fn list_tools(
&self,
_request: Option<PaginatedRequestParams>,
_ctx: RequestContext<RoleServer>,
) -> Result<ListToolsResult, McpError> {
let defs = self
.round_trip(
|id| Message::ListCommandsRequest { id, meta: None },
|id, sender, ch| {
ch.pending_lists.lock().unwrap().insert(id, sender);
},
)
.await?;
let tools = defs
.iter()
.filter(|d| self.channel.is_exposed(&d.id))
.map(command_to_tool)
.collect();
Ok(ListToolsResult {
tools,
next_cursor: None,
..Default::default()
})
}
async fn call_tool(
&self,
request: CallToolRequestParams,
_ctx: RequestContext<RoleServer>,
) -> Result<CallToolResult, McpError> {
let name = request.name.to_string();
if !self.channel.is_exposed(&name) {
return Err(mcp_error_for_unknown_tool(&name));
}
let payload = request.arguments.map(Value::Object).unwrap_or(Value::Null);
let request_payload = if payload.is_null() {
None
} else {
Some(payload)
};
let command_id = name.clone();
let response = self
.round_trip(
|id| Message::ExecuteCommandRequest {
id,
meta: None,
command_id: command_id.clone(),
request: request_payload.clone(),
},
|id, sender, ch| {
ch.pending_calls.lock().unwrap().insert(id, sender);
},
)
.await?;
match response {
ExecuteResult::Ok {
result: Some(Value::Null),
..
}
| ExecuteResult::Ok { result: None, .. } => Ok(success_to_call_result(None)),
ExecuteResult::Ok {
result: Some(value),
..
} => Ok(success_to_call_result(Some(value))),
ExecuteResult::Err { error, .. } => {
if is_tool_not_found(&error) {
Err(mcp_error_for_unknown_tool(&name))
} else {
Ok(execute_error_to_call_result(error))
}
}
}
}
}