Skip to main content

agentic_codebase/parse/
python.rs

1//! Python-specific parsing using tree-sitter.
2//!
3//! Extracts functions, classes, imports, docstrings, and async patterns.
4
5use std::path::Path;
6
7use crate::types::{AcbResult, CodeUnitType, Language, Visibility};
8
9use super::treesitter::{count_complexity, get_node_text, node_to_span};
10use super::{LanguageParser, RawCodeUnit, RawReference, ReferenceKind};
11
12/// Python language parser.
13pub struct PythonParser;
14
15impl Default for PythonParser {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl PythonParser {
22    /// Create a new Python parser.
23    pub fn new() -> Self {
24        Self
25    }
26
27    fn extract_from_node(
28        &self,
29        node: tree_sitter::Node,
30        source: &str,
31        file_path: &Path,
32        units: &mut Vec<RawCodeUnit>,
33        next_id: &mut u64,
34        parent_qname: &str,
35    ) {
36        let mut cursor = node.walk();
37        for child in node.children(&mut cursor) {
38            match child.kind() {
39                "function_definition" => {
40                    if let Some(unit) = self.extract_function(
41                        child,
42                        source,
43                        file_path,
44                        false,
45                        parent_qname,
46                        next_id,
47                    ) {
48                        let qname = unit.qualified_name.clone();
49                        units.push(unit);
50                        // Recurse into function body for nested definitions
51                        if let Some(body) = child.child_by_field_name("body") {
52                            self.extract_from_node(body, source, file_path, units, next_id, &qname);
53                        }
54                    }
55                }
56                "async_function_definition" | "async function_definition" => {
57                    // tree-sitter-python uses "function_definition" inside decorated nodes
58                    // but async functions may appear differently
59                }
60                "class_definition" => {
61                    if let Some(unit) =
62                        self.extract_class(child, source, file_path, parent_qname, next_id)
63                    {
64                        let qname = unit.qualified_name.clone();
65                        units.push(unit);
66                        if let Some(body) = child.child_by_field_name("body") {
67                            self.extract_from_node(body, source, file_path, units, next_id, &qname);
68                        }
69                    }
70                }
71                "import_statement" | "import_from_statement" => {
72                    if let Some(unit) =
73                        self.extract_import(child, source, file_path, parent_qname, next_id)
74                    {
75                        units.push(unit);
76                    }
77                }
78                "decorated_definition" => {
79                    // Look inside the decorated definition for the actual def/class
80                    self.extract_from_node(child, source, file_path, units, next_id, parent_qname);
81                }
82                _ => {
83                    // Check for assignments at module level (constants)
84                    // and recurse into other compound statements
85                }
86            }
87        }
88    }
89
90    fn extract_function(
91        &self,
92        node: tree_sitter::Node,
93        source: &str,
94        file_path: &Path,
95        _is_nested: bool,
96        parent_qname: &str,
97        next_id: &mut u64,
98    ) -> Option<RawCodeUnit> {
99        let name_node = node.child_by_field_name("name")?;
100        let name = get_node_text(name_node, source).to_string();
101
102        let qname = if parent_qname.is_empty() {
103            name.clone()
104        } else {
105            format!("{}.{}", parent_qname, name)
106        };
107
108        let is_async = node.kind() == "async_function_definition"
109            || node
110                .parent()
111                .map(|p| {
112                    let mut c = p.walk();
113                    let result = p
114                        .children(&mut c)
115                        .any(|ch| ch.kind() == "async" || get_node_text(ch, source) == "async");
116                    result
117                })
118                .unwrap_or(false);
119
120        let span = node_to_span(node);
121
122        // Extract signature from parameters
123        let sig = node.child_by_field_name("parameters").map(|params| {
124            let params_text = get_node_text(params, source);
125            let ret = node
126                .child_by_field_name("return_type")
127                .map(|r| format!(" -> {}", get_node_text(r, source)))
128                .unwrap_or_default();
129            format!("{}{}", params_text, ret)
130        });
131
132        // Extract docstring
133        let doc = self.extract_docstring(node, source);
134
135        // Visibility from name convention
136        let vis = python_visibility(&name);
137
138        // Complexity
139        let complexity_kinds = &[
140            "if_statement",
141            "elif_clause",
142            "for_statement",
143            "while_statement",
144            "try_statement",
145            "except_clause",
146            "with_statement",
147            "boolean_operator",
148            "conditional_expression",
149        ];
150        let complexity = count_complexity(node, complexity_kinds);
151
152        // Check for yield (generator)
153        let is_generator = source[node.byte_range()].contains("yield");
154
155        let id = *next_id;
156        *next_id += 1;
157
158        // Determine if this is a test function
159        let unit_type = if name.starts_with("test_") || name.starts_with("test") {
160            CodeUnitType::Test
161        } else {
162            CodeUnitType::Function
163        };
164
165        let mut unit = RawCodeUnit::new(
166            unit_type,
167            Language::Python,
168            name,
169            file_path.to_path_buf(),
170            span,
171        );
172        unit.temp_id = id;
173        unit.qualified_name = qname;
174        unit.signature = sig;
175        unit.doc = doc;
176        unit.visibility = vis;
177        unit.is_async = is_async;
178        unit.is_generator = is_generator;
179        unit.complexity = complexity;
180
181        // Extract call references from function body
182        if let Some(body) = node.child_by_field_name("body") {
183            self.extract_call_refs(body, source, &mut unit.references);
184        }
185
186        Some(unit)
187    }
188
189    fn extract_class(
190        &self,
191        node: tree_sitter::Node,
192        source: &str,
193        file_path: &Path,
194        parent_qname: &str,
195        next_id: &mut u64,
196    ) -> Option<RawCodeUnit> {
197        let name_node = node.child_by_field_name("name")?;
198        let name = get_node_text(name_node, source).to_string();
199
200        let qname = if parent_qname.is_empty() {
201            name.clone()
202        } else {
203            format!("{}.{}", parent_qname, name)
204        };
205
206        let span = node_to_span(node);
207        let doc = self.extract_docstring(node, source);
208        let vis = python_visibility(&name);
209
210        let id = *next_id;
211        *next_id += 1;
212
213        let mut unit = RawCodeUnit::new(
214            CodeUnitType::Type,
215            Language::Python,
216            name,
217            file_path.to_path_buf(),
218            span,
219        );
220        unit.temp_id = id;
221        unit.qualified_name = qname;
222        unit.doc = doc;
223        unit.visibility = vis;
224
225        // Extract base classes as inheritance references
226        if let Some(args) = node.child_by_field_name("superclasses") {
227            let mut cursor = args.walk();
228            for child in args.children(&mut cursor) {
229                if child.kind() == "identifier" || child.kind() == "attribute" {
230                    let base_name = get_node_text(child, source).to_string();
231                    unit.references.push(RawReference {
232                        name: base_name,
233                        kind: ReferenceKind::Inherit,
234                        span: node_to_span(child),
235                    });
236                }
237            }
238        }
239
240        Some(unit)
241    }
242
243    fn extract_import(
244        &self,
245        node: tree_sitter::Node,
246        source: &str,
247        file_path: &Path,
248        parent_qname: &str,
249        next_id: &mut u64,
250    ) -> Option<RawCodeUnit> {
251        let text = get_node_text(node, source).to_string();
252        let span = node_to_span(node);
253
254        // Derive a name from the import text
255        let import_name = text
256            .trim_start_matches("from ")
257            .trim_start_matches("import ")
258            .split_whitespace()
259            .next()
260            .unwrap_or("unknown")
261            .to_string();
262
263        let id = *next_id;
264        *next_id += 1;
265
266        let mut unit = RawCodeUnit::new(
267            CodeUnitType::Import,
268            Language::Python,
269            import_name.clone(),
270            file_path.to_path_buf(),
271            span,
272        );
273        unit.temp_id = id;
274        unit.qualified_name = if parent_qname.is_empty() {
275            import_name.clone()
276        } else {
277            format!("{}.{}", parent_qname, import_name)
278        };
279
280        unit.references.push(RawReference {
281            name: import_name,
282            kind: ReferenceKind::Import,
283            span,
284        });
285
286        Some(unit)
287    }
288
289    fn extract_docstring(&self, node: tree_sitter::Node, source: &str) -> Option<String> {
290        let body = node.child_by_field_name("body")?;
291        let mut cursor = body.walk();
292        let first_stmt = body.children(&mut cursor).next()?;
293
294        if first_stmt.kind() == "expression_statement" {
295            let mut c2 = first_stmt.walk();
296            let expr = first_stmt.children(&mut c2).next()?;
297            if expr.kind() == "string" {
298                let text = get_node_text(expr, source);
299                return Some(clean_docstring(text));
300            }
301        }
302        None
303    }
304
305    #[allow(clippy::only_used_in_recursion)]
306    fn extract_call_refs(
307        &self,
308        node: tree_sitter::Node,
309        source: &str,
310        refs: &mut Vec<RawReference>,
311    ) {
312        if node.kind() == "call" {
313            if let Some(func) = node.child_by_field_name("function") {
314                let name = get_node_text(func, source).to_string();
315                refs.push(RawReference {
316                    name,
317                    kind: ReferenceKind::Call,
318                    span: node_to_span(node),
319                });
320            }
321        }
322        let mut cursor = node.walk();
323        for child in node.children(&mut cursor) {
324            self.extract_call_refs(child, source, refs);
325        }
326    }
327}
328
329impl LanguageParser for PythonParser {
330    fn extract_units(
331        &self,
332        tree: &tree_sitter::Tree,
333        source: &str,
334        file_path: &Path,
335    ) -> AcbResult<Vec<RawCodeUnit>> {
336        let mut units = Vec::new();
337        let mut next_id = 0u64;
338
339        // Create module unit for the file
340        let module_name = file_path
341            .file_stem()
342            .and_then(|s| s.to_str())
343            .unwrap_or("unknown")
344            .to_string();
345
346        let root_span = node_to_span(tree.root_node());
347        let mut module_unit = RawCodeUnit::new(
348            CodeUnitType::Module,
349            Language::Python,
350            module_name.clone(),
351            file_path.to_path_buf(),
352            root_span,
353        );
354        module_unit.temp_id = next_id;
355        module_unit.qualified_name = module_name.clone();
356        next_id += 1;
357        units.push(module_unit);
358
359        // Extract all definitions
360        self.extract_from_node(
361            tree.root_node(),
362            source,
363            file_path,
364            &mut units,
365            &mut next_id,
366            &module_name,
367        );
368
369        Ok(units)
370    }
371
372    fn is_test_file(&self, path: &Path, source: &str) -> bool {
373        let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
374        name.starts_with("test_")
375            || name.ends_with("_test.py")
376            || path.components().any(|c| c.as_os_str() == "tests")
377            || source.contains("import pytest")
378            || source.contains("import unittest")
379    }
380}
381
382fn python_visibility(name: &str) -> Visibility {
383    if name.starts_with("__") && !name.ends_with("__") {
384        Visibility::Private
385    } else if name.starts_with('_') {
386        Visibility::Internal
387    } else {
388        Visibility::Public
389    }
390}
391
392fn clean_docstring(raw: &str) -> String {
393    let trimmed = raw
394        .trim_start_matches("\"\"\"")
395        .trim_end_matches("\"\"\"")
396        .trim_start_matches("'''")
397        .trim_end_matches("'''")
398        .trim_start_matches('"')
399        .trim_end_matches('"')
400        .trim_start_matches('\'')
401        .trim_end_matches('\'');
402    trimmed.lines().next().unwrap_or("").trim().to_string()
403}