acp/ast/languages/
python.rs

1//! @acp:module "Python Extractor"
2//! @acp:summary "Symbol extraction for Python source files"
3//! @acp:domain cli
4//! @acp:layer parsing
5
6use super::{node_text, LanguageExtractor};
7use crate::ast::{
8    ExtractedSymbol, FunctionCall, Import, ImportedName, Parameter, SymbolKind, Visibility,
9};
10use crate::error::Result;
11use tree_sitter::{Language, Node, Tree};
12
13/// Python language extractor
14pub struct PythonExtractor;
15
16impl LanguageExtractor for PythonExtractor {
17    fn language(&self) -> Language {
18        tree_sitter_python::LANGUAGE.into()
19    }
20
21    fn name(&self) -> &'static str {
22        "python"
23    }
24
25    fn extensions(&self) -> &'static [&'static str] {
26        &["py", "pyi"]
27    }
28
29    fn extract_symbols(&self, tree: &Tree, source: &str) -> Result<Vec<ExtractedSymbol>> {
30        let mut symbols = Vec::new();
31        let root = tree.root_node();
32        self.extract_symbols_recursive(&root, source, &mut symbols, None);
33        Ok(symbols)
34    }
35
36    fn extract_imports(&self, tree: &Tree, source: &str) -> Result<Vec<Import>> {
37        let mut imports = Vec::new();
38        let root = tree.root_node();
39        self.extract_imports_recursive(&root, source, &mut imports);
40        Ok(imports)
41    }
42
43    fn extract_calls(
44        &self,
45        tree: &Tree,
46        source: &str,
47        current_function: Option<&str>,
48    ) -> Result<Vec<FunctionCall>> {
49        let mut calls = Vec::new();
50        let root = tree.root_node();
51        self.extract_calls_recursive(&root, source, &mut calls, current_function);
52        Ok(calls)
53    }
54
55    fn extract_doc_comment(&self, node: &Node, source: &str) -> Option<String> {
56        // Look for docstring (first expression statement in function/class body)
57        let body = node.child_by_field_name("body")?;
58        let mut cursor = body.walk();
59        // Only check the first child - docstring must be first statement
60        let first_child = body.children(&mut cursor).next()?;
61        if first_child.kind() == "expression_statement" {
62            if let Some(string) = first_child.child(0) {
63                if string.kind() == "string" {
64                    let text = node_text(&string, source);
65                    return Some(Self::clean_docstring(text));
66                }
67            }
68        }
69        None
70    }
71}
72
73impl PythonExtractor {
74    fn extract_symbols_recursive(
75        &self,
76        node: &Node,
77        source: &str,
78        symbols: &mut Vec<ExtractedSymbol>,
79        parent: Option<&str>,
80    ) {
81        match node.kind() {
82            "function_definition" => {
83                if let Some(sym) = self.extract_function(node, source, parent) {
84                    symbols.push(sym);
85                    // Don't recurse into function body for nested functions
86                    // (they'll be handled when we visit them)
87                }
88            }
89
90            "class_definition" => {
91                if let Some(sym) = self.extract_class(node, source, parent) {
92                    let class_name = sym.name.clone();
93                    symbols.push(sym);
94
95                    // Extract class methods
96                    if let Some(body) = node.child_by_field_name("body") {
97                        self.extract_class_members(&body, source, symbols, Some(&class_name));
98                    }
99                    return; // Don't recurse further
100                }
101            }
102
103            "decorated_definition" => {
104                // Handle decorated functions/classes
105                // Find the first decorator to get definition_start_line
106                let decorator_start = node.start_position().row + 1;
107
108                let mut cursor = node.walk();
109                for child in node.children(&mut cursor) {
110                    if child.kind() == "function_definition" {
111                        if let Some(mut sym) = self.extract_function(&child, source, parent) {
112                            // Set definition_start_line to before the decorators
113                            sym.definition_start_line = Some(decorator_start);
114                            symbols.push(sym);
115                        }
116                    } else if child.kind() == "class_definition" {
117                        if let Some(mut sym) = self.extract_class(&child, source, parent) {
118                            let class_name = sym.name.clone();
119                            sym.definition_start_line = Some(decorator_start);
120                            symbols.push(sym);
121
122                            // Extract class methods
123                            if let Some(body) = child.child_by_field_name("body") {
124                                self.extract_class_members(
125                                    &body,
126                                    source,
127                                    symbols,
128                                    Some(&class_name),
129                                );
130                            }
131                        }
132                    }
133                }
134                return;
135            }
136
137            _ => {}
138        }
139
140        // Recurse into children
141        let mut cursor = node.walk();
142        for child in node.children(&mut cursor) {
143            self.extract_symbols_recursive(&child, source, symbols, parent);
144        }
145    }
146
147    fn extract_function(
148        &self,
149        node: &Node,
150        source: &str,
151        parent: Option<&str>,
152    ) -> Option<ExtractedSymbol> {
153        let name_node = node.child_by_field_name("name")?;
154        let name = node_text(&name_node, source).to_string();
155
156        let mut sym = ExtractedSymbol::new(
157            name.clone(),
158            SymbolKind::Function,
159            node.start_position().row + 1,
160            node.end_position().row + 1,
161        )
162        .with_columns(node.start_position().column, node.end_position().column);
163
164        // Check if this is async
165        let text = node_text(node, source);
166        if text.starts_with("async") {
167            sym = sym.async_fn();
168        }
169
170        // Check visibility (Python convention: _name is private, __name is very private)
171        if name.starts_with("__") && !name.ends_with("__") {
172            sym.visibility = Visibility::Private;
173        } else if name.starts_with('_') {
174            sym.visibility = Visibility::Protected;
175        } else {
176            sym = sym.exported();
177        }
178
179        // Extract parameters
180        if let Some(params) = node.child_by_field_name("parameters") {
181            self.extract_parameters(&params, source, &mut sym);
182        }
183
184        // Extract return type annotation
185        if let Some(ret_type) = node.child_by_field_name("return_type") {
186            sym.return_type = Some(
187                node_text(&ret_type, source)
188                    .trim_start_matches("->")
189                    .trim()
190                    .to_string(),
191            );
192        }
193
194        // Extract docstring
195        sym.doc_comment = self.extract_doc_comment(node, source);
196
197        if let Some(p) = parent {
198            sym = sym.with_parent(p);
199            sym.kind = SymbolKind::Method;
200        }
201
202        sym.signature = Some(self.build_function_signature(node, source));
203
204        // For non-decorated functions, definition_start_line equals start_line
205        if sym.definition_start_line.is_none() {
206            sym.definition_start_line = Some(node.start_position().row + 1);
207        }
208
209        Some(sym)
210    }
211
212    fn extract_class(
213        &self,
214        node: &Node,
215        source: &str,
216        parent: Option<&str>,
217    ) -> Option<ExtractedSymbol> {
218        let name_node = node.child_by_field_name("name")?;
219        let name = node_text(&name_node, source).to_string();
220
221        let mut sym = ExtractedSymbol::new(
222            name.clone(),
223            SymbolKind::Class,
224            node.start_position().row + 1,
225            node.end_position().row + 1,
226        )
227        .with_columns(node.start_position().column, node.end_position().column);
228
229        // Python classes starting with _ are considered internal
230        if name.starts_with('_') {
231            sym.visibility = Visibility::Protected;
232        } else {
233            sym = sym.exported();
234        }
235
236        // Extract docstring
237        sym.doc_comment = self.extract_doc_comment(node, source);
238
239        if let Some(p) = parent {
240            sym = sym.with_parent(p);
241        }
242
243        // For non-decorated classes, definition_start_line equals start_line
244        if sym.definition_start_line.is_none() {
245            sym.definition_start_line = Some(node.start_position().row + 1);
246        }
247
248        Some(sym)
249    }
250
251    fn extract_class_members(
252        &self,
253        body: &Node,
254        source: &str,
255        symbols: &mut Vec<ExtractedSymbol>,
256        class_name: Option<&str>,
257    ) {
258        let mut cursor = body.walk();
259        for child in body.children(&mut cursor) {
260            match child.kind() {
261                "function_definition" => {
262                    if let Some(sym) = self.extract_function(&child, source, class_name) {
263                        symbols.push(sym);
264                    }
265                }
266                "decorated_definition" => {
267                    // Capture decorator start line for definition_start_line
268                    let decorator_start = child.start_position().row + 1;
269
270                    let mut inner_cursor = child.walk();
271                    for inner in child.children(&mut inner_cursor) {
272                        if inner.kind() == "function_definition" {
273                            if let Some(mut sym) = self.extract_function(&inner, source, class_name)
274                            {
275                                // Set definition_start_line to before the decorators
276                                sym.definition_start_line = Some(decorator_start);
277
278                                // Check for @staticmethod or @classmethod
279                                let deco_text = node_text(&child, source);
280                                if deco_text.contains("@staticmethod") {
281                                    sym = sym.static_fn();
282                                }
283                                symbols.push(sym);
284                            }
285                        }
286                    }
287                }
288                _ => {}
289            }
290        }
291    }
292
293    fn extract_parameters(&self, params: &Node, source: &str, sym: &mut ExtractedSymbol) {
294        let mut cursor = params.walk();
295        for child in params.children(&mut cursor) {
296            match child.kind() {
297                "identifier" => {
298                    let name = node_text(&child, source);
299                    // Skip 'self' and 'cls'
300                    if name != "self" && name != "cls" {
301                        sym.add_parameter(Parameter {
302                            name: name.to_string(),
303                            type_info: None,
304                            default_value: None,
305                            is_rest: false,
306                            is_optional: false,
307                        });
308                    }
309                }
310                "typed_parameter" => {
311                    let name = child
312                        .child_by_field_name("name")
313                        .map(|n| node_text(&n, source).to_string())
314                        .unwrap_or_default();
315
316                    if name != "self" && name != "cls" {
317                        let type_info = child
318                            .child_by_field_name("type")
319                            .map(|n| node_text(&n, source).to_string());
320
321                        sym.add_parameter(Parameter {
322                            name,
323                            type_info,
324                            default_value: None,
325                            is_rest: false,
326                            is_optional: false,
327                        });
328                    }
329                }
330                "default_parameter" | "typed_default_parameter" => {
331                    let name = child
332                        .child_by_field_name("name")
333                        .map(|n| node_text(&n, source).to_string())
334                        .unwrap_or_default();
335
336                    if name != "self" && name != "cls" {
337                        let type_info = child
338                            .child_by_field_name("type")
339                            .map(|n| node_text(&n, source).to_string());
340                        let default_value = child
341                            .child_by_field_name("value")
342                            .map(|n| node_text(&n, source).to_string());
343
344                        sym.add_parameter(Parameter {
345                            name,
346                            type_info,
347                            default_value,
348                            is_rest: false,
349                            is_optional: true,
350                        });
351                    }
352                }
353                "list_splat_pattern" | "dictionary_splat_pattern" => {
354                    let text = node_text(&child, source);
355                    let name = text.trim_start_matches('*').to_string();
356                    let is_kwargs = text.starts_with("**");
357
358                    sym.add_parameter(Parameter {
359                        name,
360                        type_info: None,
361                        default_value: None,
362                        is_rest: !is_kwargs,
363                        is_optional: true,
364                    });
365                }
366                _ => {}
367            }
368        }
369    }
370
371    fn extract_imports_recursive(&self, node: &Node, source: &str, imports: &mut Vec<Import>) {
372        match node.kind() {
373            "import_statement" => {
374                if let Some(import) = self.parse_import(node, source) {
375                    imports.push(import);
376                }
377            }
378            "import_from_statement" => {
379                if let Some(import) = self.parse_from_import(node, source) {
380                    imports.push(import);
381                }
382            }
383            _ => {}
384        }
385
386        let mut cursor = node.walk();
387        for child in node.children(&mut cursor) {
388            self.extract_imports_recursive(&child, source, imports);
389        }
390    }
391
392    fn parse_import(&self, node: &Node, source: &str) -> Option<Import> {
393        let mut import = Import {
394            source: String::new(),
395            names: Vec::new(),
396            is_default: false,
397            is_namespace: false,
398            line: node.start_position().row + 1,
399        };
400
401        let mut cursor = node.walk();
402        for child in node.children(&mut cursor) {
403            match child.kind() {
404                "dotted_name" => {
405                    let name = node_text(&child, source).to_string();
406                    import.source = name.clone();
407                    import.names.push(ImportedName { name, alias: None });
408                }
409                "aliased_import" => {
410                    let name = child
411                        .child_by_field_name("name")
412                        .map(|n| node_text(&n, source).to_string())
413                        .unwrap_or_default();
414                    let alias = child
415                        .child_by_field_name("alias")
416                        .map(|n| node_text(&n, source).to_string());
417
418                    import.source = name.clone();
419                    import.names.push(ImportedName { name, alias });
420                }
421                _ => {}
422            }
423        }
424
425        Some(import)
426    }
427
428    fn parse_from_import(&self, node: &Node, source: &str) -> Option<Import> {
429        let module = node
430            .child_by_field_name("module_name")
431            .map(|n| node_text(&n, source).to_string())
432            .unwrap_or_default();
433
434        let mut import = Import {
435            source: module,
436            names: Vec::new(),
437            is_default: false,
438            is_namespace: false,
439            line: node.start_position().row + 1,
440        };
441
442        let mut cursor = node.walk();
443        for child in node.children(&mut cursor) {
444            match child.kind() {
445                "wildcard_import" => {
446                    import.is_namespace = true;
447                    import.names.push(ImportedName {
448                        name: "*".to_string(),
449                        alias: None,
450                    });
451                }
452                "dotted_name" | "identifier" => {
453                    import.names.push(ImportedName {
454                        name: node_text(&child, source).to_string(),
455                        alias: None,
456                    });
457                }
458                "aliased_import" => {
459                    let name = child
460                        .child_by_field_name("name")
461                        .map(|n| node_text(&n, source).to_string())
462                        .unwrap_or_default();
463                    let alias = child
464                        .child_by_field_name("alias")
465                        .map(|n| node_text(&n, source).to_string());
466
467                    import.names.push(ImportedName { name, alias });
468                }
469                _ => {}
470            }
471        }
472
473        Some(import)
474    }
475
476    fn extract_calls_recursive(
477        &self,
478        node: &Node,
479        source: &str,
480        calls: &mut Vec<FunctionCall>,
481        current_function: Option<&str>,
482    ) {
483        if node.kind() == "call" {
484            if let Some(call) = self.parse_call(node, source, current_function) {
485                calls.push(call);
486            }
487        }
488
489        let func_name = if node.kind() == "function_definition" {
490            node.child_by_field_name("name")
491                .map(|n| node_text(&n, source))
492        } else {
493            None
494        };
495
496        let current = func_name
497            .map(String::from)
498            .or_else(|| current_function.map(String::from));
499
500        let mut cursor = node.walk();
501        for child in node.children(&mut cursor) {
502            self.extract_calls_recursive(&child, source, calls, current.as_deref());
503        }
504    }
505
506    fn parse_call(
507        &self,
508        node: &Node,
509        source: &str,
510        current_function: Option<&str>,
511    ) -> Option<FunctionCall> {
512        let function = node.child_by_field_name("function")?;
513
514        let (callee, is_method, receiver) = match function.kind() {
515            "attribute" => {
516                let object = function
517                    .child_by_field_name("object")
518                    .map(|n| node_text(&n, source).to_string());
519                let attr = function
520                    .child_by_field_name("attribute")
521                    .map(|n| node_text(&n, source).to_string())?;
522                (attr, true, object)
523            }
524            "identifier" => (node_text(&function, source).to_string(), false, None),
525            _ => return None,
526        };
527
528        Some(FunctionCall {
529            caller: current_function.unwrap_or("<module>").to_string(),
530            callee,
531            line: node.start_position().row + 1,
532            is_method,
533            receiver,
534        })
535    }
536
537    fn build_function_signature(&self, node: &Node, source: &str) -> String {
538        let async_kw = if node_text(node, source).starts_with("async") {
539            "async "
540        } else {
541            ""
542        };
543
544        let name = node
545            .child_by_field_name("name")
546            .map(|n| node_text(&n, source))
547            .unwrap_or("unknown");
548
549        let params = node
550            .child_by_field_name("parameters")
551            .map(|n| node_text(&n, source))
552            .unwrap_or("()");
553
554        let return_type = node
555            .child_by_field_name("return_type")
556            .map(|n| format!(" {}", node_text(&n, source)))
557            .unwrap_or_default();
558
559        format!("{}def {}{}{}", async_kw, name, params, return_type)
560    }
561
562    fn clean_docstring(text: &str) -> String {
563        // Remove quotes
564        let text = text
565            .trim_start_matches("\"\"\"")
566            .trim_start_matches("'''")
567            .trim_end_matches("\"\"\"")
568            .trim_end_matches("'''")
569            .trim();
570
571        text.to_string()
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578
579    fn parse_py(source: &str) -> (Tree, String) {
580        let mut parser = tree_sitter::Parser::new();
581        parser
582            .set_language(&tree_sitter_python::LANGUAGE.into())
583            .unwrap();
584        let tree = parser.parse(source, None).unwrap();
585        (tree, source.to_string())
586    }
587
588    #[test]
589    fn test_extract_function() {
590        let source = r#"
591def greet(name: str) -> str:
592    """Greet someone."""
593    return f"Hello, {name}!"
594"#;
595        let (tree, src) = parse_py(source);
596        let extractor = PythonExtractor;
597        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
598
599        assert_eq!(symbols.len(), 1);
600        assert_eq!(symbols[0].name, "greet");
601        assert_eq!(symbols[0].kind, SymbolKind::Function);
602    }
603
604    #[test]
605    fn test_extract_class() {
606        let source = r#"
607class UserService:
608    """A service for managing users."""
609
610    def __init__(self, name: str):
611        self.name = name
612
613    def greet(self) -> str:
614        return f"Hello, {self.name}!"
615
616    @staticmethod
617    def create():
618        return UserService("default")
619"#;
620        let (tree, src) = parse_py(source);
621        let extractor = PythonExtractor;
622        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
623
624        assert!(symbols
625            .iter()
626            .any(|s| s.name == "UserService" && s.kind == SymbolKind::Class));
627        assert!(symbols
628            .iter()
629            .any(|s| s.name == "__init__" && s.kind == SymbolKind::Method));
630        assert!(symbols
631            .iter()
632            .any(|s| s.name == "greet" && s.kind == SymbolKind::Method));
633        assert!(symbols
634            .iter()
635            .any(|s| s.name == "create" && s.kind == SymbolKind::Method));
636    }
637
638    #[test]
639    fn test_extract_async_function() {
640        let source = r#"
641async def fetch_data(url: str) -> dict:
642    """Fetch data from URL."""
643    pass
644"#;
645        let (tree, src) = parse_py(source);
646        let extractor = PythonExtractor;
647        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
648
649        assert_eq!(symbols.len(), 1);
650        assert_eq!(symbols[0].name, "fetch_data");
651        assert!(symbols[0].is_async);
652    }
653
654    #[test]
655    fn test_decorated_function_definition_start_line() {
656        let source = r#"
657@decorator1
658@decorator2
659def my_function():
660    pass
661"#;
662        let (tree, src) = parse_py(source);
663        let extractor = PythonExtractor;
664        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
665
666        assert_eq!(symbols.len(), 1);
667        assert_eq!(symbols[0].name, "my_function");
668        // definition_start_line should be line 2 (first decorator @decorator1)
669        assert_eq!(symbols[0].definition_start_line, Some(2));
670        // start_line should be line 4 (the actual def line)
671        assert_eq!(symbols[0].start_line, 4);
672    }
673
674    #[test]
675    fn test_non_decorated_function_definition_start_line() {
676        let source = r#"
677def simple_function():
678    pass
679"#;
680        let (tree, src) = parse_py(source);
681        let extractor = PythonExtractor;
682        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
683
684        assert_eq!(symbols.len(), 1);
685        // For non-decorated functions, definition_start_line equals start_line
686        assert_eq!(symbols[0].definition_start_line, Some(2));
687        assert_eq!(symbols[0].start_line, 2);
688    }
689}