atomcode_core/tool/
trace_chain.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde::Deserialize;
4use serde_json::json;
5
6use super::{ApprovalRequirement, Tool, ToolContext, ToolDef, ToolResult};
7
8pub struct TraceChainTool;
9
10#[derive(Deserialize)]
11struct TraceChainArgs {
12 from: String,
13 to: String,
14}
15
16fn shorten_path(path: &std::path::Path) -> String {
17 let components: Vec<_> = path.components().collect();
18 if components.len() <= 3 {
19 return path.display().to_string();
20 }
21 let last3: Vec<_> = components[components.len() - 3..]
22 .iter()
23 .map(|c| c.as_os_str())
24 .collect();
25 format!(
26 ".../{}",
27 last3
28 .iter()
29 .map(|s| s.to_string_lossy())
30 .collect::<Vec<_>>()
31 .join("/")
32 )
33}
34
35#[async_trait]
36impl Tool for TraceChainTool {
37 fn definition(&self) -> ToolDef {
38 ToolDef {
39 name: "trace_chain",
40 description: "Find the shortest call chain between two symbols. Uses BFS to discover \
41 the path from `from` to `to` through function calls (max 10 hops).\n\
42 Example: {\"from\": \"handle_request\", \"to\": \"save_to_db\"}"
43 .to_string(),
44 parameters: json!({
45 "type": "object",
46 "properties": {
47 "from": { "type": "string", "description": "Source symbol name" },
48 "to": { "type": "string", "description": "Target symbol name" }
49 },
50 "required": ["from", "to"]
51 }),
52 }
53 }
54
55 fn approval(&self, _args: &str) -> ApprovalRequirement {
56 ApprovalRequirement::AutoApprove
57 }
58
59 async fn execute(&self, args: &str, ctx: &ToolContext) -> Result<ToolResult> {
60 let parsed: TraceChainArgs = serde_json::from_str(args)?;
61
62 let graph = ctx.graph.read().await;
63
64 if !graph.is_ready() {
65 return Ok(ToolResult {
66 call_id: String::new(),
67 output: "Code graph is not yet indexed. The graph will be available after the \
68 background indexer completes. Try again shortly."
69 .to_string(),
70 success: false,
71 });
72 }
73
74 let from_matches = graph.find_by_name(&parsed.from);
75 let to_matches = graph.find_by_name(&parsed.to);
76
77 if from_matches.is_empty() {
78 return Ok(ToolResult {
79 call_id: String::new(),
80 output: format!("Source symbol '{}' not found in code graph.", parsed.from),
81 success: false,
82 });
83 }
84 if to_matches.is_empty() {
85 return Ok(ToolResult {
86 call_id: String::new(),
87 output: format!("Target symbol '{}' not found in code graph.", parsed.to),
88 success: false,
89 });
90 }
91
92 let mut out = String::new();
94 let mut found_any = false;
95
96 for from_sym in &from_matches {
97 for to_sym in &to_matches {
98 if let Some(path) = graph.shortest_path(from_sym.id, to_sym.id) {
99 found_any = true;
100 out.push_str(&format!(
101 "Call chain from '{}' to '{}' ({} hops):\n",
102 parsed.from,
103 parsed.to,
104 path.len() - 1
105 ));
106 for (i, &sym_id) in path.iter().enumerate() {
107 if let Some(node) = graph.node(sym_id) {
108 let arrow = if i == 0 { ">" } else { "→" };
109 out.push_str(&format!(
110 " {} {} ({:?}) — {}\n",
111 arrow,
112 node.name,
113 node.kind,
114 shorten_path(&node.file)
115 ));
116 }
117 }
118 out.push('\n');
119 }
120 }
121 }
122
123 if !found_any {
124 out.push_str(&format!(
125 "No call chain found from '{}' to '{}' (max 10 hops).\n",
126 parsed.from, parsed.to
127 ));
128 }
129
130 Ok(ToolResult {
131 call_id: String::new(),
132 output: out,
133 success: found_any,
134 })
135 }
136}