Skip to main content

aptu_coder_core/languages/
python.rs

1// SPDX-FileCopyrightText: 2026 aptu-coder contributors
2// SPDX-License-Identifier: Apache-2.0
3/// Tree-sitter query for extracting Python elements (functions and classes).
4pub const ELEMENT_QUERY: &str = r"
5[(decorated_definition
6   definition: (function_definition
7     name: (identifier) @func_name)) @function
8 (function_definition
9   name: (identifier) @func_name) @function]
10(class_definition
11  name: (identifier) @class_name) @class
12";
13
14/// Tree-sitter query for extracting function calls.
15pub const CALL_QUERY: &str = r"
16(call
17  function: (identifier) @call)
18(call
19  function: (attribute attribute: (identifier) @call))
20";
21
22/// Tree-sitter query for extracting type references.
23/// Python grammar has no `type_identifier` node; use `(type (identifier) @type_ref)`
24/// to capture type names in annotations and `generic_type` for parameterized types.
25pub const REFERENCE_QUERY: &str = r"
26(type (identifier) @type_ref)
27(generic_type (identifier) @type_ref)
28";
29
30/// Tree-sitter query for extracting Python imports.
31pub const IMPORT_QUERY: &str = r"
32(import_statement) @import_path
33(import_from_statement) @import_path
34";
35
36/// Tree-sitter query for extracting definition and use sites.
37pub const DEFUSE_QUERY: &str = r"
38(assignment left: (identifier) @write.assign)
39(assignment left: (pattern_list (identifier) @write.tuple))
40(assignment left: (tuple_pattern (identifier) @write.tuple))
41(assignment left: (list_pattern (identifier) @write.list))
42(augmented_assignment left: (identifier) @writeread.augmented)
43(named_expression name: (identifier) @write.named)
44(identifier) @read.usage
45";
46
47use tree_sitter::Node;
48
49/// Extract inheritance information from a Python class node.
50#[must_use]
51pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
52    let mut inherits = Vec::new();
53
54    // Get superclasses field from class_definition
55    if let Some(superclasses) = node.child_by_field_name("superclasses") {
56        // superclasses contains an argument_list
57        for i in 0..superclasses.named_child_count() {
58            if let Some(child) = superclasses.named_child(u32::try_from(i).unwrap_or(u32::MAX))
59                && matches!(child.kind(), "identifier" | "attribute")
60            {
61                let text = &source[child.start_byte()..child.end_byte()];
62                inherits.push(text.to_string());
63            }
64        }
65    }
66
67    inherits
68}
69
70#[cfg(all(test, feature = "lang-python"))]
71mod tests {
72    use super::*;
73    use crate::DefUseKind;
74    use crate::parser::SemanticExtractor;
75    use tree_sitter::{Parser, StreamingIterator};
76
77    fn parse_python(src: &str) -> tree_sitter::Tree {
78        let mut parser = Parser::new();
79        parser
80            .set_language(&tree_sitter_python::LANGUAGE.into())
81            .expect("Error loading Python language");
82        parser.parse(src, None).expect("Failed to parse Python")
83    }
84
85    #[test]
86    fn test_python_element_query_happy_path() {
87        // Arrange
88        let src = "def greet(name): pass\nclass Greeter:\n    pass\n";
89        let tree = parse_python(src);
90        let root = tree.root_node();
91
92        // Act
93        let query = tree_sitter::Query::new(&tree_sitter_python::LANGUAGE.into(), ELEMENT_QUERY)
94            .expect("ELEMENT_QUERY must be valid");
95        let mut cursor = tree_sitter::QueryCursor::new();
96        let mut matches = cursor.matches(&query, root, src.as_bytes());
97
98        let mut captured_classes: Vec<String> = Vec::new();
99        let mut captured_functions: Vec<String> = Vec::new();
100        while let Some(mat) = matches.next() {
101            for capture in mat.captures {
102                let name = query.capture_names()[capture.index as usize];
103                let node = capture.node;
104                match name {
105                    "class" => {
106                        if let Some(n) = node.child_by_field_name("name") {
107                            captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
108                        }
109                    }
110                    "function" => {
111                        if let Some(n) = node.child_by_field_name("name") {
112                            captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
113                        }
114                    }
115                    _ => {}
116                }
117            }
118        }
119
120        // Assert
121        assert!(
122            captured_classes.contains(&"Greeter".to_string()),
123            "expected Greeter class, got {:?}",
124            captured_classes
125        );
126        assert!(
127            captured_functions.contains(&"greet".to_string()),
128            "expected greet function, got {:?}",
129            captured_functions
130        );
131    }
132
133    #[test]
134    fn test_python_extract_inheritance() {
135        // Arrange
136        let src = "class Cat(Animal, Domestic): pass\n";
137        let tree = parse_python(src);
138        let root = tree.root_node();
139
140        // Act -- find class_definition node
141        let mut class_node: Option<tree_sitter::Node> = None;
142        let mut stack = vec![root];
143        while let Some(node) = stack.pop() {
144            if node.kind() == "class_definition" {
145                class_node = Some(node);
146                break;
147            }
148            for i in 0..node.child_count() {
149                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
150                    stack.push(child);
151                }
152            }
153        }
154        let class = class_node.expect("class_definition not found");
155        let bases = extract_inheritance(&class, src);
156
157        // Assert
158        assert!(
159            bases.contains(&"Animal".to_string()),
160            "expected Animal, got {:?}",
161            bases
162        );
163        assert!(
164            bases.contains(&"Domestic".to_string()),
165            "expected Domestic, got {:?}",
166            bases
167        );
168    }
169
170    #[test]
171    fn test_defuse_query_write_site() {
172        // Arrange
173        let src = "x = 1\n";
174        let sites =
175            SemanticExtractor::extract_def_use_for_file(src, "python", "x", "test.py", None);
176        assert!(!sites.is_empty(), "defuse sites should not be empty");
177        let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
178        assert!(has_write, "should contain a Write DefUseSite");
179    }
180
181    #[test]
182    fn test_defuse_python_augmented_assignment() {
183        // Arrange: augmented assignment += is WriteRead
184        let src = "x = 1\nx += 2\n";
185        // Act
186        let sites =
187            SemanticExtractor::extract_def_use_for_file(src, "python", "x", "test.py", None);
188        // Assert
189        assert!(
190            !sites.is_empty(),
191            "augmented assignment should produce defuse sites"
192        );
193        let has_writeread = sites
194            .iter()
195            .any(|s| matches!(s.kind, DefUseKind::WriteRead));
196        assert!(has_writeread, "augmented assignment should be WriteRead");
197    }
198
199    #[test]
200    fn test_defuse_python_tuple_unpack() {
201        // Arrange: tuple unpack captures all LHS identifiers as Write
202        let src = "a, b = (1, 2)\n";
203        // Act
204        let sites_a =
205            SemanticExtractor::extract_def_use_for_file(src, "python", "a", "test.py", None);
206        let sites_b =
207            SemanticExtractor::extract_def_use_for_file(src, "python", "b", "test.py", None);
208        // Assert
209        assert!(
210            !sites_a.is_empty(),
211            "tuple unpack a should produce defuse sites"
212        );
213        assert!(
214            !sites_b.is_empty(),
215            "tuple unpack b should produce defuse sites"
216        );
217        let a_write = sites_a.iter().any(|s| matches!(s.kind, DefUseKind::Write));
218        let b_write = sites_b.iter().any(|s| matches!(s.kind, DefUseKind::Write));
219        assert!(a_write, "tuple unpack a should be Write");
220        assert!(b_write, "tuple unpack b should be Write");
221    }
222
223    #[test]
224    fn test_defuse_python_walrus() {
225        // Arrange: walrus operator := is Write (named_expression)
226        let src = "if (x := 42):\n    pass\n";
227        // Act
228        let sites =
229            SemanticExtractor::extract_def_use_for_file(src, "python", "x", "test.py", None);
230        // Assert
231        assert!(
232            !sites.is_empty(),
233            "walrus operator should produce defuse sites"
234        );
235        let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
236        assert!(has_write, "walrus operator should be Write");
237    }
238}