codeprism_mcp/tools/analysis/
flow.rs

1//! Data flow analysis tools
2
3use crate::tools_legacy::{CallToolParams, CallToolResult, Tool, ToolContent};
4use crate::CodePrismMcpServer;
5use anyhow::Result;
6use serde_json::Value;
7
8/// Trace data flow for a specific symbol
9async fn trace_symbol_data_flow(
10    server: &CodePrismMcpServer,
11    node: &codeprism_core::Node,
12    direction: &str,
13    max_depth: usize,
14    include_transformations: bool,
15) -> Value {
16    let mut flows = Vec::new();
17    let mut transformations = Vec::new();
18
19    // Get dependencies and references for the symbol
20    match direction {
21        "forward" | "both" => {
22            if let Ok(dependencies) = server
23                .graph_query()
24                .find_dependencies(&node.id, codeprism_core::graph::DependencyType::Direct)
25            {
26                for dep in dependencies.iter().take(max_depth) {
27                    flows.push(serde_json::json!({
28                        "direction": "forward",
29                        "from": {
30                            "id": node.id.to_hex(),
31                            "name": node.name,
32                            "kind": format!("{:?}", node.kind)
33                        },
34                        "to": {
35                            "id": dep.target_node.id.to_hex(),
36                            "name": dep.target_node.name,
37                            "kind": format!("{:?}", dep.target_node.kind),
38                            "file": dep.target_node.file.display().to_string()
39                        },
40                        "edge_type": format!("{:?}", dep.edge_kind)
41                    }));
42
43                    if include_transformations
44                        && matches!(dep.edge_kind, codeprism_core::EdgeKind::Calls)
45                    {
46                        transformations.push(serde_json::json!({
47                            "type": "function_call",
48                            "source": node.name,
49                            "target": dep.target_node.name,
50                            "transformation": "call"
51                        }));
52                    }
53                }
54            }
55        }
56        _ => {}
57    }
58
59    match direction {
60        "backward" | "both" => {
61            if let Ok(references) = server.graph_query().find_references(&node.id) {
62                for ref_info in references.iter().take(max_depth) {
63                    flows.push(serde_json::json!({
64                        "direction": "backward",
65                        "from": {
66                            "id": ref_info.source_node.id.to_hex(),
67                            "name": ref_info.source_node.name,
68                            "kind": format!("{:?}", ref_info.source_node.kind),
69                            "file": ref_info.source_node.file.display().to_string()
70                        },
71                        "to": {
72                            "id": node.id.to_hex(),
73                            "name": node.name,
74                            "kind": format!("{:?}", node.kind)
75                        },
76                        "edge_type": format!("{:?}", ref_info.edge_kind)
77                    }));
78                }
79            }
80        }
81        _ => {}
82    }
83
84    serde_json::json!({
85        "target": node.name,
86        "analysis": {
87            "direction": direction,
88            "max_depth": max_depth,
89            "include_transformations": include_transformations,
90            "symbol_info": {
91                "id": node.id.to_hex(),
92                "name": node.name,
93                "kind": format!("{:?}", node.kind),
94                "file": node.file.display().to_string()
95            }
96        },
97        "data_flow": {
98            "flows": flows,
99            "transformations": transformations
100        },
101        "summary": {
102            "total_flow_steps": flows.len(),
103            "transformation_count": transformations.len(),
104            "directions_analyzed": match direction {
105                "both" => vec!["forward", "backward"],
106                dir => vec![dir]
107            }
108        }
109    })
110}
111
112/// Analyze transitive dependencies for a specific symbol
113async fn analyze_symbol_transitive_dependencies(
114    server: &CodePrismMcpServer,
115    node: &codeprism_core::Node,
116    max_depth: usize,
117    detect_cycles: bool,
118) -> Value {
119    let mut direct_deps = Vec::new();
120    let mut transitive_deps = Vec::new();
121    let mut visited = std::collections::HashSet::new();
122    let mut cycles = Vec::new();
123
124    // Get direct dependencies
125    if let Ok(dependencies) = server
126        .graph_query()
127        .find_dependencies(&node.id, codeprism_core::graph::DependencyType::Direct)
128    {
129        for dep in &dependencies {
130            direct_deps.push(serde_json::json!({
131                "id": dep.target_node.id.to_hex(),
132                "name": dep.target_node.name,
133                "kind": format!("{:?}", dep.target_node.kind),
134                "file": dep.target_node.file.display().to_string(),
135                "edge_type": format!("{:?}", dep.edge_kind)
136            }));
137        }
138
139        // Collect transitive dependencies
140        for dep in &dependencies {
141            collect_transitive_deps(
142                server,
143                &dep.target_node.id,
144                &mut transitive_deps,
145                &mut visited,
146                &mut cycles,
147                max_depth,
148                1,
149                detect_cycles,
150                &node.id,
151            )
152            .await;
153        }
154    }
155
156    serde_json::json!({
157        "target": node.name,
158        "analysis": {
159            "max_depth": max_depth,
160            "detect_cycles": detect_cycles,
161            "symbol_info": {
162                "id": node.id.to_hex(),
163                "name": node.name,
164                "kind": format!("{:?}", node.kind),
165                "file": node.file.display().to_string()
166            }
167        },
168        "dependencies": {
169            "direct": direct_deps,
170            "transitive": transitive_deps,
171            "cycles": cycles
172        },
173        "summary": {
174            "total_direct": direct_deps.len(),
175            "total_transitive": transitive_deps.len(),
176            "max_depth_reached": visited.len(),
177            "cycles_detected": cycles.len()
178        }
179    })
180}
181
182/// Recursively collect transitive dependencies
183async fn collect_transitive_deps(
184    server: &CodePrismMcpServer,
185    current_id: &codeprism_core::NodeId,
186    transitive_deps: &mut Vec<Value>,
187    visited: &mut std::collections::HashSet<codeprism_core::NodeId>,
188    cycles: &mut Vec<Value>,
189    max_depth: usize,
190    current_depth: usize,
191    detect_cycles: bool,
192    root_id: &codeprism_core::NodeId,
193) {
194    if current_depth >= max_depth {
195        return;
196    }
197
198    if visited.contains(current_id) {
199        if detect_cycles && current_id == root_id {
200            cycles.push(serde_json::json!({
201                "cycle_detected": true,
202                "depth": current_depth,
203                "node_id": current_id.to_hex()
204            }));
205        }
206        return;
207    }
208
209    visited.insert(*current_id);
210
211    if let Ok(dependencies) = server
212        .graph_query()
213        .find_dependencies(current_id, codeprism_core::graph::DependencyType::Direct)
214    {
215        for dep in dependencies {
216            transitive_deps.push(serde_json::json!({
217                "id": dep.target_node.id.to_hex(),
218                "name": dep.target_node.name,
219                "kind": format!("{:?}", dep.target_node.kind),
220                "file": dep.target_node.file.display().to_string(),
221                "edge_type": format!("{:?}", dep.edge_kind),
222                "depth": current_depth
223            }));
224
225            // Recurse
226            Box::pin(collect_transitive_deps(
227                server,
228                &dep.target_node.id,
229                transitive_deps,
230                visited,
231                cycles,
232                max_depth,
233                current_depth + 1,
234                detect_cycles,
235                root_id,
236            ))
237            .await;
238        }
239    }
240}
241
242/// List flow analysis tools
243pub fn list_tools() -> Vec<Tool> {
244    vec![
245        Tool {
246            name: "trace_data_flow".to_string(),
247            title: Some("Trace Data Flow".to_string()),
248            description: "Track data flow through the codebase, following variable assignments, function parameters, and transformations".to_string(),
249            input_schema: serde_json::json!({
250                "type": "object",
251                "properties": {
252                    "variable_or_parameter": {
253                        "type": "string",
254                        "description": "Symbol ID of variable or parameter to trace"
255                    },
256                    "direction": {
257                        "type": "string",
258                        "enum": ["forward", "backward", "both"],
259                        "description": "Direction to trace data flow",
260                        "default": "forward"
261                    },
262                    "include_transformations": {
263                        "type": "boolean",
264                        "description": "Include data transformations (method calls, assignments)",
265                        "default": true
266                    },
267                    "max_depth": {
268                        "type": "number",
269                        "description": "Maximum depth for data flow tracing",
270                        "default": 10,
271                        "minimum": 1,
272                        "maximum": 50
273                    },
274                    "follow_function_calls": {
275                        "type": "boolean",
276                        "description": "Follow data flow across function calls",
277                        "default": true
278                    },
279                    "include_field_access": {
280                        "type": "boolean",
281                        "description": "Include field/attribute access patterns",
282                        "default": true
283                    }
284                },
285                "required": ["variable_or_parameter"]
286            }),
287        },
288        Tool {
289            name: "analyze_transitive_dependencies".to_string(),
290            title: Some("Analyze Transitive Dependencies".to_string()),
291            description: "Analyze complete dependency chains, detect cycles, and map transitive relationships".to_string(),
292            input_schema: serde_json::json!({
293                "type": "object",
294                "properties": {
295                    "target": {
296                        "type": "string",
297                        "description": "Symbol ID or file path to analyze"
298                    },
299                    "max_depth": {
300                        "type": "number",
301                        "description": "Maximum depth for transitive analysis",
302                        "default": 5,
303                        "minimum": 1,
304                        "maximum": 20
305                    },
306                    "detect_cycles": {
307                        "type": "boolean",
308                        "description": "Detect circular dependencies",
309                        "default": true
310                    },
311                    "include_external_dependencies": {
312                        "type": "boolean",
313                        "description": "Include external/third-party dependencies",
314                        "default": false
315                    },
316                    "dependency_types": {
317                        "type": "array",
318                        "items": {
319                            "type": "string",
320                            "enum": ["calls", "imports", "reads", "writes", "extends", "implements", "all"]
321                        },
322                        "description": "Types of dependencies to analyze",
323                        "default": ["all"]
324                    }
325                },
326                "required": ["target"]
327            }),
328        }
329    ]
330}
331
332/// Route flow analysis tool calls
333pub async fn call_tool(
334    server: &CodePrismMcpServer,
335    params: &CallToolParams,
336) -> Result<CallToolResult> {
337    match params.name.as_str() {
338        "trace_data_flow" => trace_data_flow(server, params.arguments.as_ref()).await,
339        "analyze_transitive_dependencies" => {
340            analyze_transitive_dependencies(server, params.arguments.as_ref()).await
341        }
342        _ => Err(anyhow::anyhow!(
343            "Unknown flow analysis tool: {}",
344            params.name
345        )),
346    }
347}
348
349/// Trace data flow
350async fn trace_data_flow(
351    server: &CodePrismMcpServer,
352    arguments: Option<&Value>,
353) -> Result<CallToolResult> {
354    let args = arguments.ok_or_else(|| anyhow::anyhow!("Missing arguments"))?;
355
356    // Support multiple parameter names for backward compatibility
357    let variable_or_parameter = args
358        .get("variable_or_parameter")
359        .or_else(|| args.get("start_symbol"))
360        .or_else(|| args.get("symbol"))
361        .or_else(|| args.get("target"))
362        .and_then(|v| v.as_str())
363        .ok_or_else(|| {
364            anyhow::anyhow!(
365                "Missing variable_or_parameter parameter (or start_symbol, symbol, target)"
366            )
367        })?;
368
369    let direction = args
370        .get("direction")
371        .and_then(|v| v.as_str())
372        .unwrap_or("forward");
373
374    let max_depth = args
375        .get("max_depth")
376        .and_then(|v| v.as_u64())
377        .map(|v| v as usize)
378        .unwrap_or(10);
379
380    let include_transformations = args
381        .get("include_transformations")
382        .and_then(|v| v.as_bool())
383        .unwrap_or(true);
384
385    // Try to resolve the symbol
386    let result = if let Ok(symbol_results) =
387        server
388            .graph_query()
389            .search_symbols(variable_or_parameter, None, Some(1))
390    {
391        if let Some(symbol_result) = symbol_results.first() {
392            // Found the symbol, now trace its data flow
393            trace_symbol_data_flow(
394                server,
395                &symbol_result.node,
396                direction,
397                max_depth,
398                include_transformations,
399            )
400            .await
401        } else {
402            serde_json::json!({
403                "target": variable_or_parameter,
404                "error": "Symbol not found",
405                "suggestion": "Check if the symbol name is correct or try using a different identifier"
406            })
407        }
408    } else {
409        serde_json::json!({
410            "target": variable_or_parameter,
411            "error": "Failed to search for symbol",
412            "suggestion": "Ensure the repository is properly indexed"
413        })
414    };
415
416    Ok(CallToolResult {
417        content: vec![ToolContent::Text {
418            text: serde_json::to_string_pretty(&result)?,
419        }],
420        is_error: Some(result.get("error").is_some()),
421    })
422}
423
424/// Analyze transitive dependencies
425async fn analyze_transitive_dependencies(
426    server: &CodePrismMcpServer,
427    arguments: Option<&Value>,
428) -> Result<CallToolResult> {
429    let args = arguments.ok_or_else(|| anyhow::anyhow!("Missing arguments"))?;
430
431    // Support multiple parameter names for backward compatibility
432    let target = args
433        .get("target")
434        .or_else(|| args.get("symbol"))
435        .and_then(|v| v.as_str())
436        .ok_or_else(|| anyhow::anyhow!("Missing target parameter (or symbol)"))?;
437
438    let max_depth = args
439        .get("max_depth")
440        .and_then(|v| v.as_u64())
441        .map(|v| v as usize)
442        .unwrap_or(5);
443
444    let detect_cycles = args
445        .get("detect_cycles")
446        .and_then(|v| v.as_bool())
447        .unwrap_or(true);
448
449    // Try to resolve the symbol
450    let result = if let Ok(symbol_results) =
451        server.graph_query().search_symbols(target, None, Some(1))
452    {
453        if let Some(symbol_result) = symbol_results.first() {
454            // Found the symbol, now analyze its transitive dependencies
455            analyze_symbol_transitive_dependencies(
456                server,
457                &symbol_result.node,
458                max_depth,
459                detect_cycles,
460            )
461            .await
462        } else {
463            serde_json::json!({
464                "target": target,
465                "error": "Symbol not found",
466                "suggestion": "Check if the symbol name is correct or try using a different identifier"
467            })
468        }
469    } else {
470        serde_json::json!({
471            "target": target,
472            "error": "Failed to search for symbol",
473            "suggestion": "Ensure the repository is properly indexed"
474        })
475    };
476
477    Ok(CallToolResult {
478        content: vec![ToolContent::Text {
479            text: serde_json::to_string_pretty(&result)?,
480        }],
481        is_error: Some(result.get("error").is_some()),
482    })
483}