1use anyhow::Context;
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9use tree_sitter::{Language, Parser, Query, QueryCursor};
10
11use crate::error::{PolyDupError, Result};
12use crate::queries::{JAVASCRIPT_QUERY, PYTHON_QUERY, RUST_QUERY};
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
16pub struct FunctionNode {
17 pub start_byte: usize,
19 pub end_byte: usize,
21 pub start_line: usize,
23 pub end_line: usize,
25 pub body: String,
27 pub name: Option<String>,
29}
30
31impl FunctionNode {
32 pub fn new(
34 start_byte: usize,
35 end_byte: usize,
36 start_line: usize,
37 end_line: usize,
38 body: String,
39 ) -> Self {
40 Self {
41 start_byte,
42 end_byte,
43 start_line,
44 end_line,
45 body,
46 name: None,
47 }
48 }
49
50 pub fn with_name(
52 start_byte: usize,
53 end_byte: usize,
54 start_line: usize,
55 end_line: usize,
56 body: String,
57 name: String,
58 ) -> Self {
59 Self {
60 start_byte,
61 end_byte,
62 start_line,
63 end_line,
64 body,
65 name: Some(name),
66 }
67 }
68
69 pub fn len(&self) -> usize {
71 self.end_byte - self.start_byte
72 }
73
74 pub fn is_empty(&self) -> bool {
76 self.len() == 0
77 }
78}
79
80pub fn extract_functions(code: &str, lang: Language) -> Result<Vec<FunctionNode>> {
95 extract_functions_with_path(code, lang, None)
96}
97
98fn extract_functions_with_path(
100 code: &str,
101 lang: Language,
102 path: Option<&Path>,
103) -> Result<Vec<FunctionNode>> {
104 let mut parser = Parser::new();
106 parser
107 .set_language(lang)
108 .context("Failed to set language for parser")?;
109
110 let tree = parser
112 .parse(code, None)
113 .ok_or_else(|| PolyDupError::Parsing("Failed to parse source code".to_string()))?;
114
115 let query_source = get_query_for_language(lang)?;
117
118 let query = Query::new(lang, query_source).map_err(|e| PolyDupError::Parsing(e.to_string()))?;
120
121 let mut cursor = QueryCursor::new();
123 let matches = cursor.matches(&query, tree.root_node(), code.as_bytes());
124
125 let mut functions = Vec::new();
127
128 for match_ in matches {
129 let mut func_start = None;
130 let mut func_end = None;
131 let mut func_start_line = None;
132 let mut func_end_line = None;
133 let mut func_name = None;
134 let mut func_body = None;
135
136 for capture in match_.captures {
137 let node = capture.node;
138 let capture_name = &query.capture_names()[capture.index as usize];
139
140 match capture_name.as_str() {
141 "func" => {
142 func_start = Some(node.start_byte());
143 func_end = Some(node.end_byte());
144 func_start_line = Some(node.start_position().row + 1);
146 func_end_line = Some(node.end_position().row + 1);
147 }
148 "function.name" => {
149 func_name = Some(
150 node.utf8_text(code.as_bytes())
151 .with_context(|| {
152 if let Some(p) = path {
153 format!(
154 "Invalid UTF-8 in function name at {}:{}",
155 p.display(),
156 node.start_position().row + 1
157 )
158 } else {
159 format!(
160 "Invalid UTF-8 in function name at line {}",
161 node.start_position().row + 1
162 )
163 }
164 })?
165 .to_string(),
166 );
167 }
168 "function.body" => {
169 func_body = Some(
170 node.utf8_text(code.as_bytes())
171 .with_context(|| {
172 if let Some(p) = path {
173 format!(
174 "Invalid UTF-8 in function body at {}:{}",
175 p.display(),
176 node.start_position().row + 1
177 )
178 } else {
179 format!(
180 "Invalid UTF-8 in function body at line {}",
181 node.start_position().row + 1
182 )
183 }
184 })?
185 .to_string(),
186 );
187 }
188 _ => {}
189 }
190 }
191
192 if let (Some(start), Some(end), Some(start_line), Some(end_line)) =
194 (func_start, func_end, func_start_line, func_end_line)
195 {
196 let body = func_body.unwrap_or_else(|| code[start..end].to_string());
197
198 let function = if let Some(name) = func_name {
199 FunctionNode::with_name(start, end, start_line, end_line, body, name)
200 } else {
201 FunctionNode::new(start, end, start_line, end_line, body)
202 };
203
204 functions.push(function);
205 }
206 }
207
208 Ok(functions)
209}
210
211fn get_query_for_language(lang: Language) -> Result<&'static str> {
213 let rust_lang = tree_sitter_rust::language();
217 let python_lang = tree_sitter_python::language();
218 let javascript_lang = tree_sitter_javascript::language();
219
220 if is_same_language(lang, rust_lang) {
221 Ok(&RUST_QUERY)
222 } else if is_same_language(lang, python_lang) {
223 Ok(&PYTHON_QUERY)
224 } else if is_same_language(lang, javascript_lang) {
225 Ok(&JAVASCRIPT_QUERY)
226 } else {
227 Err(PolyDupError::Parsing("Unsupported language".to_string()))
228 }
229}
230
231fn is_same_language(lang1: Language, lang2: Language) -> bool {
236 lang1.version() == lang2.version() && lang1.node_kind_count() == lang2.node_kind_count()
239}
240
241pub fn extract_rust_functions(code: &str) -> Result<Vec<FunctionNode>> {
243 extract_functions(code, tree_sitter_rust::language())
244}
245
246pub fn extract_python_functions(code: &str) -> Result<Vec<FunctionNode>> {
248 extract_functions(code, tree_sitter_python::language())
249}
250
251pub fn extract_javascript_functions(code: &str) -> Result<Vec<FunctionNode>> {
253 extract_functions(code, tree_sitter_javascript::language())
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_extract_rust_function() {
262 let code = r#"
263fn hello_world() {
264 println!("Hello, world!");
265}
266
267fn add(a: i32, b: i32) -> i32 {
268 a + b
269}
270"#;
271
272 let functions = extract_rust_functions(code).unwrap();
273 assert_eq!(functions.len(), 2);
274
275 assert!(functions[0].name.as_deref() == Some("hello_world"));
277 assert!(functions[0].body.contains("println!"));
278
279 assert!(functions[1].name.as_deref() == Some("add"));
281 assert!(functions[1].body.contains("a + b"));
282 }
283
284 #[test]
285 fn test_extract_python_function() {
286 let code = r#"
287def greet(name):
288 return f"Hello, {name}!"
289
290def multiply(x, y):
291 return x * y
292"#;
293
294 let functions = extract_python_functions(code).unwrap();
295 assert_eq!(functions.len(), 2);
296
297 assert!(functions[0].name.as_deref() == Some("greet"));
298 assert!(functions[1].name.as_deref() == Some("multiply"));
299 }
300
301 #[test]
302 fn test_extract_javascript_function() {
303 let code = r#"
304function sayHello() {
305 console.log("Hello!");
306}
307
308const add = (a, b) => {
309 return a + b;
310};
311"#;
312
313 let functions = extract_javascript_functions(code).unwrap();
314 assert_eq!(functions.len(), 2);
315
316 assert!(functions[0].name.as_deref() == Some("sayHello"));
317 assert!(functions[0].body.contains("console.log"));
318 }
319
320 #[test]
321 fn test_function_node_length() {
322 let node = FunctionNode::new(10, 50, 1, 5, "test body".to_string());
323 assert_eq!(node.len(), 40);
324 assert!(!node.is_empty());
325 }
326
327 #[test]
328 fn test_empty_code() {
329 let functions = extract_rust_functions("").unwrap();
330 assert_eq!(functions.len(), 0);
331 }
332
333 #[test]
334 fn test_invalid_syntax() {
335 let code = "fn broken {{{";
336 let result = extract_rust_functions(code);
337 assert!(result.is_ok());
339 assert_eq!(result.unwrap().len(), 0);
340 }
341}