Skip to main content

codegraph/
mcp.rs

1use crate::types::{Node, NodeEdge, SearchOptions};
2use crate::{find_nearest_codegraph_root, CodeGraph};
3use anyhow::{anyhow, Result};
4use serde_json::{json, Value};
5use std::io::{self, BufRead, Write};
6use std::path::PathBuf;
7
8const PROTOCOL_VERSION: &str = "2024-11-05";
9const SERVER_INSTRUCTIONS: &str = "# Codegraph — code intelligence over an indexed knowledge graph\n\nStart with codegraph_status to check index health. Use codegraph_files, codegraph_search, codegraph_context, codegraph_callers/codegraph_callees, codegraph_impact, codegraph_node, and codegraph_explore for read-only exploration. Treat results as navigation context, not correctness proof; final validation still comes from the target repo's tests, type checks, linters, or build commands. Do not initialize or reindex a project unless the user explicitly asks for that workspace-changing action.";
10
11pub struct MCPServer {
12    project_path: Option<PathBuf>,
13}
14
15impl MCPServer {
16    pub fn new(project_path: Option<PathBuf>) -> Self {
17        Self { project_path }
18    }
19
20    pub fn start(&mut self) -> Result<()> {
21        let stdin = io::stdin();
22        for line in stdin.lock().lines() {
23            let line = line?;
24            if line.trim().is_empty() {
25                continue;
26            }
27            let response = match serde_json::from_str::<Value>(&line) {
28                Ok(message) => self.handle_message(message),
29                Err(_) => Some(error_response(
30                    Value::Null,
31                    -32700,
32                    "Parse error: invalid JSON",
33                )),
34            };
35            if let Some(response) = response {
36                println!("{}", serde_json::to_string(&response)?);
37                io::stdout().flush()?;
38            }
39        }
40        Ok(())
41    }
42
43    fn handle_message(&mut self, message: Value) -> Option<Value> {
44        let id = message.get("id").cloned();
45        let method = message
46            .get("method")
47            .and_then(Value::as_str)
48            .unwrap_or_default();
49        match method {
50            "initialize" => {
51                if let Some(path) = project_path_from_initialize(&message) {
52                    self.project_path = Some(path);
53                }
54                id.map(|id| json!({
55                    "jsonrpc": "2.0",
56                    "id": id,
57                    "result": {
58                        "protocolVersion": PROTOCOL_VERSION,
59                        "capabilities": { "tools": {} },
60                        "serverInfo": { "name": "codegraph", "version": env!("CARGO_PKG_VERSION") },
61                        "instructions": SERVER_INSTRUCTIONS,
62                    }
63                }))
64            }
65            "initialized" => None,
66            "tools/list" => id.map(|id| {
67                json!({
68                    "jsonrpc": "2.0",
69                    "id": id,
70                    "result": { "tools": tools() }
71                })
72            }),
73            "tools/call" => {
74                let Some(id) = id else { return None };
75                let params = message.get("params").cloned().unwrap_or_else(|| json!({}));
76                let name = params
77                    .get("name")
78                    .and_then(Value::as_str)
79                    .unwrap_or_default();
80                let args = params
81                    .get("arguments")
82                    .cloned()
83                    .unwrap_or_else(|| json!({}));
84                match self.execute_tool(name, &args) {
85                    Ok(result) => Some(json!({ "jsonrpc": "2.0", "id": id, "result": result })),
86                    Err(err) => Some(error_response(
87                        id,
88                        -32603,
89                        &format!("Tool execution failed: {err}"),
90                    )),
91                }
92            }
93            "ping" => id.map(|id| json!({ "jsonrpc": "2.0", "id": id, "result": {} })),
94            _ => id.map(|id| error_response(id, -32601, &format!("Method not found: {method}"))),
95        }
96    }
97
98    fn execute_tool(&self, name: &str, args: &Value) -> Result<Value> {
99        let cg = self.open_project(args)?;
100        match name {
101            "codegraph_search" => {
102                let query = required_str(args, "query")?;
103                let limit = clamp(
104                    args.get("limit").and_then(Value::as_i64).unwrap_or(10),
105                    1,
106                    100,
107                );
108                let results = cg.search_nodes(
109                    query,
110                    SearchOptions {
111                        limit,
112                        ..Default::default()
113                    },
114                )?;
115                if results.is_empty() {
116                    Ok(text_result(format!("No results found for \"{query}\"")))
117                } else {
118                    let lines = results
119                        .into_iter()
120                        .map(|r| format_node(&r.node))
121                        .collect::<Vec<_>>()
122                        .join("\n");
123                    Ok(text_result(lines))
124                }
125            }
126            "codegraph_context" => {
127                let task = required_str(args, "task")?;
128                let max_nodes = clamp(
129                    args.get("maxNodes").and_then(Value::as_i64).unwrap_or(20),
130                    1,
131                    200,
132                );
133                let include_code = args
134                    .get("includeCode")
135                    .and_then(Value::as_bool)
136                    .unwrap_or(true);
137                Ok(text_result(cg.build_context(
138                    task,
139                    max_nodes,
140                    include_code,
141                )?))
142            }
143            "codegraph_callers" => {
144                let symbol = required_str(args, "symbol")?;
145                let limit = clamp(
146                    args.get("limit").and_then(Value::as_i64).unwrap_or(20),
147                    1,
148                    100,
149                ) as usize;
150                let nodes = find_matching_nodes(&cg, symbol)?;
151                if nodes.is_empty() {
152                    return Ok(text_result(format!(
153                        "Symbol \"{symbol}\" not found in the codebase"
154                    )));
155                }
156                let mut out = Vec::new();
157                for node in nodes {
158                    out.extend(cg.get_callers(&node.id, 1)?);
159                }
160                Ok(text_result(format_node_edges(
161                    &format!("Callers of {symbol}"),
162                    &out,
163                    limit,
164                )))
165            }
166            "codegraph_callees" => {
167                let symbol = required_str(args, "symbol")?;
168                let limit = clamp(
169                    args.get("limit").and_then(Value::as_i64).unwrap_or(20),
170                    1,
171                    100,
172                ) as usize;
173                let nodes = find_matching_nodes(&cg, symbol)?;
174                if nodes.is_empty() {
175                    return Ok(text_result(format!(
176                        "Symbol \"{symbol}\" not found in the codebase"
177                    )));
178                }
179                let mut out = Vec::new();
180                for node in nodes {
181                    out.extend(cg.get_callees(&node.id, 1)?);
182                }
183                Ok(text_result(format_node_edges(
184                    &format!("Callees of {symbol}"),
185                    &out,
186                    limit,
187                )))
188            }
189            "codegraph_impact" => {
190                let symbol = required_str(args, "symbol")?;
191                let depth = clamp(
192                    args.get("depth").and_then(Value::as_i64).unwrap_or(2),
193                    1,
194                    10,
195                ) as usize;
196                let nodes = find_matching_nodes(&cg, symbol)?;
197                if nodes.is_empty() {
198                    return Ok(text_result(format!(
199                        "Symbol \"{symbol}\" not found in the codebase"
200                    )));
201                }
202                let mut lines = vec![format!("## Impact: {symbol}")];
203                for node in nodes {
204                    let impact = cg.get_impact_radius(&node.id, depth)?;
205                    for n in impact.nodes.values() {
206                        lines.push(format!("- {}", format_node(n)));
207                    }
208                }
209                Ok(text_result(lines.join("\n")))
210            }
211            "codegraph_node" => {
212                let symbol = required_str(args, "symbol")?;
213                let include_code = args
214                    .get("includeCode")
215                    .and_then(Value::as_bool)
216                    .unwrap_or(false);
217                let nodes = find_matching_nodes(&cg, symbol)?;
218                let Some(node) = nodes.first() else {
219                    return Ok(text_result(format!(
220                        "Symbol \"{symbol}\" not found in the codebase"
221                    )));
222                };
223                let mut out = format_node(node);
224                if include_code {
225                    if let Ok(code) = cg.read_node_source(node) {
226                        out.push_str("\n\n```");
227                        out.push_str(node.language.as_str());
228                        out.push('\n');
229                        out.push_str(&code);
230                        out.push_str("\n```");
231                    }
232                }
233                Ok(text_result(out))
234            }
235            "codegraph_explore" => {
236                let query = required_str(args, "query")?;
237                let max_files = clamp(
238                    args.get("maxFiles").and_then(Value::as_i64).unwrap_or(12),
239                    1,
240                    20,
241                );
242                let mut text = cg.build_context(query, max_files * 5, true)?;
243                if text.len() > 35_000 {
244                    text.truncate(35_000);
245                    text.push_str("\n\n[truncated]");
246                }
247                Ok(text_result(text))
248            }
249            "codegraph_status" => {
250                let stats = cg.stats()?;
251                Ok(text_result(format!(
252                    "**Files indexed:** {}\n**Nodes:** {}\n**Edges:** {}",
253                    stats.file_count, stats.node_count, stats.edge_count
254                )))
255            }
256            "codegraph_files" => {
257                let path_filter = args.get("path").and_then(Value::as_str).unwrap_or("");
258                let files = cg.get_all_files()?;
259                let lines = files
260                    .into_iter()
261                    .filter(|f| path_filter.is_empty() || f.path.starts_with(path_filter))
262                    .map(|f| format!("{} ({}, {} symbols)", f.path, f.language, f.node_count))
263                    .collect::<Vec<_>>()
264                    .join("\n");
265                Ok(text_result(if lines.is_empty() {
266                    "No files indexed. Run `codegraph index` first.".into()
267                } else {
268                    lines
269                }))
270            }
271            _ => Err(anyhow!("Unknown tool: {name}")),
272        }
273    }
274
275    fn open_project(&self, args: &Value) -> Result<CodeGraph> {
276        if let Some(path) = args.get("projectPath").and_then(Value::as_str) {
277            return CodeGraph::open(path);
278        }
279        let start = self
280            .project_path
281            .clone()
282            .unwrap_or(std::env::current_dir()?);
283        let root = find_nearest_codegraph_root(&start)
284            .ok_or_else(|| anyhow!("CodeGraph not initialized in {}", start.display()))?;
285        CodeGraph::open(root)
286    }
287}
288
289fn tools() -> Value {
290    json!([
291        tool(
292            "codegraph_search",
293            "Quick symbol search by name.",
294            json!({"query": {"type":"string"}, "kind": {"type":"string"}, "limit": {"type":"number"}, "projectPath": {"type":"string"}}),
295            vec!["query"]
296        ),
297        tool(
298            "codegraph_context",
299            "Build comprehensive context for a task.",
300            json!({"task": {"type":"string"}, "maxNodes": {"type":"number"}, "includeCode": {"type":"boolean"}, "projectPath": {"type":"string"}}),
301            vec!["task"]
302        ),
303        tool(
304            "codegraph_callers",
305            "Find all functions/methods that call a specific symbol.",
306            json!({"symbol": {"type":"string"}, "limit": {"type":"number"}, "projectPath": {"type":"string"}}),
307            vec!["symbol"]
308        ),
309        tool(
310            "codegraph_callees",
311            "Find all functions/methods that a specific symbol calls.",
312            json!({"symbol": {"type":"string"}, "limit": {"type":"number"}, "projectPath": {"type":"string"}}),
313            vec!["symbol"]
314        ),
315        tool(
316            "codegraph_impact",
317            "Analyze the impact radius of changing a symbol.",
318            json!({"symbol": {"type":"string"}, "depth": {"type":"number"}, "projectPath": {"type":"string"}}),
319            vec!["symbol"]
320        ),
321        tool(
322            "codegraph_node",
323            "Get detailed information about a specific code symbol.",
324            json!({"symbol": {"type":"string"}, "includeCode": {"type":"boolean"}, "projectPath": {"type":"string"}}),
325            vec!["symbol"]
326        ),
327        tool(
328            "codegraph_explore",
329            "Deep exploration tool for a topic.",
330            json!({"query": {"type":"string"}, "maxFiles": {"type":"number"}, "projectPath": {"type":"string"}}),
331            vec!["query"]
332        ),
333        tool(
334            "codegraph_status",
335            "Get the status of the CodeGraph index.",
336            json!({"projectPath": {"type":"string"}}),
337            vec![]
338        ),
339        tool(
340            "codegraph_files",
341            "Get indexed project files.",
342            json!({"path": {"type":"string"}, "pattern": {"type":"string"}, "format": {"type":"string"}, "includeMetadata": {"type":"boolean"}, "maxDepth": {"type":"number"}, "projectPath": {"type":"string"}}),
343            vec![]
344        ),
345    ])
346}
347
348fn tool(name: &str, description: &str, properties: Value, required: Vec<&str>) -> Value {
349    json!({
350        "name": name,
351        "description": description,
352        "inputSchema": {
353            "type": "object",
354            "properties": properties,
355            "required": required,
356        }
357    })
358}
359
360fn project_path_from_initialize(message: &Value) -> Option<PathBuf> {
361    let params = message.get("params")?;
362    if let Some(uri) = params.get("rootUri").and_then(Value::as_str) {
363        return Some(file_uri_to_path(uri));
364    }
365    params
366        .get("workspaceFolders")
367        .and_then(Value::as_array)
368        .and_then(|folders| folders.first())
369        .and_then(|folder| folder.get("uri"))
370        .and_then(Value::as_str)
371        .map(file_uri_to_path)
372}
373
374fn file_uri_to_path(uri: &str) -> PathBuf {
375    let without_scheme = uri.strip_prefix("file://").unwrap_or(uri);
376    PathBuf::from(percent_decode(without_scheme))
377}
378
379fn percent_decode(input: &str) -> String {
380    let mut out = String::new();
381    let bytes = input.as_bytes();
382    let mut i = 0;
383    while i < bytes.len() {
384        if bytes[i] == b'%' && i + 2 < bytes.len() {
385            if let Ok(hex) = u8::from_str_radix(&input[i + 1..i + 3], 16) {
386                out.push(hex as char);
387                i += 3;
388                continue;
389            }
390        }
391        out.push(bytes[i] as char);
392        i += 1;
393    }
394    out
395}
396
397fn required_str<'a>(args: &'a Value, key: &str) -> Result<&'a str> {
398    args.get(key)
399        .and_then(Value::as_str)
400        .filter(|s| !s.is_empty())
401        .ok_or_else(|| anyhow!("{key} must be a non-empty string"))
402}
403
404fn clamp(value: i64, min: i64, max: i64) -> i64 {
405    value.max(min).min(max)
406}
407
408fn find_matching_nodes(cg: &CodeGraph, symbol: &str) -> Result<Vec<Node>> {
409    Ok(cg
410        .search_nodes(
411            symbol,
412            SearchOptions {
413                limit: 50,
414                ..Default::default()
415            },
416        )?
417        .into_iter()
418        .map(|r| r.node)
419        .collect())
420}
421
422fn format_node(node: &Node) -> String {
423    format!(
424        "{} {} {}:{}",
425        node.kind, node.name, node.file_path, node.start_line
426    )
427}
428
429fn format_node_edges(title: &str, edges: &[NodeEdge], limit: usize) -> String {
430    if edges.is_empty() {
431        return format!("No results found for {title}");
432    }
433    let mut lines = vec![format!("## {title}")];
434    for edge in edges.iter().take(limit) {
435        lines.push(format!("- {}", format_node(&edge.node)));
436    }
437    lines.join("\n")
438}
439
440fn text_result(text: String) -> Value {
441    json!({ "content": [{ "type": "text", "text": text }] })
442}
443
444fn error_response(id: Value, code: i64, message: &str) -> Value {
445    json!({ "jsonrpc": "2.0", "id": id, "error": { "code": code, "message": message } })
446}