Skip to main content

code_analyze_core/languages/
python.rs

1// SPDX-FileCopyrightText: 2026 code-analyze-mcp contributors
2// SPDX-License-Identifier: Apache-2.0
3/// Tree-sitter query for extracting Python elements (functions and classes).
4pub const ELEMENT_QUERY: &str = r"
5(function_definition
6  name: (identifier) @func_name) @function
7(class_definition
8  name: (identifier) @class_name) @class
9";
10
11/// Tree-sitter query for extracting function calls.
12pub const CALL_QUERY: &str = r"
13(call
14  function: (identifier) @call)
15(call
16  function: (attribute attribute: (identifier) @call))
17";
18
19/// Tree-sitter query for extracting type references.
20/// Python grammar has no `type_identifier` node; use `(type (identifier) @type_ref)`
21/// to capture type names in annotations and `generic_type` for parameterized types.
22pub const REFERENCE_QUERY: &str = r"
23(type (identifier) @type_ref)
24(generic_type (identifier) @type_ref)
25";
26
27/// Tree-sitter query for extracting Python imports.
28pub const IMPORT_QUERY: &str = r"
29(import_statement) @import_path
30(import_from_statement) @import_path
31";
32
33use tree_sitter::Node;
34
35/// Extract inheritance information from a Python class node.
36#[must_use]
37pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
38    let mut inherits = Vec::new();
39
40    // Get superclasses field from class_definition
41    if let Some(superclasses) = node.child_by_field_name("superclasses") {
42        // superclasses contains an argument_list
43        for i in 0..superclasses.named_child_count() {
44            if let Some(child) = superclasses.named_child(u32::try_from(i).unwrap_or(u32::MAX))
45                && matches!(child.kind(), "identifier" | "attribute")
46            {
47                let text = &source[child.start_byte()..child.end_byte()];
48                inherits.push(text.to_string());
49            }
50        }
51    }
52
53    inherits
54}
55
56#[cfg(all(test, feature = "lang-python"))]
57mod tests {
58    use super::*;
59    use tree_sitter::{Parser, StreamingIterator};
60
61    fn parse_python(src: &str) -> tree_sitter::Tree {
62        let mut parser = Parser::new();
63        parser
64            .set_language(&tree_sitter_python::LANGUAGE.into())
65            .expect("Error loading Python language");
66        parser.parse(src, None).expect("Failed to parse Python")
67    }
68
69    #[test]
70    fn test_python_element_query_happy_path() {
71        // Arrange
72        let src = "def greet(name): pass\nclass Greeter:\n    pass\n";
73        let tree = parse_python(src);
74        let root = tree.root_node();
75
76        // Act
77        let query = tree_sitter::Query::new(&tree_sitter_python::LANGUAGE.into(), ELEMENT_QUERY)
78            .expect("ELEMENT_QUERY must be valid");
79        let mut cursor = tree_sitter::QueryCursor::new();
80        let mut matches = cursor.matches(&query, root, src.as_bytes());
81
82        let mut captured_classes: Vec<String> = Vec::new();
83        let mut captured_functions: Vec<String> = Vec::new();
84        while let Some(mat) = matches.next() {
85            for capture in mat.captures {
86                let name = query.capture_names()[capture.index as usize];
87                let node = capture.node;
88                match name {
89                    "class" => {
90                        if let Some(n) = node.child_by_field_name("name") {
91                            captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
92                        }
93                    }
94                    "function" => {
95                        if let Some(n) = node.child_by_field_name("name") {
96                            captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
97                        }
98                    }
99                    _ => {}
100                }
101            }
102        }
103
104        // Assert
105        assert!(
106            captured_classes.contains(&"Greeter".to_string()),
107            "expected Greeter class, got {:?}",
108            captured_classes
109        );
110        assert!(
111            captured_functions.contains(&"greet".to_string()),
112            "expected greet function, got {:?}",
113            captured_functions
114        );
115    }
116
117    #[test]
118    fn test_python_extract_inheritance() {
119        // Arrange
120        let src = "class Cat(Animal, Domestic): pass\n";
121        let tree = parse_python(src);
122        let root = tree.root_node();
123
124        // Act -- find class_definition node
125        let mut class_node: Option<tree_sitter::Node> = None;
126        let mut stack = vec![root];
127        while let Some(node) = stack.pop() {
128            if node.kind() == "class_definition" {
129                class_node = Some(node);
130                break;
131            }
132            for i in 0..node.child_count() {
133                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
134                    stack.push(child);
135                }
136            }
137        }
138        let class = class_node.expect("class_definition not found");
139        let bases = extract_inheritance(&class, src);
140
141        // Assert
142        assert!(
143            bases.contains(&"Animal".to_string()),
144            "expected Animal, got {:?}",
145            bases
146        );
147        assert!(
148            bases.contains(&"Domestic".to_string()),
149            "expected Domestic, got {:?}",
150            bases
151        );
152    }
153}