Skip to main content

aptu_coder_core/languages/
python.rs

1// SPDX-FileCopyrightText: 2026 aptu-coder contributors
2// SPDX-License-Identifier: Apache-2.0
3/// Tree-sitter query for extracting Python elements (functions and classes).
4pub const ELEMENT_QUERY: &str = r"
5[(decorated_definition
6   definition: (function_definition
7     name: (identifier) @func_name)) @function
8 (function_definition
9   name: (identifier) @func_name) @function]
10(class_definition
11  name: (identifier) @class_name) @class
12";
13
14/// Tree-sitter query for extracting function calls.
15pub const CALL_QUERY: &str = r"
16(call
17  function: (identifier) @call)
18(call
19  function: (attribute attribute: (identifier) @call))
20";
21
22/// Tree-sitter query for extracting type references.
23/// Python grammar has no `type_identifier` node; use `(type (identifier) @type_ref)`
24/// to capture type names in annotations and `generic_type` for parameterized types.
25pub const REFERENCE_QUERY: &str = r"
26(type (identifier) @type_ref)
27(generic_type (identifier) @type_ref)
28";
29
30/// Tree-sitter query for extracting Python imports.
31pub const IMPORT_QUERY: &str = r"
32(import_statement) @import_path
33(import_from_statement) @import_path
34";
35
36/// Tree-sitter query for extracting definition and use sites.
37pub const DEFUSE_QUERY: &str = r"
38(assignment left: (identifier) @write.assign)
39(augmented_assignment left: (identifier) @writeread.augmented)
40(named_expression name: (identifier) @write.named)
41(identifier) @read.usage
42";
43
44use tree_sitter::Node;
45
46/// Extract inheritance information from a Python class node.
47#[must_use]
48pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
49    let mut inherits = Vec::new();
50
51    // Get superclasses field from class_definition
52    if let Some(superclasses) = node.child_by_field_name("superclasses") {
53        // superclasses contains an argument_list
54        for i in 0..superclasses.named_child_count() {
55            if let Some(child) = superclasses.named_child(u32::try_from(i).unwrap_or(u32::MAX))
56                && matches!(child.kind(), "identifier" | "attribute")
57            {
58                let text = &source[child.start_byte()..child.end_byte()];
59                inherits.push(text.to_string());
60            }
61        }
62    }
63
64    inherits
65}
66
67#[cfg(all(test, feature = "lang-python"))]
68mod tests {
69    use super::*;
70    use crate::DefUseKind;
71    use crate::parser::SemanticExtractor;
72    use tree_sitter::{Parser, StreamingIterator};
73
74    fn parse_python(src: &str) -> tree_sitter::Tree {
75        let mut parser = Parser::new();
76        parser
77            .set_language(&tree_sitter_python::LANGUAGE.into())
78            .expect("Error loading Python language");
79        parser.parse(src, None).expect("Failed to parse Python")
80    }
81
82    #[test]
83    fn test_python_element_query_happy_path() {
84        // Arrange
85        let src = "def greet(name): pass\nclass Greeter:\n    pass\n";
86        let tree = parse_python(src);
87        let root = tree.root_node();
88
89        // Act
90        let query = tree_sitter::Query::new(&tree_sitter_python::LANGUAGE.into(), ELEMENT_QUERY)
91            .expect("ELEMENT_QUERY must be valid");
92        let mut cursor = tree_sitter::QueryCursor::new();
93        let mut matches = cursor.matches(&query, root, src.as_bytes());
94
95        let mut captured_classes: Vec<String> = Vec::new();
96        let mut captured_functions: Vec<String> = Vec::new();
97        while let Some(mat) = matches.next() {
98            for capture in mat.captures {
99                let name = query.capture_names()[capture.index as usize];
100                let node = capture.node;
101                match name {
102                    "class" => {
103                        if let Some(n) = node.child_by_field_name("name") {
104                            captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
105                        }
106                    }
107                    "function" => {
108                        if let Some(n) = node.child_by_field_name("name") {
109                            captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
110                        }
111                    }
112                    _ => {}
113                }
114            }
115        }
116
117        // Assert
118        assert!(
119            captured_classes.contains(&"Greeter".to_string()),
120            "expected Greeter class, got {:?}",
121            captured_classes
122        );
123        assert!(
124            captured_functions.contains(&"greet".to_string()),
125            "expected greet function, got {:?}",
126            captured_functions
127        );
128    }
129
130    #[test]
131    fn test_python_extract_inheritance() {
132        // Arrange
133        let src = "class Cat(Animal, Domestic): pass\n";
134        let tree = parse_python(src);
135        let root = tree.root_node();
136
137        // Act -- find class_definition node
138        let mut class_node: Option<tree_sitter::Node> = None;
139        let mut stack = vec![root];
140        while let Some(node) = stack.pop() {
141            if node.kind() == "class_definition" {
142                class_node = Some(node);
143                break;
144            }
145            for i in 0..node.child_count() {
146                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
147                    stack.push(child);
148                }
149            }
150        }
151        let class = class_node.expect("class_definition not found");
152        let bases = extract_inheritance(&class, src);
153
154        // Assert
155        assert!(
156            bases.contains(&"Animal".to_string()),
157            "expected Animal, got {:?}",
158            bases
159        );
160        assert!(
161            bases.contains(&"Domestic".to_string()),
162            "expected Domestic, got {:?}",
163            bases
164        );
165    }
166
167    #[test]
168    fn test_defuse_query_write_site() {
169        // Arrange
170        let src = "x = 1\n";
171        let sites =
172            SemanticExtractor::extract_def_use_for_file(src, "python", "x", "test.py", None);
173        assert!(!sites.is_empty(), "defuse sites should not be empty");
174        let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
175        assert!(has_write, "should contain a Write DefUseSite");
176    }
177}