aptu_coder_core/languages/
python.rs1pub const ELEMENT_QUERY: &str = r"
5[(decorated_definition
6 definition: (function_definition
7 name: (identifier) @func_name)) @function
8 (function_definition
9 name: (identifier) @func_name) @function]
10(class_definition
11 name: (identifier) @class_name) @class
12";
13
14pub const CALL_QUERY: &str = r"
16(call
17 function: (identifier) @call)
18(call
19 function: (attribute attribute: (identifier) @call))
20";
21
22pub const REFERENCE_QUERY: &str = r"
26(type (identifier) @type_ref)
27(generic_type (identifier) @type_ref)
28";
29
30pub const IMPORT_QUERY: &str = r"
32(import_statement) @import_path
33(import_from_statement) @import_path
34";
35
36pub const DEFUSE_QUERY: &str = r"
38(assignment left: (identifier) @write.assign)
39(augmented_assignment left: (identifier) @writeread.augmented)
40(named_expression name: (identifier) @write.named)
41(identifier) @read.usage
42";
43
44use tree_sitter::Node;
45
46#[must_use]
48pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
49 let mut inherits = Vec::new();
50
51 if let Some(superclasses) = node.child_by_field_name("superclasses") {
53 for i in 0..superclasses.named_child_count() {
55 if let Some(child) = superclasses.named_child(u32::try_from(i).unwrap_or(u32::MAX))
56 && matches!(child.kind(), "identifier" | "attribute")
57 {
58 let text = &source[child.start_byte()..child.end_byte()];
59 inherits.push(text.to_string());
60 }
61 }
62 }
63
64 inherits
65}
66
67#[cfg(all(test, feature = "lang-python"))]
68mod tests {
69 use super::*;
70 use crate::DefUseKind;
71 use crate::parser::SemanticExtractor;
72 use tree_sitter::{Parser, StreamingIterator};
73
74 fn parse_python(src: &str) -> tree_sitter::Tree {
75 let mut parser = Parser::new();
76 parser
77 .set_language(&tree_sitter_python::LANGUAGE.into())
78 .expect("Error loading Python language");
79 parser.parse(src, None).expect("Failed to parse Python")
80 }
81
82 #[test]
83 fn test_python_element_query_happy_path() {
84 let src = "def greet(name): pass\nclass Greeter:\n pass\n";
86 let tree = parse_python(src);
87 let root = tree.root_node();
88
89 let query = tree_sitter::Query::new(&tree_sitter_python::LANGUAGE.into(), ELEMENT_QUERY)
91 .expect("ELEMENT_QUERY must be valid");
92 let mut cursor = tree_sitter::QueryCursor::new();
93 let mut matches = cursor.matches(&query, root, src.as_bytes());
94
95 let mut captured_classes: Vec<String> = Vec::new();
96 let mut captured_functions: Vec<String> = Vec::new();
97 while let Some(mat) = matches.next() {
98 for capture in mat.captures {
99 let name = query.capture_names()[capture.index as usize];
100 let node = capture.node;
101 match name {
102 "class" => {
103 if let Some(n) = node.child_by_field_name("name") {
104 captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
105 }
106 }
107 "function" => {
108 if let Some(n) = node.child_by_field_name("name") {
109 captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
110 }
111 }
112 _ => {}
113 }
114 }
115 }
116
117 assert!(
119 captured_classes.contains(&"Greeter".to_string()),
120 "expected Greeter class, got {:?}",
121 captured_classes
122 );
123 assert!(
124 captured_functions.contains(&"greet".to_string()),
125 "expected greet function, got {:?}",
126 captured_functions
127 );
128 }
129
130 #[test]
131 fn test_python_extract_inheritance() {
132 let src = "class Cat(Animal, Domestic): pass\n";
134 let tree = parse_python(src);
135 let root = tree.root_node();
136
137 let mut class_node: Option<tree_sitter::Node> = None;
139 let mut stack = vec![root];
140 while let Some(node) = stack.pop() {
141 if node.kind() == "class_definition" {
142 class_node = Some(node);
143 break;
144 }
145 for i in 0..node.child_count() {
146 if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
147 stack.push(child);
148 }
149 }
150 }
151 let class = class_node.expect("class_definition not found");
152 let bases = extract_inheritance(&class, src);
153
154 assert!(
156 bases.contains(&"Animal".to_string()),
157 "expected Animal, got {:?}",
158 bases
159 );
160 assert!(
161 bases.contains(&"Domestic".to_string()),
162 "expected Domestic, got {:?}",
163 bases
164 );
165 }
166
167 #[test]
168 fn test_defuse_query_write_site() {
169 let src = "x = 1\n";
171 let sites =
172 SemanticExtractor::extract_def_use_for_file(src, "python", "x", "test.py", None);
173 assert!(!sites.is_empty(), "defuse sites should not be empty");
174 let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
175 assert!(has_write, "should contain a Write DefUseSite");
176 }
177}