Skip to main content

atomcode_core/tool/
trace_chain.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde::Deserialize;
4use serde_json::json;
5
6use super::{ApprovalRequirement, Tool, ToolContext, ToolDef, ToolResult};
7
8pub struct TraceChainTool;
9
10#[derive(Deserialize)]
11struct TraceChainArgs {
12    from: String,
13    to: String,
14}
15
16fn shorten_path(path: &std::path::Path) -> String {
17    let components: Vec<_> = path.components().collect();
18    if components.len() <= 3 {
19        return path.display().to_string();
20    }
21    let last3: Vec<_> = components[components.len() - 3..]
22        .iter()
23        .map(|c| c.as_os_str())
24        .collect();
25    format!(
26        ".../{}",
27        last3
28            .iter()
29            .map(|s| s.to_string_lossy())
30            .collect::<Vec<_>>()
31            .join("/")
32    )
33}
34
35#[async_trait]
36impl Tool for TraceChainTool {
37    fn definition(&self) -> ToolDef {
38        ToolDef {
39            name: "trace_chain",
40            description: "Find the shortest call chain between two symbols. Uses BFS to discover \
41                the path from `from` to `to` through function calls (max 10 hops).\n\
42                Example: {\"from\": \"handle_request\", \"to\": \"save_to_db\"}"
43                .to_string(),
44            parameters: json!({
45                "type": "object",
46                "properties": {
47                    "from": { "type": "string", "description": "Source symbol name" },
48                    "to": { "type": "string", "description": "Target symbol name" }
49                },
50                "required": ["from", "to"]
51            }),
52        }
53    }
54
55    fn approval(&self, _args: &str) -> ApprovalRequirement {
56        ApprovalRequirement::AutoApprove
57    }
58
59    async fn execute(&self, args: &str, ctx: &ToolContext) -> Result<ToolResult> {
60        let parsed: TraceChainArgs = serde_json::from_str(args)?;
61
62        let graph = ctx.graph.read().await;
63
64        if !graph.is_ready() {
65            return Ok(ToolResult {
66                call_id: String::new(),
67                output: "Code graph is not yet indexed. The graph will be available after the \
68                    background indexer completes. Try again shortly."
69                    .to_string(),
70                success: false,
71            });
72        }
73
74        let from_matches = graph.find_by_name(&parsed.from);
75        let to_matches = graph.find_by_name(&parsed.to);
76
77        if from_matches.is_empty() {
78            return Ok(ToolResult {
79                call_id: String::new(),
80                output: format!("Source symbol '{}' not found in code graph.", parsed.from),
81                success: false,
82            });
83        }
84        if to_matches.is_empty() {
85            return Ok(ToolResult {
86                call_id: String::new(),
87                output: format!("Target symbol '{}' not found in code graph.", parsed.to),
88                success: false,
89            });
90        }
91
92        // Try all combinations of from/to matches to find any path
93        let mut out = String::new();
94        let mut found_any = false;
95
96        for from_sym in &from_matches {
97            for to_sym in &to_matches {
98                if let Some(path) = graph.shortest_path(from_sym.id, to_sym.id) {
99                    found_any = true;
100                    out.push_str(&format!(
101                        "Call chain from '{}' to '{}' ({} hops):\n",
102                        parsed.from,
103                        parsed.to,
104                        path.len() - 1
105                    ));
106                    for (i, &sym_id) in path.iter().enumerate() {
107                        if let Some(node) = graph.node(sym_id) {
108                            let arrow = if i == 0 { ">" } else { "→" };
109                            out.push_str(&format!(
110                                "  {} {} ({:?}) — {}\n",
111                                arrow,
112                                node.name,
113                                node.kind,
114                                shorten_path(&node.file)
115                            ));
116                        }
117                    }
118                    out.push('\n');
119                }
120            }
121        }
122
123        if !found_any {
124            out.push_str(&format!(
125                "No call chain found from '{}' to '{}' (max 10 hops).\n",
126                parsed.from, parsed.to
127            ));
128        }
129
130        Ok(ToolResult {
131            call_id: String::new(),
132            output: out,
133            success: found_any,
134        })
135    }
136}