1use anyhow::{anyhow, Context, Result};
7use serde::{Deserialize, Serialize};
8use tree_sitter::{Language, Parser, Query, QueryCursor};
9
10use crate::queries::{JAVASCRIPT_QUERY, PYTHON_QUERY, RUST_QUERY};
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14pub struct FunctionNode {
15 pub start_byte: usize,
17 pub end_byte: usize,
19 pub start_line: usize,
21 pub end_line: usize,
23 pub body: String,
25 pub name: Option<String>,
27}
28
29impl FunctionNode {
30 pub fn new(
32 start_byte: usize,
33 end_byte: usize,
34 start_line: usize,
35 end_line: usize,
36 body: String,
37 ) -> Self {
38 Self {
39 start_byte,
40 end_byte,
41 start_line,
42 end_line,
43 body,
44 name: None,
45 }
46 }
47
48 pub fn with_name(
50 start_byte: usize,
51 end_byte: usize,
52 start_line: usize,
53 end_line: usize,
54 body: String,
55 name: String,
56 ) -> Self {
57 Self {
58 start_byte,
59 end_byte,
60 start_line,
61 end_line,
62 body,
63 name: Some(name),
64 }
65 }
66
67 pub fn len(&self) -> usize {
69 self.end_byte - self.start_byte
70 }
71
72 pub fn is_empty(&self) -> bool {
74 self.len() == 0
75 }
76}
77
78pub fn extract_functions(code: &str, lang: Language) -> Result<Vec<FunctionNode>> {
93 let mut parser = Parser::new();
95 parser
96 .set_language(lang)
97 .context("Failed to set language for parser")?;
98
99 let tree = parser
101 .parse(code, None)
102 .ok_or_else(|| anyhow!("Failed to parse source code"))?;
103
104 let query_source = get_query_for_language(lang)?;
106
107 let query = Query::new(lang, query_source).context("Failed to compile Tree-sitter query")?;
109
110 let mut cursor = QueryCursor::new();
112 let matches = cursor.matches(&query, tree.root_node(), code.as_bytes());
113
114 let mut functions = Vec::new();
116
117 for match_ in matches {
118 let mut func_start = None;
119 let mut func_end = None;
120 let mut func_start_line = None;
121 let mut func_end_line = None;
122 let mut func_name = None;
123 let mut func_body = None;
124
125 for capture in match_.captures {
126 let node = capture.node;
127 let capture_name = &query.capture_names()[capture.index as usize];
128
129 match capture_name.as_str() {
130 "func" => {
131 func_start = Some(node.start_byte());
132 func_end = Some(node.end_byte());
133 func_start_line = Some(node.start_position().row + 1);
135 func_end_line = Some(node.end_position().row + 1);
136 }
137 "function.name" => {
138 func_name = Some(
139 node.utf8_text(code.as_bytes())
140 .context("Invalid UTF-8 in function name")?
141 .to_string(),
142 );
143 }
144 "function.body" => {
145 func_body = Some(
146 node.utf8_text(code.as_bytes())
147 .context("Invalid UTF-8 in function body")?
148 .to_string(),
149 );
150 }
151 _ => {}
152 }
153 }
154
155 if let (Some(start), Some(end), Some(start_line), Some(end_line)) =
157 (func_start, func_end, func_start_line, func_end_line)
158 {
159 let body = func_body.unwrap_or_else(|| code[start..end].to_string());
160
161 let function = if let Some(name) = func_name {
162 FunctionNode::with_name(start, end, start_line, end_line, body, name)
163 } else {
164 FunctionNode::new(start, end, start_line, end_line, body)
165 };
166
167 functions.push(function);
168 }
169 }
170
171 Ok(functions)
172}
173
174fn get_query_for_language(lang: Language) -> Result<&'static str> {
176 let rust_lang = tree_sitter_rust::language();
180 let python_lang = tree_sitter_python::language();
181 let javascript_lang = tree_sitter_javascript::language();
182
183 if is_same_language(lang, rust_lang) {
184 Ok(&RUST_QUERY)
185 } else if is_same_language(lang, python_lang) {
186 Ok(&PYTHON_QUERY)
187 } else if is_same_language(lang, javascript_lang) {
188 Ok(&JAVASCRIPT_QUERY)
189 } else {
190 Err(anyhow!("Unsupported language"))
191 }
192}
193
194fn is_same_language(lang1: Language, lang2: Language) -> bool {
199 lang1.version() == lang2.version() && lang1.node_kind_count() == lang2.node_kind_count()
202}
203
204pub fn extract_rust_functions(code: &str) -> Result<Vec<FunctionNode>> {
206 extract_functions(code, tree_sitter_rust::language())
207}
208
209pub fn extract_python_functions(code: &str) -> Result<Vec<FunctionNode>> {
211 extract_functions(code, tree_sitter_python::language())
212}
213
214pub fn extract_javascript_functions(code: &str) -> Result<Vec<FunctionNode>> {
216 extract_functions(code, tree_sitter_javascript::language())
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_extract_rust_function() {
225 let code = r#"
226fn hello_world() {
227 println!("Hello, world!");
228}
229
230fn add(a: i32, b: i32) -> i32 {
231 a + b
232}
233"#;
234
235 let functions = extract_rust_functions(code).unwrap();
236 assert_eq!(functions.len(), 2);
237
238 assert!(functions[0].name.as_deref() == Some("hello_world"));
240 assert!(functions[0].body.contains("println!"));
241
242 assert!(functions[1].name.as_deref() == Some("add"));
244 assert!(functions[1].body.contains("a + b"));
245 }
246
247 #[test]
248 fn test_extract_python_function() {
249 let code = r#"
250def greet(name):
251 return f"Hello, {name}!"
252
253def multiply(x, y):
254 return x * y
255"#;
256
257 let functions = extract_python_functions(code).unwrap();
258 assert_eq!(functions.len(), 2);
259
260 assert!(functions[0].name.as_deref() == Some("greet"));
261 assert!(functions[1].name.as_deref() == Some("multiply"));
262 }
263
264 #[test]
265 fn test_extract_javascript_function() {
266 let code = r#"
267function sayHello() {
268 console.log("Hello!");
269}
270
271const add = (a, b) => {
272 return a + b;
273};
274"#;
275
276 let functions = extract_javascript_functions(code).unwrap();
277 assert_eq!(functions.len(), 2);
278
279 assert!(functions[0].name.as_deref() == Some("sayHello"));
280 assert!(functions[0].body.contains("console.log"));
281 }
282
283 #[test]
284 fn test_function_node_length() {
285 let node = FunctionNode::new(10, 50, 1, 5, "test body".to_string());
286 assert_eq!(node.len(), 40);
287 assert!(!node.is_empty());
288 }
289
290 #[test]
291 fn test_empty_code() {
292 let functions = extract_rust_functions("").unwrap();
293 assert_eq!(functions.len(), 0);
294 }
295
296 #[test]
297 fn test_invalid_syntax() {
298 let code = "fn broken {{{";
299 let result = extract_rust_functions(code);
300 assert!(result.is_ok());
302 assert_eq!(result.unwrap().len(), 0);
303 }
304}