1use crate::tools_legacy::{CallToolParams, CallToolResult, Tool, ToolContent};
4use crate::CodePrismMcpServer;
5use anyhow::Result;
6use serde_json::Value;
7
8async fn trace_symbol_data_flow(
10 server: &CodePrismMcpServer,
11 node: &codeprism_core::Node,
12 direction: &str,
13 max_depth: usize,
14 include_transformations: bool,
15) -> Value {
16 let mut flows = Vec::new();
17 let mut transformations = Vec::new();
18
19 match direction {
21 "forward" | "both" => {
22 if let Ok(dependencies) = server
23 .graph_query()
24 .find_dependencies(&node.id, codeprism_core::graph::DependencyType::Direct)
25 {
26 for dep in dependencies.iter().take(max_depth) {
27 flows.push(serde_json::json!({
28 "direction": "forward",
29 "from": {
30 "id": node.id.to_hex(),
31 "name": node.name,
32 "kind": format!("{:?}", node.kind)
33 },
34 "to": {
35 "id": dep.target_node.id.to_hex(),
36 "name": dep.target_node.name,
37 "kind": format!("{:?}", dep.target_node.kind),
38 "file": dep.target_node.file.display().to_string()
39 },
40 "edge_type": format!("{:?}", dep.edge_kind)
41 }));
42
43 if include_transformations
44 && matches!(dep.edge_kind, codeprism_core::EdgeKind::Calls)
45 {
46 transformations.push(serde_json::json!({
47 "type": "function_call",
48 "source": node.name,
49 "target": dep.target_node.name,
50 "transformation": "call"
51 }));
52 }
53 }
54 }
55 }
56 _ => {}
57 }
58
59 match direction {
60 "backward" | "both" => {
61 if let Ok(references) = server.graph_query().find_references(&node.id) {
62 for ref_info in references.iter().take(max_depth) {
63 flows.push(serde_json::json!({
64 "direction": "backward",
65 "from": {
66 "id": ref_info.source_node.id.to_hex(),
67 "name": ref_info.source_node.name,
68 "kind": format!("{:?}", ref_info.source_node.kind),
69 "file": ref_info.source_node.file.display().to_string()
70 },
71 "to": {
72 "id": node.id.to_hex(),
73 "name": node.name,
74 "kind": format!("{:?}", node.kind)
75 },
76 "edge_type": format!("{:?}", ref_info.edge_kind)
77 }));
78 }
79 }
80 }
81 _ => {}
82 }
83
84 serde_json::json!({
85 "target": node.name,
86 "analysis": {
87 "direction": direction,
88 "max_depth": max_depth,
89 "include_transformations": include_transformations,
90 "symbol_info": {
91 "id": node.id.to_hex(),
92 "name": node.name,
93 "kind": format!("{:?}", node.kind),
94 "file": node.file.display().to_string()
95 }
96 },
97 "data_flow": {
98 "flows": flows,
99 "transformations": transformations
100 },
101 "summary": {
102 "total_flow_steps": flows.len(),
103 "transformation_count": transformations.len(),
104 "directions_analyzed": match direction {
105 "both" => vec!["forward", "backward"],
106 dir => vec![dir]
107 }
108 }
109 })
110}
111
112async fn analyze_symbol_transitive_dependencies(
114 server: &CodePrismMcpServer,
115 node: &codeprism_core::Node,
116 max_depth: usize,
117 detect_cycles: bool,
118) -> Value {
119 let mut direct_deps = Vec::new();
120 let mut transitive_deps = Vec::new();
121 let mut visited = std::collections::HashSet::new();
122 let mut cycles = Vec::new();
123
124 if let Ok(dependencies) = server
126 .graph_query()
127 .find_dependencies(&node.id, codeprism_core::graph::DependencyType::Direct)
128 {
129 for dep in &dependencies {
130 direct_deps.push(serde_json::json!({
131 "id": dep.target_node.id.to_hex(),
132 "name": dep.target_node.name,
133 "kind": format!("{:?}", dep.target_node.kind),
134 "file": dep.target_node.file.display().to_string(),
135 "edge_type": format!("{:?}", dep.edge_kind)
136 }));
137 }
138
139 for dep in &dependencies {
141 collect_transitive_deps(
142 server,
143 &dep.target_node.id,
144 &mut transitive_deps,
145 &mut visited,
146 &mut cycles,
147 max_depth,
148 1,
149 detect_cycles,
150 &node.id,
151 )
152 .await;
153 }
154 }
155
156 serde_json::json!({
157 "target": node.name,
158 "analysis": {
159 "max_depth": max_depth,
160 "detect_cycles": detect_cycles,
161 "symbol_info": {
162 "id": node.id.to_hex(),
163 "name": node.name,
164 "kind": format!("{:?}", node.kind),
165 "file": node.file.display().to_string()
166 }
167 },
168 "dependencies": {
169 "direct": direct_deps,
170 "transitive": transitive_deps,
171 "cycles": cycles
172 },
173 "summary": {
174 "total_direct": direct_deps.len(),
175 "total_transitive": transitive_deps.len(),
176 "max_depth_reached": visited.len(),
177 "cycles_detected": cycles.len()
178 }
179 })
180}
181
182async fn collect_transitive_deps(
184 server: &CodePrismMcpServer,
185 current_id: &codeprism_core::NodeId,
186 transitive_deps: &mut Vec<Value>,
187 visited: &mut std::collections::HashSet<codeprism_core::NodeId>,
188 cycles: &mut Vec<Value>,
189 max_depth: usize,
190 current_depth: usize,
191 detect_cycles: bool,
192 root_id: &codeprism_core::NodeId,
193) {
194 if current_depth >= max_depth {
195 return;
196 }
197
198 if visited.contains(current_id) {
199 if detect_cycles && current_id == root_id {
200 cycles.push(serde_json::json!({
201 "cycle_detected": true,
202 "depth": current_depth,
203 "node_id": current_id.to_hex()
204 }));
205 }
206 return;
207 }
208
209 visited.insert(*current_id);
210
211 if let Ok(dependencies) = server
212 .graph_query()
213 .find_dependencies(current_id, codeprism_core::graph::DependencyType::Direct)
214 {
215 for dep in dependencies {
216 transitive_deps.push(serde_json::json!({
217 "id": dep.target_node.id.to_hex(),
218 "name": dep.target_node.name,
219 "kind": format!("{:?}", dep.target_node.kind),
220 "file": dep.target_node.file.display().to_string(),
221 "edge_type": format!("{:?}", dep.edge_kind),
222 "depth": current_depth
223 }));
224
225 Box::pin(collect_transitive_deps(
227 server,
228 &dep.target_node.id,
229 transitive_deps,
230 visited,
231 cycles,
232 max_depth,
233 current_depth + 1,
234 detect_cycles,
235 root_id,
236 ))
237 .await;
238 }
239 }
240}
241
242pub fn list_tools() -> Vec<Tool> {
244 vec![
245 Tool {
246 name: "trace_data_flow".to_string(),
247 title: Some("Trace Data Flow".to_string()),
248 description: "Track data flow through the codebase, following variable assignments, function parameters, and transformations".to_string(),
249 input_schema: serde_json::json!({
250 "type": "object",
251 "properties": {
252 "variable_or_parameter": {
253 "type": "string",
254 "description": "Symbol ID of variable or parameter to trace"
255 },
256 "direction": {
257 "type": "string",
258 "enum": ["forward", "backward", "both"],
259 "description": "Direction to trace data flow",
260 "default": "forward"
261 },
262 "include_transformations": {
263 "type": "boolean",
264 "description": "Include data transformations (method calls, assignments)",
265 "default": true
266 },
267 "max_depth": {
268 "type": "number",
269 "description": "Maximum depth for data flow tracing",
270 "default": 10,
271 "minimum": 1,
272 "maximum": 50
273 },
274 "follow_function_calls": {
275 "type": "boolean",
276 "description": "Follow data flow across function calls",
277 "default": true
278 },
279 "include_field_access": {
280 "type": "boolean",
281 "description": "Include field/attribute access patterns",
282 "default": true
283 }
284 },
285 "required": ["variable_or_parameter"]
286 }),
287 },
288 Tool {
289 name: "analyze_transitive_dependencies".to_string(),
290 title: Some("Analyze Transitive Dependencies".to_string()),
291 description: "Analyze complete dependency chains, detect cycles, and map transitive relationships".to_string(),
292 input_schema: serde_json::json!({
293 "type": "object",
294 "properties": {
295 "target": {
296 "type": "string",
297 "description": "Symbol ID or file path to analyze"
298 },
299 "max_depth": {
300 "type": "number",
301 "description": "Maximum depth for transitive analysis",
302 "default": 5,
303 "minimum": 1,
304 "maximum": 20
305 },
306 "detect_cycles": {
307 "type": "boolean",
308 "description": "Detect circular dependencies",
309 "default": true
310 },
311 "include_external_dependencies": {
312 "type": "boolean",
313 "description": "Include external/third-party dependencies",
314 "default": false
315 },
316 "dependency_types": {
317 "type": "array",
318 "items": {
319 "type": "string",
320 "enum": ["calls", "imports", "reads", "writes", "extends", "implements", "all"]
321 },
322 "description": "Types of dependencies to analyze",
323 "default": ["all"]
324 }
325 },
326 "required": ["target"]
327 }),
328 }
329 ]
330}
331
332pub async fn call_tool(
334 server: &CodePrismMcpServer,
335 params: &CallToolParams,
336) -> Result<CallToolResult> {
337 match params.name.as_str() {
338 "trace_data_flow" => trace_data_flow(server, params.arguments.as_ref()).await,
339 "analyze_transitive_dependencies" => {
340 analyze_transitive_dependencies(server, params.arguments.as_ref()).await
341 }
342 _ => Err(anyhow::anyhow!(
343 "Unknown flow analysis tool: {}",
344 params.name
345 )),
346 }
347}
348
349async fn trace_data_flow(
351 server: &CodePrismMcpServer,
352 arguments: Option<&Value>,
353) -> Result<CallToolResult> {
354 let args = arguments.ok_or_else(|| anyhow::anyhow!("Missing arguments"))?;
355
356 let variable_or_parameter = args
358 .get("variable_or_parameter")
359 .or_else(|| args.get("start_symbol"))
360 .or_else(|| args.get("symbol"))
361 .or_else(|| args.get("target"))
362 .and_then(|v| v.as_str())
363 .ok_or_else(|| {
364 anyhow::anyhow!(
365 "Missing variable_or_parameter parameter (or start_symbol, symbol, target)"
366 )
367 })?;
368
369 let direction = args
370 .get("direction")
371 .and_then(|v| v.as_str())
372 .unwrap_or("forward");
373
374 let max_depth = args
375 .get("max_depth")
376 .and_then(|v| v.as_u64())
377 .map(|v| v as usize)
378 .unwrap_or(10);
379
380 let include_transformations = args
381 .get("include_transformations")
382 .and_then(|v| v.as_bool())
383 .unwrap_or(true);
384
385 let result = if let Ok(symbol_results) =
387 server
388 .graph_query()
389 .search_symbols(variable_or_parameter, None, Some(1))
390 {
391 if let Some(symbol_result) = symbol_results.first() {
392 trace_symbol_data_flow(
394 server,
395 &symbol_result.node,
396 direction,
397 max_depth,
398 include_transformations,
399 )
400 .await
401 } else {
402 serde_json::json!({
403 "target": variable_or_parameter,
404 "error": "Symbol not found",
405 "suggestion": "Check if the symbol name is correct or try using a different identifier"
406 })
407 }
408 } else {
409 serde_json::json!({
410 "target": variable_or_parameter,
411 "error": "Failed to search for symbol",
412 "suggestion": "Ensure the repository is properly indexed"
413 })
414 };
415
416 Ok(CallToolResult {
417 content: vec![ToolContent::Text {
418 text: serde_json::to_string_pretty(&result)?,
419 }],
420 is_error: Some(result.get("error").is_some()),
421 })
422}
423
424async fn analyze_transitive_dependencies(
426 server: &CodePrismMcpServer,
427 arguments: Option<&Value>,
428) -> Result<CallToolResult> {
429 let args = arguments.ok_or_else(|| anyhow::anyhow!("Missing arguments"))?;
430
431 let target = args
433 .get("target")
434 .or_else(|| args.get("symbol"))
435 .and_then(|v| v.as_str())
436 .ok_or_else(|| anyhow::anyhow!("Missing target parameter (or symbol)"))?;
437
438 let max_depth = args
439 .get("max_depth")
440 .and_then(|v| v.as_u64())
441 .map(|v| v as usize)
442 .unwrap_or(5);
443
444 let detect_cycles = args
445 .get("detect_cycles")
446 .and_then(|v| v.as_bool())
447 .unwrap_or(true);
448
449 let result = if let Ok(symbol_results) =
451 server.graph_query().search_symbols(target, None, Some(1))
452 {
453 if let Some(symbol_result) = symbol_results.first() {
454 analyze_symbol_transitive_dependencies(
456 server,
457 &symbol_result.node,
458 max_depth,
459 detect_cycles,
460 )
461 .await
462 } else {
463 serde_json::json!({
464 "target": target,
465 "error": "Symbol not found",
466 "suggestion": "Check if the symbol name is correct or try using a different identifier"
467 })
468 }
469 } else {
470 serde_json::json!({
471 "target": target,
472 "error": "Failed to search for symbol",
473 "suggestion": "Ensure the repository is properly indexed"
474 })
475 };
476
477 Ok(CallToolResult {
478 content: vec![ToolContent::Text {
479 text: serde_json::to_string_pretty(&result)?,
480 }],
481 is_error: Some(result.get("error").is_some()),
482 })
483}