1use crate::error::ResearchError;
4use crate::models::{Language, ReferenceKind, SymbolReference};
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7use tree_sitter::{Language as TSLanguage, Parser};
8
9pub struct ReferenceTracker;
11
12#[derive(Debug, Clone)]
14pub struct ReferenceTrackingResult {
15 pub references_by_symbol: HashMap<String, Vec<SymbolReference>>,
17 pub references_by_file: HashMap<PathBuf, Vec<SymbolReference>>,
19}
20
21impl ReferenceTracker {
22 pub fn track_references(
33 path: &Path,
34 language: &Language,
35 content: &str,
36 known_symbols: &HashMap<String, String>,
37 ) -> Result<Vec<SymbolReference>, ResearchError> {
38 let mut parser = Parser::new();
39 let ts_language = Self::get_tree_sitter_language(language)?;
40 parser
41 .set_language(ts_language)
42 .map_err(|_| ResearchError::AnalysisFailed {
43 reason: format!("Failed to set language for {:?}", language),
44 context: "Reference tracking requires a valid tree-sitter language parser"
45 .to_string(),
46 })?;
47
48 let tree = parser.parse(content, None)
49 .ok_or_else(|| ResearchError::AnalysisFailed {
50 reason: "Failed to parse file".to_string(),
51 context: "Tree-sitter parser could not generate an abstract syntax tree for reference tracking".to_string(),
52 })?;
53
54 let mut references = Vec::new();
55 let root = tree.root_node();
56
57 Self::track_references_recursive(
59 &root,
60 content,
61 path,
62 language,
63 known_symbols,
64 &mut references,
65 )?;
66
67 Ok(references)
68 }
69
70 fn track_references_recursive(
72 node: &tree_sitter::Node,
73 content: &str,
74 path: &Path,
75 language: &Language,
76 known_symbols: &HashMap<String, String>,
77 references: &mut Vec<SymbolReference>,
78 ) -> Result<(), ResearchError> {
79 if let Some(reference) =
81 Self::extract_reference_from_node(node, content, path, language, known_symbols)
82 {
83 references.push(reference);
84 }
85
86 let mut cursor = node.walk();
88 for child in node.children(&mut cursor) {
89 Self::track_references_recursive(
90 &child,
91 content,
92 path,
93 language,
94 known_symbols,
95 references,
96 )?;
97 }
98
99 Ok(())
100 }
101
102 fn extract_reference_from_node(
104 node: &tree_sitter::Node,
105 content: &str,
106 path: &Path,
107 language: &Language,
108 known_symbols: &HashMap<String, String>,
109 ) -> Option<SymbolReference> {
110 match language {
111 Language::Rust => Self::extract_rust_reference(node, content, path, known_symbols),
112 Language::TypeScript => {
113 Self::extract_typescript_reference(node, content, path, known_symbols)
114 }
115 Language::Python => Self::extract_python_reference(node, content, path, known_symbols),
116 Language::Go => Self::extract_go_reference(node, content, path, known_symbols),
117 Language::Java => Self::extract_java_reference(node, content, path, known_symbols),
118 _ => None,
119 }
120 }
121
122 fn extract_rust_reference(
124 node: &tree_sitter::Node,
125 content: &str,
126 path: &Path,
127 known_symbols: &HashMap<String, String>,
128 ) -> Option<SymbolReference> {
129 let kind_str = node.kind();
130
131 if kind_str != "identifier" {
133 return None;
134 }
135
136 let name = node.utf8_text(content.as_bytes()).ok()?.to_string();
137
138 let symbol_id = known_symbols.get(&name)?.clone();
140
141 let line = Self::get_line_from_byte_offset(content, node.start_byte());
142
143 Some(SymbolReference {
144 symbol_id,
145 file: path.to_path_buf(),
146 line,
147 kind: ReferenceKind::Usage,
148 })
149 }
150
151 fn extract_typescript_reference(
153 node: &tree_sitter::Node,
154 content: &str,
155 path: &Path,
156 known_symbols: &HashMap<String, String>,
157 ) -> Option<SymbolReference> {
158 let kind_str = node.kind();
159
160 if kind_str != "identifier" && kind_str != "type_identifier" {
162 return None;
163 }
164
165 let name = node.utf8_text(content.as_bytes()).ok()?.to_string();
166
167 let symbol_id = known_symbols.get(&name)?.clone();
169
170 let line = Self::get_line_from_byte_offset(content, node.start_byte());
171
172 Some(SymbolReference {
173 symbol_id,
174 file: path.to_path_buf(),
175 line,
176 kind: ReferenceKind::Usage,
177 })
178 }
179
180 fn extract_python_reference(
182 node: &tree_sitter::Node,
183 content: &str,
184 path: &Path,
185 known_symbols: &HashMap<String, String>,
186 ) -> Option<SymbolReference> {
187 let kind_str = node.kind();
188
189 if kind_str != "identifier" {
191 return None;
192 }
193
194 let name = node.utf8_text(content.as_bytes()).ok()?.to_string();
195
196 let symbol_id = known_symbols.get(&name)?.clone();
198
199 let line = Self::get_line_from_byte_offset(content, node.start_byte());
200
201 Some(SymbolReference {
202 symbol_id,
203 file: path.to_path_buf(),
204 line,
205 kind: ReferenceKind::Usage,
206 })
207 }
208
209 fn extract_go_reference(
211 node: &tree_sitter::Node,
212 content: &str,
213 path: &Path,
214 known_symbols: &HashMap<String, String>,
215 ) -> Option<SymbolReference> {
216 let kind_str = node.kind();
217
218 if kind_str != "identifier" {
220 return None;
221 }
222
223 let name = node.utf8_text(content.as_bytes()).ok()?.to_string();
224
225 let symbol_id = known_symbols.get(&name)?.clone();
227
228 let line = Self::get_line_from_byte_offset(content, node.start_byte());
229
230 Some(SymbolReference {
231 symbol_id,
232 file: path.to_path_buf(),
233 line,
234 kind: ReferenceKind::Usage,
235 })
236 }
237
238 fn extract_java_reference(
240 node: &tree_sitter::Node,
241 content: &str,
242 path: &Path,
243 known_symbols: &HashMap<String, String>,
244 ) -> Option<SymbolReference> {
245 let kind_str = node.kind();
246
247 if kind_str != "identifier" {
249 return None;
250 }
251
252 let name = node.utf8_text(content.as_bytes()).ok()?.to_string();
253
254 let symbol_id = known_symbols.get(&name)?.clone();
256
257 let line = Self::get_line_from_byte_offset(content, node.start_byte());
258
259 Some(SymbolReference {
260 symbol_id,
261 file: path.to_path_buf(),
262 line,
263 kind: ReferenceKind::Usage,
264 })
265 }
266
267 fn get_tree_sitter_language(language: &Language) -> Result<TSLanguage, ResearchError> {
269 match language {
270 Language::Rust => Ok(tree_sitter_rust::language()),
271 Language::TypeScript => Ok(tree_sitter_typescript::language_typescript()),
272 Language::Python => Ok(tree_sitter_python::language()),
273 Language::Go => Ok(tree_sitter_go::language()),
274 Language::Java => Ok(tree_sitter_java::language()),
275 _ => Err(ResearchError::AnalysisFailed {
276 reason: format!("Unsupported language for reference tracking: {:?}", language),
277 context: "Reference tracking is only supported for Rust, TypeScript, Python, Go, and Java".to_string(),
278 }),
279 }
280 }
281
282 fn get_line_from_byte_offset(content: &str, byte_offset: usize) -> usize {
284 let prefix = &content[..byte_offset.min(content.len())];
285 prefix.matches('\n').count() + 1
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_track_rust_references() {
296 let content = "fn main() { let x = 5; println!(\"{}\", x); }";
297 let path = Path::new("test.rs");
298 let mut known_symbols = HashMap::new();
299 known_symbols.insert("x".to_string(), "test.rs:1:11".to_string());
300
301 let references =
302 ReferenceTracker::track_references(path, &Language::Rust, content, &known_symbols)
303 .expect("Failed to track references");
304
305 assert!(!references.is_empty());
307 }
308
309 #[test]
310 fn test_track_python_references() {
311 let content = "def foo():\n x = 5\n print(x)";
312 let path = Path::new("test.py");
313 let mut known_symbols = HashMap::new();
314 known_symbols.insert("x".to_string(), "test.py:2:5".to_string());
315
316 let references =
317 ReferenceTracker::track_references(path, &Language::Python, content, &known_symbols)
318 .expect("Failed to track references");
319
320 let _ = references;
322 }
323
324 #[test]
325 fn test_track_references_empty_symbols() {
326 let content = "fn main() { let x = 5; }";
327 let path = Path::new("test.rs");
328 let known_symbols = HashMap::new();
329
330 let references =
331 ReferenceTracker::track_references(path, &Language::Rust, content, &known_symbols)
332 .expect("Failed to track references");
333
334 assert!(references.is_empty());
336 }
337
338 #[test]
339 fn test_get_line_from_byte_offset() {
340 let content = "line1\nline2\nline3";
341 assert_eq!(ReferenceTracker::get_line_from_byte_offset(content, 0), 1);
343 assert_eq!(ReferenceTracker::get_line_from_byte_offset(content, 6), 2);
345 assert_eq!(ReferenceTracker::get_line_from_byte_offset(content, 12), 3);
347 }
348
349 #[test]
350 fn test_unsupported_language() {
351 let content = "some code";
352 let path = Path::new("test.unknown");
353 let known_symbols = HashMap::new();
354 let result = ReferenceTracker::track_references(
355 path,
356 &Language::Other("unknown".to_string()),
357 content,
358 &known_symbols,
359 );
360
361 assert!(result.is_err());
362 }
363}