1use std::cell::RefCell;
9
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11
12use crate::stdlib::json_to_vm_value;
13use crate::value::{VmClosure, VmError, VmValue};
14use crate::vm::Vm;
15
16thread_local! {
17 static MCP_SERVE_REGISTRY: RefCell<Option<VmValue>> = const { RefCell::new(None) };
20}
21
22pub fn register_mcp_server_builtins(vm: &mut Vm) {
24 vm.register_builtin("mcp_serve", |args, _out| {
25 let registry = args.first().cloned().ok_or_else(|| {
26 VmError::Runtime("mcp_serve: requires a tool_registry argument".into())
27 })?;
28
29 if let VmValue::Dict(d) = ®istry {
31 match d.get("_type") {
32 Some(VmValue::String(t)) if &**t == "tool_registry" => {}
33 _ => {
34 return Err(VmError::Runtime(
35 "mcp_serve: argument must be a tool registry (created with tool_registry())"
36 .into(),
37 ));
38 }
39 }
40 } else {
41 return Err(VmError::Runtime(
42 "mcp_serve: argument must be a tool registry".into(),
43 ));
44 }
45
46 MCP_SERVE_REGISTRY.with(|cell| {
47 *cell.borrow_mut() = Some(registry);
48 });
49
50 Ok(VmValue::Nil)
51 });
52}
53
54pub fn take_mcp_serve_registry() -> Option<VmValue> {
57 MCP_SERVE_REGISTRY.with(|cell| cell.borrow_mut().take())
58}
59
60const PROTOCOL_VERSION: &str = "2024-11-05";
62
63pub struct McpToolDef {
65 pub name: String,
66 pub description: String,
67 pub input_schema: serde_json::Value,
68 pub handler: VmClosure,
69}
70
71pub struct McpServer {
73 server_name: String,
74 server_version: String,
75 tools: Vec<McpToolDef>,
76}
77
78impl McpServer {
79 pub fn new(server_name: String, tools: Vec<McpToolDef>) -> Self {
80 Self {
81 server_name,
82 server_version: env!("CARGO_PKG_VERSION").to_string(),
83 tools,
84 }
85 }
86
87 pub async fn run(&self, vm: &mut Vm) -> Result<(), VmError> {
90 let stdin = BufReader::new(tokio::io::stdin());
91 let mut stdout = tokio::io::stdout();
92 let mut lines = stdin.lines();
93
94 while let Ok(Some(line)) = lines.next_line().await {
95 let trimmed = line.trim();
96 if trimmed.is_empty() {
97 continue;
98 }
99
100 let msg: serde_json::Value = match serde_json::from_str(trimmed) {
101 Ok(v) => v,
102 Err(_) => continue,
103 };
104
105 let method = msg.get("method").and_then(|m| m.as_str()).unwrap_or("");
106 let id = msg.get("id").cloned();
107 let params = msg.get("params").cloned().unwrap_or(serde_json::json!({}));
108
109 if id.is_none() {
111 continue;
113 }
114
115 let id = id.unwrap();
116
117 let response = match method {
118 "initialize" => self.handle_initialize(&id),
119 "ping" => serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": {} }),
120 "tools/list" => self.handle_tools_list(&id),
121 "tools/call" => self.handle_tools_call(&id, ¶ms, vm).await,
122 _ => serde_json::json!({
123 "jsonrpc": "2.0",
124 "id": id,
125 "error": {
126 "code": -32601,
127 "message": format!("Method not found: {method}")
128 }
129 }),
130 };
131
132 let mut response_line = serde_json::to_string(&response)
133 .map_err(|e| VmError::Runtime(format!("MCP server serialization error: {e}")))?;
134 response_line.push('\n');
135 stdout
136 .write_all(response_line.as_bytes())
137 .await
138 .map_err(|e| VmError::Runtime(format!("MCP server write error: {e}")))?;
139 stdout
140 .flush()
141 .await
142 .map_err(|e| VmError::Runtime(format!("MCP server flush error: {e}")))?;
143 }
144
145 Ok(())
146 }
147
148 fn handle_initialize(&self, id: &serde_json::Value) -> serde_json::Value {
149 serde_json::json!({
150 "jsonrpc": "2.0",
151 "id": id,
152 "result": {
153 "protocolVersion": PROTOCOL_VERSION,
154 "capabilities": {
155 "tools": {}
156 },
157 "serverInfo": {
158 "name": self.server_name,
159 "version": self.server_version
160 }
161 }
162 })
163 }
164
165 fn handle_tools_list(&self, id: &serde_json::Value) -> serde_json::Value {
166 let tools: Vec<serde_json::Value> = self
167 .tools
168 .iter()
169 .map(|t| {
170 serde_json::json!({
171 "name": t.name,
172 "description": t.description,
173 "inputSchema": t.input_schema,
174 })
175 })
176 .collect();
177
178 serde_json::json!({
179 "jsonrpc": "2.0",
180 "id": id,
181 "result": { "tools": tools }
182 })
183 }
184
185 async fn handle_tools_call(
186 &self,
187 id: &serde_json::Value,
188 params: &serde_json::Value,
189 vm: &mut Vm,
190 ) -> serde_json::Value {
191 let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
192
193 let tool = match self.tools.iter().find(|t| t.name == tool_name) {
194 Some(t) => t,
195 None => {
196 return serde_json::json!({
197 "jsonrpc": "2.0",
198 "id": id,
199 "error": {
200 "code": -32602,
201 "message": format!("Unknown tool: {tool_name}")
202 }
203 });
204 }
205 };
206
207 let arguments = params
209 .get("arguments")
210 .cloned()
211 .unwrap_or(serde_json::json!({}));
212 let args_vm = json_to_vm_value(&arguments);
213
214 let result = vm.call_closure_pub(&tool.handler, &[args_vm], &[]).await;
216
217 match result {
218 Ok(value) => {
219 let text = value.display();
220 serde_json::json!({
221 "jsonrpc": "2.0",
222 "id": id,
223 "result": {
224 "content": [{ "type": "text", "text": text }],
225 "isError": false
226 }
227 })
228 }
229 Err(e) => {
230 let error_text = format!("{e}");
231 serde_json::json!({
232 "jsonrpc": "2.0",
233 "id": id,
234 "result": {
235 "content": [{ "type": "text", "text": error_text }],
236 "isError": true
237 }
238 })
239 }
240 }
241 }
242}
243
244pub fn tool_registry_to_mcp_tools(registry: &VmValue) -> Result<Vec<McpToolDef>, VmError> {
246 let dict = match registry {
247 VmValue::Dict(d) => d,
248 _ => {
249 return Err(VmError::Runtime(
250 "mcp_serve: argument must be a tool registry".into(),
251 ));
252 }
253 };
254
255 match dict.get("_type") {
257 Some(VmValue::String(t)) if &**t == "tool_registry" => {}
258 _ => {
259 return Err(VmError::Runtime(
260 "mcp_serve: argument must be a tool registry (created with tool_registry())".into(),
261 ));
262 }
263 }
264
265 let tools = match dict.get("tools") {
266 Some(VmValue::List(list)) => list,
267 _ => return Ok(Vec::new()),
268 };
269
270 let mut mcp_tools = Vec::new();
271 for tool in tools.iter() {
272 if let VmValue::Dict(entry) = tool {
273 let name = entry.get("name").map(|v| v.display()).unwrap_or_default();
274 let description = entry
275 .get("description")
276 .map(|v| v.display())
277 .unwrap_or_default();
278
279 let handler = match entry.get("handler") {
280 Some(VmValue::Closure(c)) => (**c).clone(),
281 _ => {
282 return Err(VmError::Runtime(format!(
283 "mcp_serve: tool '{name}' has no handler closure"
284 )));
285 }
286 };
287
288 let input_schema = params_to_json_schema(entry.get("parameters"));
289
290 mcp_tools.push(McpToolDef {
291 name,
292 description,
293 input_schema,
294 handler,
295 });
296 }
297 }
298
299 Ok(mcp_tools)
300}
301
302fn params_to_json_schema(params: Option<&VmValue>) -> serde_json::Value {
314 let params_dict = match params {
315 Some(VmValue::Dict(d)) => d,
316 _ => {
317 return serde_json::json!({ "type": "object", "properties": {} });
318 }
319 };
320
321 let mut properties = serde_json::Map::new();
322 let mut required = Vec::new();
323
324 for (param_name, param_def) in params_dict.iter() {
325 if let VmValue::Dict(def) = param_def {
326 let mut prop = serde_json::Map::new();
327
328 if let Some(VmValue::String(t)) = def.get("type") {
329 prop.insert("type".to_string(), serde_json::Value::String(t.to_string()));
330 }
331 if let Some(VmValue::String(d)) = def.get("description") {
332 prop.insert(
333 "description".to_string(),
334 serde_json::Value::String(d.to_string()),
335 );
336 }
337
338 let is_required = matches!(def.get("required"), Some(VmValue::Bool(true)));
340 if is_required {
341 required.push(serde_json::Value::String(param_name.clone()));
342 }
343
344 properties.insert(param_name.clone(), serde_json::Value::Object(prop));
345 } else if let VmValue::String(type_str) = param_def {
346 let mut prop = serde_json::Map::new();
348 prop.insert(
349 "type".to_string(),
350 serde_json::Value::String(type_str.to_string()),
351 );
352 properties.insert(param_name.clone(), serde_json::Value::Object(prop));
353 }
354 }
355
356 let mut schema = serde_json::Map::new();
357 schema.insert(
358 "type".to_string(),
359 serde_json::Value::String("object".to_string()),
360 );
361 schema.insert(
362 "properties".to_string(),
363 serde_json::Value::Object(properties),
364 );
365 if !required.is_empty() {
366 schema.insert("required".to_string(), serde_json::Value::Array(required));
367 }
368
369 serde_json::Value::Object(schema)
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use std::collections::BTreeMap;
376 use std::rc::Rc;
377
378 #[test]
379 fn test_params_to_json_schema_empty() {
380 let schema = params_to_json_schema(None);
381 assert_eq!(
382 schema,
383 serde_json::json!({ "type": "object", "properties": {} })
384 );
385 }
386
387 #[test]
388 fn test_params_to_json_schema_with_params() {
389 let mut params = BTreeMap::new();
390 let mut param_def = BTreeMap::new();
391 param_def.insert("type".to_string(), VmValue::String(Rc::from("string")));
392 param_def.insert(
393 "description".to_string(),
394 VmValue::String(Rc::from("A file path")),
395 );
396 param_def.insert("required".to_string(), VmValue::Bool(true));
397 params.insert("path".to_string(), VmValue::Dict(Rc::new(param_def)));
398
399 let schema = params_to_json_schema(Some(&VmValue::Dict(Rc::new(params))));
400 let expected = serde_json::json!({
401 "type": "object",
402 "properties": {
403 "path": {
404 "type": "string",
405 "description": "A file path"
406 }
407 },
408 "required": ["path"]
409 });
410 assert_eq!(schema, expected);
411 }
412
413 #[test]
414 fn test_params_to_json_schema_simple_form() {
415 let mut params = BTreeMap::new();
416 params.insert("query".to_string(), VmValue::String(Rc::from("string")));
417
418 let schema = params_to_json_schema(Some(&VmValue::Dict(Rc::new(params))));
419 assert_eq!(
420 schema["properties"]["query"]["type"],
421 serde_json::json!("string")
422 );
423 }
424
425 #[test]
426 fn test_tool_registry_to_mcp_tools_invalid() {
427 let result = tool_registry_to_mcp_tools(&VmValue::Nil);
428 assert!(result.is_err());
429 }
430
431 #[test]
432 fn test_tool_registry_to_mcp_tools_empty() {
433 let mut registry = BTreeMap::new();
434 registry.insert(
435 "_type".to_string(),
436 VmValue::String(Rc::from("tool_registry")),
437 );
438 registry.insert("tools".to_string(), VmValue::List(Rc::new(Vec::new())));
439
440 let result = tool_registry_to_mcp_tools(&VmValue::Dict(Rc::new(registry)));
441 assert!(result.is_ok());
442 assert!(result.unwrap().is_empty());
443 }
444}