dupe_core/
parsing.rs

1//! Tree-sitter parsing logic for extracting functions from source code
2//!
3//! This module provides the core parsing functionality to extract function
4//! definitions from Rust, Python, and JavaScript codebases using Tree-sitter.
5
6use anyhow::Context;
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9use tree_sitter::{Language, Parser, Query, QueryCursor};
10
11use crate::error::{PolyDupError, Result};
12use crate::queries::{JAVASCRIPT_QUERY, PYTHON_QUERY, RUST_QUERY};
13
14/// Represents a parsed function node from source code
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
16pub struct FunctionNode {
17    /// Starting byte offset in the source file
18    pub start_byte: usize,
19    /// Ending byte offset in the source file
20    pub end_byte: usize,
21    /// Starting line number (1-indexed)
22    pub start_line: usize,
23    /// Ending line number (1-indexed)
24    pub end_line: usize,
25    /// The function body as a string
26    pub body: String,
27    /// Optional function name (if captured by query)
28    pub name: Option<String>,
29}
30
31impl FunctionNode {
32    /// Creates a new FunctionNode
33    pub fn new(
34        start_byte: usize,
35        end_byte: usize,
36        start_line: usize,
37        end_line: usize,
38        body: String,
39    ) -> Self {
40        Self {
41            start_byte,
42            end_byte,
43            start_line,
44            end_line,
45            body,
46            name: None,
47        }
48    }
49
50    /// Creates a new FunctionNode with a name
51    pub fn with_name(
52        start_byte: usize,
53        end_byte: usize,
54        start_line: usize,
55        end_line: usize,
56        body: String,
57        name: String,
58    ) -> Self {
59        Self {
60            start_byte,
61            end_byte,
62            start_line,
63            end_line,
64            body,
65            name: Some(name),
66        }
67    }
68
69    /// Returns the length of the function in bytes
70    pub fn len(&self) -> usize {
71        self.end_byte - self.start_byte
72    }
73
74    /// Returns true if the function is empty
75    pub fn is_empty(&self) -> bool {
76        self.len() == 0
77    }
78}
79
80/// Extracts all function definitions from the given source code
81///
82/// # Arguments
83/// * `code` - The source code to parse
84/// * `lang` - The Tree-sitter Language grammar to use
85///
86/// # Returns
87/// * `Result<Vec<FunctionNode>>` - A vector of extracted function nodes
88///
89/// # Errors
90/// Returns an error if:
91/// - The parser fails to parse the code
92/// - The query compilation fails
93/// - Invalid UTF-8 is encountered
94pub fn extract_functions(code: &str, lang: Language) -> Result<Vec<FunctionNode>> {
95    extract_functions_with_path(code, lang, None)
96}
97
98/// Internal function that accepts optional path for better error messages
99fn extract_functions_with_path(
100    code: &str,
101    lang: Language,
102    path: Option<&Path>,
103) -> Result<Vec<FunctionNode>> {
104    // Create a new parser
105    let mut parser = Parser::new();
106    parser
107        .set_language(lang)
108        .context("Failed to set language for parser")?;
109
110    // Parse the source code
111    let tree = parser
112        .parse(code, None)
113        .ok_or_else(|| PolyDupError::Parsing("Failed to parse source code".to_string()))?;
114
115    // Get the appropriate query for the language
116    let query_source = get_query_for_language(lang)?;
117
118    // Compile the query
119    let query = Query::new(lang, query_source).map_err(|e| PolyDupError::Parsing(e.to_string()))?;
120
121    // Execute the query
122    let mut cursor = QueryCursor::new();
123    let matches = cursor.matches(&query, tree.root_node(), code.as_bytes());
124
125    // Extract function nodes from query matches
126    let mut functions = Vec::new();
127
128    for match_ in matches {
129        let mut func_start = None;
130        let mut func_end = None;
131        let mut func_start_line = None;
132        let mut func_end_line = None;
133        let mut func_name = None;
134        let mut func_body = None;
135
136        for capture in match_.captures {
137            let node = capture.node;
138            let capture_name = &query.capture_names()[capture.index as usize];
139
140            match capture_name.as_str() {
141                "func" => {
142                    func_start = Some(node.start_byte());
143                    func_end = Some(node.end_byte());
144                    // Tree-sitter rows are 0-indexed, convert to 1-indexed for humans
145                    func_start_line = Some(node.start_position().row + 1);
146                    func_end_line = Some(node.end_position().row + 1);
147                }
148                "function.name" => {
149                    func_name = Some(
150                        node.utf8_text(code.as_bytes())
151                            .with_context(|| {
152                                if let Some(p) = path {
153                                    format!(
154                                        "Invalid UTF-8 in function name at {}:{}",
155                                        p.display(),
156                                        node.start_position().row + 1
157                                    )
158                                } else {
159                                    format!(
160                                        "Invalid UTF-8 in function name at line {}",
161                                        node.start_position().row + 1
162                                    )
163                                }
164                            })?
165                            .to_string(),
166                    );
167                }
168                "function.body" => {
169                    func_body = Some(
170                        node.utf8_text(code.as_bytes())
171                            .with_context(|| {
172                                if let Some(p) = path {
173                                    format!(
174                                        "Invalid UTF-8 in function body at {}:{}",
175                                        p.display(),
176                                        node.start_position().row + 1
177                                    )
178                                } else {
179                                    format!(
180                                        "Invalid UTF-8 in function body at line {}",
181                                        node.start_position().row + 1
182                                    )
183                                }
184                            })?
185                            .to_string(),
186                    );
187                }
188                _ => {}
189            }
190        }
191
192        // Create FunctionNode if we have the required information
193        if let (Some(start), Some(end), Some(start_line), Some(end_line)) =
194            (func_start, func_end, func_start_line, func_end_line)
195        {
196            let body = func_body.unwrap_or_else(|| code[start..end].to_string());
197
198            let function = if let Some(name) = func_name {
199                FunctionNode::with_name(start, end, start_line, end_line, body, name)
200            } else {
201                FunctionNode::new(start, end, start_line, end_line, body)
202            };
203
204            functions.push(function);
205        }
206    }
207
208    Ok(functions)
209}
210
211/// Returns the appropriate query string for a given Tree-sitter Language
212fn get_query_for_language(lang: Language) -> Result<&'static str> {
213    // Compare language pointers to identify which language we're dealing with
214    // This is necessary because Language doesn't implement PartialEq
215
216    let rust_lang = tree_sitter_rust::language();
217    let python_lang = tree_sitter_python::language();
218    let javascript_lang = tree_sitter_javascript::language();
219
220    if is_same_language(lang, rust_lang) {
221        Ok(&RUST_QUERY)
222    } else if is_same_language(lang, python_lang) {
223        Ok(&PYTHON_QUERY)
224    } else if is_same_language(lang, javascript_lang) {
225        Ok(&JAVASCRIPT_QUERY)
226    } else {
227        Err(PolyDupError::Parsing("Unsupported language".to_string()))
228    }
229}
230
231/// Compares two Tree-sitter Language instances
232///
233/// Since Language doesn't implement PartialEq, we compare their internal
234/// pointers as a proxy for equality.
235fn is_same_language(lang1: Language, lang2: Language) -> bool {
236    // Languages are considered equal if they have the same version and node kind count
237    // This is a heuristic but works well in practice
238    lang1.version() == lang2.version() && lang1.node_kind_count() == lang2.node_kind_count()
239}
240
241/// Convenience function to extract functions from Rust code
242pub fn extract_rust_functions(code: &str) -> Result<Vec<FunctionNode>> {
243    extract_functions(code, tree_sitter_rust::language())
244}
245
246/// Convenience function to extract functions from Python code
247pub fn extract_python_functions(code: &str) -> Result<Vec<FunctionNode>> {
248    extract_functions(code, tree_sitter_python::language())
249}
250
251/// Convenience function to extract functions from JavaScript code
252pub fn extract_javascript_functions(code: &str) -> Result<Vec<FunctionNode>> {
253    extract_functions(code, tree_sitter_javascript::language())
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_extract_rust_function() {
262        let code = r#"
263fn hello_world() {
264    println!("Hello, world!");
265}
266
267fn add(a: i32, b: i32) -> i32 {
268    a + b
269}
270"#;
271
272        let functions = extract_rust_functions(code).unwrap();
273        assert_eq!(functions.len(), 2);
274
275        // Check first function
276        assert!(functions[0].name.as_deref() == Some("hello_world"));
277        assert!(functions[0].body.contains("println!"));
278
279        // Check second function
280        assert!(functions[1].name.as_deref() == Some("add"));
281        assert!(functions[1].body.contains("a + b"));
282    }
283
284    #[test]
285    fn test_extract_python_function() {
286        let code = r#"
287def greet(name):
288    return f"Hello, {name}!"
289
290def multiply(x, y):
291    return x * y
292"#;
293
294        let functions = extract_python_functions(code).unwrap();
295        assert_eq!(functions.len(), 2);
296
297        assert!(functions[0].name.as_deref() == Some("greet"));
298        assert!(functions[1].name.as_deref() == Some("multiply"));
299    }
300
301    #[test]
302    fn test_extract_javascript_function() {
303        let code = r#"
304function sayHello() {
305    console.log("Hello!");
306}
307
308const add = (a, b) => {
309    return a + b;
310};
311"#;
312
313        let functions = extract_javascript_functions(code).unwrap();
314        assert_eq!(functions.len(), 2);
315
316        assert!(functions[0].name.as_deref() == Some("sayHello"));
317        assert!(functions[0].body.contains("console.log"));
318    }
319
320    #[test]
321    fn test_function_node_length() {
322        let node = FunctionNode::new(10, 50, 1, 5, "test body".to_string());
323        assert_eq!(node.len(), 40);
324        assert!(!node.is_empty());
325    }
326
327    #[test]
328    fn test_empty_code() {
329        let functions = extract_rust_functions("").unwrap();
330        assert_eq!(functions.len(), 0);
331    }
332
333    #[test]
334    fn test_invalid_syntax() {
335        let code = "fn broken {{{";
336        let result = extract_rust_functions(code);
337        // Should parse but find no complete functions
338        assert!(result.is_ok());
339        assert_eq!(result.unwrap().len(), 0);
340    }
341}