raz_override/
detector.rs

1use crate::error::{OverrideError, Result};
2use tree_sitter::{Parser, Query, QueryCursor, StreamingIteratorMut};
3
4/// Information about a detected function
5#[derive(Debug, Clone, PartialEq)]
6pub struct FunctionInfo {
7    /// The function name
8    pub name: String,
9    /// Start line (0-indexed)
10    pub start_line: usize,
11    /// End line (0-indexed)
12    pub end_line: usize,
13    /// Start column
14    pub start_column: usize,
15    /// End column  
16    pub end_column: usize,
17    /// The full function signature
18    pub signature: String,
19    /// Whether this is a test function
20    pub is_test: bool,
21    /// Whether this is an async function
22    pub is_async: bool,
23}
24
25/// Detects functions in Rust source code using tree-sitter
26pub struct FunctionDetector {
27    parser: Parser,
28    function_query: Query,
29}
30
31impl FunctionDetector {
32    /// Create a new function detector
33    pub fn new() -> Result<Self> {
34        let mut parser = Parser::new();
35        let language = tree_sitter_rust::LANGUAGE;
36        parser
37            .set_language(&language.into())
38            .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
39
40        // Query to find function items at any level
41        // We'll filter out duplicates by checking parent nodes
42        let query_source = r#"
43(function_item
44    name: (identifier) @function.name
45) @function.definition
46"#;
47
48        let function_query = Query::new(&language.into(), query_source)
49            .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
50
51        Ok(Self {
52            parser,
53            function_query,
54        })
55    }
56
57    /// Find all functions in a source file
58    pub fn find_functions(&mut self, source: &str) -> Result<Vec<FunctionInfo>> {
59        let tree = self
60            .parser
61            .parse(source, None)
62            .ok_or_else(|| OverrideError::ParseError("Failed to parse source".to_string()))?;
63
64        let root_node = tree.root_node();
65        let mut cursor = QueryCursor::new();
66
67        let mut functions = Vec::new();
68        let mut matches = cursor.matches(&self.function_query, root_node, source.as_bytes());
69
70        while let Some(match_) = matches.next_mut() {
71            let mut name = None;
72            let mut node = None;
73
74            for capture in match_.captures {
75                let capture_name = &self.function_query.capture_names()[capture.index as usize];
76                match capture_name as &str {
77                    "function.name" => {
78                        name = Some(
79                            capture
80                                .node
81                                .utf8_text(source.as_bytes())
82                                .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?
83                                .to_string(),
84                        );
85                    }
86                    "function.definition" => {
87                        node = Some(capture.node);
88                    }
89                    _ => {}
90                }
91            }
92
93            if let (Some(name), Some(node)) = (name, node) {
94                let start_pos = node.start_position();
95                let end_pos = node.end_position();
96
97                // Extract signature
98                let signature = node
99                    .utf8_text(source.as_bytes())
100                    .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?
101                    .lines()
102                    .next()
103                    .unwrap_or("")
104                    .trim()
105                    .to_string();
106
107                // Check for test attribute
108                let is_test = self.has_test_attribute(&node, source)?;
109
110                // Check if async
111                let is_async = signature.starts_with("async ");
112
113                functions.push(FunctionInfo {
114                    name,
115                    start_line: start_pos.row,
116                    end_line: end_pos.row,
117                    start_column: start_pos.column,
118                    end_column: end_pos.column,
119                    signature,
120                    is_test,
121                    is_async,
122                });
123            }
124        }
125
126        Ok(functions)
127    }
128
129    /// Find the function at a specific line
130    pub fn find_function_at_line(
131        &mut self,
132        source: &str,
133        line: usize,
134    ) -> Result<Option<FunctionInfo>> {
135        let functions = self.find_functions(source)?;
136
137        Ok(functions
138            .into_iter()
139            .find(|f| line >= f.start_line && line <= f.end_line))
140    }
141
142    /// Find the function at a specific position (line and column)
143    pub fn find_function_at_position(
144        &mut self,
145        source: &str,
146        line: usize,
147        column: usize,
148    ) -> Result<Option<FunctionInfo>> {
149        let functions = self.find_functions(source)?;
150
151        // Find the most specific function that contains the position
152        Ok(functions
153            .into_iter()
154            .filter(|f| {
155                line >= f.start_line
156                    && line <= f.end_line
157                    && (line > f.start_line || column >= f.start_column)
158                    && (line < f.end_line || column <= f.end_column)
159            })
160            .min_by_key(|f| (f.end_line - f.start_line, f.end_column - f.start_column)))
161    }
162
163    /// Find functions by name (supports partial matching)
164    pub fn find_functions_by_name(
165        &mut self,
166        source: &str,
167        name: &str,
168    ) -> Result<Vec<FunctionInfo>> {
169        let functions = self.find_functions(source)?;
170
171        Ok(functions
172            .into_iter()
173            .filter(|f| f.name.contains(name))
174            .collect())
175    }
176
177    /// Check if a node has a test attribute
178    fn has_test_attribute(&self, node: &tree_sitter::Node, source: &str) -> Result<bool> {
179        // Check if the function name starts with test_
180        if let Ok(text) = node.utf8_text(source.as_bytes()) {
181            if text.contains("fn test_") {
182                return Ok(true);
183            }
184        }
185
186        // Check for test attribute
187        if let Some(prev) = node.prev_sibling() {
188            if prev.kind() == "attribute_item" {
189                let text = prev
190                    .utf8_text(source.as_bytes())
191                    .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
192                return Ok(text.contains("#[test]") || text.contains("#[tokio::test]"));
193            }
194        }
195
196        // Check parent for attributes (in case of impl blocks)
197        let mut current = *node;
198        while let Some(parent) = current.parent() {
199            if parent.kind() == "impl_item" {
200                break;
201            }
202            if let Some(prev) = parent.prev_sibling() {
203                if prev.kind() == "attribute_item" {
204                    let text = prev
205                        .utf8_text(source.as_bytes())
206                        .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
207                    if text.contains("#[test]") || text.contains("#[tokio::test]") {
208                        return Ok(true);
209                    }
210                }
211            }
212            current = parent;
213        }
214
215        Ok(false)
216    }
217}
218
219impl Default for FunctionDetector {
220    fn default() -> Self {
221        Self::new().expect("Failed to create FunctionDetector")
222    }
223}
224
225/// Find function in a file at a specific position
226pub fn find_function_at_position(
227    file_path: &std::path::Path,
228    line: usize,
229    column: Option<usize>,
230) -> Result<Option<FunctionInfo>> {
231    let source = std::fs::read_to_string(file_path)?;
232    let mut detector = FunctionDetector::new()?;
233
234    if let Some(col) = column {
235        detector.find_function_at_position(&source, line, col)
236    } else {
237        detector.find_function_at_line(&source, line)
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_find_simple_function() {
247        let source = r#"
248fn main() {
249    println!("Hello, world!");
250}
251
252fn helper() -> i32 {
253    42
254}
255"#;
256
257        let mut detector = FunctionDetector::new().unwrap();
258        let functions = detector.find_functions(source).unwrap();
259
260        assert_eq!(functions.len(), 2);
261        assert_eq!(functions[0].name, "main");
262        assert_eq!(functions[1].name, "helper");
263    }
264
265    #[test]
266    fn test_find_impl_methods() {
267        let source = r#"
268struct MyStruct;
269
270impl MyStruct {
271    fn new() -> Self {
272        Self
273    }
274    
275    fn method(&self) {
276        // method body
277    }
278}
279"#;
280
281        let mut detector = FunctionDetector::new().unwrap();
282        let functions = detector.find_functions(source).unwrap();
283
284        assert_eq!(functions.len(), 2);
285        assert_eq!(functions[0].name, "new");
286        assert_eq!(functions[1].name, "method");
287    }
288
289    #[test]
290    fn test_find_test_functions() {
291        let source = r#"
292#[test]
293fn test_something() {
294    assert_eq!(1 + 1, 2);
295}
296
297#[tokio::test]
298async fn test_async() {
299    // async test
300}
301
302fn test_by_name() {
303    // This should also be detected as a test
304}
305"#;
306
307        let mut detector = FunctionDetector::new().unwrap();
308        let functions = detector.find_functions(source).unwrap();
309
310        assert_eq!(functions.len(), 3);
311        assert!(functions[0].is_test);
312        assert!(functions[1].is_test);
313        assert!(functions[1].is_async);
314        assert!(functions[2].is_test); // Detected by name
315    }
316
317    #[test]
318    fn test_find_function_at_line() {
319        let source = r#"
320fn first() {
321    // line 2
322    // line 3
323}
324
325fn second() {
326    // line 7
327}
328"#;
329
330        let mut detector = FunctionDetector::new().unwrap();
331
332        let func = detector.find_function_at_line(source, 2).unwrap();
333        assert_eq!(func.unwrap().name, "first");
334
335        let func = detector.find_function_at_line(source, 7).unwrap();
336        assert_eq!(func.unwrap().name, "second");
337
338        let func = detector.find_function_at_line(source, 5).unwrap();
339        assert!(func.is_none());
340    }
341
342    #[test]
343    fn test_find_function_at_position() {
344        let source = r#"
345fn outer() {
346    fn inner() {
347        // line 3, various columns
348    }
349}
350"#;
351
352        let mut detector = FunctionDetector::new().unwrap();
353
354        // Position inside inner function
355        let func = detector.find_function_at_position(source, 3, 8).unwrap();
356        assert_eq!(func.unwrap().name, "inner");
357
358        // Position at the edge of outer function
359        let func = detector.find_function_at_position(source, 1, 0).unwrap();
360        assert_eq!(func.unwrap().name, "outer");
361    }
362}