codeprism_mcp/tools/analysis/
flow.rs

1//! Data flow analysis tools
2
3#![allow(clippy::too_many_arguments)]
4
5use crate::tools_legacy::{CallToolParams, CallToolResult, Tool, ToolContent};
6use crate::CodePrismMcpServer;
7use anyhow::Result;
8use serde_json::Value;
9
10/// Trace data flow for a specific symbol
11async fn trace_symbol_data_flow(
12    server: &CodePrismMcpServer,
13    node: &codeprism_core::Node,
14    direction: &str,
15    max_depth: usize,
16    include_transformations: bool,
17) -> Value {
18    let mut flows = Vec::new();
19    let mut transformations = Vec::new();
20
21    // Get dependencies and references for the symbol
22    match direction {
23        "forward" | "both" => {
24            if let Ok(dependencies) = server
25                .graph_query()
26                .find_dependencies(&node.id, codeprism_core::graph::DependencyType::Direct)
27            {
28                for dep in dependencies.iter().take(max_depth) {
29                    flows.push(serde_json::json!({
30                        "direction": "forward",
31                        "from": {
32                            "id": node.id.to_hex(),
33                            "name": node.name,
34                            "kind": format!("{:?}", node.kind)
35                        },
36                        "to": {
37                            "id": dep.target_node.id.to_hex(),
38                            "name": dep.target_node.name,
39                            "kind": format!("{:?}", dep.target_node.kind),
40                            "file": dep.target_node.file.display().to_string()
41                        },
42                        "edge_type": format!("{:?}", dep.edge_kind)
43                    }));
44
45                    if include_transformations
46                        && matches!(dep.edge_kind, codeprism_core::EdgeKind::Calls)
47                    {
48                        transformations.push(serde_json::json!({
49                            "type": "function_call",
50                            "source": node.name,
51                            "target": dep.target_node.name,
52                            "transformation": "call"
53                        }));
54                    }
55                }
56            }
57        }
58        _ => {}
59    }
60
61    match direction {
62        "backward" | "both" => {
63            if let Ok(references) = server.graph_query().find_references(&node.id) {
64                for ref_info in references.iter().take(max_depth) {
65                    flows.push(serde_json::json!({
66                        "direction": "backward",
67                        "from": {
68                            "id": ref_info.source_node.id.to_hex(),
69                            "name": ref_info.source_node.name,
70                            "kind": format!("{:?}", ref_info.source_node.kind),
71                            "file": ref_info.source_node.file.display().to_string()
72                        },
73                        "to": {
74                            "id": node.id.to_hex(),
75                            "name": node.name,
76                            "kind": format!("{:?}", node.kind)
77                        },
78                        "edge_type": format!("{:?}", ref_info.edge_kind)
79                    }));
80                }
81            }
82        }
83        _ => {}
84    }
85
86    serde_json::json!({
87        "target": node.name,
88        "analysis": {
89            "direction": direction,
90            "max_depth": max_depth,
91            "include_transformations": include_transformations,
92            "symbol_info": {
93                "id": node.id.to_hex(),
94                "name": node.name,
95                "kind": format!("{:?}", node.kind),
96                "file": node.file.display().to_string()
97            }
98        },
99        "data_flow": {
100            "flows": flows,
101            "transformations": transformations
102        },
103        "summary": {
104            "total_flow_steps": flows.len(),
105            "transformation_count": transformations.len(),
106            "directions_analyzed": match direction {
107                "both" => vec!["forward", "backward"],
108                dir => vec![dir]
109            }
110        }
111    })
112}
113
114/// Analyze transitive dependencies for a specific symbol
115async fn analyze_symbol_transitive_dependencies(
116    server: &CodePrismMcpServer,
117    node: &codeprism_core::Node,
118    max_depth: usize,
119    detect_cycles: bool,
120) -> Value {
121    let mut direct_deps = Vec::new();
122    let mut transitive_deps = Vec::new();
123    let mut visited = std::collections::HashSet::new();
124    let mut cycles = Vec::new();
125
126    // Get direct dependencies
127    if let Ok(dependencies) = server
128        .graph_query()
129        .find_dependencies(&node.id, codeprism_core::graph::DependencyType::Direct)
130    {
131        for dep in &dependencies {
132            direct_deps.push(serde_json::json!({
133                "id": dep.target_node.id.to_hex(),
134                "name": dep.target_node.name,
135                "kind": format!("{:?}", dep.target_node.kind),
136                "file": dep.target_node.file.display().to_string(),
137                "edge_type": format!("{:?}", dep.edge_kind)
138            }));
139        }
140
141        // Collect transitive dependencies
142        for dep in &dependencies {
143            collect_transitive_deps(
144                server,
145                &dep.target_node.id,
146                &mut transitive_deps,
147                &mut visited,
148                &mut cycles,
149                max_depth,
150                1,
151                detect_cycles,
152                &node.id,
153            )
154            .await;
155        }
156    }
157
158    serde_json::json!({
159        "target": node.name,
160        "analysis": {
161            "max_depth": max_depth,
162            "detect_cycles": detect_cycles,
163            "symbol_info": {
164                "id": node.id.to_hex(),
165                "name": node.name,
166                "kind": format!("{:?}", node.kind),
167                "file": node.file.display().to_string()
168            }
169        },
170        "dependencies": {
171            "direct": direct_deps,
172            "transitive": transitive_deps,
173            "cycles": cycles
174        },
175        "summary": {
176            "total_direct": direct_deps.len(),
177            "total_transitive": transitive_deps.len(),
178            "max_depth_reached": visited.len(),
179            "cycles_detected": cycles.len()
180        }
181    })
182}
183
184/// Recursively collect transitive dependencies
185async fn collect_transitive_deps(
186    server: &CodePrismMcpServer,
187    current_id: &codeprism_core::NodeId,
188    transitive_deps: &mut Vec<Value>,
189    visited: &mut std::collections::HashSet<codeprism_core::NodeId>,
190    cycles: &mut Vec<Value>,
191    max_depth: usize,
192    current_depth: usize,
193    detect_cycles: bool,
194    root_id: &codeprism_core::NodeId,
195) {
196    if current_depth >= max_depth {
197        return;
198    }
199
200    if visited.contains(current_id) {
201        if detect_cycles && current_id == root_id {
202            cycles.push(serde_json::json!({
203                "cycle_detected": true,
204                "depth": current_depth,
205                "node_id": current_id.to_hex()
206            }));
207        }
208        return;
209    }
210
211    visited.insert(*current_id);
212
213    if let Ok(dependencies) = server
214        .graph_query()
215        .find_dependencies(current_id, codeprism_core::graph::DependencyType::Direct)
216    {
217        for dep in dependencies {
218            transitive_deps.push(serde_json::json!({
219                "id": dep.target_node.id.to_hex(),
220                "name": dep.target_node.name,
221                "kind": format!("{:?}", dep.target_node.kind),
222                "file": dep.target_node.file.display().to_string(),
223                "edge_type": format!("{:?}", dep.edge_kind),
224                "depth": current_depth
225            }));
226
227            // Recurse
228            Box::pin(collect_transitive_deps(
229                server,
230                &dep.target_node.id,
231                transitive_deps,
232                visited,
233                cycles,
234                max_depth,
235                current_depth + 1,
236                detect_cycles,
237                root_id,
238            ))
239            .await;
240        }
241    }
242}
243
244/// List flow analysis tools
245pub fn list_tools() -> Vec<Tool> {
246    vec![
247        Tool {
248            name: "trace_data_flow".to_string(),
249            title: Some("Trace Data Flow".to_string()),
250            description: "Track data flow through the codebase, following variable assignments, function parameters, and transformations".to_string(),
251            input_schema: serde_json::json!({
252                "type": "object",
253                "properties": {
254                    "variable_or_parameter": {
255                        "type": "string",
256                        "description": "Symbol ID of variable or parameter to trace"
257                    },
258                    "direction": {
259                        "type": "string",
260                        "enum": ["forward", "backward", "both"],
261                        "description": "Direction to trace data flow",
262                        "default": "forward"
263                    },
264                    "include_transformations": {
265                        "type": "boolean",
266                        "description": "Include data transformations (method calls, assignments)",
267                        "default": true
268                    },
269                    "max_depth": {
270                        "type": "number",
271                        "description": "Maximum depth for data flow tracing",
272                        "default": 10,
273                        "minimum": 1,
274                        "maximum": 50
275                    },
276                    "follow_function_calls": {
277                        "type": "boolean",
278                        "description": "Follow data flow across function calls",
279                        "default": true
280                    },
281                    "include_field_access": {
282                        "type": "boolean",
283                        "description": "Include field/attribute access patterns",
284                        "default": true
285                    }
286                },
287                "required": ["variable_or_parameter"]
288            }),
289        },
290        Tool {
291            name: "analyze_transitive_dependencies".to_string(),
292            title: Some("Analyze Transitive Dependencies".to_string()),
293            description: "Analyze complete dependency chains, detect cycles, and map transitive relationships".to_string(),
294            input_schema: serde_json::json!({
295                "type": "object",
296                "properties": {
297                    "target": {
298                        "type": "string",
299                        "description": "Symbol ID or file path to analyze"
300                    },
301                    "max_depth": {
302                        "type": "number",
303                        "description": "Maximum depth for transitive analysis",
304                        "default": 5,
305                        "minimum": 1,
306                        "maximum": 20
307                    },
308                    "detect_cycles": {
309                        "type": "boolean",
310                        "description": "Detect circular dependencies",
311                        "default": true
312                    },
313                    "include_external_dependencies": {
314                        "type": "boolean",
315                        "description": "Include external/third-party dependencies",
316                        "default": false
317                    },
318                    "dependency_types": {
319                        "type": "array",
320                        "items": {
321                            "type": "string",
322                            "enum": ["calls", "imports", "reads", "writes", "extends", "implements", "all"]
323                        },
324                        "description": "Types of dependencies to analyze",
325                        "default": ["all"]
326                    }
327                },
328                "required": ["target"]
329            }),
330        }
331    ]
332}
333
334/// Route flow analysis tool calls
335pub async fn call_tool(
336    server: &CodePrismMcpServer,
337    params: &CallToolParams,
338) -> Result<CallToolResult> {
339    match params.name.as_str() {
340        "trace_data_flow" => trace_data_flow(server, params.arguments.as_ref()).await,
341        "analyze_transitive_dependencies" => {
342            analyze_transitive_dependencies(server, params.arguments.as_ref()).await
343        }
344        _ => Err(anyhow::anyhow!(
345            "Unknown flow analysis tool: {}",
346            params.name
347        )),
348    }
349}
350
351/// Trace data flow
352async fn trace_data_flow(
353    server: &CodePrismMcpServer,
354    arguments: Option<&Value>,
355) -> Result<CallToolResult> {
356    let args = arguments.ok_or_else(|| anyhow::anyhow!("Missing arguments"))?;
357
358    // Support multiple parameter names for backward compatibility
359    let variable_or_parameter = args
360        .get("variable_or_parameter")
361        .or_else(|| args.get("start_symbol"))
362        .or_else(|| args.get("symbol"))
363        .or_else(|| args.get("target"))
364        .and_then(|v| v.as_str())
365        .ok_or_else(|| {
366            anyhow::anyhow!(
367                "Missing variable_or_parameter parameter (or start_symbol, symbol, target)"
368            )
369        })?;
370
371    let direction = args
372        .get("direction")
373        .and_then(|v| v.as_str())
374        .unwrap_or("forward");
375
376    let max_depth = args
377        .get("max_depth")
378        .and_then(|v| v.as_u64())
379        .map(|v| v as usize)
380        .unwrap_or(10);
381
382    let include_transformations = args
383        .get("include_transformations")
384        .and_then(|v| v.as_bool())
385        .unwrap_or(true);
386
387    // Try to resolve the symbol
388    let result = if let Ok(symbol_results) =
389        server
390            .graph_query()
391            .search_symbols(variable_or_parameter, None, Some(1))
392    {
393        if let Some(symbol_result) = symbol_results.first() {
394            // Found the symbol, now trace its data flow
395            trace_symbol_data_flow(
396                server,
397                &symbol_result.node,
398                direction,
399                max_depth,
400                include_transformations,
401            )
402            .await
403        } else {
404            serde_json::json!({
405                "target": variable_or_parameter,
406                "error": "Symbol not found",
407                "suggestion": "Check if the symbol name is correct or try using a different identifier"
408            })
409        }
410    } else {
411        serde_json::json!({
412            "target": variable_or_parameter,
413            "error": "Failed to search for symbol",
414            "suggestion": "Ensure the repository is properly indexed"
415        })
416    };
417
418    Ok(CallToolResult {
419        content: vec![ToolContent::Text {
420            text: serde_json::to_string_pretty(&result)?,
421        }],
422        is_error: Some(result.get("error").is_some()),
423    })
424}
425
426/// Analyze transitive dependencies
427async fn analyze_transitive_dependencies(
428    server: &CodePrismMcpServer,
429    arguments: Option<&Value>,
430) -> Result<CallToolResult> {
431    let args = arguments.ok_or_else(|| anyhow::anyhow!("Missing arguments"))?;
432
433    // Support multiple parameter names for backward compatibility
434    let target = args
435        .get("target")
436        .or_else(|| args.get("symbol"))
437        .and_then(|v| v.as_str())
438        .ok_or_else(|| anyhow::anyhow!("Missing target parameter (or symbol)"))?;
439
440    let max_depth = args
441        .get("max_depth")
442        .and_then(|v| v.as_u64())
443        .map(|v| v as usize)
444        .unwrap_or(5);
445
446    let detect_cycles = args
447        .get("detect_cycles")
448        .and_then(|v| v.as_bool())
449        .unwrap_or(true);
450
451    // Try to resolve the symbol
452    let result = if let Ok(symbol_results) =
453        server.graph_query().search_symbols(target, None, Some(1))
454    {
455        if let Some(symbol_result) = symbol_results.first() {
456            // Found the symbol, now analyze its transitive dependencies
457            analyze_symbol_transitive_dependencies(
458                server,
459                &symbol_result.node,
460                max_depth,
461                detect_cycles,
462            )
463            .await
464        } else {
465            serde_json::json!({
466                "target": target,
467                "error": "Symbol not found",
468                "suggestion": "Check if the symbol name is correct or try using a different identifier"
469            })
470        }
471    } else {
472        serde_json::json!({
473            "target": target,
474            "error": "Failed to search for symbol",
475            "suggestion": "Ensure the repository is properly indexed"
476        })
477    };
478
479    Ok(CallToolResult {
480        content: vec![ToolContent::Text {
481            text: serde_json::to_string_pretty(&result)?,
482        }],
483        is_error: Some(result.get("error").is_some()),
484    })
485}