Skip to main content

atomcode_core/tool/
find_references.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde::Deserialize;
4use serde_json::json;
5use serde_json::Value;
6use tokio::process::Command;
7
8use super::{ApprovalRequirement, Tool, ToolContext, ToolDef, ToolResult};
9
10pub struct FindReferencesTool;
11
12#[derive(Deserialize)]
13struct FindReferencesArgs {
14    symbol: String,
15    path: Option<String>,
16}
17
18#[async_trait]
19impl Tool for FindReferencesTool {
20    fn definition(&self) -> ToolDef {
21        ToolDef {
22            name: "find_references",
23            description: "Find all references to a symbol (function, class, variable) across the project.\n\
24                Uses ripgrep for speed, then tree-sitter to classify each match as definition, call, or import.\n\
25                Returns the definition location + all call/usage sites with file:line context.\n\
26                Examples:\n\
27                - {\"symbol\": \"process_data\"} → finds definition + all calls across the project\n\
28                - {\"symbol\": \"UserService\", \"path\": \"src/\"} → search only in src/".to_string(),
29            parameters: json!({
30                "type": "object",
31                "properties": {
32                    "symbol": { "type": "string", "description": "Symbol name to find references for" },
33                    "path": { "type": "string", "description": "Directory to search in (default: working directory)" }
34                },
35                "required": ["symbol"]
36            }),
37        }
38    }
39
40    fn approval(&self, _args: &str) -> ApprovalRequirement {
41        ApprovalRequirement::AutoApprove
42    }
43
44    fn approval_with_context(&self, args: &str, ctx: &ToolContext) -> ApprovalRequirement {
45        let parsed = match serde_json::from_str::<FindReferencesArgs>(args) {
46            Ok(parsed) => parsed,
47            Err(_) => return self.approval(args),
48        };
49        let working_dir = match ctx.working_dir.try_read() {
50            Ok(wd) => wd.clone(),
51            Err(_) => return self.approval(args),
52        };
53        let raw_path = parsed.path.as_deref().unwrap_or(".");
54        match super::approval_for_path(raw_path, &working_dir, super::ExternalPathAction::Read) {
55            Ok(approval) => approval,
56            Err(_) => self.approval(args),
57        }
58    }
59
60    async fn execute(&self, args: &str, ctx: &ToolContext) -> Result<ToolResult> {
61        let parsed: FindReferencesArgs = serde_json::from_str(args)?;
62        let wd = ctx.working_dir.read().await.clone();
63        let search_dir =
64            match super::inspect_path_access(parsed.path.as_deref().unwrap_or("."), &wd) {
65                Ok(access) => access.path.to_string_lossy().to_string(),
66                Err(err) => {
67                    return Ok(ToolResult {
68                        call_id: String::new(),
69                        output: err.to_string(),
70                        success: false,
71                    });
72                }
73            };
74
75        // Use ripgrep to find all occurrences (word boundary match)
76        let pattern = format!(r"\b{}\b", regex::escape(&parsed.symbol));
77        let mut cmd = Command::new("rg");
78        cmd.args(&[
79            "--json",
80            "--line-number",
81            "--color=never",
82            "--max-count=30",
83            "-w", // word boundary
84            &pattern,
85            &search_dir,
86        ]);
87        crate::process_utils::suppress_console_window(&mut cmd);
88        let output = cmd.output().await;
89
90        let rg_output = match output {
91            Ok(o) => String::from_utf8_lossy(&o.stdout).to_string(),
92            Err(_) => {
93                return Ok(ToolResult {
94                    call_id: String::new(),
95                    output: "ripgrep not found. Install it: cargo install ripgrep".to_string(),
96                    success: false,
97                });
98            }
99        };
100
101        if rg_output.trim().is_empty() {
102            return Ok(ToolResult {
103                call_id: String::new(),
104                output: format!(
105                    "No references found for '{}' in {}",
106                    parsed.symbol, search_dir
107                ),
108                success: false,
109            });
110        }
111
112        // Classify each match using tree-sitter
113        let mut searcher = ctx.semantic.lock().await;
114        let mut definitions = Vec::new();
115        let mut references = Vec::new();
116
117        for matched in parse_rg_json_matches(&rg_output).into_iter().take(30) {
118            let file_path = std::path::Path::new(&matched.file);
119
120            // Try to determine if this is a definition or usage
121            let is_def = if let Some(symbols) = searcher.list_symbols(file_path) {
122                symbols
123                    .iter()
124                    .any(|s| s.name == parsed.symbol && s.start_line == matched.line_no)
125            } else {
126                // Heuristic: check if line contains definition keywords
127                let trimmed = matched.content.trim();
128                trimmed.starts_with("fn ")
129                    || trimmed.starts_with("pub fn ")
130                    || trimmed.starts_with("def ")
131                    || trimmed.starts_with("class ")
132                    || trimmed.starts_with("function ")
133                    || trimmed.starts_with("func ")
134                    || trimmed.starts_with("struct ")
135                    || trimmed.starts_with("pub struct ")
136                    || trimmed.starts_with("type ")
137                    || trimmed.starts_with("interface ")
138                    || trimmed.contains("= function")
139                    || trimmed.contains("=> {")
140            };
141
142            let short_file = matched
143                .file
144                .strip_prefix(&search_dir)
145                .unwrap_or(&matched.file)
146                .trim_start_matches('/');
147
148            let entry = format!("  {}:{}: {}", short_file, matched.line_no, matched.content.trim());
149            if is_def {
150                definitions.push(entry);
151            } else {
152                references.push(entry);
153            }
154        }
155
156        let mut out = format!("References for '{}' in {}:\n\n", parsed.symbol, search_dir);
157
158        if !definitions.is_empty() {
159            out.push_str("DEFINITIONS:\n");
160            for d in &definitions {
161                out.push_str(d);
162                out.push('\n');
163            }
164            out.push('\n');
165        }
166
167        if !references.is_empty() {
168            out.push_str(&format!("USAGES ({}):\n", references.len()));
169            for r in &references {
170                out.push_str(r);
171                out.push('\n');
172            }
173        }
174
175        Ok(ToolResult {
176            call_id: String::new(),
177            output: out,
178            success: true,
179        })
180    }
181}
182
183struct RgMatch {
184    file: String,
185    line_no: usize,
186    content: String,
187}
188
189fn parse_rg_json_matches(output: &str) -> Vec<RgMatch> {
190    output
191        .lines()
192        .filter_map(|line| {
193            let value: Value = serde_json::from_str(line).ok()?;
194            if value.get("type")?.as_str()? != "match" {
195                return None;
196            }
197            let data = value.get("data")?;
198            let file = data.get("path")?.get("text")?.as_str()?.to_string();
199            let line_no = data.get("line_number")?.as_u64()? as usize;
200            let content = data.get("lines")?.get("text")?.as_str()?.to_string();
201            Some(RgMatch {
202                file,
203                line_no,
204                content,
205            })
206        })
207        .collect()
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn parses_rg_json_with_colons_in_matched_file_paths() {
216        let output = r#"{"type":"begin","data":{"path":{"text":"src:main.rs"}}}
217{"type":"match","data":{"path":{"text":"src:main.rs"},"lines":{"text":"fn target_symbol() {}\n"},"line_number":1,"absolute_offset":0,"submatches":[]}}
218{"type":"match","data":{"path":{"text":"src:main.rs"},"lines":{"text":"fn caller() { target_symbol(); }\n"},"line_number":2,"absolute_offset":22,"submatches":[]}}
219{"type":"end","data":{"path":{"text":"src:main.rs"},"binary_offset":null,"stats":{}}}
220"#;
221
222        let matches = parse_rg_json_matches(output);
223        assert_eq!(matches.len(), 2);
224        assert_eq!(matches[0].file, "src:main.rs");
225        assert_eq!(matches[0].line_no, 1);
226        assert_eq!(matches[0].content, "fn target_symbol() {}\n");
227        assert_eq!(matches[1].file, "src:main.rs");
228        assert_eq!(matches[1].line_no, 2);
229        assert_eq!(matches[1].content, "fn caller() { target_symbol(); }\n");
230    }
231}