codeprism_mcp/tools/core/
navigation.rs

1//! Navigation tools for tracing paths and dependencies
2
3use crate::tools_legacy::{CallToolParams, CallToolResult, Tool, ToolContent};
4use crate::CodePrismMcpServer;
5use anyhow::Result;
6use serde_json::Value;
7
8/// List navigation tools
9pub fn list_tools() -> Vec<Tool> {
10    vec![
11        Tool {
12            name: "trace_path".to_string(),
13            title: Some("Trace Execution Path".to_string()),
14            description: "Find the shortest path between two code symbols".to_string(),
15            input_schema: serde_json::json!({
16                "type": "object",
17                "properties": {
18                    "source": {
19                        "type": "string",
20                        "description": "Source symbol identifier (node ID)"
21                    },
22                    "target": {
23                        "type": "string",
24                        "description": "Target symbol identifier (node ID)"
25                    },
26                    "max_depth": {
27                        "type": "number",
28                        "description": "Maximum search depth",
29                        "default": 10
30                    }
31                },
32                "required": ["source", "target"]
33            }),
34        },
35        Tool {
36            name: "find_dependencies".to_string(),
37            title: Some("Find Dependencies".to_string()),
38            description: "Analyze dependencies for a code symbol or file".to_string(),
39            input_schema: serde_json::json!({
40                "type": "object",
41                "properties": {
42                    "target": {
43                        "type": "string",
44                        "description": "Symbol ID or file path to analyze"
45                    },
46                    "dependency_type": {
47                        "type": "string",
48                        "enum": ["direct", "calls", "imports", "reads", "writes"],
49                        "description": "Type of dependencies to find",
50                        "default": "direct"
51                    }
52                },
53                "required": ["target"]
54            }),
55        },
56        Tool {
57            name: "find_references".to_string(),
58            title: Some("Find References".to_string()),
59            description: "Find all references to a symbol across the codebase".to_string(),
60            input_schema: serde_json::json!({
61                "type": "object",
62                "properties": {
63                    "symbol_id": {
64                        "type": "string",
65                        "description": "Symbol identifier to find references for"
66                    },
67                    "include_definitions": {
68                        "type": "boolean",
69                        "description": "Include symbol definitions",
70                        "default": true
71                    },
72                    "context_lines": {
73                        "type": "number",
74                        "description": "Number of lines before and after the symbol to include as context",
75                        "default": 4
76                    }
77                },
78                "required": ["symbol_id"]
79            }),
80        },
81    ]
82}
83
84/// Route navigation tool calls
85pub async fn call_tool(
86    server: &CodePrismMcpServer,
87    params: &CallToolParams,
88) -> Result<CallToolResult> {
89    match params.name.as_str() {
90        "trace_path" => trace_path(server, params.arguments.as_ref()).await,
91        "find_dependencies" => find_dependencies(server, params.arguments.as_ref()).await,
92        "find_references" => find_references(server, params.arguments.as_ref()).await,
93        _ => Err(anyhow::anyhow!("Unknown navigation tool: {}", params.name)),
94    }
95}
96
97/// Parse node ID from hex string
98fn parse_node_id(hex_str: &str) -> Result<codeprism_core::NodeId> {
99    codeprism_core::NodeId::from_hex(hex_str)
100        .map_err(|e| anyhow::anyhow!("Invalid node ID '{}': {}", hex_str, e))
101}
102
103/// Resolve symbol name to node ID using search
104async fn resolve_symbol_name(
105    server: &CodePrismMcpServer,
106    symbol_name: &str,
107) -> Result<Option<codeprism_core::NodeId>> {
108    // Try to search for the symbol by name
109    let results = server
110        .graph_query()
111        .search_symbols(symbol_name, None, Some(10))?;
112
113    // Look for exact match first
114    for result in &results {
115        if result.node.name == symbol_name {
116            return Ok(Some(result.node.id));
117        }
118    }
119
120    // If no exact match, return the first result if any
121    if let Some(first) = results.first() {
122        Ok(Some(first.node.id))
123    } else {
124        Ok(None)
125    }
126}
127
128/// Resolve symbol identifier - try as node ID first, then as symbol name
129async fn resolve_symbol_identifier(
130    server: &CodePrismMcpServer,
131    identifier: &str,
132) -> Result<codeprism_core::NodeId> {
133    // First try to parse as node ID
134    if let Ok(node_id) = parse_node_id(identifier) {
135        return Ok(node_id);
136    }
137
138    // Then try to resolve as symbol name
139    if let Some(node_id) = resolve_symbol_name(server, identifier).await? {
140        return Ok(node_id);
141    }
142
143    Err(anyhow::anyhow!("Could not resolve symbol identifier '{}'. Please provide either a valid node ID (hex string) or symbol name that exists in the codebase.", identifier))
144}
145
146/// Extract source context around a specific line
147fn extract_source_context(
148    file_path: &std::path::Path,
149    line_number: usize,
150    context_lines: usize,
151) -> Option<serde_json::Value> {
152    if let Ok(content) = std::fs::read_to_string(file_path) {
153        let lines: Vec<&str> = content.lines().collect();
154        let total_lines = lines.len();
155
156        if line_number == 0 || line_number > total_lines {
157            return None;
158        }
159
160        // Convert to 0-based indexing
161        let target_line_idx = line_number - 1;
162
163        // Calculate context range
164        let start_idx = target_line_idx.saturating_sub(context_lines);
165        let end_idx = std::cmp::min(target_line_idx + context_lines + 1, total_lines);
166
167        let context_lines_data: Vec<serde_json::Value> = (start_idx..end_idx)
168            .map(|idx| {
169                serde_json::json!({
170                    "line_number": idx + 1,
171                    "content": lines[idx],
172                    "is_target": idx == target_line_idx
173                })
174            })
175            .collect();
176
177        Some(serde_json::json!({
178            "file": file_path.display().to_string(),
179            "target_line": line_number,
180            "context_start": start_idx + 1,
181            "context_end": end_idx,
182            "lines": context_lines_data
183        }))
184    } else {
185        None
186    }
187}
188
189/// Create node info with source context
190fn create_node_info_with_context(
191    node: &codeprism_core::Node,
192    context_lines: usize,
193) -> serde_json::Value {
194    let mut info = serde_json::json!({
195        "id": node.id.to_hex(),
196        "name": node.name,
197        "kind": format!("{:?}", node.kind),
198        "file": node.file.display().to_string(),
199        "span": {
200            "start_line": node.span.start_line,
201            "end_line": node.span.end_line,
202            "start_column": node.span.start_column,
203            "end_column": node.span.end_column
204        }
205    });
206
207    if let Some(context) = extract_source_context(&node.file, node.span.start_line, context_lines) {
208        info["source_context"] = context;
209    }
210
211    info
212}
213
214/// Validate that a dependency node has a valid name
215fn is_valid_dependency_node(node: &codeprism_core::Node) -> bool {
216    // Filter out Call nodes with invalid names
217    if matches!(node.kind, codeprism_core::NodeKind::Call) {
218        // Check for common invalid patterns
219        if node.name.is_empty()
220            || node.name == ")"
221            || node.name == "("
222            || node.name.trim().is_empty()
223            || node.name.chars().all(|c| !c.is_alphanumeric() && c != '_')
224        {
225            return false;
226        }
227    }
228
229    // All other nodes are considered valid
230    true
231}
232
233/// Trace path between two symbols
234async fn trace_path(
235    server: &CodePrismMcpServer,
236    arguments: Option<&Value>,
237) -> Result<CallToolResult> {
238    let args = arguments.ok_or_else(|| anyhow::anyhow!("Missing arguments"))?;
239
240    let source_str = args
241        .get("source")
242        .and_then(|v| v.as_str())
243        .ok_or_else(|| anyhow::anyhow!("Missing source parameter"))?;
244
245    let target_str = args
246        .get("target")
247        .and_then(|v| v.as_str())
248        .ok_or_else(|| anyhow::anyhow!("Missing target parameter"))?;
249
250    let max_depth = args
251        .get("max_depth")
252        .and_then(|v| v.as_u64())
253        .map(|v| v as usize);
254
255    // Resolve source and target identifiers (node IDs or symbol names)
256    let source_id = resolve_symbol_identifier(server, source_str).await?;
257    let target_id = resolve_symbol_identifier(server, target_str).await?;
258
259    match server
260        .graph_query()
261        .find_path(&source_id, &target_id, max_depth)?
262    {
263        Some(path_result) => {
264            let result = serde_json::json!({
265                "found": true,
266                "source": source_str,
267                "target": target_str,
268                "distance": path_result.distance,
269                "path": path_result.path.iter().map(|id| id.to_hex()).collect::<Vec<_>>(),
270                "edges": path_result.edges.iter().map(|edge| {
271                    serde_json::json!({
272                        "source": edge.source.to_hex(),
273                        "target": edge.target.to_hex(),
274                        "kind": format!("{:?}", edge.kind)
275                    })
276                }).collect::<Vec<_>>()
277            });
278
279            Ok(CallToolResult {
280                content: vec![ToolContent::Text {
281                    text: serde_json::to_string_pretty(&result)?,
282                }],
283                is_error: Some(false),
284            })
285        }
286        None => {
287            let result = serde_json::json!({
288                "found": false,
289                "source": source_str,
290                "target": target_str,
291                "message": "No path found between the specified symbols"
292            });
293
294            Ok(CallToolResult {
295                content: vec![ToolContent::Text {
296                    text: serde_json::to_string_pretty(&result)?,
297                }],
298                is_error: Some(false),
299            })
300        }
301    }
302}
303
304/// Find dependencies of a symbol
305async fn find_dependencies(
306    server: &CodePrismMcpServer,
307    arguments: Option<&Value>,
308) -> Result<CallToolResult> {
309    let args = arguments.ok_or_else(|| anyhow::anyhow!("Missing arguments"))?;
310
311    // Support both "target" and "symbol" parameter names for backward compatibility
312    let target = args
313        .get("target")
314        .or_else(|| args.get("symbol"))
315        .and_then(|v| v.as_str())
316        .ok_or_else(|| anyhow::anyhow!("Missing target parameter (or symbol)"))?;
317
318    let dependency_type_str = args
319        .get("dependency_type")
320        .and_then(|v| v.as_str())
321        .unwrap_or("direct");
322
323    let dependency_type = match dependency_type_str {
324        "direct" => codeprism_core::graph::DependencyType::Direct,
325        "calls" => codeprism_core::graph::DependencyType::Calls,
326        "imports" => codeprism_core::graph::DependencyType::Imports,
327        "reads" => codeprism_core::graph::DependencyType::Reads,
328        "writes" => codeprism_core::graph::DependencyType::Writes,
329        _ => {
330            return Ok(CallToolResult {
331                content: vec![ToolContent::Text {
332                    text: format!("Invalid dependency type: {}", dependency_type_str),
333                }],
334                is_error: Some(true),
335            })
336        }
337    };
338
339    // Try to resolve as symbol identifier first, then as file path
340    let dependencies = if let Ok(node_id) = resolve_symbol_identifier(server, target).await {
341        server
342            .graph_query()
343            .find_dependencies(&node_id, dependency_type)?
344    } else {
345        // Handle file path - find all nodes in the file and get their dependencies
346        let file_path = std::path::PathBuf::from(target);
347        let nodes = server.graph_store().get_nodes_in_file(&file_path);
348        let mut all_deps = Vec::new();
349        for node in nodes {
350            let deps = server
351                .graph_query()
352                .find_dependencies(&node.id, dependency_type.clone())?;
353            all_deps.extend(deps);
354        }
355        all_deps
356    };
357
358    // Filter out invalid Call nodes with malformed names
359    let valid_dependencies: Vec<_> = dependencies
360        .iter()
361        .filter(|dep| is_valid_dependency_node(&dep.target_node))
362        .collect();
363
364    let result = serde_json::json!({
365        "target": target,
366        "dependency_type": dependency_type_str,
367        "dependencies": valid_dependencies.iter().map(|dep| {
368            serde_json::json!({
369                "id": dep.target_node.id.to_hex(),
370                "name": dep.target_node.name,
371                "kind": format!("{:?}", dep.target_node.kind),
372                "file": dep.target_node.file.display().to_string(),
373                "edge_kind": format!("{:?}", dep.edge_kind)
374            })
375        }).collect::<Vec<_>>()
376    });
377
378    Ok(CallToolResult {
379        content: vec![ToolContent::Text {
380            text: serde_json::to_string_pretty(&result)?,
381        }],
382        is_error: Some(false),
383    })
384}
385
386/// Find references to a symbol
387async fn find_references(
388    server: &CodePrismMcpServer,
389    arguments: Option<&Value>,
390) -> Result<CallToolResult> {
391    let args = arguments.ok_or_else(|| anyhow::anyhow!("Missing arguments"))?;
392
393    // Support both "symbol_id" and "symbol" parameter names for backward compatibility
394    let symbol_id_str = args
395        .get("symbol_id")
396        .or_else(|| args.get("symbol"))
397        .and_then(|v| v.as_str())
398        .ok_or_else(|| anyhow::anyhow!("Missing symbol_id parameter (or symbol)"))?;
399
400    let _include_definitions = args
401        .get("include_definitions")
402        .and_then(|v| v.as_bool())
403        .unwrap_or(true);
404
405    let context_lines = args
406        .get("context_lines")
407        .and_then(|v| v.as_u64())
408        .map(|v| v as usize)
409        .unwrap_or(4);
410
411    let symbol_id = resolve_symbol_identifier(server, symbol_id_str).await?;
412    let references = server.graph_query().find_references(&symbol_id)?;
413
414    let result = serde_json::json!({
415        "symbol_id": symbol_id_str,
416        "references": references.iter().map(|ref_| {
417            let mut ref_info = create_node_info_with_context(&ref_.source_node, context_lines);
418            ref_info["edge_kind"] = serde_json::json!(format!("{:?}", ref_.edge_kind));
419            ref_info["reference_location"] = serde_json::json!({
420                "file": ref_.location.file.display().to_string(),
421                "span": {
422                    "start_line": ref_.location.span.start_line,
423                    "end_line": ref_.location.span.end_line,
424                    "start_column": ref_.location.span.start_column,
425                    "end_column": ref_.location.span.end_column
426                }
427            });
428            ref_info
429        }).collect::<Vec<_>>()
430    });
431
432    Ok(CallToolResult {
433        content: vec![ToolContent::Text {
434            text: serde_json::to_string_pretty(&result)?,
435        }],
436        is_error: Some(false),
437    })
438}