use crate::error::{OverrideError, Result};
use tree_sitter::{Parser, Query, QueryCursor, StreamingIteratorMut};
#[derive(Debug, Clone, PartialEq)]
pub struct FunctionInfo {
pub name: String,
pub start_line: usize,
pub end_line: usize,
pub start_column: usize,
pub end_column: usize,
pub signature: String,
pub is_test: bool,
pub is_async: bool,
}
pub struct FunctionDetector {
parser: Parser,
function_query: Query,
}
impl FunctionDetector {
pub fn new() -> Result<Self> {
let mut parser = Parser::new();
let language = tree_sitter_rust::LANGUAGE;
parser
.set_language(&language.into())
.map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
let query_source = r#"
(function_item
name: (identifier) @function.name
) @function.definition
"#;
let function_query = Query::new(&language.into(), query_source)
.map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
Ok(Self {
parser,
function_query,
})
}
pub fn find_functions(&mut self, source: &str) -> Result<Vec<FunctionInfo>> {
let tree = self
.parser
.parse(source, None)
.ok_or_else(|| OverrideError::ParseError("Failed to parse source".to_string()))?;
let root_node = tree.root_node();
let mut cursor = QueryCursor::new();
let mut functions = Vec::new();
let mut matches = cursor.matches(&self.function_query, root_node, source.as_bytes());
while let Some(match_) = matches.next_mut() {
let mut name = None;
let mut node = None;
for capture in match_.captures {
let capture_name = &self.function_query.capture_names()[capture.index as usize];
match capture_name as &str {
"function.name" => {
name = Some(
capture
.node
.utf8_text(source.as_bytes())
.map_err(|e| OverrideError::TreeSitterError(e.to_string()))?
.to_string(),
);
}
"function.definition" => {
node = Some(capture.node);
}
_ => {}
}
}
if let (Some(name), Some(node)) = (name, node) {
let start_pos = node.start_position();
let end_pos = node.end_position();
let signature = node
.utf8_text(source.as_bytes())
.map_err(|e| OverrideError::TreeSitterError(e.to_string()))?
.lines()
.next()
.unwrap_or("")
.trim()
.to_string();
let is_test = self.has_test_attribute(&node, source)?;
let is_async = signature.starts_with("async ");
functions.push(FunctionInfo {
name,
start_line: start_pos.row,
end_line: end_pos.row,
start_column: start_pos.column,
end_column: end_pos.column,
signature,
is_test,
is_async,
});
}
}
Ok(functions)
}
pub fn find_function_at_line(
&mut self,
source: &str,
line: usize,
) -> Result<Option<FunctionInfo>> {
let functions = self.find_functions(source)?;
Ok(functions
.into_iter()
.find(|f| line >= f.start_line && line <= f.end_line))
}
pub fn find_function_at_position(
&mut self,
source: &str,
line: usize,
column: usize,
) -> Result<Option<FunctionInfo>> {
let functions = self.find_functions(source)?;
Ok(functions
.into_iter()
.filter(|f| {
line >= f.start_line
&& line <= f.end_line
&& (line > f.start_line || column >= f.start_column)
&& (line < f.end_line || column <= f.end_column)
})
.min_by_key(|f| (f.end_line - f.start_line, f.end_column - f.start_column)))
}
pub fn find_functions_by_name(
&mut self,
source: &str,
name: &str,
) -> Result<Vec<FunctionInfo>> {
let functions = self.find_functions(source)?;
Ok(functions
.into_iter()
.filter(|f| f.name.contains(name))
.collect())
}
fn has_test_attribute(&self, node: &tree_sitter::Node, source: &str) -> Result<bool> {
if let Ok(text) = node.utf8_text(source.as_bytes()) {
if text.contains("fn test_") {
return Ok(true);
}
}
if let Some(prev) = node.prev_sibling() {
if prev.kind() == "attribute_item" {
let text = prev
.utf8_text(source.as_bytes())
.map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
return Ok(text.contains("#[test]") || text.contains("#[tokio::test]"));
}
}
let mut current = *node;
while let Some(parent) = current.parent() {
if parent.kind() == "impl_item" {
break;
}
if let Some(prev) = parent.prev_sibling() {
if prev.kind() == "attribute_item" {
let text = prev
.utf8_text(source.as_bytes())
.map_err(|e| OverrideError::TreeSitterError(e.to_string()))?;
if text.contains("#[test]") || text.contains("#[tokio::test]") {
return Ok(true);
}
}
}
current = parent;
}
Ok(false)
}
}
impl Default for FunctionDetector {
fn default() -> Self {
Self::new().expect("Failed to create FunctionDetector")
}
}
pub fn find_function_at_position(
file_path: &std::path::Path,
line: usize,
column: Option<usize>,
) -> Result<Option<FunctionInfo>> {
let source = std::fs::read_to_string(file_path)?;
let mut detector = FunctionDetector::new()?;
if let Some(col) = column {
detector.find_function_at_position(&source, line, col)
} else {
detector.find_function_at_line(&source, line)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_find_simple_function() {
let source = r#"
fn main() {
println!("Hello, world!");
}
fn helper() -> i32 {
42
}
"#;
let mut detector = FunctionDetector::new().unwrap();
let functions = detector.find_functions(source).unwrap();
assert_eq!(functions.len(), 2);
assert_eq!(functions[0].name, "main");
assert_eq!(functions[1].name, "helper");
}
#[test]
fn test_find_impl_methods() {
let source = r#"
struct MyStruct;
impl MyStruct {
fn new() -> Self {
Self
}
fn method(&self) {
// method body
}
}
"#;
let mut detector = FunctionDetector::new().unwrap();
let functions = detector.find_functions(source).unwrap();
assert_eq!(functions.len(), 2);
assert_eq!(functions[0].name, "new");
assert_eq!(functions[1].name, "method");
}
#[test]
fn test_find_test_functions() {
let source = r#"
#[test]
fn test_something() {
assert_eq!(1 + 1, 2);
}
#[tokio::test]
async fn test_async() {
// async test
}
fn test_by_name() {
// This should also be detected as a test
}
"#;
let mut detector = FunctionDetector::new().unwrap();
let functions = detector.find_functions(source).unwrap();
assert_eq!(functions.len(), 3);
assert!(functions[0].is_test);
assert!(functions[1].is_test);
assert!(functions[1].is_async);
assert!(functions[2].is_test); }
#[test]
fn test_find_function_at_line() {
let source = r#"
fn first() {
// line 2
// line 3
}
fn second() {
// line 7
}
"#;
let mut detector = FunctionDetector::new().unwrap();
let func = detector.find_function_at_line(source, 2).unwrap();
assert_eq!(func.unwrap().name, "first");
let func = detector.find_function_at_line(source, 7).unwrap();
assert_eq!(func.unwrap().name, "second");
let func = detector.find_function_at_line(source, 5).unwrap();
assert!(func.is_none());
}
#[test]
fn test_find_function_at_position() {
let source = r#"
fn outer() {
fn inner() {
// line 3, various columns
}
}
"#;
let mut detector = FunctionDetector::new().unwrap();
let func = detector.find_function_at_position(source, 3, 8).unwrap();
assert_eq!(func.unwrap().name, "inner");
let func = detector.find_function_at_position(source, 1, 0).unwrap();
assert_eq!(func.unwrap().name, "outer");
}
}