use std::path::PathBuf;
use grapha_core::graph::Graph;
use serde_json::{Value, json};
use tantivy::Index;
use crate::mcp::types::ToolDefinition;
use crate::query;
use crate::recall::{self, Recall};
use crate::search;
use crate::store::Store;
pub struct McpState {
pub graph: Graph,
pub search_index: Index,
pub store_path: PathBuf,
pub recall: Recall,
}
pub fn tool_definitions() -> Vec<ToolDefinition> {
vec![
ToolDefinition {
name: "search_symbols".to_string(),
description: "Search for symbols by name, kind, module, file, or role. Returns matching symbols with relevance scores.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query (symbol name or keyword)"
},
"limit": {
"type": "integer",
"description": "Maximum number of results (default: 20)",
"default": 20
},
"kind": {
"type": "string",
"description": "Filter by symbol kind (function, struct, enum, trait, etc.)"
},
"module": {
"type": "string",
"description": "Filter by module name"
},
"file": {
"type": "string",
"description": "Filter by file path glob"
},
"role": {
"type": "string",
"description": "Filter by role (entry_point, terminal, internal)"
},
"fuzzy": {
"type": "boolean",
"description": "Enable fuzzy matching (default: false)",
"default": false
}
},
"required": ["query"]
}),
},
ToolDefinition {
name: "get_symbol_context".to_string(),
description: "Get 360-degree context for a symbol: callers, callees, implementors, containment, and type references.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "Symbol name or ID"
}
},
"required": ["symbol"]
}),
},
ToolDefinition {
name: "get_impact".to_string(),
description: "Analyze the blast radius of changing a symbol using BFS traversal.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "Symbol name or ID"
},
"depth": {
"type": "integer",
"description": "Maximum traversal depth (default: 3)",
"default": 3
}
},
"required": ["symbol"]
}),
},
ToolDefinition {
name: "get_file_map".to_string(),
description: "Get a map of files and symbols organized by module and directory.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"module": {
"type": "string",
"description": "Filter by module name (optional)"
}
}
}),
},
ToolDefinition {
name: "trace".to_string(),
description: "Trace dataflow forward from a symbol to terminals, or reverse from a symbol back to entry points.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "Symbol name or ID"
},
"direction": {
"type": "string",
"enum": ["forward", "reverse"],
"description": "Trace direction (default: forward)",
"default": "forward"
},
"depth": {
"type": "integer",
"description": "Maximum traversal depth (default: 10 for forward, unlimited for reverse)"
}
},
"required": ["symbol"]
}),
},
ToolDefinition {
name: "get_file_symbols".to_string(),
description: "List all symbols in a file, ordered by source position. Returns declarations (structs, functions, properties, etc.) excluding synthetic view/branch nodes.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"file": {
"type": "string",
"description": "File name or path suffix (e.g. \"RoomPage.swift\" or \"src/main.rs\")"
}
},
"required": ["file"]
}),
},
ToolDefinition {
name: "batch_context".to_string(),
description: "Get 360-degree context for multiple symbols in a single call. Returns a map of symbol ID to context result. More efficient than multiple get_symbol_context calls.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"symbols": {
"type": "array",
"items": { "type": "string" },
"description": "Array of symbol names or IDs"
}
},
"required": ["symbols"]
}),
},
ToolDefinition {
name: "analyze_complexity".to_string(),
description: "Analyze the structural complexity of a type (struct, class, enum, protocol). Returns property count, method count, dependency count, invalidation sources, init parameter count, extension count, containment depth, blast radius, and an overall severity rating.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "Type name or ID to analyze"
}
},
"required": ["symbol"]
}),
},
ToolDefinition {
name: "detect_smells".to_string(),
description: "Scan the entire graph for code smells: god types (>15 properties), excessive dependencies (>10), wide invalidation surfaces (>5 sources), massive inits (>8 params), deep nesting (>5 levels), high fan-out/fan-in (>15 calls), and many extensions (>5). Returns smells sorted by severity.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"module": {
"type": "string",
"description": "Filter smells to a specific module (optional, scans all if omitted)"
}
}
}),
},
ToolDefinition {
name: "get_module_summary".to_string(),
description: "Get high-level metrics for each module: symbol count, file count, symbols by kind, edge count, cross-module coupling ratio, entry points, and terminals. Sorted by symbol count descending.".to_string(),
input_schema: json!({
"type": "object",
"properties": {}
}),
},
ToolDefinition {
name: "reload".to_string(),
description: "Reload the graph and search index from disk. Use after running `grapha index` from the CLI to pick up changes without restarting the MCP server.".to_string(),
input_schema: json!({
"type": "object",
"properties": {}
}),
},
]
}
fn text_content(text: String) -> Value {
json!({
"content": [{
"type": "text",
"text": text
}]
})
}
fn tool_error(message: String) -> Value {
json!({
"content": [{
"type": "text",
"text": message
}],
"isError": true
})
}
fn format_query_error(err: &query::QueryResolveError) -> String {
match err {
query::QueryResolveError::NotFound { query } => {
format!("symbol not found: {query}")
}
query::QueryResolveError::Ambiguous { query, candidates } => {
let mut msg = format!("ambiguous query: {query}\n");
for c in candidates {
msg.push_str(&format!(" - {} [{:?}] in {}\n", c.name, c.kind, c.file));
}
msg.push_str(&format!("hint: {}", query::ambiguity_hint()));
msg
}
query::QueryResolveError::NotFunction { hint } => hint.clone(),
}
}
fn serialize_result<T: serde::Serialize>(result: &T) -> Value {
match serde_json::to_string_pretty(result) {
Ok(json) => text_content(json),
Err(e) => tool_error(format!("failed to serialize result: {e}")),
}
}
pub fn handle_tool_call(state: &mut McpState, tool_name: &str, arguments: &Value) -> Value {
match tool_name {
"search_symbols" => handle_search_symbols(state, arguments),
"get_symbol_context" => handle_get_symbol_context(state, arguments),
"get_impact" => handle_get_impact(state, arguments),
"get_file_map" => handle_get_file_map(state, arguments),
"trace" => handle_trace(state, arguments),
"get_file_symbols" => handle_get_file_symbols(state, arguments),
"batch_context" => handle_batch_context(state, arguments),
"analyze_complexity" => handle_analyze_complexity(state, arguments),
"detect_smells" => handle_detect_smells(state, arguments),
"get_module_summary" => handle_get_module_summary(state),
"reload" => handle_reload(state),
_ => tool_error(format!("unknown tool: {tool_name}")),
}
}
fn resolve_symbol(state: &mut McpState, query: &str) -> Result<String, Value> {
match recall::resolve_with_recall(&state.graph.nodes, query, &mut state.recall) {
Ok(node) => Ok(node.id.clone()),
Err(e) => Err(tool_error(format_query_error(&e))),
}
}
fn handle_search_symbols(state: &McpState, arguments: &Value) -> Value {
let query_str = match arguments.get("query").and_then(|v| v.as_str()) {
Some(q) => q,
None => return tool_error("missing required parameter: query".to_string()),
};
let limit = arguments
.get("limit")
.and_then(|v| v.as_u64())
.unwrap_or(20) as usize;
let options = search::SearchOptions {
kind: arguments
.get("kind")
.and_then(|v| v.as_str())
.map(String::from),
module: arguments
.get("module")
.and_then(|v| v.as_str())
.map(String::from),
file_glob: arguments
.get("file")
.and_then(|v| v.as_str())
.map(String::from),
role: arguments
.get("role")
.and_then(|v| v.as_str())
.map(String::from),
fuzzy: arguments
.get("fuzzy")
.and_then(|v| v.as_bool())
.unwrap_or(false),
};
match search::search_filtered(&state.search_index, query_str, limit, &options) {
Ok(results) => serialize_result(&results),
Err(e) => tool_error(format!("search failed: {e}")),
}
}
fn handle_get_symbol_context(state: &mut McpState, arguments: &Value) -> Value {
let query = match arguments.get("symbol").and_then(|v| v.as_str()) {
Some(s) => s,
None => return tool_error("missing required parameter: symbol".to_string()),
};
let symbol_id = match resolve_symbol(state, query) {
Ok(id) => id,
Err(e) => return e,
};
match query::context::query_context(&state.graph, &symbol_id) {
Ok(result) => serialize_result(&result),
Err(e) => tool_error(format_query_error(&e)),
}
}
fn handle_get_impact(state: &mut McpState, arguments: &Value) -> Value {
let query = match arguments.get("symbol").and_then(|v| v.as_str()) {
Some(s) => s,
None => return tool_error("missing required parameter: symbol".to_string()),
};
let symbol_id = match resolve_symbol(state, query) {
Ok(id) => id,
Err(e) => return e,
};
let depth = arguments.get("depth").and_then(|v| v.as_u64()).unwrap_or(3) as usize;
match query::impact::query_impact(&state.graph, &symbol_id, depth) {
Ok(result) => serialize_result(&result),
Err(e) => tool_error(format_query_error(&e)),
}
}
fn handle_get_file_map(state: &McpState, arguments: &Value) -> Value {
let module = arguments.get("module").and_then(|v| v.as_str());
let result = query::map::file_map(&state.graph, module);
serialize_result(&result)
}
fn handle_trace(state: &mut McpState, arguments: &Value) -> Value {
let query = match arguments.get("symbol").and_then(|v| v.as_str()) {
Some(s) => s,
None => return tool_error("missing required parameter: symbol".to_string()),
};
let symbol_id = match resolve_symbol(state, query) {
Ok(id) => id,
Err(e) => return e,
};
let direction = arguments
.get("direction")
.and_then(|v| v.as_str())
.unwrap_or("forward");
let depth = arguments.get("depth").and_then(|v| v.as_u64());
match direction {
"forward" => {
let max_depth = depth.unwrap_or(10) as usize;
match query::trace::query_trace(&state.graph, &symbol_id, max_depth) {
Ok(result) => serialize_result(&result),
Err(e) => tool_error(format_query_error(&e)),
}
}
"reverse" => {
let max_depth = depth.map(|d| d as usize);
match query::reverse::query_reverse(&state.graph, &symbol_id, max_depth) {
Ok(result) => serialize_result(&result),
Err(e) => tool_error(format_query_error(&e)),
}
}
other => tool_error(format!(
"invalid direction: {other} (expected \"forward\" or \"reverse\")"
)),
}
}
fn handle_get_file_symbols(state: &McpState, arguments: &Value) -> Value {
let file = match arguments.get("file").and_then(|v| v.as_str()) {
Some(f) => f,
None => return tool_error("missing required parameter: file".to_string()),
};
let result = query::file_symbols::query_file_symbols(&state.graph, file);
if result.total == 0 {
return tool_error(format!("no symbols found in file matching: {file}"));
}
serialize_result(&result)
}
fn handle_batch_context(state: &mut McpState, arguments: &Value) -> Value {
let symbols = match arguments.get("symbols").and_then(|v| v.as_array()) {
Some(arr) => arr,
None => return tool_error("missing required parameter: symbols (array)".to_string()),
};
let symbol_strs: Vec<&str> = symbols.iter().filter_map(|v| v.as_str()).collect();
if symbol_strs.is_empty() {
return tool_error("symbols array is empty".to_string());
}
if symbol_strs.len() > 20 {
return tool_error("batch_context supports at most 20 symbols per call".to_string());
}
let mut results: Vec<Value> = Vec::with_capacity(symbol_strs.len());
for symbol in &symbol_strs {
let resolved = resolve_symbol(state, symbol);
let query_id = match &resolved {
Ok(id) => id.as_str(),
Err(_) => symbol,
};
match query::context::query_context(&state.graph, query_id) {
Ok(ctx) => {
results.push(json!({
"query": symbol,
"result": serde_json::to_value(&ctx).unwrap_or(Value::Null),
}));
}
Err(e) => {
results.push(json!({
"query": symbol,
"error": format_query_error(&e),
}));
}
}
}
serialize_result(&results)
}
fn handle_analyze_complexity(state: &mut McpState, arguments: &Value) -> Value {
let query = match arguments.get("symbol").and_then(|v| v.as_str()) {
Some(s) => s,
None => return tool_error("missing required parameter: symbol".to_string()),
};
let symbol_id = match resolve_symbol(state, query) {
Ok(id) => id,
Err(e) => return e,
};
match query::complexity::query_complexity(&state.graph, &symbol_id) {
Ok(result) => serialize_result(&result),
Err(e) => tool_error(format_query_error(&e)),
}
}
fn handle_detect_smells(state: &McpState, arguments: &Value) -> Value {
let module_filter = arguments.get("module").and_then(|v| v.as_str());
let mut result = query::smells::detect_smells(&state.graph);
if let Some(module) = module_filter {
let module_lower = module.to_lowercase();
result.smells.retain(|smell| {
state.graph.nodes.iter().any(|n| {
n.id == smell.symbol.id
&& n.module
.as_ref()
.is_some_and(|m| m.to_lowercase() == module_lower)
})
});
result.total = result.smells.len();
result.by_severity.clear();
for smell in &result.smells {
*result
.by_severity
.entry(smell.severity.clone())
.or_default() += 1;
}
}
serialize_result(&result)
}
fn handle_get_module_summary(state: &McpState) -> Value {
let result = query::module_summary::query_module_summary(&state.graph);
serialize_result(&result)
}
fn handle_reload(state: &mut McpState) -> Value {
let db_path = state.store_path.join("grapha.db");
let search_index_path = state.store_path.join("search_index");
let store = crate::store::sqlite::SqliteStore::new(db_path);
let graph = match store.load() {
Ok(g) => g,
Err(e) => return tool_error(format!("failed to reload graph: {e}")),
};
let search_index = if search_index_path.exists() {
match tantivy::Index::open_in_dir(&search_index_path) {
Ok(idx) => idx,
Err(e) => return tool_error(format!("failed to reload search index: {e}")),
}
} else {
match search::build_index(&graph, &search_index_path) {
Ok(idx) => idx,
Err(e) => return tool_error(format!("failed to build search index: {e}")),
}
};
let node_count = graph.nodes.len();
let edge_count = graph.edges.len();
state.graph = graph;
state.search_index = search_index;
let valid_ids: std::collections::HashSet<&str> =
state.graph.nodes.iter().map(|n| n.id.as_str()).collect();
state.recall.prune(&valid_ids);
text_content(format!(
"Reloaded successfully: {node_count} nodes, {edge_count} edges"
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tool_definitions_count() {
let tools = tool_definitions();
assert_eq!(tools.len(), 11);
let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"search_symbols"));
assert!(names.contains(&"get_symbol_context"));
assert!(names.contains(&"get_impact"));
assert!(names.contains(&"get_file_map"));
assert!(names.contains(&"trace"));
assert!(names.contains(&"get_file_symbols"));
assert!(names.contains(&"batch_context"));
assert!(names.contains(&"analyze_complexity"));
assert!(names.contains(&"detect_smells"));
assert!(names.contains(&"get_module_summary"));
assert!(names.contains(&"reload"));
}
#[test]
fn unknown_tool_returns_error() {
let mut state = make_test_state();
let result = handle_tool_call(&mut state, "nonexistent", &json!({}));
assert!(
result
.get("isError")
.and_then(|v| v.as_bool())
.unwrap_or(false)
);
}
#[test]
fn search_symbols_missing_query_returns_error() {
let mut state = make_test_state();
let result = handle_tool_call(&mut state, "search_symbols", &json!({}));
assert!(
result
.get("isError")
.and_then(|v| v.as_bool())
.unwrap_or(false)
);
}
#[test]
fn get_file_symbols_missing_file_returns_error() {
let mut state = make_test_state();
let result = handle_tool_call(&mut state, "get_file_symbols", &json!({}));
assert!(
result
.get("isError")
.and_then(|v| v.as_bool())
.unwrap_or(false)
);
}
#[test]
fn batch_context_empty_array_returns_error() {
let mut state = make_test_state();
let result = handle_tool_call(&mut state, "batch_context", &json!({"symbols": []}));
assert!(
result
.get("isError")
.and_then(|v| v.as_bool())
.unwrap_or(false)
);
}
#[test]
fn detect_smells_on_empty_graph() {
let mut state = make_test_state();
let result = handle_tool_call(&mut state, "detect_smells", &json!({}));
assert!(
!result
.get("isError")
.and_then(|v| v.as_bool())
.unwrap_or(false)
);
}
#[test]
fn get_module_summary_on_empty_graph() {
let mut state = make_test_state();
let result = handle_tool_call(&mut state, "get_module_summary", &json!({}));
assert!(
!result
.get("isError")
.and_then(|v| v.as_bool())
.unwrap_or(false)
);
}
fn make_test_state() -> McpState {
let graph = Graph {
version: String::new(),
nodes: vec![],
edges: vec![],
};
let schema = tantivy::schema::Schema::builder().build();
let index = Index::create_in_ram(schema);
McpState {
graph,
search_index: index,
store_path: PathBuf::from("/tmp/test"),
recall: Recall::new(),
}
}
}