Skip to main content

graphify_serve/mcp/
mod.rs

1//! MCP (Model Context Protocol) server implementation.
2//!
3//! Implements JSON-RPC 2.0 over stdio for AI coding assistant integration.
4//! Protocol spec: <https://modelcontextprotocol.io/>
5
6mod handlers;
7mod tools;
8
9use std::io::{self, BufRead, Write};
10use std::path::Path;
11
12use graphify_core::graph::KnowledgeGraph;
13use serde_json::{Value, json};
14use tracing::{debug, error, info};
15
16use crate::ServeError;
17use crate::search::SearchIndex;
18
19const SERVER_NAME: &str = "graphify-rs";
20const SERVER_VERSION: &str = env!("CARGO_PKG_VERSION");
21const PROTOCOL_VERSION: &str = "2024-11-05";
22
23fn jsonrpc_response(id: &Value, result: Value) -> Value {
24    json!({
25        "jsonrpc": "2.0",
26        "id": id,
27        "result": result
28    })
29}
30
31fn jsonrpc_error(id: &Value, code: i64, message: &str) -> Value {
32    json!({
33        "jsonrpc": "2.0",
34        "id": id,
35        "error": {
36            "code": code,
37            "message": message
38        }
39    })
40}
41
42fn dispatch_tools_call(graph: &KnowledgeGraph, index: &SearchIndex, request: &Value) -> Value {
43    let id = &request["id"];
44    let tool_name = request["params"]["name"].as_str().unwrap_or("");
45    let args = &request["params"]["arguments"];
46
47    debug!("tools/call: {tool_name}");
48
49    let result = match tool_name {
50        "query_graph" => handlers::handle_query_graph(graph, index, args),
51        "get_node" => handlers::handle_get_node(graph, args),
52        "get_neighbors" => handlers::handle_get_neighbors(graph, args),
53        "get_community" => handlers::handle_get_community(graph, args),
54        "god_nodes" => handlers::handle_god_nodes(graph, args),
55        "graph_stats" => handlers::handle_graph_stats(graph),
56        "shortest_path" => handlers::handle_shortest_path(graph, args),
57        "find_all_paths" => handlers::handle_find_all_paths(graph, args),
58        "weighted_path" => handlers::handle_weighted_path(graph, args),
59        "community_bridges" => handlers::handle_community_bridges(graph, args),
60        "graph_diff" => handlers::handle_graph_diff(graph, args),
61        "pagerank" => handlers::handle_pagerank(graph, args),
62        "detect_cycles" => handlers::handle_detect_cycles(graph, args),
63        "smart_summary" => handlers::handle_smart_summary(graph, args),
64        "find_similar" => handlers::handle_find_similar(graph, args),
65        "explore" => handlers::handle_explore(graph, index, args),
66        _ => handlers::tool_result_error(&format!("Unknown tool: {tool_name}")),
67    };
68
69    jsonrpc_response(id, result)
70}
71
72fn dispatch(graph: &KnowledgeGraph, index: &SearchIndex, request: &Value) -> Option<Value> {
73    let method = request["method"].as_str().unwrap_or("");
74    let id = &request["id"];
75
76    match method {
77        "initialize" => {
78            info!("MCP initialize");
79            Some(jsonrpc_response(
80                id,
81                json!({
82                    "protocolVersion": PROTOCOL_VERSION,
83                    "capabilities": {
84                        "tools": {}
85                    },
86                    "serverInfo": {
87                        "name": SERVER_NAME,
88                        "version": SERVER_VERSION
89                    }
90                }),
91            ))
92        }
93        "notifications/initialized" => {
94            debug!("Client initialized");
95            None
96        }
97        "tools/list" => {
98            debug!("tools/list");
99            Some(jsonrpc_response(
100                id,
101                json!({
102                    "tools": tools::tool_definitions()
103                }),
104            ))
105        }
106        "tools/call" => Some(dispatch_tools_call(graph, index, request)),
107        "ping" => Some(jsonrpc_response(id, json!({}))),
108        _ => {
109            if id.is_null() {
110                None // notification, ignore
111            } else {
112                Some(jsonrpc_error(
113                    id,
114                    -32601,
115                    &format!("Method not found: {method}"),
116                ))
117            }
118        }
119    }
120}
121
122/// Start the MCP server, reading JSON-RPC requests from stdin and writing
123/// responses to stdout. Logs go to stderr so they don't interfere with the
124/// protocol.
125pub fn run_mcp_server(graph_path: &Path) -> Result<(), ServeError> {
126    let graph = crate::load_graph(graph_path)?;
127    let search_index = SearchIndex::build(&graph);
128    let stats = crate::graph_stats(&graph);
129    let null = Value::Null;
130    info!(
131        "MCP server started: {} nodes, {} edges",
132        stats.get("node_count").unwrap_or(&null),
133        stats.get("edge_count").unwrap_or(&null),
134    );
135
136    let stdin = io::stdin();
137    let stdout = io::stdout();
138    let mut stdout_lock = stdout.lock();
139
140    for line in stdin.lock().lines() {
141        let line = match line {
142            Ok(l) => l,
143            Err(e) => {
144                error!("stdin read error: {e}");
145                break;
146            }
147        };
148
149        let trimmed = line.trim();
150        if trimmed.is_empty() {
151            continue;
152        }
153
154        let request: Value = match serde_json::from_str(trimmed) {
155            Ok(v) => v,
156            Err(e) => {
157                error!("JSON parse error: {e}");
158                let err = jsonrpc_error(&Value::Null, -32700, &format!("Parse error: {e}"));
159                if let Ok(json) = serde_json::to_string(&err) {
160                    let _ = writeln!(stdout_lock, "{}", json);
161                }
162                let _ = stdout_lock.flush();
163                continue;
164            }
165        };
166
167        if let Some(response) = dispatch(&graph, &search_index, &request) {
168            let out = match serde_json::to_string(&response) {
169                Ok(s) => s,
170                Err(e) => {
171                    error!("response serialization failed: {e}");
172                    continue;
173                }
174            };
175            if let Err(e) = writeln!(stdout_lock, "{}", out) {
176                error!("stdout write error: {e}");
177                break;
178            }
179            let _ = stdout_lock.flush();
180        }
181    }
182
183    info!("MCP server shutting down");
184    Ok(())
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use graphify_core::confidence::Confidence;
191    use graphify_core::model::{GraphEdge, GraphNode, NodeType};
192    use std::collections::HashMap;
193
194    fn make_node(id: &str, label: &str, community: Option<usize>) -> GraphNode {
195        GraphNode {
196            id: id.into(),
197            label: label.into(),
198            source_file: "test.rs".into(),
199            source_location: None,
200            node_type: NodeType::Class,
201            community,
202            extra: HashMap::new(),
203        }
204    }
205
206    fn make_edge(src: &str, tgt: &str) -> GraphEdge {
207        GraphEdge {
208            source: src.into(),
209            target: tgt.into(),
210            relation: "calls".into(),
211            confidence: Confidence::Extracted,
212            confidence_score: 1.0,
213            source_file: "test.rs".into(),
214            source_location: None,
215            weight: 1.0,
216            provenance: None,
217            extra: HashMap::new(),
218        }
219    }
220
221    fn test_graph() -> KnowledgeGraph {
222        let mut g = KnowledgeGraph::new();
223        g.add_node(make_node("auth", "AuthService", Some(0)))
224            .unwrap();
225        g.add_node(make_node("user", "UserManager", Some(0)))
226            .unwrap();
227        g.add_node(make_node("db", "Database", Some(1))).unwrap();
228        g.add_node(make_node("cache", "CacheLayer", Some(1)))
229            .unwrap();
230        g.add_edge(make_edge("auth", "user")).unwrap();
231        g.add_edge(make_edge("auth", "db")).unwrap();
232        g.add_edge(make_edge("user", "db")).unwrap();
233        g.add_edge(make_edge("user", "cache")).unwrap();
234        g
235    }
236
237    fn test_index(g: &KnowledgeGraph) -> SearchIndex {
238        SearchIndex::build(g)
239    }
240
241    #[test]
242    fn test_initialize() {
243        let g = test_graph();
244        let idx = test_index(&g);
245        let req = json!({"jsonrpc": "2.0", "method": "initialize", "id": 1});
246        let resp = dispatch(&g, &idx, &req).unwrap();
247        assert_eq!(resp["id"], 1);
248        assert!(resp["result"]["protocolVersion"].is_string());
249        assert!(resp["result"]["capabilities"]["tools"].is_object());
250        assert_eq!(resp["result"]["serverInfo"]["name"], SERVER_NAME);
251    }
252
253    #[test]
254    fn test_tools_list() {
255        let g = test_graph();
256        let idx = test_index(&g);
257        let req = json!({"jsonrpc": "2.0", "method": "tools/list", "id": 2});
258        let resp = dispatch(&g, &idx, &req).unwrap();
259        let tools = resp["result"]["tools"].as_array().unwrap();
260        assert_eq!(tools.len(), 16);
261
262        let names: Vec<&str> = tools.iter().map(|t| t["name"].as_str().unwrap()).collect();
263        assert!(names.contains(&"query_graph"));
264        assert!(names.contains(&"get_node"));
265        assert!(names.contains(&"get_neighbors"));
266        assert!(names.contains(&"get_community"));
267        assert!(names.contains(&"god_nodes"));
268        assert!(names.contains(&"graph_stats"));
269        assert!(names.contains(&"shortest_path"));
270    }
271
272    #[test]
273    fn test_query_graph() {
274        let g = test_graph();
275        let idx = test_index(&g);
276        let req = json!({
277            "jsonrpc": "2.0", "method": "tools/call", "id": 3,
278            "params": {"name": "query_graph", "arguments": {"question": "auth service"}}
279        });
280        let resp = dispatch(&g, &idx, &req).unwrap();
281        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
282        assert!(text.contains("Knowledge Graph Context"));
283        assert!(text.contains("AuthService"));
284    }
285
286    #[test]
287    fn test_get_node() {
288        let g = test_graph();
289        let idx = test_index(&g);
290        let req = json!({
291            "jsonrpc": "2.0", "method": "tools/call", "id": 4,
292            "params": {"name": "get_node", "arguments": {"node_id": "auth"}}
293        });
294        let resp = dispatch(&g, &idx, &req).unwrap();
295        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
296        assert!(text.contains("AuthService"));
297        assert!(text.contains("\"degree\""));
298    }
299
300    #[test]
301    fn test_get_node_not_found() {
302        let g = test_graph();
303        let idx = test_index(&g);
304        let req = json!({
305            "jsonrpc": "2.0", "method": "tools/call", "id": 5,
306            "params": {"name": "get_node", "arguments": {"node_id": "nonexistent"}}
307        });
308        let resp = dispatch(&g, &idx, &req).unwrap();
309        assert!(resp["result"]["isError"].as_bool().unwrap_or(false));
310    }
311
312    #[test]
313    fn test_get_neighbors() {
314        let g = test_graph();
315        let idx = test_index(&g);
316        let req = json!({
317            "jsonrpc": "2.0", "method": "tools/call", "id": 6,
318            "params": {"name": "get_neighbors", "arguments": {"node_id": "auth", "depth": 1}}
319        });
320        let resp = dispatch(&g, &idx, &req).unwrap();
321        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
322        assert!(text.contains("total_neighbors"));
323    }
324
325    #[test]
326    fn test_get_community() {
327        let g = test_graph();
328        let idx = test_index(&g);
329        let req = json!({
330            "jsonrpc": "2.0", "method": "tools/call", "id": 7,
331            "params": {"name": "get_community", "arguments": {"community_id": 0}}
332        });
333        let resp = dispatch(&g, &idx, &req).unwrap();
334        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
335        assert!(text.contains("AuthService") || text.contains("UserManager"));
336    }
337
338    #[test]
339    fn test_god_nodes() {
340        let g = test_graph();
341        let idx = test_index(&g);
342        let req = json!({
343            "jsonrpc": "2.0", "method": "tools/call", "id": 8,
344            "params": {"name": "god_nodes", "arguments": {"top_n": 3}}
345        });
346        let resp = dispatch(&g, &idx, &req).unwrap();
347        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
348        assert!(text.contains("god_nodes"));
349    }
350
351    #[test]
352    fn test_graph_stats() {
353        let g = test_graph();
354        let idx = test_index(&g);
355        let req = json!({
356            "jsonrpc": "2.0", "method": "tools/call", "id": 9,
357            "params": {"name": "graph_stats", "arguments": {}}
358        });
359        let resp = dispatch(&g, &idx, &req).unwrap();
360        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
361        assert!(text.contains("node_count"));
362        assert!(text.contains("edge_count"));
363    }
364
365    #[test]
366    fn test_shortest_path() {
367        let g = test_graph();
368        let idx = test_index(&g);
369        let req = json!({
370            "jsonrpc": "2.0", "method": "tools/call", "id": 10,
371            "params": {"name": "shortest_path", "arguments": {"source": "auth", "target": "cache"}}
372        });
373        let resp = dispatch(&g, &idx, &req).unwrap();
374        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
375        assert!(text.contains("path_length"));
376        // auth -> user -> cache = length 2
377        let parsed: Value = serde_json::from_str(text).unwrap();
378        assert_eq!(parsed["path_length"], 2);
379    }
380
381    #[test]
382    fn test_shortest_path_no_path() {
383        let mut g = KnowledgeGraph::new();
384        g.add_node(make_node("a", "A", None)).unwrap();
385        g.add_node(make_node("b", "B", None)).unwrap();
386        let idx = test_index(&g);
387        let req = json!({
388            "jsonrpc": "2.0", "method": "tools/call", "id": 11,
389            "params": {"name": "shortest_path", "arguments": {"source": "a", "target": "b"}}
390        });
391        let resp = dispatch(&g, &idx, &req).unwrap();
392        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
393        assert!(text.contains("No path found"));
394    }
395
396    #[test]
397    fn test_shortest_path_same_node() {
398        let g = test_graph();
399        let idx = test_index(&g);
400        let req = json!({
401            "jsonrpc": "2.0", "method": "tools/call", "id": 12,
402            "params": {"name": "shortest_path", "arguments": {"source": "auth", "target": "auth"}}
403        });
404        let resp = dispatch(&g, &idx, &req).unwrap();
405        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
406        let parsed: Value = serde_json::from_str(text).unwrap();
407        assert_eq!(parsed["path_length"], 0);
408    }
409
410    #[test]
411    fn test_unknown_tool() {
412        let g = test_graph();
413        let idx = test_index(&g);
414        let req = json!({
415            "jsonrpc": "2.0", "method": "tools/call", "id": 13,
416            "params": {"name": "nonexistent_tool", "arguments": {}}
417        });
418        let resp = dispatch(&g, &idx, &req).unwrap();
419        assert!(resp["result"]["isError"].as_bool().unwrap_or(false));
420    }
421
422    #[test]
423    fn test_unknown_method() {
424        let g = test_graph();
425        let idx = test_index(&g);
426        let req = json!({"jsonrpc": "2.0", "method": "unknown/method", "id": 14});
427        let resp = dispatch(&g, &idx, &req).unwrap();
428        assert!(resp["error"].is_object());
429        assert_eq!(resp["error"]["code"], -32601);
430    }
431
432    #[test]
433    fn test_notification_no_response() {
434        let g = test_graph();
435        let idx = test_index(&g);
436        let req = json!({"jsonrpc": "2.0", "method": "notifications/initialized"});
437        assert!(dispatch(&g, &idx, &req).is_none());
438    }
439
440    #[test]
441    fn test_ping() {
442        let g = test_graph();
443        let idx = test_index(&g);
444        let req = json!({"jsonrpc": "2.0", "method": "ping", "id": 15});
445        let resp = dispatch(&g, &idx, &req).unwrap();
446        assert_eq!(resp["id"], 15);
447        assert!(resp["result"].is_object());
448    }
449
450    #[test]
451    fn test_find_all_paths() {
452        let g = test_graph();
453        let idx = test_index(&g);
454        let req = json!({
455            "jsonrpc": "2.0", "method": "tools/call", "id": 20,
456            "params": {"name": "find_all_paths", "arguments": {
457                "source": "auth", "target": "db", "max_length": 4
458            }}
459        });
460        let resp = dispatch(&g, &idx, &req).unwrap();
461        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
462        let parsed: serde_json::Value = serde_json::from_str(text).unwrap();
463        assert!(
464            parsed["path_count"].as_u64().unwrap() >= 2,
465            "should find multiple paths"
466        );
467    }
468
469    #[test]
470    fn test_find_all_paths_no_path() {
471        let mut g = KnowledgeGraph::new();
472        g.add_node(make_node("x", "X", None)).unwrap();
473        g.add_node(make_node("y", "Y", None)).unwrap();
474        let idx = test_index(&g);
475        let req = json!({
476            "jsonrpc": "2.0", "method": "tools/call", "id": 21,
477            "params": {"name": "find_all_paths", "arguments": {
478                "source": "x", "target": "y"
479            }}
480        });
481        let resp = dispatch(&g, &idx, &req).unwrap();
482        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
483        let parsed: serde_json::Value = serde_json::from_str(text).unwrap();
484        assert_eq!(parsed["path_count"].as_u64().unwrap(), 0);
485    }
486
487    #[test]
488    fn test_weighted_path() {
489        let g = test_graph();
490        let idx = test_index(&g);
491        let req = json!({
492            "jsonrpc": "2.0", "method": "tools/call", "id": 22,
493            "params": {"name": "weighted_path", "arguments": {
494                "source": "auth", "target": "cache"
495            }}
496        });
497        let resp = dispatch(&g, &idx, &req).unwrap();
498        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
499        let parsed: serde_json::Value = serde_json::from_str(text).unwrap();
500        assert!(parsed["path_length"].as_u64().unwrap() >= 1);
501        assert!(parsed["total_cost"].as_f64().unwrap() > 0.0);
502    }
503
504    #[test]
505    fn test_weighted_path_not_found() {
506        let mut g = KnowledgeGraph::new();
507        g.add_node(make_node("x", "X", None)).unwrap();
508        g.add_node(make_node("y", "Y", None)).unwrap();
509        let idx = test_index(&g);
510        let req = json!({
511            "jsonrpc": "2.0", "method": "tools/call", "id": 23,
512            "params": {"name": "weighted_path", "arguments": {
513                "source": "x", "target": "y"
514            }}
515        });
516        let resp = dispatch(&g, &idx, &req).unwrap();
517        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
518        assert!(text.contains("No path found"));
519    }
520
521    #[test]
522    fn test_community_bridges() {
523        let g = test_graph();
524        let idx = test_index(&g);
525        let req = json!({
526            "jsonrpc": "2.0", "method": "tools/call", "id": 24,
527            "params": {"name": "community_bridges", "arguments": {"top_n": 5}}
528        });
529        let resp = dispatch(&g, &idx, &req).unwrap();
530        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
531        let parsed: serde_json::Value = serde_json::from_str(text).unwrap();
532        assert!(parsed["bridges"].as_array().is_some());
533    }
534
535    #[test]
536    fn test_graph_diff_missing_file() {
537        let g = test_graph();
538        let idx = test_index(&g);
539        let req = json!({
540            "jsonrpc": "2.0", "method": "tools/call", "id": 25,
541            "params": {"name": "graph_diff", "arguments": {
542                "other_graph": "/nonexistent/graph.json"
543            }}
544        });
545        let resp = dispatch(&g, &idx, &req).unwrap();
546        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
547        assert!(text.contains("Failed to load graph"));
548    }
549
550    #[test]
551    fn test_find_all_paths_missing_source() {
552        let g = test_graph();
553        let idx = test_index(&g);
554        let req = json!({
555            "jsonrpc": "2.0", "method": "tools/call", "id": 26,
556            "params": {"name": "find_all_paths", "arguments": {
557                "source": "nonexistent", "target": "db"
558            }}
559        });
560        let resp = dispatch(&g, &idx, &req).unwrap();
561        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
562        assert!(text.contains("not found"));
563    }
564
565    #[test]
566    fn test_weighted_path_with_min_confidence() {
567        let g = test_graph();
568        let idx = test_index(&g);
569        let req = json!({
570            "jsonrpc": "2.0", "method": "tools/call", "id": 27,
571            "params": {"name": "weighted_path", "arguments": {
572                "source": "auth", "target": "db", "min_confidence": 0.5
573            }}
574        });
575        let resp = dispatch(&g, &idx, &req).unwrap();
576        let text = resp["result"]["content"][0]["text"].as_str().unwrap();
577        let parsed: serde_json::Value = serde_json::from_str(text).unwrap();
578        assert!(parsed["path_length"].as_u64().unwrap() >= 1);
579    }
580}