codebank/parser/lang/
python.rs

1use crate::{
2    Error, FileUnit, FunctionUnit, LanguageParser, ModuleUnit, PythonParser, Result, StructUnit,
3    Visibility,
4};
5use std::fs;
6use std::ops::{Deref, DerefMut};
7use std::path::Path;
8use tree_sitter::{Node, Parser};
9
10// Helper function to get the text of a node
11fn get_node_text(node: Node, source_code: &str) -> Option<String> {
12    node.utf8_text(source_code.as_bytes())
13        .ok()
14        .map(String::from)
15}
16
17// Helper function to get the text of the first child node of a specific kind
18fn get_child_node_text<'a>(node: Node<'a>, kind: &str, source_code: &'a str) -> Option<String> {
19    node.children(&mut node.walk())
20        .find(|child| child.kind() == kind)
21        .and_then(|child| child.utf8_text(source_code.as_bytes()).ok())
22        .map(String::from)
23}
24
25impl PythonParser {
26    pub fn try_new() -> Result<Self> {
27        let mut parser = Parser::new();
28        let language = tree_sitter_python::LANGUAGE;
29        parser
30            .set_language(&language.into())
31            .map_err(|e| Error::TreeSitter(e.to_string()))?;
32        Ok(Self { parser })
33    }
34
35    // Extract docstring from a node
36    fn extract_documentation(&self, node: Node, source_code: &str) -> Option<String> {
37        let mut cursor = node.walk();
38        let mut children = node.children(&mut cursor);
39
40        // For function/class nodes, we need to skip the definition line
41        if node.kind() == "function_definition" || node.kind() == "class_definition" {
42            children.next(); // Skip the function/class definition line
43        }
44
45        // Look for the docstring
46        for child in children {
47            match child.kind() {
48                "block" => {
49                    // For function/class bodies, look in the block
50                    let mut body_cursor = child.walk();
51                    let mut body_children = child.children(&mut body_cursor);
52                    if let Some(first_expr) = body_children.next() {
53                        if first_expr.kind() == "expression_statement" {
54                            if let Some(string) = first_expr
55                                .children(&mut first_expr.walk())
56                                .find(|c| c.kind() == "string")
57                            {
58                                return self.clean_docstring(string, source_code);
59                            }
60                        }
61                    }
62                }
63                "expression_statement" => {
64                    // For module level docstrings
65                    if let Some(string) = child
66                        .children(&mut child.walk())
67                        .find(|c| c.kind() == "string")
68                    {
69                        return self.clean_docstring(string, source_code);
70                    }
71                }
72                "ERROR" => {
73                    // For ERROR nodes, try to get the string content directly
74                    let mut error_cursor = child.walk();
75                    let error_children = child.children(&mut error_cursor);
76                    for error_child in error_children {
77                        if error_child.kind() == "string" {
78                            if let Some(string_content) = error_child
79                                .children(&mut error_child.walk())
80                                .find(|c| c.kind() == "string_content")
81                            {
82                                if let Some(content) = get_node_text(string_content, source_code) {
83                                    return Some(content.trim().to_string());
84                                }
85                            }
86                        }
87                    }
88                }
89                _ => continue,
90            }
91        }
92        None
93    }
94
95    // Helper to clean up docstring content
96    fn clean_docstring(&self, node: Node, source_code: &str) -> Option<String> {
97        let doc = get_node_text(node, source_code)?;
98        // Clean up the docstring - handle both single and triple quotes
99        let doc = if doc.starts_with("\"\"\"") && doc.ends_with("\"\"\"") {
100            // Handle triple quotes
101            doc[3..doc.len() - 3].trim()
102        } else if doc.starts_with("'''") && doc.ends_with("'''") {
103            // Handle triple single quotes
104            doc[3..doc.len() - 3].trim()
105        } else {
106            // Handle single quotes
107            doc.trim_matches('"').trim_matches('\'').trim()
108        };
109        Some(doc.to_string())
110    }
111
112    // Extract decorators from a node
113    fn extract_decorators(&self, node: Node, source_code: &str) -> Vec<String> {
114        let mut decorators = Vec::new();
115        let mut cursor = node.walk();
116
117        // Look for decorators before the function/class definition
118        for child in node.children(&mut cursor) {
119            if child.kind() == "decorator" {
120                if let Some(text) = get_node_text(child, source_code) {
121                    decorators.push(text);
122                }
123            }
124        }
125        decorators
126    }
127
128    // Parse function and extract its details
129    fn parse_function(&self, node: Node, source_code: &str) -> Result<FunctionUnit> {
130        // If this is a decorated function, get the actual function definition
131        let function_node = if node.kind() == "decorated_definition" {
132            node.children(&mut node.walk())
133                .find(|child| child.kind() == "function_definition")
134                .unwrap_or(node)
135        } else {
136            node
137        };
138
139        let name = get_child_node_text(function_node, "identifier", source_code)
140            .unwrap_or_else(|| "unknown".to_string());
141        let documentation = self.extract_documentation(function_node, source_code);
142        let attributes = self.extract_decorators(node, source_code);
143        let source = get_node_text(function_node, source_code);
144        let visibility = if name.starts_with('_') {
145            Visibility::Private
146        } else {
147            Visibility::Public
148        };
149
150        let mut signature = None;
151        let mut body = None;
152
153        if let Some(src) = &source {
154            if let Some(body_start_idx) = src.find(':') {
155                signature = Some(src[0..body_start_idx].trim().to_string());
156                body = Some(src[body_start_idx + 1..].trim().to_string());
157            }
158        }
159
160        Ok(FunctionUnit {
161            name,
162            visibility,
163            documentation,
164            source,
165            signature,
166            body,
167            attributes,
168        })
169    }
170
171    // Parse class and extract its details
172    fn parse_class(&self, node: Node, source_code: &str) -> Result<StructUnit> {
173        // If this is a decorated class, get the actual class definition
174        let class_node = if node.kind() == "decorated_definition" {
175            node.children(&mut node.walk())
176                .find(|child| child.kind() == "class_definition")
177                .unwrap_or(node)
178        } else {
179            node
180        };
181
182        let name = get_child_node_text(class_node, "identifier", source_code)
183            .unwrap_or_else(|| "unknown".to_string());
184        let documentation = self.extract_documentation(class_node, source_code);
185        let attributes = self.extract_decorators(node, source_code);
186        let source = get_node_text(class_node, source_code);
187        let visibility = if name.starts_with('_') {
188            Visibility::Private
189        } else {
190            Visibility::Public
191        };
192
193        // TODO: parse class head
194        let head = format!("class {}", name);
195
196        // Extract methods from class body
197        let mut methods = Vec::new();
198        let mut cursor = class_node.walk();
199        for child in class_node.children(&mut cursor) {
200            if child.kind() == "block" {
201                let mut block_cursor = child.walk();
202                for method_node in child.children(&mut block_cursor) {
203                    match method_node.kind() {
204                        "function_definition" | "decorated_definition" => {
205                            if let Ok(method) = self.parse_function(method_node, source_code) {
206                                methods.push(method);
207                            }
208                        }
209                        _ => continue,
210                    }
211                }
212            }
213        }
214
215        Ok(StructUnit {
216            name,
217            head,
218            visibility,
219            documentation,
220            source,
221            attributes,
222            methods,
223        })
224    }
225
226    #[allow(dead_code)]
227    // Parse module and extract its details
228    fn parse_module(&self, node: Node, source_code: &str) -> Result<ModuleUnit> {
229        let name = get_child_node_text(node, "identifier", source_code)
230            .unwrap_or_else(|| "unknown".to_string());
231        let document = self.extract_documentation(node, source_code);
232        let source = get_node_text(node, source_code);
233        let visibility = if name.starts_with('_') {
234            Visibility::Private
235        } else {
236            Visibility::Public
237        };
238
239        Ok(ModuleUnit {
240            name,
241            visibility,
242            document,
243            source,
244            attributes: Vec::new(),
245            declares: Vec::new(),
246            functions: Vec::new(),
247            structs: Vec::new(),
248            traits: Vec::new(),
249            impls: Vec::new(),
250            submodules: Vec::new(),
251        })
252    }
253}
254
255impl LanguageParser for PythonParser {
256    fn parse_file(&mut self, file_path: &Path) -> Result<FileUnit> {
257        let source_code = fs::read_to_string(file_path).map_err(Error::Io)?;
258        let tree = self
259            .parse(source_code.as_bytes(), None)
260            .ok_or_else(|| Error::TreeSitter("Failed to parse Python file".to_string()))?;
261
262        let mut file_unit = FileUnit {
263            path: file_path.to_path_buf(),
264            source: Some(source_code.clone()),
265            document: None,
266            declares: Vec::new(),
267            modules: Vec::new(),
268            functions: Vec::new(),
269            structs: Vec::new(),
270            traits: Vec::new(),
271            impls: Vec::new(),
272        };
273
274        let root_node = tree.root_node();
275
276        // First look for module docstring
277        {
278            let mut cursor = root_node.walk();
279            let mut children = root_node.children(&mut cursor);
280
281            if let Some(first_expr) = children.next() {
282                if first_expr.kind() == "expression_statement" {
283                    if let Some(string) = first_expr
284                        .children(&mut first_expr.walk())
285                        .find(|c| c.kind() == "string")
286                    {
287                        if let Some(doc) = get_node_text(string, &source_code) {
288                            // Clean up the docstring - handle both single and triple quotes
289                            let doc = doc
290                                .trim_start_matches(r#"""""#)
291                                .trim_end_matches(r#"""""#)
292                                .trim_start_matches(r#"'''"#)
293                                .trim_end_matches(r#"'''"#)
294                                .trim_start_matches('"')
295                                .trim_end_matches('"')
296                                .trim_start_matches('\'')
297                                .trim_end_matches('\'')
298                                .trim();
299                            file_unit.document = Some(doc.to_string());
300                        }
301                    }
302                }
303            }
304        }
305
306        // Process imports first
307        {
308            let mut cursor = root_node.walk();
309            for node in root_node.children(&mut cursor) {
310                if node.kind() == "import_statement" || node.kind() == "import_from_statement" {
311                    if let Some(import_text) = get_node_text(node, &source_code) {
312                        file_unit.declares.push(crate::DeclareStatements {
313                            source: import_text,
314                            kind: crate::DeclareKind::Import,
315                        });
316                    }
317                }
318            }
319        }
320
321        // Then process all top-level nodes
322        let mut cursor = root_node.walk();
323        for node in root_node.children(&mut cursor) {
324            match node.kind() {
325                "function_definition" => {
326                    let func = self.parse_function(node, &source_code)?;
327                    file_unit.functions.push(func);
328                }
329                "class_definition" => {
330                    let class = self.parse_class(node, &source_code)?;
331                    file_unit.structs.push(class);
332                }
333                "decorated_definition" => {
334                    let mut node_cursor = node.walk();
335                    let children: Vec<_> = node.children(&mut node_cursor).collect();
336                    if let Some(def_node) = children.iter().find(|n| {
337                        n.kind() == "function_definition" || n.kind() == "class_definition"
338                    }) {
339                        match def_node.kind() {
340                            "function_definition" => {
341                                let func = self.parse_function(node, &source_code)?;
342                                file_unit.functions.push(func);
343                            }
344                            "class_definition" => {
345                                let class = self.parse_class(node, &source_code)?;
346                                file_unit.structs.push(class);
347                            }
348                            _ => {}
349                        }
350                    }
351                }
352                _ => continue,
353            }
354        }
355
356        Ok(file_unit)
357    }
358}
359
360impl Deref for PythonParser {
361    type Target = Parser;
362
363    fn deref(&self) -> &Self::Target {
364        &self.parser
365    }
366}
367
368impl DerefMut for PythonParser {
369    fn deref_mut(&mut self) -> &mut Self::Target {
370        &mut self.parser
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use std::path::PathBuf;
378
379    fn create_test_file(content: &str) -> Result<(tempfile::TempDir, PathBuf)> {
380        let dir = tempfile::tempdir().map_err(Error::Io)?;
381        let file_path = dir.path().join("test.py");
382        fs::write(&file_path, content).map_err(Error::Io)?;
383        Ok((dir, file_path))
384    }
385
386    #[test]
387    fn test_parse_function() -> Result<()> {
388        let content = r#"
389def hello_world():
390    """This is a docstring."""
391    print("Hello, World!")
392"#;
393        let (_dir, file_path) = create_test_file(content)?;
394        let mut parser = PythonParser::try_new()?;
395        let file_unit = parser.parse_file(&file_path)?;
396
397        assert_eq!(file_unit.functions.len(), 1);
398        let func = &file_unit.functions[0];
399        assert_eq!(func.name, "hello_world");
400        assert_eq!(func.visibility, Visibility::Public);
401        assert_eq!(func.documentation, Some("This is a docstring.".to_string()));
402        Ok(())
403    }
404
405    #[test]
406    fn test_parse_class() -> Result<()> {
407        let content = r#"
408@dataclass
409class Person:
410    """A person class."""
411    def __init__(self, name: str):
412        self.name = name
413"#;
414        let (_dir, file_path) = create_test_file(content)?;
415        let mut parser = PythonParser::try_new()?;
416        let file_unit = parser.parse_file(&file_path)?;
417
418        assert_eq!(file_unit.structs.len(), 1);
419        let class = &file_unit.structs[0];
420        assert_eq!(class.name, "Person");
421        assert_eq!(class.visibility, Visibility::Public);
422        assert_eq!(class.documentation, Some("A person class.".to_string()));
423        assert_eq!(class.attributes.len(), 1);
424        assert_eq!(class.attributes[0], "@dataclass");
425        Ok(())
426    }
427
428    #[test]
429    fn test_parse_private_members() -> Result<()> {
430        let content = r#"
431def _private_function():
432    """A private function."""
433    pass
434
435class _PrivateClass:
436    """A private class."""
437    pass
438"#;
439        let (_dir, file_path) = create_test_file(content)?;
440        let mut parser = PythonParser::try_new()?;
441        let file_unit = parser.parse_file(&file_path)?;
442
443        assert_eq!(file_unit.functions[0].visibility, Visibility::Private);
444        assert_eq!(file_unit.structs[0].visibility, Visibility::Private);
445        Ok(())
446    }
447
448    #[test]
449    fn test_parse_module_docstring() -> Result<()> {
450        let content = r#"'''This is a module docstring.'''
451
452def hello_world():
453    pass
454"#;
455        let (_dir, file_path) = create_test_file(content)?;
456        let mut parser = PythonParser::try_new()?;
457        let file_unit = parser.parse_file(&file_path)?;
458
459        assert_eq!(
460            file_unit.document,
461            Some("This is a module docstring.".to_string())
462        );
463        Ok(())
464    }
465
466    #[test]
467    fn test_parse_module_docstring_with_triple_quotes() -> Result<()> {
468        let content = r#"'''This is a module docstring with triple quotes.'''
469
470def hello_world():
471    pass
472"#;
473        let (_dir, file_path) = create_test_file(content)?;
474        let mut parser = PythonParser::try_new()?;
475        let file_unit = parser.parse_file(&file_path)?;
476
477        assert_eq!(
478            file_unit.document,
479            Some("This is a module docstring with triple quotes.".to_string())
480        );
481        Ok(())
482    }
483}