1use std::collections::{HashMap, HashSet};
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14
15use coralstack_cmd_ipc::{
16 ChannelError, CommandChannel, CommandDef, ExecuteResult, Message, MessageId,
17};
18use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
19use futures::channel::oneshot;
20use futures::future::BoxFuture;
21use futures::lock::Mutex as AsyncMutex;
22use futures::StreamExt;
23use rmcp::handler::server::ServerHandler;
24use rmcp::model::{
25 CallToolRequestParams, CallToolResult, Implementation, ListToolsResult, PaginatedRequestParams,
26 ServerCapabilities, ServerInfo,
27};
28use rmcp::service::RequestContext;
29use rmcp::transport::IntoTransport;
30use rmcp::{ErrorData as McpError, RoleServer, ServiceExt};
31use serde_json::Value;
32
33use crate::translate::{
34 command_to_tool, execute_error_to_call_result, is_tool_not_found, mcp_error_for_unknown_tool,
35 success_to_call_result,
36};
37
38const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
39
40#[derive(Debug, thiserror::Error)]
42pub enum McpServerError {
43 #[error("MCP transport error: {0}")]
44 Transport(String),
45 #[error("MCP protocol error: {0}")]
46 Protocol(String),
47}
48
49pub struct McpServerChannel {
74 id: String,
75 impl_name: Mutex<String>,
76 impl_version: Mutex<String>,
77 instructions: Mutex<Option<String>>,
78 timeout: Mutex<Duration>,
79
80 include: Mutex<Option<HashSet<String>>>,
84 exclude: Mutex<HashSet<String>>,
87
88 tx: UnboundedSender<Message>,
90 rx: AsyncMutex<Option<UnboundedReceiver<Message>>>,
91
92 pending_lists: Mutex<HashMap<MessageId, oneshot::Sender<Vec<CommandDef>>>>,
94 pending_calls: Mutex<HashMap<MessageId, oneshot::Sender<ExecuteResult>>>,
95
96 closed: AtomicBool,
97}
98
99impl McpServerChannel {
100 pub fn new(id: impl Into<String>) -> Self {
104 let (tx, rx) = unbounded();
105 Self {
106 id: id.into(),
107 impl_name: Mutex::new("cmd-ipc-mcp".into()),
108 impl_version: Mutex::new(env!("CARGO_PKG_VERSION").into()),
109 instructions: Mutex::new(None),
110 timeout: Mutex::new(DEFAULT_TIMEOUT),
111 include: Mutex::new(None),
112 exclude: Mutex::new(HashSet::new()),
113 tx,
114 rx: AsyncMutex::new(Some(rx)),
115 pending_lists: Mutex::new(HashMap::new()),
116 pending_calls: Mutex::new(HashMap::new()),
117 closed: AtomicBool::new(false),
118 }
119 }
120
121 pub fn with_implementation(self, name: impl Into<String>, version: impl Into<String>) -> Self {
124 *self.impl_name.lock().unwrap() = name.into();
125 *self.impl_version.lock().unwrap() = version.into();
126 self
127 }
128
129 pub fn with_instructions(self, instructions: impl Into<String>) -> Self {
133 *self.instructions.lock().unwrap() = Some(instructions.into());
134 self
135 }
136
137 pub fn with_timeout(self, timeout: Duration) -> Self {
140 *self.timeout.lock().unwrap() = timeout;
141 self
142 }
143
144 pub fn with_include<I, S>(self, ids: I) -> Self
153 where
154 I: IntoIterator<Item = S>,
155 S: Into<String>,
156 {
157 *self.include.lock().unwrap() = Some(ids.into_iter().map(Into::into).collect());
158 self
159 }
160
161 pub fn with_exclude<I, S>(self, ids: I) -> Self
168 where
169 I: IntoIterator<Item = S>,
170 S: Into<String>,
171 {
172 *self.exclude.lock().unwrap() = ids.into_iter().map(Into::into).collect();
173 self
174 }
175
176 fn is_exposed(&self, command_id: &str) -> bool {
180 if command_id.starts_with('_') {
181 return false;
182 }
183 if self.exclude.lock().unwrap().contains(command_id) {
184 return false;
185 }
186 if let Some(ref allow) = *self.include.lock().unwrap() {
187 if !allow.contains(command_id) {
188 return false;
189 }
190 }
191 true
192 }
193
194 pub async fn serve<T, E, A>(self: Arc<Self>, transport: T) -> Result<(), McpServerError>
219 where
220 T: IntoTransport<RoleServer, E, A>,
221 E: std::error::Error + Send + Sync + 'static,
222 {
223 let handler = McpHandler { channel: self };
224 let service = handler
225 .serve(transport)
226 .await
227 .map_err(|e| McpServerError::Transport(e.to_string()))?;
228 service
229 .waiting()
230 .await
231 .map_err(|e| McpServerError::Protocol(e.to_string()))?;
232 Ok(())
233 }
234
235 pub async fn serve_stdio(self: Arc<Self>) -> Result<(), McpServerError> {
237 self.serve(rmcp::transport::io::stdio()).await
238 }
239
240 pub fn into_handler(self: Arc<Self>) -> impl ServerHandler + Clone {
249 McpHandler { channel: self }
250 }
251
252 fn server_info(&self) -> ServerInfo {
253 let capabilities = ServerCapabilities::builder().enable_tools().build();
254 let implementation = Implementation::new(
255 self.impl_name.lock().unwrap().clone(),
256 self.impl_version.lock().unwrap().clone(),
257 );
258 let mut info = ServerInfo::new(capabilities).with_server_info(implementation);
259 if let Some(ref s) = *self.instructions.lock().unwrap() {
260 info = info.with_instructions(s.clone());
261 }
262 info
263 }
264
265 fn timeout_duration(&self) -> Duration {
266 *self.timeout.lock().unwrap()
267 }
268}
269
270impl CommandChannel for McpServerChannel {
271 fn id(&self) -> &str {
272 &self.id
273 }
274
275 fn start(&self) -> BoxFuture<'_, Result<(), ChannelError>> {
276 Box::pin(async { Ok(()) })
277 }
278
279 fn close(&self) -> BoxFuture<'_, ()> {
280 Box::pin(async move {
281 self.closed.store(true, Ordering::SeqCst);
282 self.tx.close_channel();
284 self.pending_lists.lock().unwrap().clear();
287 self.pending_calls.lock().unwrap().clear();
288 })
289 }
290
291 fn send(&self, msg: Message) -> Result<(), ChannelError> {
298 if self.closed.load(Ordering::SeqCst) {
299 return Err(ChannelError::Closed);
300 }
301 match msg {
302 Message::ListCommandsResponse { thid, commands, .. } => {
303 if let Some(tx) = self.pending_lists.lock().unwrap().remove(&thid) {
304 let _ = tx.send(commands);
305 }
306 }
307 Message::ExecuteCommandResponse { thid, response, .. } => {
308 if let Some(tx) = self.pending_calls.lock().unwrap().remove(&thid) {
309 let _ = tx.send(response);
310 }
311 }
312 _ => {}
313 }
314 Ok(())
315 }
316
317 fn recv(&self) -> BoxFuture<'_, Option<Message>> {
320 Box::pin(async move {
321 let mut guard = self.rx.lock().await;
322 let rx = guard.as_mut()?;
323 rx.next().await
324 })
325 }
326}
327
328#[derive(Clone)]
331struct McpHandler {
332 channel: Arc<McpServerChannel>,
333}
334
335impl McpHandler {
336 async fn round_trip<T, F>(
340 &self,
341 build_request: impl FnOnce(MessageId) -> Message,
342 register_pending: F,
343 ) -> Result<T, McpError>
344 where
345 F: FnOnce(MessageId, oneshot::Sender<T>, &McpServerChannel),
346 {
347 let id = MessageId::new_v4();
348 let (sender, receiver) = oneshot::channel();
349 register_pending(id, sender, &self.channel);
350
351 if let Err(e) = self.channel.tx.unbounded_send(build_request(id)) {
352 return Err(McpError::internal_error(
353 format!("cmd-ipc channel closed: {e}"),
354 None,
355 ));
356 }
357
358 match tokio::time::timeout(self.channel.timeout_duration(), receiver).await {
359 Ok(Ok(value)) => Ok(value),
360 Ok(Err(_)) => Err(McpError::internal_error(
361 "cmd-ipc channel closed before response".to_string(),
362 None,
363 )),
364 Err(_) => Err(McpError::internal_error(
365 "timed out waiting for cmd-ipc response".to_string(),
366 None,
367 )),
368 }
369 }
370}
371
372impl ServerHandler for McpHandler {
373 fn get_info(&self) -> ServerInfo {
374 self.channel.server_info()
375 }
376
377 async fn list_tools(
378 &self,
379 _request: Option<PaginatedRequestParams>,
380 _ctx: RequestContext<RoleServer>,
381 ) -> Result<ListToolsResult, McpError> {
382 let defs = self
383 .round_trip(
384 |id| Message::ListCommandsRequest { id, meta: None },
385 |id, sender, ch| {
386 ch.pending_lists.lock().unwrap().insert(id, sender);
387 },
388 )
389 .await?;
390 let tools = defs
395 .iter()
396 .filter(|d| self.channel.is_exposed(&d.id))
397 .map(command_to_tool)
398 .collect();
399 Ok(ListToolsResult {
400 tools,
401 next_cursor: None,
402 ..Default::default()
403 })
404 }
405
406 async fn call_tool(
407 &self,
408 request: CallToolRequestParams,
409 _ctx: RequestContext<RoleServer>,
410 ) -> Result<CallToolResult, McpError> {
411 let name = request.name.to_string();
412 if !self.channel.is_exposed(&name) {
415 return Err(mcp_error_for_unknown_tool(&name));
416 }
417 let payload = request.arguments.map(Value::Object).unwrap_or(Value::Null);
418 let request_payload = if payload.is_null() {
419 None
420 } else {
421 Some(payload)
422 };
423 let command_id = name.clone();
424
425 let response = self
426 .round_trip(
427 |id| Message::ExecuteCommandRequest {
428 id,
429 meta: None,
430 command_id: command_id.clone(),
431 request: request_payload.clone(),
432 },
433 |id, sender, ch| {
434 ch.pending_calls.lock().unwrap().insert(id, sender);
435 },
436 )
437 .await?;
438
439 match response {
440 ExecuteResult::Ok {
441 result: Some(Value::Null),
442 ..
443 }
444 | ExecuteResult::Ok { result: None, .. } => Ok(success_to_call_result(None)),
445 ExecuteResult::Ok {
446 result: Some(value),
447 ..
448 } => Ok(success_to_call_result(Some(value))),
449 ExecuteResult::Err { error, .. } => {
450 if is_tool_not_found(&error) {
451 Err(mcp_error_for_unknown_tool(&name))
452 } else {
453 Ok(execute_error_to_call_result(error))
454 }
455 }
456 }
457 }
458}