Skip to main content

atomcode_core/tool/
trace_callees.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde::Deserialize;
4use serde_json::json;
5
6use super::{ApprovalRequirement, Tool, ToolContext, ToolDef, ToolResult};
7
8pub struct TraceCalleesTool;
9
10#[derive(Deserialize)]
11struct TraceCalleesArgs {
12    symbol: String,
13    depth: Option<usize>,
14}
15
16fn shorten_path(path: &std::path::Path) -> String {
17    let components: Vec<_> = path.components().collect();
18    if components.len() <= 3 {
19        return path.display().to_string();
20    }
21    let last3: Vec<_> = components[components.len() - 3..]
22        .iter()
23        .map(|c| c.as_os_str())
24        .collect();
25    format!(
26        ".../{}",
27        last3
28            .iter()
29            .map(|s| s.to_string_lossy())
30            .collect::<Vec<_>>()
31            .join("/")
32    )
33}
34
35#[async_trait]
36impl Tool for TraceCalleesTool {
37    fn definition(&self) -> ToolDef {
38        ToolDef {
39            name: "trace_callees",
40            description: "Trace all callees of a symbol (forward call graph). Uses BFS to find \
41                functions/methods that are directly or transitively called by the given symbol.\n\
42                Returns a tree showing callee chains up to the specified depth.\n\
43                Example: {\"symbol\": \"main\", \"depth\": 2}"
44                .to_string(),
45            parameters: json!({
46                "type": "object",
47                "properties": {
48                    "symbol": { "type": "string", "description": "Symbol name to trace callees for" },
49                    "depth": { "type": "integer", "description": "Max traversal depth (default: 3, max: 5)" }
50                },
51                "required": ["symbol"]
52            }),
53        }
54    }
55
56    fn approval(&self, _args: &str) -> ApprovalRequirement {
57        ApprovalRequirement::AutoApprove
58    }
59
60    async fn execute(&self, args: &str, ctx: &ToolContext) -> Result<ToolResult> {
61        let parsed: TraceCalleesArgs = serde_json::from_str(args)?;
62        let depth = parsed.depth.unwrap_or(3).min(5);
63
64        let graph = ctx.graph.read().await;
65
66        if !graph.is_ready() {
67            return Ok(ToolResult {
68                call_id: String::new(),
69                output: "Code graph is not yet indexed. The graph will be available after the \
70                    background indexer completes. Try again shortly."
71                    .to_string(),
72                success: false,
73            });
74        }
75
76        let matches = graph.find_by_name(&parsed.symbol);
77        if matches.is_empty() {
78            return Ok(ToolResult {
79                call_id: String::new(),
80                output: format!(
81                    "Symbol '{}' not found in code graph ({} symbols indexed).",
82                    parsed.symbol,
83                    graph.node_count()
84                ),
85                success: false,
86            });
87        }
88
89        let mut out = String::new();
90        for sym in &matches {
91            out.push_str(&format!(
92                "Callees of {} ({:?}) in {}:\n",
93                sym.name,
94                sym.kind,
95                shorten_path(&sym.file)
96            ));
97
98            let callees = graph.trace_callees(sym.id, depth);
99            if callees.is_empty() {
100                out.push_str("  (no callees found)\n");
101            } else {
102                for (callee_id, d) in &callees {
103                    if let Some(node) = graph.node(*callee_id) {
104                        let indent = "  ".repeat(*d);
105                        out.push_str(&format!(
106                            "{}[depth {}] {} ({:?}) — {}\n",
107                            indent,
108                            d,
109                            node.name,
110                            node.kind,
111                            shorten_path(&node.file)
112                        ));
113                    }
114                }
115            }
116            out.push('\n');
117        }
118
119        Ok(ToolResult {
120            call_id: String::new(),
121            output: out,
122            success: true,
123        })
124    }
125}