acp/ast/languages/
python.rs

1//! @acp:module "Python Extractor"
2//! @acp:summary "Symbol extraction for Python source files"
3//! @acp:domain cli
4//! @acp:layer parsing
5
6use super::{node_text, LanguageExtractor};
7use crate::ast::{
8    ExtractedSymbol, FunctionCall, Import, ImportedName, Parameter, SymbolKind, Visibility,
9};
10use crate::error::Result;
11use tree_sitter::{Language, Node, Tree};
12
13/// Python language extractor
14pub struct PythonExtractor;
15
16impl LanguageExtractor for PythonExtractor {
17    fn language(&self) -> Language {
18        tree_sitter_python::LANGUAGE.into()
19    }
20
21    fn name(&self) -> &'static str {
22        "python"
23    }
24
25    fn extensions(&self) -> &'static [&'static str] {
26        &["py", "pyi"]
27    }
28
29    fn extract_symbols(&self, tree: &Tree, source: &str) -> Result<Vec<ExtractedSymbol>> {
30        let mut symbols = Vec::new();
31        let root = tree.root_node();
32        self.extract_symbols_recursive(&root, source, &mut symbols, None);
33        Ok(symbols)
34    }
35
36    fn extract_imports(&self, tree: &Tree, source: &str) -> Result<Vec<Import>> {
37        let mut imports = Vec::new();
38        let root = tree.root_node();
39        self.extract_imports_recursive(&root, source, &mut imports);
40        Ok(imports)
41    }
42
43    fn extract_calls(
44        &self,
45        tree: &Tree,
46        source: &str,
47        current_function: Option<&str>,
48    ) -> Result<Vec<FunctionCall>> {
49        let mut calls = Vec::new();
50        let root = tree.root_node();
51        self.extract_calls_recursive(&root, source, &mut calls, current_function);
52        Ok(calls)
53    }
54
55    fn extract_doc_comment(&self, node: &Node, source: &str) -> Option<String> {
56        // Look for docstring (first expression statement in function/class body)
57        let body = node.child_by_field_name("body")?;
58        let mut cursor = body.walk();
59        // Only check the first child - docstring must be first statement
60        let first_child = body.children(&mut cursor).next()?;
61        if first_child.kind() == "expression_statement" {
62            if let Some(string) = first_child.child(0) {
63                if string.kind() == "string" {
64                    let text = node_text(&string, source);
65                    return Some(Self::clean_docstring(text));
66                }
67            }
68        }
69        None
70    }
71}
72
73impl PythonExtractor {
74    fn extract_symbols_recursive(
75        &self,
76        node: &Node,
77        source: &str,
78        symbols: &mut Vec<ExtractedSymbol>,
79        parent: Option<&str>,
80    ) {
81        match node.kind() {
82            "function_definition" => {
83                if let Some(sym) = self.extract_function(node, source, parent) {
84                    symbols.push(sym);
85                    // Don't recurse into function body for nested functions
86                    // (they'll be handled when we visit them)
87                }
88            }
89
90            "class_definition" => {
91                if let Some(sym) = self.extract_class(node, source, parent) {
92                    let class_name = sym.name.clone();
93                    symbols.push(sym);
94
95                    // Extract class methods
96                    if let Some(body) = node.child_by_field_name("body") {
97                        self.extract_class_members(&body, source, symbols, Some(&class_name));
98                    }
99                    return; // Don't recurse further
100                }
101            }
102
103            "decorated_definition" => {
104                // Handle decorated functions/classes
105                let mut cursor = node.walk();
106                for child in node.children(&mut cursor) {
107                    if child.kind() == "function_definition" || child.kind() == "class_definition" {
108                        self.extract_symbols_recursive(&child, source, symbols, parent);
109                    }
110                }
111                return;
112            }
113
114            _ => {}
115        }
116
117        // Recurse into children
118        let mut cursor = node.walk();
119        for child in node.children(&mut cursor) {
120            self.extract_symbols_recursive(&child, source, symbols, parent);
121        }
122    }
123
124    fn extract_function(
125        &self,
126        node: &Node,
127        source: &str,
128        parent: Option<&str>,
129    ) -> Option<ExtractedSymbol> {
130        let name_node = node.child_by_field_name("name")?;
131        let name = node_text(&name_node, source).to_string();
132
133        let mut sym = ExtractedSymbol::new(
134            name.clone(),
135            SymbolKind::Function,
136            node.start_position().row + 1,
137            node.end_position().row + 1,
138        )
139        .with_columns(node.start_position().column, node.end_position().column);
140
141        // Check if this is async
142        let text = node_text(node, source);
143        if text.starts_with("async") {
144            sym = sym.async_fn();
145        }
146
147        // Check visibility (Python convention: _name is private, __name is very private)
148        if name.starts_with("__") && !name.ends_with("__") {
149            sym.visibility = Visibility::Private;
150        } else if name.starts_with('_') {
151            sym.visibility = Visibility::Protected;
152        } else {
153            sym = sym.exported();
154        }
155
156        // Extract parameters
157        if let Some(params) = node.child_by_field_name("parameters") {
158            self.extract_parameters(&params, source, &mut sym);
159        }
160
161        // Extract return type annotation
162        if let Some(ret_type) = node.child_by_field_name("return_type") {
163            sym.return_type = Some(
164                node_text(&ret_type, source)
165                    .trim_start_matches("->")
166                    .trim()
167                    .to_string(),
168            );
169        }
170
171        // Extract docstring
172        sym.doc_comment = self.extract_doc_comment(node, source);
173
174        if let Some(p) = parent {
175            sym = sym.with_parent(p);
176            sym.kind = SymbolKind::Method;
177        }
178
179        sym.signature = Some(self.build_function_signature(node, source));
180
181        Some(sym)
182    }
183
184    fn extract_class(
185        &self,
186        node: &Node,
187        source: &str,
188        parent: Option<&str>,
189    ) -> Option<ExtractedSymbol> {
190        let name_node = node.child_by_field_name("name")?;
191        let name = node_text(&name_node, source).to_string();
192
193        let mut sym = ExtractedSymbol::new(
194            name.clone(),
195            SymbolKind::Class,
196            node.start_position().row + 1,
197            node.end_position().row + 1,
198        )
199        .with_columns(node.start_position().column, node.end_position().column);
200
201        // Python classes starting with _ are considered internal
202        if name.starts_with('_') {
203            sym.visibility = Visibility::Protected;
204        } else {
205            sym = sym.exported();
206        }
207
208        // Extract docstring
209        sym.doc_comment = self.extract_doc_comment(node, source);
210
211        if let Some(p) = parent {
212            sym = sym.with_parent(p);
213        }
214
215        Some(sym)
216    }
217
218    fn extract_class_members(
219        &self,
220        body: &Node,
221        source: &str,
222        symbols: &mut Vec<ExtractedSymbol>,
223        class_name: Option<&str>,
224    ) {
225        let mut cursor = body.walk();
226        for child in body.children(&mut cursor) {
227            match child.kind() {
228                "function_definition" => {
229                    if let Some(sym) = self.extract_function(&child, source, class_name) {
230                        symbols.push(sym);
231                    }
232                }
233                "decorated_definition" => {
234                    let mut inner_cursor = child.walk();
235                    for inner in child.children(&mut inner_cursor) {
236                        if inner.kind() == "function_definition" {
237                            if let Some(mut sym) = self.extract_function(&inner, source, class_name)
238                            {
239                                // Check for @staticmethod or @classmethod
240                                let deco_text = node_text(&child, source);
241                                if deco_text.contains("@staticmethod") {
242                                    sym = sym.static_fn();
243                                }
244                                symbols.push(sym);
245                            }
246                        }
247                    }
248                }
249                _ => {}
250            }
251        }
252    }
253
254    fn extract_parameters(&self, params: &Node, source: &str, sym: &mut ExtractedSymbol) {
255        let mut cursor = params.walk();
256        for child in params.children(&mut cursor) {
257            match child.kind() {
258                "identifier" => {
259                    let name = node_text(&child, source);
260                    // Skip 'self' and 'cls'
261                    if name != "self" && name != "cls" {
262                        sym.add_parameter(Parameter {
263                            name: name.to_string(),
264                            type_info: None,
265                            default_value: None,
266                            is_rest: false,
267                            is_optional: false,
268                        });
269                    }
270                }
271                "typed_parameter" => {
272                    let name = child
273                        .child_by_field_name("name")
274                        .map(|n| node_text(&n, source).to_string())
275                        .unwrap_or_default();
276
277                    if name != "self" && name != "cls" {
278                        let type_info = child
279                            .child_by_field_name("type")
280                            .map(|n| node_text(&n, source).to_string());
281
282                        sym.add_parameter(Parameter {
283                            name,
284                            type_info,
285                            default_value: None,
286                            is_rest: false,
287                            is_optional: false,
288                        });
289                    }
290                }
291                "default_parameter" | "typed_default_parameter" => {
292                    let name = child
293                        .child_by_field_name("name")
294                        .map(|n| node_text(&n, source).to_string())
295                        .unwrap_or_default();
296
297                    if name != "self" && name != "cls" {
298                        let type_info = child
299                            .child_by_field_name("type")
300                            .map(|n| node_text(&n, source).to_string());
301                        let default_value = child
302                            .child_by_field_name("value")
303                            .map(|n| node_text(&n, source).to_string());
304
305                        sym.add_parameter(Parameter {
306                            name,
307                            type_info,
308                            default_value,
309                            is_rest: false,
310                            is_optional: true,
311                        });
312                    }
313                }
314                "list_splat_pattern" | "dictionary_splat_pattern" => {
315                    let text = node_text(&child, source);
316                    let name = text.trim_start_matches('*').to_string();
317                    let is_kwargs = text.starts_with("**");
318
319                    sym.add_parameter(Parameter {
320                        name,
321                        type_info: None,
322                        default_value: None,
323                        is_rest: !is_kwargs,
324                        is_optional: true,
325                    });
326                }
327                _ => {}
328            }
329        }
330    }
331
332    fn extract_imports_recursive(&self, node: &Node, source: &str, imports: &mut Vec<Import>) {
333        match node.kind() {
334            "import_statement" => {
335                if let Some(import) = self.parse_import(node, source) {
336                    imports.push(import);
337                }
338            }
339            "import_from_statement" => {
340                if let Some(import) = self.parse_from_import(node, source) {
341                    imports.push(import);
342                }
343            }
344            _ => {}
345        }
346
347        let mut cursor = node.walk();
348        for child in node.children(&mut cursor) {
349            self.extract_imports_recursive(&child, source, imports);
350        }
351    }
352
353    fn parse_import(&self, node: &Node, source: &str) -> Option<Import> {
354        let mut import = Import {
355            source: String::new(),
356            names: Vec::new(),
357            is_default: false,
358            is_namespace: false,
359            line: node.start_position().row + 1,
360        };
361
362        let mut cursor = node.walk();
363        for child in node.children(&mut cursor) {
364            match child.kind() {
365                "dotted_name" => {
366                    let name = node_text(&child, source).to_string();
367                    import.source = name.clone();
368                    import.names.push(ImportedName { name, alias: None });
369                }
370                "aliased_import" => {
371                    let name = child
372                        .child_by_field_name("name")
373                        .map(|n| node_text(&n, source).to_string())
374                        .unwrap_or_default();
375                    let alias = child
376                        .child_by_field_name("alias")
377                        .map(|n| node_text(&n, source).to_string());
378
379                    import.source = name.clone();
380                    import.names.push(ImportedName { name, alias });
381                }
382                _ => {}
383            }
384        }
385
386        Some(import)
387    }
388
389    fn parse_from_import(&self, node: &Node, source: &str) -> Option<Import> {
390        let module = node
391            .child_by_field_name("module_name")
392            .map(|n| node_text(&n, source).to_string())
393            .unwrap_or_default();
394
395        let mut import = Import {
396            source: module,
397            names: Vec::new(),
398            is_default: false,
399            is_namespace: false,
400            line: node.start_position().row + 1,
401        };
402
403        let mut cursor = node.walk();
404        for child in node.children(&mut cursor) {
405            match child.kind() {
406                "wildcard_import" => {
407                    import.is_namespace = true;
408                    import.names.push(ImportedName {
409                        name: "*".to_string(),
410                        alias: None,
411                    });
412                }
413                "dotted_name" | "identifier" => {
414                    import.names.push(ImportedName {
415                        name: node_text(&child, source).to_string(),
416                        alias: None,
417                    });
418                }
419                "aliased_import" => {
420                    let name = child
421                        .child_by_field_name("name")
422                        .map(|n| node_text(&n, source).to_string())
423                        .unwrap_or_default();
424                    let alias = child
425                        .child_by_field_name("alias")
426                        .map(|n| node_text(&n, source).to_string());
427
428                    import.names.push(ImportedName { name, alias });
429                }
430                _ => {}
431            }
432        }
433
434        Some(import)
435    }
436
437    fn extract_calls_recursive(
438        &self,
439        node: &Node,
440        source: &str,
441        calls: &mut Vec<FunctionCall>,
442        current_function: Option<&str>,
443    ) {
444        if node.kind() == "call" {
445            if let Some(call) = self.parse_call(node, source, current_function) {
446                calls.push(call);
447            }
448        }
449
450        let func_name = if node.kind() == "function_definition" {
451            node.child_by_field_name("name")
452                .map(|n| node_text(&n, source))
453        } else {
454            None
455        };
456
457        let current = func_name
458            .map(String::from)
459            .or_else(|| current_function.map(String::from));
460
461        let mut cursor = node.walk();
462        for child in node.children(&mut cursor) {
463            self.extract_calls_recursive(&child, source, calls, current.as_deref());
464        }
465    }
466
467    fn parse_call(
468        &self,
469        node: &Node,
470        source: &str,
471        current_function: Option<&str>,
472    ) -> Option<FunctionCall> {
473        let function = node.child_by_field_name("function")?;
474
475        let (callee, is_method, receiver) = match function.kind() {
476            "attribute" => {
477                let object = function
478                    .child_by_field_name("object")
479                    .map(|n| node_text(&n, source).to_string());
480                let attr = function
481                    .child_by_field_name("attribute")
482                    .map(|n| node_text(&n, source).to_string())?;
483                (attr, true, object)
484            }
485            "identifier" => (node_text(&function, source).to_string(), false, None),
486            _ => return None,
487        };
488
489        Some(FunctionCall {
490            caller: current_function.unwrap_or("<module>").to_string(),
491            callee,
492            line: node.start_position().row + 1,
493            is_method,
494            receiver,
495        })
496    }
497
498    fn build_function_signature(&self, node: &Node, source: &str) -> String {
499        let async_kw = if node_text(node, source).starts_with("async") {
500            "async "
501        } else {
502            ""
503        };
504
505        let name = node
506            .child_by_field_name("name")
507            .map(|n| node_text(&n, source))
508            .unwrap_or("unknown");
509
510        let params = node
511            .child_by_field_name("parameters")
512            .map(|n| node_text(&n, source))
513            .unwrap_or("()");
514
515        let return_type = node
516            .child_by_field_name("return_type")
517            .map(|n| format!(" {}", node_text(&n, source)))
518            .unwrap_or_default();
519
520        format!("{}def {}{}{}", async_kw, name, params, return_type)
521    }
522
523    fn clean_docstring(text: &str) -> String {
524        // Remove quotes
525        let text = text
526            .trim_start_matches("\"\"\"")
527            .trim_start_matches("'''")
528            .trim_end_matches("\"\"\"")
529            .trim_end_matches("'''")
530            .trim();
531
532        text.to_string()
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    fn parse_py(source: &str) -> (Tree, String) {
541        let mut parser = tree_sitter::Parser::new();
542        parser
543            .set_language(&tree_sitter_python::LANGUAGE.into())
544            .unwrap();
545        let tree = parser.parse(source, None).unwrap();
546        (tree, source.to_string())
547    }
548
549    #[test]
550    fn test_extract_function() {
551        let source = r#"
552def greet(name: str) -> str:
553    """Greet someone."""
554    return f"Hello, {name}!"
555"#;
556        let (tree, src) = parse_py(source);
557        let extractor = PythonExtractor;
558        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
559
560        assert_eq!(symbols.len(), 1);
561        assert_eq!(symbols[0].name, "greet");
562        assert_eq!(symbols[0].kind, SymbolKind::Function);
563    }
564
565    #[test]
566    fn test_extract_class() {
567        let source = r#"
568class UserService:
569    """A service for managing users."""
570
571    def __init__(self, name: str):
572        self.name = name
573
574    def greet(self) -> str:
575        return f"Hello, {self.name}!"
576
577    @staticmethod
578    def create():
579        return UserService("default")
580"#;
581        let (tree, src) = parse_py(source);
582        let extractor = PythonExtractor;
583        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
584
585        assert!(symbols
586            .iter()
587            .any(|s| s.name == "UserService" && s.kind == SymbolKind::Class));
588        assert!(symbols
589            .iter()
590            .any(|s| s.name == "__init__" && s.kind == SymbolKind::Method));
591        assert!(symbols
592            .iter()
593            .any(|s| s.name == "greet" && s.kind == SymbolKind::Method));
594        assert!(symbols
595            .iter()
596            .any(|s| s.name == "create" && s.kind == SymbolKind::Method));
597    }
598
599    #[test]
600    fn test_extract_async_function() {
601        let source = r#"
602async def fetch_data(url: str) -> dict:
603    """Fetch data from URL."""
604    pass
605"#;
606        let (tree, src) = parse_py(source);
607        let extractor = PythonExtractor;
608        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
609
610        assert_eq!(symbols.len(), 1);
611        assert_eq!(symbols[0].name, "fetch_data");
612        assert!(symbols[0].is_async);
613    }
614}