use anyhow::{anyhow, Context, Result};
use serde::{Deserialize, Serialize};
use tree_sitter::{Language, Parser, Query, QueryCursor};
use crate::queries::{JAVASCRIPT_QUERY, PYTHON_QUERY, RUST_QUERY};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct FunctionNode {
pub start_byte: usize,
pub end_byte: usize,
pub start_line: usize,
pub end_line: usize,
pub body: String,
pub name: Option<String>,
}
impl FunctionNode {
pub fn new(start_byte: usize, end_byte: usize, start_line: usize, end_line: usize, body: String) -> Self {
Self {
start_byte,
end_byte,
start_line,
end_line,
body,
name: None,
}
}
pub fn with_name(start_byte: usize, end_byte: usize, start_line: usize, end_line: usize, body: String, name: String) -> Self {
Self {
start_byte,
end_byte,
start_line,
end_line,
body,
name: Some(name),
}
}
pub fn len(&self) -> usize {
self.end_byte - self.start_byte
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub fn extract_functions(code: &str, lang: Language) -> Result<Vec<FunctionNode>> {
let mut parser = Parser::new();
parser
.set_language(lang)
.context("Failed to set language for parser")?;
let tree = parser
.parse(code, None)
.ok_or_else(|| anyhow!("Failed to parse source code"))?;
let query_source = get_query_for_language(lang)?;
let query = Query::new(lang, query_source)
.context("Failed to compile Tree-sitter query")?;
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), code.as_bytes());
let mut functions = Vec::new();
for match_ in matches {
let mut func_start = None;
let mut func_end = None;
let mut func_start_line = None;
let mut func_end_line = None;
let mut func_name = None;
let mut func_body = None;
for capture in match_.captures {
let node = capture.node;
let capture_name = &query.capture_names()[capture.index as usize];
match capture_name.as_str() {
"func" => {
func_start = Some(node.start_byte());
func_end = Some(node.end_byte());
func_start_line = Some(node.start_position().row + 1);
func_end_line = Some(node.end_position().row + 1);
}
"function.name" => {
func_name = Some(
node.utf8_text(code.as_bytes())
.context("Invalid UTF-8 in function name")?
.to_string(),
);
}
"function.body" => {
func_body = Some(
node.utf8_text(code.as_bytes())
.context("Invalid UTF-8 in function body")?
.to_string(),
);
}
_ => {}
}
}
if let (Some(start), Some(end), Some(start_line), Some(end_line)) = (func_start, func_end, func_start_line, func_end_line) {
let body = func_body.unwrap_or_else(|| {
code[start..end].to_string()
});
let function = if let Some(name) = func_name {
FunctionNode::with_name(start, end, start_line, end_line, body, name)
} else {
FunctionNode::new(start, end, start_line, end_line, body)
};
functions.push(function);
}
}
Ok(functions)
}
fn get_query_for_language(lang: Language) -> Result<&'static str> {
let rust_lang = tree_sitter_rust::language();
let python_lang = tree_sitter_python::language();
let javascript_lang = tree_sitter_javascript::language();
if is_same_language(lang, rust_lang) {
Ok(&RUST_QUERY)
} else if is_same_language(lang, python_lang) {
Ok(&PYTHON_QUERY)
} else if is_same_language(lang, javascript_lang) {
Ok(&JAVASCRIPT_QUERY)
} else {
Err(anyhow!("Unsupported language"))
}
}
fn is_same_language(lang1: Language, lang2: Language) -> bool {
lang1.version() == lang2.version() &&
lang1.node_kind_count() == lang2.node_kind_count()
}
pub fn extract_rust_functions(code: &str) -> Result<Vec<FunctionNode>> {
extract_functions(code, tree_sitter_rust::language())
}
pub fn extract_python_functions(code: &str) -> Result<Vec<FunctionNode>> {
extract_functions(code, tree_sitter_python::language())
}
pub fn extract_javascript_functions(code: &str) -> Result<Vec<FunctionNode>> {
extract_functions(code, tree_sitter_javascript::language())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_rust_function() {
let code = r#"
fn hello_world() {
println!("Hello, world!");
}
fn add(a: i32, b: i32) -> i32 {
a + b
}
"#;
let functions = extract_rust_functions(code).unwrap();
assert_eq!(functions.len(), 2);
assert!(functions[0].name.as_deref() == Some("hello_world"));
assert!(functions[0].body.contains("println!"));
assert!(functions[1].name.as_deref() == Some("add"));
assert!(functions[1].body.contains("a + b"));
}
#[test]
fn test_extract_python_function() {
let code = r#"
def greet(name):
return f"Hello, {name}!"
def multiply(x, y):
return x * y
"#;
let functions = extract_python_functions(code).unwrap();
assert_eq!(functions.len(), 2);
assert!(functions[0].name.as_deref() == Some("greet"));
assert!(functions[1].name.as_deref() == Some("multiply"));
}
#[test]
fn test_extract_javascript_function() {
let code = r#"
function sayHello() {
console.log("Hello!");
}
const add = (a, b) => {
return a + b;
};
"#;
let functions = extract_javascript_functions(code).unwrap();
assert_eq!(functions.len(), 2);
assert!(functions[0].name.as_deref() == Some("sayHello"));
assert!(functions[0].body.contains("console.log"));
}
#[test]
fn test_function_node_length() {
let node = FunctionNode::new(10, 50, 1, 5, "test body".to_string());
assert_eq!(node.len(), 40);
assert!(!node.is_empty());
}
#[test]
fn test_empty_code() {
let functions = extract_rust_functions("").unwrap();
assert_eq!(functions.len(), 0);
}
#[test]
fn test_invalid_syntax() {
let code = "fn broken {{{";
let result = extract_rust_functions(code);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 0);
}
}