ricecoder_research/
reference_tracker.rs

1//! Reference tracking across source files
2
3use crate::error::ResearchError;
4use crate::models::{Language, ReferenceKind, SymbolReference};
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7use tree_sitter::{Language as TSLanguage, Parser};
8
9/// Tracks symbol references across files
10pub struct ReferenceTracker;
11
12/// Result of reference tracking
13#[derive(Debug, Clone)]
14pub struct ReferenceTrackingResult {
15    /// Map from symbol ID to all references to that symbol
16    pub references_by_symbol: HashMap<String, Vec<SymbolReference>>,
17    /// Map from file path to all references in that file
18    pub references_by_file: HashMap<PathBuf, Vec<SymbolReference>>,
19}
20
21impl ReferenceTracker {
22    /// Track symbol references in a source file
23    ///
24    /// # Arguments
25    /// * `path` - Path to the source file
26    /// * `language` - Programming language of the file
27    /// * `content` - File content as string
28    /// * `known_symbols` - Map of symbol names to their IDs
29    ///
30    /// # Returns
31    /// A vector of symbol references found in the file
32    pub fn track_references(
33        path: &Path,
34        language: &Language,
35        content: &str,
36        known_symbols: &HashMap<String, String>,
37    ) -> Result<Vec<SymbolReference>, ResearchError> {
38        let mut parser = Parser::new();
39        let ts_language = Self::get_tree_sitter_language(language)?;
40        parser
41            .set_language(ts_language)
42            .map_err(|_| ResearchError::AnalysisFailed {
43                reason: format!("Failed to set language for {:?}", language),
44                context: "Reference tracking requires a valid tree-sitter language parser"
45                    .to_string(),
46            })?;
47
48        let tree = parser.parse(content, None)
49            .ok_or_else(|| ResearchError::AnalysisFailed {
50                reason: "Failed to parse file".to_string(),
51                context: "Tree-sitter parser could not generate an abstract syntax tree for reference tracking".to_string(),
52            })?;
53
54        let mut references = Vec::new();
55        let root = tree.root_node();
56
57        // Extract references based on language
58        Self::track_references_recursive(
59            &root,
60            content,
61            path,
62            language,
63            known_symbols,
64            &mut references,
65        )?;
66
67        Ok(references)
68    }
69
70    /// Recursively track references from AST nodes
71    fn track_references_recursive(
72        node: &tree_sitter::Node,
73        content: &str,
74        path: &Path,
75        language: &Language,
76        known_symbols: &HashMap<String, String>,
77        references: &mut Vec<SymbolReference>,
78    ) -> Result<(), ResearchError> {
79        // Extract references from current node if applicable
80        if let Some(reference) =
81            Self::extract_reference_from_node(node, content, path, language, known_symbols)
82        {
83            references.push(reference);
84        }
85
86        // Recursively process children
87        let mut cursor = node.walk();
88        for child in node.children(&mut cursor) {
89            Self::track_references_recursive(
90                &child,
91                content,
92                path,
93                language,
94                known_symbols,
95                references,
96            )?;
97        }
98
99        Ok(())
100    }
101
102    /// Extract a single reference from a node if it represents a symbol reference
103    fn extract_reference_from_node(
104        node: &tree_sitter::Node,
105        content: &str,
106        path: &Path,
107        language: &Language,
108        known_symbols: &HashMap<String, String>,
109    ) -> Option<SymbolReference> {
110        match language {
111            Language::Rust => Self::extract_rust_reference(node, content, path, known_symbols),
112            Language::TypeScript => {
113                Self::extract_typescript_reference(node, content, path, known_symbols)
114            }
115            Language::Python => Self::extract_python_reference(node, content, path, known_symbols),
116            Language::Go => Self::extract_go_reference(node, content, path, known_symbols),
117            Language::Java => Self::extract_java_reference(node, content, path, known_symbols),
118            _ => None,
119        }
120    }
121
122    /// Extract references from Rust code
123    fn extract_rust_reference(
124        node: &tree_sitter::Node,
125        content: &str,
126        path: &Path,
127        known_symbols: &HashMap<String, String>,
128    ) -> Option<SymbolReference> {
129        let kind_str = node.kind();
130
131        // Check for identifier usage (not definition)
132        if kind_str != "identifier" {
133            return None;
134        }
135
136        let name = node.utf8_text(content.as_bytes()).ok()?.to_string();
137
138        // Check if this identifier refers to a known symbol
139        let symbol_id = known_symbols.get(&name)?.clone();
140
141        let line = Self::get_line_from_byte_offset(content, node.start_byte());
142
143        Some(SymbolReference {
144            symbol_id,
145            file: path.to_path_buf(),
146            line,
147            kind: ReferenceKind::Usage,
148        })
149    }
150
151    /// Extract references from TypeScript/JavaScript code
152    fn extract_typescript_reference(
153        node: &tree_sitter::Node,
154        content: &str,
155        path: &Path,
156        known_symbols: &HashMap<String, String>,
157    ) -> Option<SymbolReference> {
158        let kind_str = node.kind();
159
160        // Check for identifier usage
161        if kind_str != "identifier" && kind_str != "type_identifier" {
162            return None;
163        }
164
165        let name = node.utf8_text(content.as_bytes()).ok()?.to_string();
166
167        // Check if this identifier refers to a known symbol
168        let symbol_id = known_symbols.get(&name)?.clone();
169
170        let line = Self::get_line_from_byte_offset(content, node.start_byte());
171
172        Some(SymbolReference {
173            symbol_id,
174            file: path.to_path_buf(),
175            line,
176            kind: ReferenceKind::Usage,
177        })
178    }
179
180    /// Extract references from Python code
181    fn extract_python_reference(
182        node: &tree_sitter::Node,
183        content: &str,
184        path: &Path,
185        known_symbols: &HashMap<String, String>,
186    ) -> Option<SymbolReference> {
187        let kind_str = node.kind();
188
189        // Check for identifier usage
190        if kind_str != "identifier" {
191            return None;
192        }
193
194        let name = node.utf8_text(content.as_bytes()).ok()?.to_string();
195
196        // Check if this identifier refers to a known symbol
197        let symbol_id = known_symbols.get(&name)?.clone();
198
199        let line = Self::get_line_from_byte_offset(content, node.start_byte());
200
201        Some(SymbolReference {
202            symbol_id,
203            file: path.to_path_buf(),
204            line,
205            kind: ReferenceKind::Usage,
206        })
207    }
208
209    /// Extract references from Go code
210    fn extract_go_reference(
211        node: &tree_sitter::Node,
212        content: &str,
213        path: &Path,
214        known_symbols: &HashMap<String, String>,
215    ) -> Option<SymbolReference> {
216        let kind_str = node.kind();
217
218        // Check for identifier usage
219        if kind_str != "identifier" {
220            return None;
221        }
222
223        let name = node.utf8_text(content.as_bytes()).ok()?.to_string();
224
225        // Check if this identifier refers to a known symbol
226        let symbol_id = known_symbols.get(&name)?.clone();
227
228        let line = Self::get_line_from_byte_offset(content, node.start_byte());
229
230        Some(SymbolReference {
231            symbol_id,
232            file: path.to_path_buf(),
233            line,
234            kind: ReferenceKind::Usage,
235        })
236    }
237
238    /// Extract references from Java code
239    fn extract_java_reference(
240        node: &tree_sitter::Node,
241        content: &str,
242        path: &Path,
243        known_symbols: &HashMap<String, String>,
244    ) -> Option<SymbolReference> {
245        let kind_str = node.kind();
246
247        // Check for identifier usage
248        if kind_str != "identifier" {
249            return None;
250        }
251
252        let name = node.utf8_text(content.as_bytes()).ok()?.to_string();
253
254        // Check if this identifier refers to a known symbol
255        let symbol_id = known_symbols.get(&name)?.clone();
256
257        let line = Self::get_line_from_byte_offset(content, node.start_byte());
258
259        Some(SymbolReference {
260            symbol_id,
261            file: path.to_path_buf(),
262            line,
263            kind: ReferenceKind::Usage,
264        })
265    }
266
267    /// Get tree-sitter language for a programming language
268    fn get_tree_sitter_language(language: &Language) -> Result<TSLanguage, ResearchError> {
269        match language {
270            Language::Rust => Ok(tree_sitter_rust::language()),
271            Language::TypeScript => Ok(tree_sitter_typescript::language_typescript()),
272            Language::Python => Ok(tree_sitter_python::language()),
273            Language::Go => Ok(tree_sitter_go::language()),
274            Language::Java => Ok(tree_sitter_java::language()),
275            _ => Err(ResearchError::AnalysisFailed {
276                reason: format!("Unsupported language for reference tracking: {:?}", language),
277                context: "Reference tracking is only supported for Rust, TypeScript, Python, Go, and Java".to_string(),
278            }),
279        }
280    }
281
282    /// Calculate line number from byte offset
283    fn get_line_from_byte_offset(content: &str, byte_offset: usize) -> usize {
284        let prefix = &content[..byte_offset.min(content.len())];
285        // Count newlines to get the line number (1-indexed)
286        prefix.matches('\n').count() + 1
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_track_rust_references() {
296        let content = "fn main() { let x = 5; println!(\"{}\", x); }";
297        let path = Path::new("test.rs");
298        let mut known_symbols = HashMap::new();
299        known_symbols.insert("x".to_string(), "test.rs:1:11".to_string());
300
301        let references =
302            ReferenceTracker::track_references(path, &Language::Rust, content, &known_symbols)
303                .expect("Failed to track references");
304
305        // Should find at least one reference to 'x'
306        assert!(!references.is_empty());
307    }
308
309    #[test]
310    fn test_track_python_references() {
311        let content = "def foo():\n    x = 5\n    print(x)";
312        let path = Path::new("test.py");
313        let mut known_symbols = HashMap::new();
314        known_symbols.insert("x".to_string(), "test.py:2:5".to_string());
315
316        let references =
317            ReferenceTracker::track_references(path, &Language::Python, content, &known_symbols)
318                .expect("Failed to track references");
319
320        // Should find references to 'x'
321        let _ = references;
322    }
323
324    #[test]
325    fn test_track_references_empty_symbols() {
326        let content = "fn main() { let x = 5; }";
327        let path = Path::new("test.rs");
328        let known_symbols = HashMap::new();
329
330        let references =
331            ReferenceTracker::track_references(path, &Language::Rust, content, &known_symbols)
332                .expect("Failed to track references");
333
334        // Should find no references since no symbols are known
335        assert!(references.is_empty());
336    }
337
338    #[test]
339    fn test_get_line_from_byte_offset() {
340        let content = "line1\nline2\nline3";
341        // Byte offset 0 is at the start of line 1
342        assert_eq!(ReferenceTracker::get_line_from_byte_offset(content, 0), 1);
343        // Byte offset 6 is after the newline, at the start of line 2
344        assert_eq!(ReferenceTracker::get_line_from_byte_offset(content, 6), 2);
345        // Byte offset 12 is after the second newline, at the start of line 3
346        assert_eq!(ReferenceTracker::get_line_from_byte_offset(content, 12), 3);
347    }
348
349    #[test]
350    fn test_unsupported_language() {
351        let content = "some code";
352        let path = Path::new("test.unknown");
353        let known_symbols = HashMap::new();
354        let result = ReferenceTracker::track_references(
355            path,
356            &Language::Other("unknown".to_string()),
357            content,
358            &known_symbols,
359        );
360
361        assert!(result.is_err());
362    }
363}