Skip to main content

code_analyze_core/languages/
python.rs

1// SPDX-FileCopyrightText: 2026 code-analyze-mcp contributors
2// SPDX-License-Identifier: Apache-2.0
3/// Tree-sitter query for extracting Python elements (functions and classes).
4pub const ELEMENT_QUERY: &str = r"
5(function_definition
6  name: (identifier) @func_name) @function
7(class_definition
8  name: (identifier) @class_name) @class
9";
10
11/// Tree-sitter query for extracting function calls.
12pub const CALL_QUERY: &str = r"
13(call
14  function: (identifier) @call)
15(call
16  function: (attribute attribute: (identifier) @call))
17";
18
19/// Tree-sitter query for extracting type references.
20/// Python grammar has no `type_identifier` node; use `(type (identifier) @type_ref)`
21/// to capture type names in annotations and `generic_type` for parameterized types.
22pub const REFERENCE_QUERY: &str = r"
23(type (identifier) @type_ref)
24(generic_type (identifier) @type_ref)
25";
26
27/// Tree-sitter query for extracting Python imports.
28pub const IMPORT_QUERY: &str = r"
29(import_statement) @import_path
30(import_from_statement) @import_path
31";
32
33/// Tree-sitter query for extracting definition and use sites.
34pub const DEFUSE_QUERY: &str = r"
35(assignment left: (identifier) @write.assign)
36(augmented_assignment left: (identifier) @writeread.augmented)
37(named_expression name: (identifier) @write.named)
38(identifier) @read.usage
39";
40
41use tree_sitter::Node;
42
43/// Extract inheritance information from a Python class node.
44#[must_use]
45pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
46    let mut inherits = Vec::new();
47
48    // Get superclasses field from class_definition
49    if let Some(superclasses) = node.child_by_field_name("superclasses") {
50        // superclasses contains an argument_list
51        for i in 0..superclasses.named_child_count() {
52            if let Some(child) = superclasses.named_child(u32::try_from(i).unwrap_or(u32::MAX))
53                && matches!(child.kind(), "identifier" | "attribute")
54            {
55                let text = &source[child.start_byte()..child.end_byte()];
56                inherits.push(text.to_string());
57            }
58        }
59    }
60
61    inherits
62}
63
64#[cfg(all(test, feature = "lang-python"))]
65mod tests {
66    use super::*;
67    use crate::DefUseKind;
68    use crate::parser::SemanticExtractor;
69    use tree_sitter::{Parser, StreamingIterator};
70
71    fn parse_python(src: &str) -> tree_sitter::Tree {
72        let mut parser = Parser::new();
73        parser
74            .set_language(&tree_sitter_python::LANGUAGE.into())
75            .expect("Error loading Python language");
76        parser.parse(src, None).expect("Failed to parse Python")
77    }
78
79    #[test]
80    fn test_python_element_query_happy_path() {
81        // Arrange
82        let src = "def greet(name): pass\nclass Greeter:\n    pass\n";
83        let tree = parse_python(src);
84        let root = tree.root_node();
85
86        // Act
87        let query = tree_sitter::Query::new(&tree_sitter_python::LANGUAGE.into(), ELEMENT_QUERY)
88            .expect("ELEMENT_QUERY must be valid");
89        let mut cursor = tree_sitter::QueryCursor::new();
90        let mut matches = cursor.matches(&query, root, src.as_bytes());
91
92        let mut captured_classes: Vec<String> = Vec::new();
93        let mut captured_functions: Vec<String> = Vec::new();
94        while let Some(mat) = matches.next() {
95            for capture in mat.captures {
96                let name = query.capture_names()[capture.index as usize];
97                let node = capture.node;
98                match name {
99                    "class" => {
100                        if let Some(n) = node.child_by_field_name("name") {
101                            captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
102                        }
103                    }
104                    "function" => {
105                        if let Some(n) = node.child_by_field_name("name") {
106                            captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
107                        }
108                    }
109                    _ => {}
110                }
111            }
112        }
113
114        // Assert
115        assert!(
116            captured_classes.contains(&"Greeter".to_string()),
117            "expected Greeter class, got {:?}",
118            captured_classes
119        );
120        assert!(
121            captured_functions.contains(&"greet".to_string()),
122            "expected greet function, got {:?}",
123            captured_functions
124        );
125    }
126
127    #[test]
128    fn test_python_extract_inheritance() {
129        // Arrange
130        let src = "class Cat(Animal, Domestic): pass\n";
131        let tree = parse_python(src);
132        let root = tree.root_node();
133
134        // Act -- find class_definition node
135        let mut class_node: Option<tree_sitter::Node> = None;
136        let mut stack = vec![root];
137        while let Some(node) = stack.pop() {
138            if node.kind() == "class_definition" {
139                class_node = Some(node);
140                break;
141            }
142            for i in 0..node.child_count() {
143                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
144                    stack.push(child);
145                }
146            }
147        }
148        let class = class_node.expect("class_definition not found");
149        let bases = extract_inheritance(&class, src);
150
151        // Assert
152        assert!(
153            bases.contains(&"Animal".to_string()),
154            "expected Animal, got {:?}",
155            bases
156        );
157        assert!(
158            bases.contains(&"Domestic".to_string()),
159            "expected Domestic, got {:?}",
160            bases
161        );
162    }
163
164    #[test]
165    fn test_defuse_query_write_site() {
166        // Arrange
167        let src = "x = 1\n";
168        let sites =
169            SemanticExtractor::extract_def_use_for_file(src, "python", "x", "test.py", None);
170        assert!(!sites.is_empty(), "defuse sites should not be empty");
171        let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
172        assert!(has_write, "should contain a Write DefUseSite");
173    }
174}