1use crate::error::{OverrideError, Result};
2use tree_sitter::{Parser, Query, QueryCursor, StreamingIteratorMut};
3
4#[derive(Debug, Clone, PartialEq)]
6pub struct FunctionInfo {
7 pub name: String,
9 pub start_line: usize,
11 pub end_line: usize,
13 pub start_column: usize,
15 pub end_column: usize,
17 pub signature: String,
19 pub is_test: bool,
21 pub is_async: bool,
23}
24
25pub struct FunctionDetector {
27 parser: Parser,
28 function_query: Query,
29}
30
31impl FunctionDetector {
32 pub fn new() -> Result<Self> {
34 let mut parser = Parser::new();
35 let language = tree_sitter_rust::LANGUAGE;
36 parser
37 .set_language(&language.into())
38 .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
39
40 let query_source = r#"
43(function_item
44 name: (identifier) @function.name
45) @function.definition
46"#;
47
48 let function_query = Query::new(&language.into(), query_source)
49 .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
50
51 Ok(Self {
52 parser,
53 function_query,
54 })
55 }
56
57 pub fn find_functions(&mut self, source: &str) -> Result<Vec<FunctionInfo>> {
59 let tree = self
60 .parser
61 .parse(source, None)
62 .ok_or_else(|| OverrideError::ParseError("Failed to parse source".to_string()))?;
63
64 let root_node = tree.root_node();
65 let mut cursor = QueryCursor::new();
66
67 let mut functions = Vec::new();
68 let mut matches = cursor.matches(&self.function_query, root_node, source.as_bytes());
69
70 while let Some(match_) = matches.next_mut() {
71 let mut name = None;
72 let mut node = None;
73
74 for capture in match_.captures {
75 let capture_name = &self.function_query.capture_names()[capture.index as usize];
76 match capture_name as &str {
77 "function.name" => {
78 name = Some(
79 capture
80 .node
81 .utf8_text(source.as_bytes())
82 .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?
83 .to_string(),
84 );
85 }
86 "function.definition" => {
87 node = Some(capture.node);
88 }
89 _ => {}
90 }
91 }
92
93 if let (Some(name), Some(node)) = (name, node) {
94 let start_pos = node.start_position();
95 let end_pos = node.end_position();
96
97 let signature = node
99 .utf8_text(source.as_bytes())
100 .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?
101 .lines()
102 .next()
103 .unwrap_or("")
104 .trim()
105 .to_string();
106
107 let is_test = self.has_test_attribute(&node, source)?;
109
110 let is_async = signature.starts_with("async ");
112
113 functions.push(FunctionInfo {
114 name,
115 start_line: start_pos.row,
116 end_line: end_pos.row,
117 start_column: start_pos.column,
118 end_column: end_pos.column,
119 signature,
120 is_test,
121 is_async,
122 });
123 }
124 }
125
126 Ok(functions)
127 }
128
129 pub fn find_function_at_line(
131 &mut self,
132 source: &str,
133 line: usize,
134 ) -> Result<Option<FunctionInfo>> {
135 let functions = self.find_functions(source)?;
136
137 Ok(functions
138 .into_iter()
139 .find(|f| line >= f.start_line && line <= f.end_line))
140 }
141
142 pub fn find_function_at_position(
144 &mut self,
145 source: &str,
146 line: usize,
147 column: usize,
148 ) -> Result<Option<FunctionInfo>> {
149 let functions = self.find_functions(source)?;
150
151 Ok(functions
153 .into_iter()
154 .filter(|f| {
155 line >= f.start_line
156 && line <= f.end_line
157 && (line > f.start_line || column >= f.start_column)
158 && (line < f.end_line || column <= f.end_column)
159 })
160 .min_by_key(|f| (f.end_line - f.start_line, f.end_column - f.start_column)))
161 }
162
163 pub fn find_functions_by_name(
165 &mut self,
166 source: &str,
167 name: &str,
168 ) -> Result<Vec<FunctionInfo>> {
169 let functions = self.find_functions(source)?;
170
171 Ok(functions
172 .into_iter()
173 .filter(|f| f.name.contains(name))
174 .collect())
175 }
176
177 fn has_test_attribute(&self, node: &tree_sitter::Node, source: &str) -> Result<bool> {
179 if let Ok(text) = node.utf8_text(source.as_bytes()) {
181 if text.contains("fn test_") {
182 return Ok(true);
183 }
184 }
185
186 if let Some(prev) = node.prev_sibling() {
188 if prev.kind() == "attribute_item" {
189 let text = prev
190 .utf8_text(source.as_bytes())
191 .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
192 return Ok(text.contains("#[test]") || text.contains("#[tokio::test]"));
193 }
194 }
195
196 let mut current = *node;
198 while let Some(parent) = current.parent() {
199 if parent.kind() == "impl_item" {
200 break;
201 }
202 if let Some(prev) = parent.prev_sibling() {
203 if prev.kind() == "attribute_item" {
204 let text = prev
205 .utf8_text(source.as_bytes())
206 .map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
207 if text.contains("#[test]") || text.contains("#[tokio::test]") {
208 return Ok(true);
209 }
210 }
211 }
212 current = parent;
213 }
214
215 Ok(false)
216 }
217}
218
219impl Default for FunctionDetector {
220 fn default() -> Self {
221 Self::new().expect("Failed to create FunctionDetector")
222 }
223}
224
225pub fn find_function_at_position(
227 file_path: &std::path::Path,
228 line: usize,
229 column: Option<usize>,
230) -> Result<Option<FunctionInfo>> {
231 let source = std::fs::read_to_string(file_path)?;
232 let mut detector = FunctionDetector::new()?;
233
234 if let Some(col) = column {
235 detector.find_function_at_position(&source, line, col)
236 } else {
237 detector.find_function_at_line(&source, line)
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_find_simple_function() {
247 let source = r#"
248fn main() {
249 println!("Hello, world!");
250}
251
252fn helper() -> i32 {
253 42
254}
255"#;
256
257 let mut detector = FunctionDetector::new().unwrap();
258 let functions = detector.find_functions(source).unwrap();
259
260 assert_eq!(functions.len(), 2);
261 assert_eq!(functions[0].name, "main");
262 assert_eq!(functions[1].name, "helper");
263 }
264
265 #[test]
266 fn test_find_impl_methods() {
267 let source = r#"
268struct MyStruct;
269
270impl MyStruct {
271 fn new() -> Self {
272 Self
273 }
274
275 fn method(&self) {
276 // method body
277 }
278}
279"#;
280
281 let mut detector = FunctionDetector::new().unwrap();
282 let functions = detector.find_functions(source).unwrap();
283
284 assert_eq!(functions.len(), 2);
285 assert_eq!(functions[0].name, "new");
286 assert_eq!(functions[1].name, "method");
287 }
288
289 #[test]
290 fn test_find_test_functions() {
291 let source = r#"
292#[test]
293fn test_something() {
294 assert_eq!(1 + 1, 2);
295}
296
297#[tokio::test]
298async fn test_async() {
299 // async test
300}
301
302fn test_by_name() {
303 // This should also be detected as a test
304}
305"#;
306
307 let mut detector = FunctionDetector::new().unwrap();
308 let functions = detector.find_functions(source).unwrap();
309
310 assert_eq!(functions.len(), 3);
311 assert!(functions[0].is_test);
312 assert!(functions[1].is_test);
313 assert!(functions[1].is_async);
314 assert!(functions[2].is_test); }
316
317 #[test]
318 fn test_find_function_at_line() {
319 let source = r#"
320fn first() {
321 // line 2
322 // line 3
323}
324
325fn second() {
326 // line 7
327}
328"#;
329
330 let mut detector = FunctionDetector::new().unwrap();
331
332 let func = detector.find_function_at_line(source, 2).unwrap();
333 assert_eq!(func.unwrap().name, "first");
334
335 let func = detector.find_function_at_line(source, 7).unwrap();
336 assert_eq!(func.unwrap().name, "second");
337
338 let func = detector.find_function_at_line(source, 5).unwrap();
339 assert!(func.is_none());
340 }
341
342 #[test]
343 fn test_find_function_at_position() {
344 let source = r#"
345fn outer() {
346 fn inner() {
347 // line 3, various columns
348 }
349}
350"#;
351
352 let mut detector = FunctionDetector::new().unwrap();
353
354 let func = detector.find_function_at_position(source, 3, 8).unwrap();
356 assert_eq!(func.unwrap().name, "inner");
357
358 let func = detector.find_function_at_position(source, 1, 0).unwrap();
360 assert_eq!(func.unwrap().name, "outer");
361 }
362}