Skip to main content

gitcortex_mcp/mcp/
tools.rs

1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use gitcortex_core::{schema::NodeKind, store::GraphStore};
5use gitcortex_store::kuzu::KuzuGraphStore;
6use rmcp::{
7    handler::server::router::tool::ToolRouter,
8    handler::server::wrapper::Parameters,
9    model::{
10        CallToolResult, Content, GetPromptRequestParams, GetPromptResult, ListPromptsResult,
11        PaginatedRequestParams, PromptMessage, PromptMessageRole,
12    },
13    prompt, prompt_handler, prompt_router,
14    service::RequestContext,
15    tool, tool_handler, tool_router, RoleServer,
16};
17use schemars::JsonSchema;
18use serde::Deserialize;
19use serde_json::json;
20
21// ── Parameter types ───────────────────────────────────────────────────────────
22
23#[derive(Debug, Deserialize, JsonSchema)]
24pub struct GcxDispatchParams {
25    /// Which graph operation to run. One of: lookup_symbol, find_callers, find_callees,
26    /// find_unused_symbols, get_subgraph, search_code, start_tour, wiki_symbol,
27    /// trace_path, list_definitions, symbol_context, list_symbols_in_range.
28    pub action: String,
29    /// Parameters for the chosen action as a JSON object (same fields as the
30    /// individual tool: name, function_name, seed_name, query, file, branch,
31    /// depth, limit, direction, src, dst, start_line, end_line).
32    pub params: serde_json::Value,
33}
34
35#[derive(Debug, Deserialize, JsonSchema)]
36pub struct LookupSymbolParams {
37    /// Symbol name to search for (unqualified).
38    pub name: String,
39    /// When true, matches any symbol whose name *contains* `name` (substring).
40    /// When false (default), exact match only.
41    pub fuzzy: Option<bool>,
42    /// Branch name (defaults to "main" if omitted).
43    pub branch: Option<String>,
44}
45
46#[derive(Debug, Deserialize, JsonSchema)]
47pub struct FindCallersParams {
48    /// Name of the function/method to find callers of.
49    pub function_name: String,
50    /// How many hops to walk up the call graph (1–5, default 1).
51    /// depth=1 returns direct callers only. depth=3 walks three levels.
52    pub depth: Option<u8>,
53    pub branch: Option<String>,
54}
55
56#[derive(Debug, Deserialize, JsonSchema)]
57pub struct SymbolContextParams {
58    /// Symbol name to look up (unqualified).
59    pub name: String,
60    /// Branch name (defaults to current branch if omitted).
61    pub branch: Option<String>,
62}
63
64#[derive(Debug, Deserialize, JsonSchema)]
65pub struct ListDefinitionsParams {
66    /// Repo-relative path to a source file.
67    pub file: String,
68    pub branch: Option<String>,
69}
70
71#[derive(Debug, Deserialize, JsonSchema)]
72pub struct BranchDiffParams {
73    pub from_branch: String,
74    pub to_branch: String,
75}
76
77#[derive(Debug, Deserialize, JsonSchema)]
78pub struct DetectChangesParams {
79    /// Branch to query (defaults to "main" if omitted).
80    pub branch: Option<String>,
81}
82
83#[derive(Debug, Deserialize, JsonSchema)]
84pub struct FindCalleesParams {
85    /// Name of the function/method to trace callees of.
86    pub function_name: String,
87    /// How many hops to walk forward in the call graph (1–5, default 1).
88    pub depth: Option<u8>,
89    pub branch: Option<String>,
90}
91
92#[derive(Debug, Deserialize, JsonSchema)]
93pub struct FindImplementorsParams {
94    /// Trait, interface, or abstract class name to find implementors of.
95    pub trait_name: String,
96    pub branch: Option<String>,
97}
98
99#[derive(Debug, Deserialize, JsonSchema)]
100pub struct TracePathParams {
101    /// Starting function/method name.
102    pub from: String,
103    /// Target function/method name.
104    pub to: String,
105    pub branch: Option<String>,
106}
107
108#[derive(Debug, Deserialize, JsonSchema)]
109pub struct ListSymbolsInRangeParams {
110    /// Repo-relative path to a source file.
111    pub file: String,
112    /// Start line of the range (1-indexed, inclusive).
113    pub start_line: u32,
114    /// End line of the range (1-indexed, inclusive).
115    pub end_line: u32,
116    pub branch: Option<String>,
117}
118
119#[derive(Debug, Deserialize, JsonSchema)]
120pub struct FindUnusedSymbolsParams {
121    /// Optional NodeKind filter: "function", "method", "struct", etc.
122    pub kind: Option<String>,
123    /// Max symbols returned (default 30, capped at 200). `count` always reports
124    /// the true total; `truncated` flags when the list was longer.
125    pub limit: Option<usize>,
126    pub branch: Option<String>,
127}
128
129#[derive(Debug, Deserialize, JsonSchema)]
130pub struct GetSubgraphParams {
131    /// Seed symbol name (unqualified).
132    pub seed_name: String,
133    /// How many hops to expand from the seed (1–5, default 1). Depth 2+ on a
134    /// high-degree hub returns a large subgraph — raise deliberately.
135    pub depth: Option<u8>,
136    /// Direction: "in" (callers/ancestors), "out" (callees/descendants), "both" (default).
137    pub direction: Option<String>,
138    /// Max nodes returned (default 30, capped at 200). Edges are filtered to the
139    /// kept node set; `truncated` flags when the neighbourhood was larger.
140    pub limit: Option<usize>,
141    pub branch: Option<String>,
142}
143
144#[derive(Debug, Deserialize, JsonSchema)]
145pub struct WikiSymbolParams {
146    /// Symbol to summarise (unqualified name).
147    pub name: String,
148    pub branch: Option<String>,
149}
150
151#[derive(Debug, Deserialize, JsonSchema)]
152pub struct SearchCodeParams {
153    /// Free-text query — substring matched against `name` and `qualified_name`.
154    pub query: String,
155    /// Max results (default 10, capped at 200).
156    pub limit: Option<usize>,
157    pub branch: Option<String>,
158}
159
160#[derive(Debug, Deserialize, JsonSchema)]
161pub struct StartTourParams {
162    /// Optional seed symbol — when given, the tour walks outward from it
163    /// along the call graph. When omitted, picks the highest-centrality
164    /// entry points across the repo.
165    pub seed: Option<String>,
166    /// How many steps in the tour (default 12, capped at 50).
167    pub limit: Option<usize>,
168    pub branch: Option<String>,
169}
170
171// ── Server ────────────────────────────────────────────────────────────────────
172
173/// The MCP server handler. One shared `KuzuGraphStore` wrapped in `Arc<Mutex>`
174/// so all handler calls can share state safely.
175#[derive(Clone)]
176pub struct GitCortexServer {
177    store: Arc<std::sync::Mutex<KuzuGraphStore>>,
178    repo_root: PathBuf,
179    default_branch: String,
180    compact: bool,
181}
182
183impl GitCortexServer {
184    pub fn new(repo_root: &Path) -> anyhow::Result<Self> {
185        Self::new_with_mode(repo_root, false)
186    }
187
188    pub fn new_with_mode(repo_root: &Path, compact: bool) -> anyhow::Result<Self> {
189        let store = KuzuGraphStore::open(repo_root)?;
190        let default_branch = detect_current_branch(repo_root).unwrap_or_else(|| "main".into());
191        Ok(Self {
192            store: Arc::new(std::sync::Mutex::new(store)),
193            repo_root: repo_root.to_owned(),
194            default_branch,
195            compact,
196        })
197    }
198
199    fn active_tool_router(&self) -> ToolRouter<Self> {
200        let mut router = Self::tool_router();
201        if self.compact {
202            for name in [
203                "lookup_symbol",
204                "find_callers",
205                "symbol_context",
206                "list_definitions",
207                "branch_diff_graph",
208                "detect_changes",
209                "find_callees",
210                "find_implementors",
211                "trace_path",
212                "list_symbols_in_range",
213                "find_unused_symbols",
214                "get_subgraph",
215                "wiki_symbol",
216                "search_code",
217                "start_tour",
218            ] {
219                router.disable_route(name);
220            }
221        }
222        router
223    }
224}
225
226fn detect_current_branch(repo_root: &Path) -> Option<String> {
227    let out = std::process::Command::new("git")
228        .args(["symbolic-ref", "--short", "HEAD"])
229        .current_dir(repo_root)
230        .output()
231        .ok()?;
232    if out.status.success() {
233        let s = String::from_utf8(out.stdout).ok()?;
234        let b = s.trim().to_owned();
235        if b.is_empty() {
236            None
237        } else {
238            Some(b)
239        }
240    } else {
241        None
242    }
243}
244
245// ── Tool implementations ──────────────────────────────────────────────────────
246
247#[tool_router]
248impl GitCortexServer {
249    /// Look up all nodes (functions, structs, traits, etc.) by name.
250    #[tool(
251        description = "Look up nodes in the code knowledge graph by name. Set fuzzy=true for substring matching (e.g. 'auth' finds 'validate_auth', 'auth_middleware'). Default is exact match."
252    )]
253    fn lookup_symbol(&self, Parameters(p): Parameters<LookupSymbolParams>) -> CallToolResult {
254        let branch = p
255            .branch
256            .as_deref()
257            .unwrap_or(&self.default_branch)
258            .to_owned();
259        let fuzzy = p.fuzzy.unwrap_or(false);
260        let store = match self.store.lock() {
261            Ok(g) => g,
262            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
263        };
264        match store.lookup_symbol(&branch, &p.name, fuzzy) {
265            Ok(nodes) => {
266                let items: Vec<_> = nodes
267                    .iter()
268                    .map(|n| {
269                        json!({
270                            "id": n.id.as_str(),
271                            "kind": n.kind.to_string(),
272                            "name": n.name,
273                            "qualified_name": n.qualified_name,
274                            "file": n.file.display().to_string(),
275                            "start_line": n.span.start_line,
276                            "end_line": n.span.end_line,
277                            "visibility": format!("{:?}", n.metadata.visibility),
278                            "is_async": n.metadata.is_async,
279                            "is_unsafe": n.metadata.is_unsafe,
280                        })
281                    })
282                    .collect();
283                CallToolResult::structured(json!(items))
284            }
285            Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
286        }
287    }
288
289    /// Find all callers of a function or method, with optional multi-hop depth.
290    #[tool(
291        description = "Find callers of a function. depth=1 (default) = direct callers; \
292        depth=2..5 = multi-hop. Results capped per hop; total count always returned."
293    )]
294    fn find_callers(&self, Parameters(p): Parameters<FindCallersParams>) -> CallToolResult {
295        let branch = p
296            .branch
297            .as_deref()
298            .unwrap_or(&self.default_branch)
299            .to_owned();
300        let depth = p.depth.unwrap_or(1).max(1);
301        let store = match self.store.lock() {
302            Ok(g) => g,
303            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
304        };
305
306        // Cap the caller list. The risk level is computed from the true total,
307        // so a hub symbol still reports CRITICAL even though we return a head.
308        const MAX_CALLERS: usize = 25;
309        const MAX_PER_HOP: usize = 15;
310        if depth == 1 {
311            match store.find_callers(&branch, &p.function_name) {
312                Ok(nodes) => {
313                    let total = nodes.len();
314                    let items: Vec<_> = nodes
315                        .iter()
316                        .take(MAX_CALLERS)
317                        .map(|n| {
318                            json!({
319                                "hop": 1,
320                                "kind": n.kind.to_string(),
321                                "name": n.name,
322                                "qualified_name": n.qualified_name,
323                                "file": n.file.display().to_string(),
324                                "start_line": n.span.start_line,
325                            })
326                        })
327                        .collect();
328                    let risk = match total {
329                        0..=2 => "LOW",
330                        3..=10 => "MEDIUM",
331                        11..=30 => "HIGH",
332                        _ => "CRITICAL",
333                    };
334                    CallToolResult::structured(json!({
335                        "summary": format!("{total} caller(s) — risk {risk}{}",
336                            if total > items.len() {
337                                format!(", showing top {}", items.len())
338                            } else { String::new() }),
339                        "function": p.function_name,
340                        "depth": 1,
341                        "risk_level": risk,
342                        "total_callers": total,
343                        "returned": items.len(),
344                        "truncated": total > items.len(),
345                        "callers": items,
346                    }))
347                }
348                Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
349            }
350        } else {
351            match store.find_callers_deep(&branch, &p.function_name, depth) {
352                Ok(result) => {
353                    let hops: Vec<_> = result
354                        .hops
355                        .iter()
356                        .enumerate()
357                        .map(|(i, nodes)| {
358                            let total = nodes.len();
359                            let callers: Vec<_> = nodes
360                                .iter()
361                                .take(MAX_PER_HOP)
362                                .map(|n| {
363                                    json!({
364                                        "kind": n.kind.to_string(),
365                                        "name": n.name,
366                                        "qualified_name": n.qualified_name,
367                                        "file": n.file.display().to_string(),
368                                        "start_line": n.span.start_line,
369                                    })
370                                })
371                                .collect();
372                            json!({
373                                "hop": i + 1,
374                                "total": total,
375                                "truncated": total > MAX_PER_HOP,
376                                "callers": callers,
377                            })
378                        })
379                        .collect();
380                    CallToolResult::structured(json!({
381                        "function": p.function_name,
382                        "depth": depth,
383                        "risk_level": result.risk_level,
384                        "hops": hops,
385                    }))
386                }
387                Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
388            }
389        }
390    }
391
392    /// Get a 360° view of a symbol: definition, callers, callees, and type usages.
393    #[tool(
394        description = "Get a complete picture of a symbol in one call: where it's defined, \
395        what calls it (callers), what it calls (callees), and which code references it as a type. \
396        Use this instead of chaining lookup_symbol + find_callers separately."
397    )]
398    fn symbol_context(&self, Parameters(p): Parameters<SymbolContextParams>) -> CallToolResult {
399        let branch = p
400            .branch
401            .as_deref()
402            .unwrap_or(&self.default_branch)
403            .to_owned();
404        let store = match self.store.lock() {
405            Ok(g) => g,
406            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
407        };
408        match store.symbol_context(&branch, &p.name) {
409            Ok(ctx) => {
410                let node_json = |n: &gitcortex_core::graph::Node| {
411                    json!({
412                        "kind": n.kind.to_string(),
413                        "name": n.name,
414                        "qualified_name": n.qualified_name,
415                        "file": n.file.display().to_string(),
416                        "start_line": n.span.start_line,
417                    })
418                };
419                CallToolResult::structured(json!({
420                    "definition": {
421                        "kind": ctx.definition.kind.to_string(),
422                        "name": ctx.definition.name,
423                        "qualified_name": ctx.definition.qualified_name,
424                        "file": ctx.definition.file.display().to_string(),
425                        "start_line": ctx.definition.span.start_line,
426                        "end_line": ctx.definition.span.end_line,
427                        "visibility": format!("{:?}", ctx.definition.metadata.visibility),
428                        "is_async": ctx.definition.metadata.is_async,
429                    },
430                    "callers": ctx.callers.iter().map(node_json).collect::<Vec<_>>(),
431                    "callees": ctx.callees.iter().map(node_json).collect::<Vec<_>>(),
432                    "used_by": ctx.used_by.iter().map(node_json).collect::<Vec<_>>(),
433                }))
434            }
435            Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
436        }
437    }
438
439    /// List all symbols defined in a source file, ordered by line number.
440    #[tool(
441        description = "List all functions, structs, traits, and other definitions in a source file, ordered by line number."
442    )]
443    fn list_definitions(&self, Parameters(p): Parameters<ListDefinitionsParams>) -> CallToolResult {
444        let branch = p
445            .branch
446            .as_deref()
447            .unwrap_or(&self.default_branch)
448            .to_owned();
449        let store = match self.store.lock() {
450            Ok(g) => g,
451            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
452        };
453        match store.list_definitions(&branch, Path::new(&p.file)) {
454            Ok(nodes) => {
455                let items: Vec<_> = nodes
456                    .iter()
457                    .map(|n| {
458                        json!({
459                            "kind": n.kind.to_string(),
460                            "name": n.name,
461                            "qualified_name": n.qualified_name,
462                            "start_line": n.span.start_line,
463                            "end_line": n.span.end_line,
464                            "loc": n.metadata.loc,
465                            "visibility": format!("{:?}", n.metadata.visibility),
466                            "is_async": n.metadata.is_async,
467                        })
468                    })
469                    .collect();
470                CallToolResult::structured(json!(items))
471            }
472            Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
473        }
474    }
475
476    /// Compute the graph diff between two branches.
477    #[tool(
478        description = "Show what nodes were added or removed between two branches. Useful for understanding what changed in a feature branch vs main."
479    )]
480    fn branch_diff_graph(&self, Parameters(p): Parameters<BranchDiffParams>) -> CallToolResult {
481        let store = match self.store.lock() {
482            Ok(g) => g,
483            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
484        };
485        match store.branch_diff(&p.from_branch, &p.to_branch) {
486            Ok(diff) => {
487                let added: Vec<_> = diff
488                    .added_nodes
489                    .iter()
490                    .map(|n| {
491                        json!({
492                            "kind": n.kind.to_string(),
493                            "name": n.name,
494                            "file": n.file.display().to_string(),
495                            "start_line": n.span.start_line,
496                        })
497                    })
498                    .collect();
499
500                // Resolve removed node IDs to full node objects from the from_branch.
501                let from_nodes = store.list_all_nodes(&p.from_branch).unwrap_or_default();
502                let from_map: std::collections::HashMap<_, _> =
503                    from_nodes.iter().map(|n| (n.id.clone(), n)).collect();
504                let removed: Vec<_> = diff
505                    .removed_node_ids
506                    .iter()
507                    .filter_map(|id| from_map.get(id))
508                    .map(|n| {
509                        json!({
510                            "kind": n.kind.to_string(),
511                            "name": n.name,
512                            "file": n.file.display().to_string(),
513                            "start_line": n.span.start_line,
514                        })
515                    })
516                    .collect();
517
518                CallToolResult::structured(json!({
519                    "from": p.from_branch,
520                    "to": p.to_branch,
521                    "added_nodes": added,
522                    "removed_nodes": removed,
523                }))
524            }
525            Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
526        }
527    }
528
529    /// Detect which indexed symbols are affected by current staged (or HEAD) changes.
530    #[tool(
531        description = "Map the current git diff (staged changes, or HEAD diff if nothing is staged) \
532        to the indexed symbol graph. Returns which functions/structs were changed, their direct callers, \
533        and a risk level. Use this before committing to understand blast radius automatically."
534    )]
535    fn detect_changes(&self, Parameters(p): Parameters<DetectChangesParams>) -> CallToolResult {
536        let branch = p
537            .branch
538            .as_deref()
539            .unwrap_or(&self.default_branch)
540            .to_owned();
541
542        let diff_text = run_git_diff(&self.repo_root, &["diff", "--staged"])
543            .filter(|s| !s.trim().is_empty())
544            .or_else(|| run_git_diff(&self.repo_root, &["diff", "HEAD"]))
545            .unwrap_or_default();
546
547        if diff_text.trim().is_empty() {
548            return CallToolResult::success(vec![Content::text(
549                "No staged or unstaged changes detected.",
550            )]);
551        }
552
553        let hunks = parse_diff_hunks(&diff_text);
554        let store = match self.store.lock() {
555            Ok(g) => g,
556            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
557        };
558
559        let mut changed_symbols: Vec<serde_json::Value> = Vec::new();
560        let mut total_affected: usize = 0;
561
562        for (file_path, ranges) in &hunks {
563            let path = PathBuf::from(file_path);
564            let definitions = match store.list_definitions(&branch, &path) {
565                Ok(d) => d,
566                Err(_) => continue,
567            };
568            for node in &definitions {
569                let overlaps = ranges
570                    .iter()
571                    .any(|(s, e)| node.span.start_line <= *e && node.span.end_line >= *s);
572                if !overlaps {
573                    continue;
574                }
575                let callers = store.find_callers(&branch, &node.name).unwrap_or_default();
576                let caller_names: Vec<&str> = callers.iter().map(|c| c.name.as_str()).collect();
577                total_affected += 1 + caller_names.len();
578                changed_symbols.push(json!({
579                    "kind": node.kind.to_string(),
580                    "name": node.name,
581                    "file": file_path,
582                    "start_line": node.span.start_line,
583                    "end_line": node.span.end_line,
584                    "callers": caller_names,
585                }));
586            }
587        }
588
589        if changed_symbols.is_empty() {
590            return CallToolResult::success(vec![Content::text(
591                "Changed lines do not overlap with any indexed symbols.",
592            )]);
593        }
594
595        let risk_level = match total_affected {
596            0..=5 => "LOW",
597            6..=20 => "MEDIUM",
598            21..=50 => "HIGH",
599            _ => "CRITICAL",
600        };
601
602        CallToolResult::structured(json!({
603            "risk_level": risk_level,
604            "total_affected": total_affected,
605            "changed_symbols": changed_symbols,
606        }))
607    }
608
609    /// Find all callees of a function/method, tracing forward through the call graph.
610    #[tool(
611        description = "Find all functions/methods that the named function calls. \
612        Inverse of find_callers — traces forward (downstream). Use depth=1..5 to walk multiple hops. \
613        Returns callees grouped by hop distance."
614    )]
615    fn find_callees(&self, Parameters(p): Parameters<FindCalleesParams>) -> CallToolResult {
616        let branch = p
617            .branch
618            .as_deref()
619            .unwrap_or(&self.default_branch)
620            .to_owned();
621        let depth = p.depth.unwrap_or(1).max(1);
622        let store = match self.store.lock() {
623            Ok(g) => g,
624            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
625        };
626        match store.find_callees(&branch, &p.function_name, depth) {
627            Ok(result) => {
628                let hops: Vec<_> = result
629                    .hops
630                    .iter()
631                    .enumerate()
632                    .map(|(i, nodes)| {
633                        let callees: Vec<_> = nodes
634                            .iter()
635                            .map(|n| {
636                                json!({
637                                    "kind": n.kind.to_string(),
638                                    "name": n.name,
639                                    "qualified_name": n.qualified_name,
640                                    "file": n.file.display().to_string(),
641                                    "start_line": n.span.start_line,
642                                })
643                            })
644                            .collect();
645                        json!({ "hop": i + 1, "callees": callees })
646                    })
647                    .collect();
648                CallToolResult::structured(json!({
649                    "function": p.function_name,
650                    "depth": depth,
651                    "hops": hops,
652                }))
653            }
654            Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
655        }
656    }
657
658    /// Find all structs/classes that implement a trait or interface.
659    #[tool(
660        description = "Find all concrete types (structs, classes) that implement or inherit the named \
661        trait or interface. Works for Rust traits, Java/TypeScript interfaces, and Go structural types."
662    )]
663    fn find_implementors(
664        &self,
665        Parameters(p): Parameters<FindImplementorsParams>,
666    ) -> CallToolResult {
667        let branch = p
668            .branch
669            .as_deref()
670            .unwrap_or(&self.default_branch)
671            .to_owned();
672        let store = match self.store.lock() {
673            Ok(g) => g,
674            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
675        };
676        match store.find_implementors(&branch, &p.trait_name) {
677            Ok(nodes) => {
678                let items: Vec<_> = nodes
679                    .iter()
680                    .map(|n| {
681                        json!({
682                            "kind": n.kind.to_string(),
683                            "name": n.name,
684                            "qualified_name": n.qualified_name,
685                            "file": n.file.display().to_string(),
686                            "start_line": n.span.start_line,
687                        })
688                    })
689                    .collect();
690                CallToolResult::structured(json!({
691                    "trait": p.trait_name,
692                    "implementors": items,
693                }))
694            }
695            Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
696        }
697    }
698
699    /// Find a call path between two symbols in the codebase.
700    #[tool(
701        description = "Find a call path from one function to another. Returns the shortest chain of \
702        calls connecting `from` to `to`. Returns an empty array if no path exists within 6 hops. \
703        Most useful for debugging 'how can A reach B?' questions."
704    )]
705    fn trace_path(&self, Parameters(p): Parameters<TracePathParams>) -> CallToolResult {
706        let branch = p
707            .branch
708            .as_deref()
709            .unwrap_or(&self.default_branch)
710            .to_owned();
711        let store = match self.store.lock() {
712            Ok(g) => g,
713            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
714        };
715        match store.trace_path(&branch, &p.from, &p.to) {
716            Ok(path) => {
717                let nodes: Vec<_> = path
718                    .iter()
719                    .map(|n| {
720                        json!({
721                            "kind": n.kind.to_string(),
722                            "name": n.name,
723                            "file": n.file.display().to_string(),
724                            "start_line": n.span.start_line,
725                        })
726                    })
727                    .collect();
728                CallToolResult::structured(json!({
729                    "from": p.from,
730                    "to": p.to,
731                    "found": !path.is_empty(),
732                    "path": nodes,
733                }))
734            }
735            Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
736        }
737    }
738
739    /// Find all indexed symbols that overlap a line range in a file.
740    #[tool(
741        description = "List all symbols (functions, structs, etc.) in a source file whose span \
742        overlaps the given line range. Use this to map a stack trace, diff hunk, or grep result \
743        to the symbols responsible."
744    )]
745    fn list_symbols_in_range(
746        &self,
747        Parameters(p): Parameters<ListSymbolsInRangeParams>,
748    ) -> CallToolResult {
749        let branch = p
750            .branch
751            .as_deref()
752            .unwrap_or(&self.default_branch)
753            .to_owned();
754        let store = match self.store.lock() {
755            Ok(g) => g,
756            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
757        };
758        let path = Path::new(&p.file);
759        match store.list_symbols_in_range(&branch, path, p.start_line, p.end_line) {
760            Ok(nodes) => {
761                let items: Vec<_> = nodes
762                    .iter()
763                    .map(|n| {
764                        json!({
765                            "kind": n.kind.to_string(),
766                            "name": n.name,
767                            "qualified_name": n.qualified_name,
768                            "start_line": n.span.start_line,
769                            "end_line": n.span.end_line,
770                            "loc": n.metadata.loc,
771                        })
772                    })
773                    .collect();
774                CallToolResult::structured(json!({
775                    "file": p.file,
776                    "range": { "start": p.start_line, "end": p.end_line },
777                    "symbols": items,
778                }))
779            }
780            Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
781        }
782    }
783
784    /// Find symbols with no callers or type references — potential dead code.
785    #[tool(
786        description = "Find symbols that are never called or used as a type anywhere in the indexed \
787        codebase. Useful for identifying dead code, safe-to-rename candidates, or refactoring targets. \
788        Pass kind='function' to restrict to functions only."
789    )]
790    fn find_unused_symbols(
791        &self,
792        Parameters(p): Parameters<FindUnusedSymbolsParams>,
793    ) -> CallToolResult {
794        let branch = p
795            .branch
796            .as_deref()
797            .unwrap_or(&self.default_branch)
798            .to_owned();
799        let kind = p.kind.as_deref().and_then(|k| match k {
800            "function" => Some(NodeKind::Function),
801            "method" => Some(NodeKind::Method),
802            "struct" => Some(NodeKind::Struct),
803            "trait" => Some(NodeKind::Trait),
804            "interface" => Some(NodeKind::Interface),
805            "enum" => Some(NodeKind::Enum),
806            "constant" => Some(NodeKind::Constant),
807            _ => None,
808        });
809        let store = match self.store.lock() {
810            Ok(g) => g,
811            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
812        };
813        let limit = p.limit.unwrap_or(30).min(200);
814        match store.find_unused_symbols(&branch, kind) {
815            Ok(nodes) => {
816                // Return a ranked head, not the whole list. An agent acts on the
817                // first handful; dumping every unused symbol costs more tokens
818                // than a grep the model would have run instead.
819                let items: Vec<_> = nodes
820                    .iter()
821                    .take(limit)
822                    .map(|n| {
823                        json!({
824                            "kind": n.kind.to_string(),
825                            "name": n.name,
826                            "qualified_name": n.qualified_name,
827                            "file": n.file.display().to_string(),
828                            "start_line": n.span.start_line,
829                            "visibility": format!("{:?}", n.metadata.visibility),
830                        })
831                    })
832                    .collect();
833                CallToolResult::structured(json!({
834                    "branch": branch,
835                    "unused_symbols": items,
836                    "count": nodes.len(),
837                    "returned": items.len(),
838                    "truncated": nodes.len() > items.len(),
839                }))
840            }
841            Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
842        }
843    }
844
845    /// Return a neighbourhood subgraph around a seed symbol.
846    #[tool(
847        description = "Return the subgraph centred on a seed symbol — nodes and edges reachable \
848        within `depth` hops (default 1; raise for wider context). Direction='out' downstream, \
849        'in' upstream, 'both' (default). Capped at `limit` nodes (default 30) with a `truncated` \
850        flag — prefer find_callers/find_callees for a targeted answer over a wide neighbourhood dump."
851    )]
852    fn get_subgraph(&self, Parameters(p): Parameters<GetSubgraphParams>) -> CallToolResult {
853        let branch = p
854            .branch
855            .as_deref()
856            .unwrap_or(&self.default_branch)
857            .to_owned();
858        let depth = p.depth.unwrap_or(1).clamp(1, 5);
859        let max_nodes = p.limit.unwrap_or(30).min(200);
860        let direction = p.direction.as_deref().unwrap_or("both").to_owned();
861        let store = match self.store.lock() {
862            Ok(g) => g,
863            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
864        };
865        match store.get_subgraph(&branch, &p.seed_name, depth, &direction) {
866            Ok(sg) => {
867                // Cap the node set, then keep only edges whose endpoints both
868                // survive — a full neighbourhood dump on a hub symbol otherwise
869                // costs more tokens than reading the file it describes.
870                let kept: Vec<_> = sg.nodes.iter().take(max_nodes).collect();
871                let kept_ids: std::collections::HashSet<String> =
872                    kept.iter().map(|n| n.id.as_str()).collect();
873                let nodes: Vec<_> = kept
874                    .iter()
875                    .map(|n| {
876                        json!({
877                            "id": n.id.as_str(),
878                            "kind": n.kind.to_string(),
879                            "name": n.name,
880                            "file": n.file.display().to_string(),
881                            "start_line": n.span.start_line,
882                        })
883                    })
884                    .collect();
885                let edges: Vec<_> = sg
886                    .edges
887                    .iter()
888                    .filter(|e| {
889                        kept_ids.contains(&e.src.as_str()) && kept_ids.contains(&e.dst.as_str())
890                    })
891                    .map(|e| {
892                        json!({
893                            "src": e.src.as_str(),
894                            "dst": e.dst.as_str(),
895                            "kind": e.kind.to_string(),
896                        })
897                    })
898                    .collect();
899                CallToolResult::structured(json!({
900                    "seed": p.seed_name,
901                    "depth": depth,
902                    "direction": direction,
903                    "node_count": sg.nodes.len(),
904                    "edge_count": sg.edges.len(),
905                    "returned_nodes": nodes.len(),
906                    "returned_edges": edges.len(),
907                    "truncated": sg.nodes.len() > nodes.len(),
908                    "nodes": nodes,
909                    "edges": edges,
910                }))
911            }
912            Err(e) => CallToolResult::error(vec![Content::text(format!("query failed: {e}"))]),
913        }
914    }
915
916    /// Render a wiki-style markdown summary for a symbol.
917    #[tool(
918        description = "Markdown wiki for a symbol: signature, doc-comment, top callers/callees. \
919        Use for deep explanation; use lookup_symbol for a quick definition."
920    )]
921    fn wiki_symbol(&self, Parameters(p): Parameters<WikiSymbolParams>) -> CallToolResult {
922        let branch = p
923            .branch
924            .as_deref()
925            .unwrap_or(&self.default_branch)
926            .to_owned();
927        let store = match self.store.lock() {
928            Ok(g) => g,
929            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
930        };
931        match super::wiki::render_symbol(&*store, &branch, &p.name) {
932            Ok(markdown) => CallToolResult::structured(json!({
933                "symbol": p.name,
934                "branch": branch,
935                "markdown": markdown,
936            })),
937            Err(e) => CallToolResult::error(vec![Content::text(format!("wiki failed: {e}"))]),
938        }
939    }
940
941    /// Search the graph by name + qualified-name with deterministic ranking.
942    #[tool(
943        description = "Search the code graph by name. Ranks exact > prefix > substring; \
944        functions/structs boosted. Use before grep for symbol discovery. Default limit=10."
945    )]
946    fn search_code(&self, Parameters(p): Parameters<SearchCodeParams>) -> CallToolResult {
947        let branch = p
948            .branch
949            .as_deref()
950            .unwrap_or(&self.default_branch)
951            .to_owned();
952        let store = match self.store.lock() {
953            Ok(g) => g,
954            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
955        };
956        match super::search::search(&*store, &branch, &p.query, p.limit) {
957            Ok(hits) => CallToolResult::structured(json!({
958                "query": p.query,
959                "branch": branch,
960                "count": hits.len(),
961                "hits": hits,
962            })),
963            Err(e) => CallToolResult::error(vec![Content::text(format!("search failed: {e}"))]),
964        }
965    }
966
967    /// Generate a guided tour through the repo's important symbols.
968    #[tool(
969        description = "Generate a guided tour through the codebase. Without a seed, picks the \
970        highest-centrality public functions/structs to give a new contributor an entry path. \
971        With a seed, BFS-walks outward from it along call edges. Returns ordered tour steps \
972        with rationale per step and a rendered markdown plan."
973    )]
974    fn start_tour(&self, Parameters(p): Parameters<StartTourParams>) -> CallToolResult {
975        let branch = p
976            .branch
977            .as_deref()
978            .unwrap_or(&self.default_branch)
979            .to_owned();
980        let store = match self.store.lock() {
981            Ok(g) => g,
982            Err(_) => return CallToolResult::error(vec![Content::text("store mutex poisoned")]),
983        };
984        match super::tour::generate(&*store, &branch, p.seed.as_deref(), p.limit) {
985            Ok(tour) => {
986                let markdown = super::tour::render_markdown(&tour);
987                CallToolResult::structured(json!({
988                    "branch": tour.branch,
989                    "seed": tour.seed,
990                    "steps": tour.steps,
991                    "markdown": markdown,
992                }))
993            }
994            Err(e) => CallToolResult::error(vec![Content::text(format!("tour failed: {e}"))]),
995        }
996    }
997
998    /// Single-entry dispatch — one schema instead of fifteen.
999    ///
1000    /// Prefer this tool to keep per-turn schema overhead low. All individual
1001    /// tools remain available for direct use; this is an additive alias.
1002    #[tool(description = "Query the GitCortex code knowledge graph. \
1003        action: lookup_symbol | find_callers | find_callees | find_unused_symbols | \
1004        get_subgraph | search_code | start_tour | wiki_symbol | trace_path | \
1005        list_definitions | symbol_context | list_symbols_in_range | branch_diff_graph. \
1006        params: JSON object with the same fields as the individual tool (name/function_name/\
1007        seed_name/query/file/branch/depth/limit/direction as applicable). \
1008        Returns identical output to the individual tool.")]
1009    fn gcx(&self, Parameters(p): Parameters<GcxDispatchParams>) -> CallToolResult {
1010        let branch_val = p
1011            .params
1012            .get("branch")
1013            .and_then(|v| v.as_str())
1014            .map(|s| s.to_owned());
1015
1016        // Helper: extract a string field from params.
1017        macro_rules! str_field {
1018            ($key:expr) => {
1019                match p.params.get($key).and_then(|v| v.as_str()) {
1020                    Some(s) => s.to_owned(),
1021                    None => {
1022                        return CallToolResult::error(vec![Content::text(format!(
1023                            "gcx dispatch: params.{} is required for action={}",
1024                            $key, p.action
1025                        ))])
1026                    }
1027                }
1028            };
1029        }
1030
1031        match p.action.as_str() {
1032            "lookup_symbol" => self.lookup_symbol(Parameters(LookupSymbolParams {
1033                name: str_field!("name"),
1034                fuzzy: p.params.get("fuzzy").and_then(|v| v.as_bool()),
1035                branch: branch_val,
1036            })),
1037            "find_callers" => self.find_callers(Parameters(FindCallersParams {
1038                function_name: str_field!("function_name"),
1039                depth: p
1040                    .params
1041                    .get("depth")
1042                    .and_then(|v| v.as_u64())
1043                    .map(|n| n as u8),
1044                branch: branch_val,
1045            })),
1046            "find_callees" => self.find_callees(Parameters(FindCalleesParams {
1047                function_name: str_field!("function_name"),
1048                depth: p
1049                    .params
1050                    .get("depth")
1051                    .and_then(|v| v.as_u64())
1052                    .map(|n| n as u8),
1053                branch: branch_val,
1054            })),
1055            "find_unused_symbols" => {
1056                self.find_unused_symbols(Parameters(FindUnusedSymbolsParams {
1057                    kind: p
1058                        .params
1059                        .get("kind")
1060                        .and_then(|v| v.as_str())
1061                        .map(|s| s.to_owned()),
1062                    limit: p
1063                        .params
1064                        .get("limit")
1065                        .and_then(|v| v.as_u64())
1066                        .map(|n| n as usize),
1067                    branch: branch_val,
1068                }))
1069            }
1070            "get_subgraph" => self.get_subgraph(Parameters(GetSubgraphParams {
1071                seed_name: str_field!("seed_name"),
1072                depth: p
1073                    .params
1074                    .get("depth")
1075                    .and_then(|v| v.as_u64())
1076                    .map(|n| n as u8),
1077                direction: p
1078                    .params
1079                    .get("direction")
1080                    .and_then(|v| v.as_str())
1081                    .map(|s| s.to_owned()),
1082                limit: p
1083                    .params
1084                    .get("limit")
1085                    .and_then(|v| v.as_u64())
1086                    .map(|n| n as usize),
1087                branch: branch_val,
1088            })),
1089            "search_code" => self.search_code(Parameters(SearchCodeParams {
1090                query: str_field!("query"),
1091                limit: p
1092                    .params
1093                    .get("limit")
1094                    .and_then(|v| v.as_u64())
1095                    .map(|n| n as usize),
1096                branch: branch_val,
1097            })),
1098            "start_tour" => self.start_tour(Parameters(StartTourParams {
1099                seed: p
1100                    .params
1101                    .get("seed")
1102                    .and_then(|v| v.as_str())
1103                    .map(|s| s.to_owned()),
1104                limit: p
1105                    .params
1106                    .get("limit")
1107                    .and_then(|v| v.as_u64())
1108                    .map(|n| n as usize),
1109                branch: branch_val,
1110            })),
1111            "wiki_symbol" => self.wiki_symbol(Parameters(WikiSymbolParams {
1112                name: str_field!("name"),
1113                branch: branch_val,
1114            })),
1115            "trace_path" => self.trace_path(Parameters(TracePathParams {
1116                from: p
1117                    .params
1118                    .get("from")
1119                    .or_else(|| p.params.get("src"))
1120                    .and_then(|v| v.as_str())
1121                    .map(|s| s.to_owned())
1122                    .unwrap_or_default(),
1123                to: p
1124                    .params
1125                    .get("to")
1126                    .or_else(|| p.params.get("dst"))
1127                    .and_then(|v| v.as_str())
1128                    .map(|s| s.to_owned())
1129                    .unwrap_or_default(),
1130                branch: branch_val,
1131            })),
1132            "list_definitions" => self.list_definitions(Parameters(ListDefinitionsParams {
1133                file: str_field!("file"),
1134                branch: branch_val,
1135            })),
1136            "symbol_context" => self.symbol_context(Parameters(SymbolContextParams {
1137                name: str_field!("name"),
1138                branch: branch_val,
1139            })),
1140            "list_symbols_in_range" => {
1141                self.list_symbols_in_range(Parameters(ListSymbolsInRangeParams {
1142                    file: str_field!("file"),
1143                    start_line: p
1144                        .params
1145                        .get("start_line")
1146                        .and_then(|v| v.as_u64())
1147                        .unwrap_or(1) as u32,
1148                    end_line: p
1149                        .params
1150                        .get("end_line")
1151                        .and_then(|v| v.as_u64())
1152                        .unwrap_or(u32::MAX as u64) as u32,
1153                    branch: branch_val,
1154                }))
1155            }
1156            other => CallToolResult::error(vec![Content::text(format!(
1157                "gcx dispatch: unknown action '{other}'. Valid: lookup_symbol, find_callers, \
1158                find_callees, find_unused_symbols, get_subgraph, search_code, start_tour, \
1159                wiki_symbol, trace_path, list_definitions, symbol_context, list_symbols_in_range"
1160            ))]),
1161        }
1162    }
1163}
1164
1165// ── Prompt parameter types ────────────────────────────────────────────────────
1166
1167#[derive(Debug, Deserialize, JsonSchema)]
1168pub struct DetectImpactParams {
1169    /// Comma-separated list of changed file paths (repo-relative).
1170    pub changed_files: String,
1171    /// Branch to query (defaults to "main").
1172    pub branch: Option<String>,
1173}
1174
1175#[derive(Debug, Deserialize, JsonSchema)]
1176pub struct GenerateMapParams {
1177    /// Branch to document (defaults to "main").
1178    pub branch: Option<String>,
1179}
1180
1181// ── Prompt implementations ────────────────────────────────────────────────────
1182
1183#[prompt_router]
1184impl GitCortexServer {
1185    /// Analyse the blast radius of changed files before committing.
1186    /// Walks the call graph from changed symbols to find all downstream callers
1187    /// and produces a risk assessment (LOW / MEDIUM / HIGH / CRITICAL).
1188    #[prompt(
1189        name = "detect_impact",
1190        description = "Pre-commit impact analysis — maps changed files to affected callers and scores risk"
1191    )]
1192    fn detect_impact(&self, Parameters(p): Parameters<DetectImpactParams>) -> GetPromptResult {
1193        let branch = p.branch.as_deref().unwrap_or("main");
1194        let files = p.changed_files.trim().to_owned();
1195
1196        let user_msg = format!(
1197            r#"I am about to commit changes to these files on branch `{branch}`:
1198
1199{files}
1200
1201Please analyse the blast radius of these changes using the GitCortex knowledge graph:
1202
12031. For each changed file call `list_definitions` to identify which symbols were likely touched.
12042. For each key function or struct, call `find_callers` to find direct callers.
12053. Repeat `find_callers` one level deeper for any HIGH-traffic callers.
12064. Summarise your findings as:
1207   - **Changed symbols**: list each modified function/struct with its file and line.
1208   - **Direct callers**: who calls the changed code.
1209   - **Transitive callers**: notable callers two hops away.
1210   - **Risk level**: LOW / MEDIUM / HIGH / CRITICAL with a one-line justification.
1211   - **Recommended actions**: tests to run, reviewers to notify, docs to update.
1212"#
1213        );
1214
1215        GetPromptResult::new(vec![PromptMessage::new_text(
1216            PromptMessageRole::User,
1217            user_msg,
1218        )])
1219        .with_description("Impact analysis of staged changes using the call graph")
1220    }
1221
1222    /// Generate a Mermaid architecture diagram from the knowledge graph.
1223    /// Summarises modules, key structs/traits, and their relationships.
1224    #[prompt(
1225        name = "generate_map",
1226        description = "Architecture documentation — produces a Mermaid diagram of modules, types, and key relationships"
1227    )]
1228    fn generate_map(&self, Parameters(p): Parameters<GenerateMapParams>) -> GetPromptResult {
1229        let branch = p.branch.as_deref().unwrap_or("main");
1230
1231        let user_msg = format!(
1232            r#"Generate an architecture map of this codebase on branch `{branch}` using GitCortex.
1233
1234Steps:
12351. Call `list_definitions` on each major source file to collect modules, structs, traits, and functions.
12362. Call `find_callers` on the top-level entry points to understand key execution flows.
12373. Call `lookup_symbol` on core traits to find all their implementors.
1238
1239Then produce:
1240
1241## Architecture Overview
1242A prose summary (3–5 sentences) of what this codebase does and how it is structured.
1243
1244## Module Map
1245```mermaid
1246graph TD
1247  %% Add nodes for each module/crate and edges for depends-on relationships
1248```
1249
1250## Key Types
1251A table: | Type | Kind | Responsibility | Implemented by |
1252
1253## Core Flows
1254Numbered list of the 2–4 most important execution paths (entry point → key functions → output).
1255
1256## Dependency Notes
1257Any circular dependencies, large fan-outs, or architectural concerns visible in the graph.
1258"#
1259        );
1260
1261        GetPromptResult::new(vec![PromptMessage::new_text(
1262            PromptMessageRole::User,
1263            user_msg,
1264        )])
1265        .with_description(
1266            "Architecture documentation with Mermaid diagram from the knowledge graph",
1267        )
1268    }
1269}
1270
1271// ── Combined ServerHandler (tools + prompts) ──────────────────────────────────
1272
1273#[tool_handler(router = self.active_tool_router())]
1274#[prompt_handler(router = Self::prompt_router())]
1275impl rmcp::ServerHandler for GitCortexServer {
1276    fn get_tool(&self, name: &str) -> Option<rmcp::model::Tool> {
1277        self.active_tool_router().get(name).cloned()
1278    }
1279}
1280
1281// ── Git diff helpers ──────────────────────────────────────────────────────────
1282
1283fn run_git_diff(repo_root: &Path, args: &[&str]) -> Option<String> {
1284    let out = std::process::Command::new("git")
1285        .args(args)
1286        .current_dir(repo_root)
1287        .output()
1288        .ok()?;
1289    if out.status.success() {
1290        String::from_utf8(out.stdout).ok()
1291    } else {
1292        None
1293    }
1294}
1295
1296/// Parse unified diff text into `(repo_relative_file_path, [(start_line, end_line)])`.
1297fn parse_diff_hunks(diff: &str) -> Vec<(String, Vec<(u32, u32)>)> {
1298    let mut result: Vec<(String, Vec<(u32, u32)>)> = Vec::new();
1299    let mut cur_file: Option<String> = None;
1300    let mut cur_hunks: Vec<(u32, u32)> = Vec::new();
1301
1302    for line in diff.lines() {
1303        if let Some(path) = line.strip_prefix("+++ b/") {
1304            if let Some(f) = cur_file.take() {
1305                if !cur_hunks.is_empty() {
1306                    result.push((f, std::mem::take(&mut cur_hunks)));
1307                }
1308            }
1309            cur_file = Some(path.to_owned());
1310        } else if line.starts_with("@@ ") {
1311            if let Some(hunk) = parse_hunk_header(line) {
1312                cur_hunks.push(hunk);
1313            }
1314        }
1315    }
1316    if let Some(f) = cur_file {
1317        if !cur_hunks.is_empty() {
1318            result.push((f, cur_hunks));
1319        }
1320    }
1321    result
1322}
1323
1324/// Extract the new-file line range from a unified diff hunk header.
1325/// `@@ -old_start[,old_count] +new_start[,new_count] @@`
1326fn parse_hunk_header(line: &str) -> Option<(u32, u32)> {
1327    let rest = line.strip_prefix("@@ ")?;
1328    let plus_pos = rest.find(" +")?;
1329    let new_part = &rest[plus_pos + 2..];
1330    let end = new_part.find(' ').unwrap_or(new_part.len());
1331    let range = &new_part[..end];
1332    if let Some(comma) = range.find(',') {
1333        let start: u32 = range[..comma].parse().ok()?;
1334        let count: u32 = range[comma + 1..].parse().ok()?;
1335        Some((start, start + count.saturating_sub(1)))
1336    } else {
1337        let start: u32 = range.parse().ok()?;
1338        Some((start, start))
1339    }
1340}