1use anyhow::{anyhow, Context, Result};
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9use tree_sitter::{Language, Parser, Query, QueryCursor};
10
11use crate::queries::{JAVASCRIPT_QUERY, PYTHON_QUERY, RUST_QUERY};
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
15pub struct FunctionNode {
16 pub start_byte: usize,
18 pub end_byte: usize,
20 pub start_line: usize,
22 pub end_line: usize,
24 pub body: String,
26 pub name: Option<String>,
28}
29
30impl FunctionNode {
31 pub fn new(
33 start_byte: usize,
34 end_byte: usize,
35 start_line: usize,
36 end_line: usize,
37 body: String,
38 ) -> Self {
39 Self {
40 start_byte,
41 end_byte,
42 start_line,
43 end_line,
44 body,
45 name: None,
46 }
47 }
48
49 pub fn with_name(
51 start_byte: usize,
52 end_byte: usize,
53 start_line: usize,
54 end_line: usize,
55 body: String,
56 name: String,
57 ) -> Self {
58 Self {
59 start_byte,
60 end_byte,
61 start_line,
62 end_line,
63 body,
64 name: Some(name),
65 }
66 }
67
68 pub fn len(&self) -> usize {
70 self.end_byte - self.start_byte
71 }
72
73 pub fn is_empty(&self) -> bool {
75 self.len() == 0
76 }
77}
78
79pub fn extract_functions(code: &str, lang: Language) -> Result<Vec<FunctionNode>> {
94 extract_functions_with_path(code, lang, None)
95}
96
97fn extract_functions_with_path(
99 code: &str,
100 lang: Language,
101 path: Option<&Path>,
102) -> Result<Vec<FunctionNode>> {
103 let mut parser = Parser::new();
105 parser
106 .set_language(lang)
107 .context("Failed to set language for parser")?;
108
109 let tree = parser
111 .parse(code, None)
112 .ok_or_else(|| anyhow!("Failed to parse source code"))?;
113
114 let query_source = get_query_for_language(lang)?;
116
117 let query = Query::new(lang, query_source).context("Failed to compile Tree-sitter query")?;
119
120 let mut cursor = QueryCursor::new();
122 let matches = cursor.matches(&query, tree.root_node(), code.as_bytes());
123
124 let mut functions = Vec::new();
126
127 for match_ in matches {
128 let mut func_start = None;
129 let mut func_end = None;
130 let mut func_start_line = None;
131 let mut func_end_line = None;
132 let mut func_name = None;
133 let mut func_body = None;
134
135 for capture in match_.captures {
136 let node = capture.node;
137 let capture_name = &query.capture_names()[capture.index as usize];
138
139 match capture_name.as_str() {
140 "func" => {
141 func_start = Some(node.start_byte());
142 func_end = Some(node.end_byte());
143 func_start_line = Some(node.start_position().row + 1);
145 func_end_line = Some(node.end_position().row + 1);
146 }
147 "function.name" => {
148 func_name = Some(
149 node.utf8_text(code.as_bytes())
150 .with_context(|| {
151 if let Some(p) = path {
152 format!(
153 "Invalid UTF-8 in function name at {}:{}",
154 p.display(),
155 node.start_position().row + 1
156 )
157 } else {
158 format!(
159 "Invalid UTF-8 in function name at line {}",
160 node.start_position().row + 1
161 )
162 }
163 })?
164 .to_string(),
165 );
166 }
167 "function.body" => {
168 func_body = Some(
169 node.utf8_text(code.as_bytes())
170 .with_context(|| {
171 if let Some(p) = path {
172 format!(
173 "Invalid UTF-8 in function body at {}:{}",
174 p.display(),
175 node.start_position().row + 1
176 )
177 } else {
178 format!(
179 "Invalid UTF-8 in function body at line {}",
180 node.start_position().row + 1
181 )
182 }
183 })?
184 .to_string(),
185 );
186 }
187 _ => {}
188 }
189 }
190
191 if let (Some(start), Some(end), Some(start_line), Some(end_line)) =
193 (func_start, func_end, func_start_line, func_end_line)
194 {
195 let body = func_body.unwrap_or_else(|| code[start..end].to_string());
196
197 let function = if let Some(name) = func_name {
198 FunctionNode::with_name(start, end, start_line, end_line, body, name)
199 } else {
200 FunctionNode::new(start, end, start_line, end_line, body)
201 };
202
203 functions.push(function);
204 }
205 }
206
207 Ok(functions)
208}
209
210fn get_query_for_language(lang: Language) -> Result<&'static str> {
212 let rust_lang = tree_sitter_rust::language();
216 let python_lang = tree_sitter_python::language();
217 let javascript_lang = tree_sitter_javascript::language();
218
219 if is_same_language(lang, rust_lang) {
220 Ok(&RUST_QUERY)
221 } else if is_same_language(lang, python_lang) {
222 Ok(&PYTHON_QUERY)
223 } else if is_same_language(lang, javascript_lang) {
224 Ok(&JAVASCRIPT_QUERY)
225 } else {
226 Err(anyhow!("Unsupported language"))
227 }
228}
229
230fn is_same_language(lang1: Language, lang2: Language) -> bool {
235 lang1.version() == lang2.version() && lang1.node_kind_count() == lang2.node_kind_count()
238}
239
240pub fn extract_rust_functions(code: &str) -> Result<Vec<FunctionNode>> {
242 extract_functions(code, tree_sitter_rust::language())
243}
244
245pub fn extract_python_functions(code: &str) -> Result<Vec<FunctionNode>> {
247 extract_functions(code, tree_sitter_python::language())
248}
249
250pub fn extract_javascript_functions(code: &str) -> Result<Vec<FunctionNode>> {
252 extract_functions(code, tree_sitter_javascript::language())
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_extract_rust_function() {
261 let code = r#"
262fn hello_world() {
263 println!("Hello, world!");
264}
265
266fn add(a: i32, b: i32) -> i32 {
267 a + b
268}
269"#;
270
271 let functions = extract_rust_functions(code).unwrap();
272 assert_eq!(functions.len(), 2);
273
274 assert!(functions[0].name.as_deref() == Some("hello_world"));
276 assert!(functions[0].body.contains("println!"));
277
278 assert!(functions[1].name.as_deref() == Some("add"));
280 assert!(functions[1].body.contains("a + b"));
281 }
282
283 #[test]
284 fn test_extract_python_function() {
285 let code = r#"
286def greet(name):
287 return f"Hello, {name}!"
288
289def multiply(x, y):
290 return x * y
291"#;
292
293 let functions = extract_python_functions(code).unwrap();
294 assert_eq!(functions.len(), 2);
295
296 assert!(functions[0].name.as_deref() == Some("greet"));
297 assert!(functions[1].name.as_deref() == Some("multiply"));
298 }
299
300 #[test]
301 fn test_extract_javascript_function() {
302 let code = r#"
303function sayHello() {
304 console.log("Hello!");
305}
306
307const add = (a, b) => {
308 return a + b;
309};
310"#;
311
312 let functions = extract_javascript_functions(code).unwrap();
313 assert_eq!(functions.len(), 2);
314
315 assert!(functions[0].name.as_deref() == Some("sayHello"));
316 assert!(functions[0].body.contains("console.log"));
317 }
318
319 #[test]
320 fn test_function_node_length() {
321 let node = FunctionNode::new(10, 50, 1, 5, "test body".to_string());
322 assert_eq!(node.len(), 40);
323 assert!(!node.is_empty());
324 }
325
326 #[test]
327 fn test_empty_code() {
328 let functions = extract_rust_functions("").unwrap();
329 assert_eq!(functions.len(), 0);
330 }
331
332 #[test]
333 fn test_invalid_syntax() {
334 let code = "fn broken {{{";
335 let result = extract_rust_functions(code);
336 assert!(result.is_ok());
338 assert_eq!(result.unwrap().len(), 0);
339 }
340}