neural_shared/parser/
python.rs

1//! Python parser using tree-sitter
2
3use super::{Location, ParsedFile, Parser, Symbol, SymbolKind};
4use crate::Result;
5use std::path::Path;
6use tree_sitter::{Node, Parser as TSParser, Tree};
7
8pub struct PythonParser;
9
10impl PythonParser {
11    pub fn new() -> Result<Self> {
12        Ok(Self)
13    }
14
15    fn extract_definitions(&self, tree: &Tree, source: &str, file_path: &str) -> Vec<Symbol> {
16        let mut definitions = Vec::new();
17        let root = tree.root_node();
18
19        self.traverse_for_definitions(root, source, file_path, &mut definitions, None);
20
21        definitions
22    }
23
24    fn traverse_for_definitions(
25        &self,
26        node: Node,
27        source: &str,
28        file_path: &str,
29        definitions: &mut Vec<Symbol>,
30        current_class: Option<String>,
31    ) {
32        let kind = node.kind();
33
34        match kind {
35            "function_definition" => {
36                // Extract function name
37                if let Some(name_node) = node.child_by_field_name("name") {
38                    let name = name_node
39                        .utf8_text(source.as_bytes())
40                        .unwrap_or("")
41                        .to_string();
42                    let pos = name_node.start_position();
43
44                    let symbol_kind = if let Some(ref class_name) = current_class {
45                        SymbolKind::Method {
46                            class_name: class_name.clone(),
47                        }
48                    } else {
49                        SymbolKind::Function
50                    };
51
52                    definitions.push(Symbol::new(
53                        name,
54                        symbol_kind,
55                        Location {
56                            file: file_path.to_string(),
57                            line: pos.row + 1,
58                            column: pos.column,
59                        },
60                    ));
61                }
62            }
63            "class_definition" => {
64                // Extract class name
65                if let Some(name_node) = node.child_by_field_name("name") {
66                    let name = name_node
67                        .utf8_text(source.as_bytes())
68                        .unwrap_or("")
69                        .to_string();
70                    let pos = name_node.start_position();
71
72                    definitions.push(Symbol::new(
73                        name.clone(),
74                        SymbolKind::Class,
75                        Location {
76                            file: file_path.to_string(),
77                            line: pos.row + 1,
78                            column: pos.column,
79                        },
80                    ));
81
82                    // Traverse class body with class context
83                    let mut cursor = node.walk();
84                    for child in node.children(&mut cursor) {
85                        self.traverse_for_definitions(
86                            child,
87                            source,
88                            file_path,
89                            definitions,
90                            Some(name.clone()),
91                        );
92                    }
93                    return; // Don't traverse children again below
94                }
95            }
96            _ => {}
97        }
98
99        // Traverse children
100        let mut cursor = node.walk();
101        for child in node.children(&mut cursor) {
102            self.traverse_for_definitions(
103                child,
104                source,
105                file_path,
106                definitions,
107                current_class.clone(),
108            );
109        }
110    }
111
112    fn extract_usages(&self, tree: &Tree, source: &str, file_path: &str) -> Vec<Symbol> {
113        let mut usages = Vec::new();
114        let root = tree.root_node();
115
116        self.traverse_for_usages(root, source, file_path, &mut usages);
117
118        usages
119    }
120
121    fn traverse_for_usages(
122        &self,
123        node: Node,
124        source: &str,
125        file_path: &str,
126        usages: &mut Vec<Symbol>,
127    ) {
128        let kind = node.kind();
129
130        match kind {
131            "call" => {
132                // Extract function name being called
133                if let Some(func_node) = node.child_by_field_name("function") {
134                    let name = self.extract_call_name(func_node, source);
135                    if !name.is_empty() {
136                        let pos = func_node.start_position();
137                        usages.push(Symbol::new(
138                            name,
139                            SymbolKind::Function, // We don't know if it's a function or method yet
140                            Location {
141                                file: file_path.to_string(),
142                                line: pos.row + 1,
143                                column: pos.column,
144                            },
145                        ));
146                    }
147                }
148            }
149            "identifier" => {
150                // Track variable usages (for future enhancement)
151                // For now, we focus on function calls
152            }
153            _ => {}
154        }
155
156        // Traverse children
157        let mut cursor = node.walk();
158        for child in node.children(&mut cursor) {
159            self.traverse_for_usages(child, source, file_path, usages);
160        }
161    }
162
163    fn extract_call_name(&self, node: Node, source: &str) -> String {
164        match node.kind() {
165            "identifier" => node.utf8_text(source.as_bytes()).unwrap_or("").to_string(),
166            "attribute" => {
167                // For obj.method() calls, extract the method name
168                if let Some(attr_node) = node.child_by_field_name("attribute") {
169                    attr_node
170                        .utf8_text(source.as_bytes())
171                        .unwrap_or("")
172                        .to_string()
173                } else {
174                    String::new()
175                }
176            }
177            _ => String::new(),
178        }
179    }
180
181    fn extract_entry_points(&self, tree: &Tree, source: &str) -> Vec<String> {
182        let mut entry_points = Vec::new();
183        let root = tree.root_node();
184
185        self.traverse_for_entry_points(root, source, &mut entry_points);
186
187        entry_points
188    }
189
190    fn traverse_for_entry_points(&self, node: Node, source: &str, entry_points: &mut Vec<String>) {
191        let kind = node.kind();
192
193        // Detect if __name__ == "__main__" pattern
194        if kind == "if_statement" {
195            if let Some(condition) = node.child_by_field_name("condition") {
196                let condition_text = condition.utf8_text(source.as_bytes()).unwrap_or("");
197                // Look for __name__ == "__main__" or "__main__" == __name__
198                if condition_text.contains("__name__") && condition_text.contains("\"__main__\"") {
199                    // Extract calls in the if block
200                    if let Some(consequence) = node.child_by_field_name("consequence") {
201                        self.extract_calls_from_block(consequence, source, entry_points);
202                    }
203                }
204            }
205        }
206
207        // Also detect functions that start with "test_" as entry points (pytest convention)
208        if kind == "function_definition" {
209            if let Some(name_node) = node.child_by_field_name("name") {
210                let name = name_node.utf8_text(source.as_bytes()).unwrap_or("");
211                if name.starts_with("test_") {
212                    entry_points.push(name.to_string());
213                }
214            }
215        }
216
217        // Traverse children
218        let mut cursor = node.walk();
219        for child in node.children(&mut cursor) {
220            self.traverse_for_entry_points(child, source, entry_points);
221        }
222    }
223
224    fn extract_calls_from_block(&self, node: Node, source: &str, entry_points: &mut Vec<String>) {
225        let kind = node.kind();
226
227        if kind == "call" {
228            if let Some(func_node) = node.child_by_field_name("function") {
229                let name = self.extract_call_name(func_node, source);
230                if !name.is_empty() {
231                    entry_points.push(name);
232                }
233            }
234        }
235
236        // Traverse children
237        let mut cursor = node.walk();
238        for child in node.children(&mut cursor) {
239            self.extract_calls_from_block(child, source, entry_points);
240        }
241    }
242}
243
244impl Parser for PythonParser {
245    fn parse(&self, source: &str, file_path: &Path) -> Result<ParsedFile> {
246        // Parser needs to be mutable, so we need to use interior mutability
247        // For now, we'll create a new parser each time (not ideal but works for MVP)
248        let mut parser = TSParser::new();
249        parser.set_language(tree_sitter_python::language())?;
250
251        let tree = parser
252            .parse(source, None)
253            .ok_or_else(|| anyhow::anyhow!("Failed to parse Python file"))?;
254
255        let file_path_str = file_path.to_string_lossy().to_string();
256
257        let definitions = self.extract_definitions(&tree, source, &file_path_str);
258        let usages = self.extract_usages(&tree, source, &file_path_str);
259        let entry_points = self.extract_entry_points(&tree, source);
260
261        Ok(ParsedFile {
262            path: file_path_str,
263            definitions,
264            usages,
265            entry_points,
266        })
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_parse_simple_function() {
276        let parser = PythonParser::new().unwrap();
277        let source = r#"
278def hello():
279    print("Hello, world!")
280"#;
281        let result = parser.parse(source, Path::new("test.py"));
282        assert!(result.is_ok());
283
284        let parsed = result.unwrap();
285        assert_eq!(parsed.definitions.len(), 1);
286        assert_eq!(parsed.definitions[0].name, "hello");
287    }
288
289    #[test]
290    fn test_parse_class_with_methods() {
291        let parser = PythonParser::new().unwrap();
292        let source = r#"
293class Calculator:
294    def add(self, a, b):
295        return a + b
296    
297    def subtract(self, a, b):
298        return a - b
299"#;
300        let result = parser.parse(source, Path::new("test.py"));
301        assert!(result.is_ok());
302
303        let parsed = result.unwrap();
304        // Should have 1 class + 2 methods = 3 definitions
305        assert_eq!(parsed.definitions.len(), 3);
306    }
307
308    #[test]
309    fn test_parse_function_calls() {
310        let parser = PythonParser::new().unwrap();
311        let source = r#"
312def foo():
313    pass
314
315def bar():
316    foo()
317    print("test")
318"#;
319        let result = parser.parse(source, Path::new("test.py"));
320        assert!(result.is_ok());
321
322        let parsed = result.unwrap();
323        assert_eq!(parsed.definitions.len(), 2); // foo, bar
324        assert!(!parsed.usages.is_empty()); // At least foo() call
325    }
326}