Skip to main content

aptu_coder_core/languages/
java.rs

1// SPDX-FileCopyrightText: 2026 aptu-coder contributors
2// SPDX-License-Identifier: Apache-2.0
3/// Tree-sitter query for extracting Java elements (methods and classes).
4pub const ELEMENT_QUERY: &str = r"
5(method_declaration
6  name: (identifier) @method_name) @function
7(class_declaration
8  name: (identifier) @class_name) @class
9(interface_declaration
10  name: (identifier) @interface_name) @class
11(enum_declaration
12  name: (identifier) @enum_name) @class
13";
14
15/// Tree-sitter query for extracting function calls.
16pub const CALL_QUERY: &str = r"
17(method_invocation
18  name: (identifier) @call)
19";
20
21/// Tree-sitter query for extracting type references.
22pub const REFERENCE_QUERY: &str = r"
23(type_identifier) @type_ref
24";
25
26/// Tree-sitter query for extracting Java imports.
27pub const IMPORT_QUERY: &str = r"
28(import_declaration) @import_path
29";
30
31/// Tree-sitter query for extracting definition and use sites.
32pub const DEFUSE_QUERY: &str = r"
33(local_variable_declaration declarator: (variable_declarator name: (identifier) @write.local))
34(assignment_expression left: (identifier) @write.assign)
35(update_expression (identifier) @writeread.update)
36(identifier) @read.usage
37";
38
39use tree_sitter::Node;
40
41/// Extract function name from a Java method declaration.
42#[must_use]
43pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
44    if node.kind() != "method_declaration" {
45        return None;
46    }
47    node.child_by_field_name("name").and_then(|n| {
48        let end = n.end_byte();
49        if end <= source.len() {
50            Some(source[n.start_byte()..end].to_string())
51        } else {
52            None
53        }
54    })
55}
56
57/// Find receiver type (enclosing class/interface/enum) for a Java method.
58#[must_use]
59pub fn find_receiver_type(node: &Node, source: &str) -> Option<String> {
60    if node.kind() != "method_declaration" {
61        return None;
62    }
63
64    // Walk ancestors to find enclosing class, interface, or enum
65    let mut current = *node;
66    while let Some(parent) = current.parent() {
67        match parent.kind() {
68            "class_declaration" | "interface_declaration" | "enum_declaration" => {
69                // Found the enclosing type, extract its name
70                return parent.child_by_field_name("name").and_then(|n| {
71                    let end = n.end_byte();
72                    if end <= source.len() {
73                        Some(source[n.start_byte()..end].to_string())
74                    } else {
75                        None
76                    }
77                });
78            }
79            _ => {
80                current = parent;
81            }
82        }
83    }
84
85    None
86}
87
88/// Find method name when inside a class/interface/enum body.
89#[must_use]
90pub fn find_method_for_receiver(
91    node: &Node,
92    source: &str,
93    _depth: Option<usize>,
94) -> Option<String> {
95    if node.kind() != "method_declaration" {
96        return None;
97    }
98
99    // Verify that the method is inside a class, interface, or enum
100    let mut current = *node;
101    let mut in_type_body = false;
102    while let Some(parent) = current.parent() {
103        match parent.kind() {
104            "class_declaration" | "interface_declaration" | "enum_declaration" => {
105                in_type_body = true;
106                break;
107            }
108            _ => {
109                current = parent;
110            }
111        }
112    }
113
114    if !in_type_body {
115        return None;
116    }
117
118    // Return the method name
119    node.child_by_field_name("name").and_then(|n| {
120        let end = n.end_byte();
121        if end <= source.len() {
122            Some(source[n.start_byte()..end].to_string())
123        } else {
124            None
125        }
126    })
127}
128
129/// Extract inheritance information from a Java class node.
130#[must_use]
131pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
132    let mut inherits = Vec::new();
133
134    // Extract superclass (extends)
135    if let Some(superclass) = node.child_by_field_name("superclass") {
136        for i in 0..superclass.named_child_count() {
137            if let Some(child) = superclass.named_child(u32::try_from(i).unwrap_or(u32::MAX))
138                && child.kind() == "type_identifier"
139            {
140                let text = &source[child.start_byte()..child.end_byte()];
141                inherits.push(format!("extends {text}"));
142            }
143        }
144    }
145
146    // Extract interfaces (implements)
147    if let Some(interfaces) = node.child_by_field_name("interfaces") {
148        for i in 0..interfaces.named_child_count() {
149            if let Some(type_list) = interfaces.named_child(u32::try_from(i).unwrap_or(u32::MAX)) {
150                for j in 0..type_list.named_child_count() {
151                    if let Some(type_node) =
152                        type_list.named_child(u32::try_from(j).unwrap_or(u32::MAX))
153                        && type_node.kind() == "type_identifier"
154                    {
155                        let text = &source[type_node.start_byte()..type_node.end_byte()];
156                        inherits.push(format!("implements {text}"));
157                    }
158                }
159            }
160        }
161    }
162
163    inherits
164}
165
166#[cfg(all(test, feature = "lang-java"))]
167mod tests {
168    use super::*;
169    use crate::DefUseKind;
170    use crate::parser::SemanticExtractor;
171    use tree_sitter::{Parser, StreamingIterator};
172
173    fn parse_java(src: &str) -> tree_sitter::Tree {
174        let mut parser = Parser::new();
175        parser
176            .set_language(&tree_sitter_java::LANGUAGE.into())
177            .expect("Error loading Java language");
178        parser.parse(src, None).expect("Failed to parse Java")
179    }
180
181    #[test]
182    fn test_extract_function_name() {
183        // Arrange: method inside a class
184        let src = "class C { void foo() {} }";
185        let tree = parse_java(src);
186        let root = tree.root_node();
187
188        // Find method_declaration node using stack traversal
189        let mut method_node = None;
190        let mut stack = vec![root];
191        while let Some(node) = stack.pop() {
192            if node.kind() == "method_declaration" {
193                method_node = Some(node);
194                break;
195            }
196            for i in 0..node.child_count() {
197                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
198                    stack.push(child);
199                }
200            }
201        }
202        let method_node = method_node.expect("expected method_declaration");
203
204        // Act
205        let result = extract_function_name(&method_node, src, "java");
206
207        // Assert
208        assert_eq!(result, Some("foo".to_string()));
209    }
210
211    #[test]
212    fn test_find_receiver_type() {
213        // Arrange: method inside a class
214        let src = "class MyClass { void bar() {} }";
215        let tree = parse_java(src);
216        let root = tree.root_node();
217
218        // Find method_declaration node using stack traversal
219        let mut method_node = None;
220        let mut stack = vec![root];
221        while let Some(node) = stack.pop() {
222            if node.kind() == "method_declaration" {
223                method_node = Some(node);
224                break;
225            }
226            for i in 0..node.child_count() {
227                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
228                    stack.push(child);
229                }
230            }
231        }
232        let method_node = method_node.expect("expected method_declaration");
233
234        // Act
235        let result = find_receiver_type(&method_node, src);
236
237        // Assert
238        assert_eq!(result, Some("MyClass".to_string()));
239    }
240
241    #[test]
242    fn test_find_method_for_receiver() {
243        // Arrange: method inside a class
244        let src = "class C { void baz() {} }";
245        let tree = parse_java(src);
246        let root = tree.root_node();
247
248        // Find method_declaration node using stack traversal
249        let mut method_node = None;
250        let mut stack = vec![root];
251        while let Some(node) = stack.pop() {
252            if node.kind() == "method_declaration" {
253                method_node = Some(node);
254                break;
255            }
256            for i in 0..node.child_count() {
257                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
258                    stack.push(child);
259                }
260            }
261        }
262        let method_node = method_node.expect("expected method_declaration");
263
264        // Act
265        let result = find_method_for_receiver(&method_node, src, None);
266
267        // Assert
268        assert_eq!(result, Some("baz".to_string()));
269    }
270
271    #[test]
272    fn test_java_element_query_happy_path() {
273        // Arrange
274        let src = "class Animal { void eat() {} }";
275        let tree = parse_java(src);
276        let root = tree.root_node();
277
278        // Act -- verify ELEMENT_QUERY compiles and matches class + method
279        let query = tree_sitter::Query::new(&tree_sitter_java::LANGUAGE.into(), ELEMENT_QUERY)
280            .expect("ELEMENT_QUERY must be valid");
281        let mut cursor = tree_sitter::QueryCursor::new();
282        let mut matches = cursor.matches(&query, root, src.as_bytes());
283
284        let mut captured_classes: Vec<String> = Vec::new();
285        let mut captured_functions: Vec<String> = Vec::new();
286        while let Some(mat) = matches.next() {
287            for capture in mat.captures {
288                let name = query.capture_names()[capture.index as usize];
289                let node = capture.node;
290                match name {
291                    "class" => {
292                        if let Some(n) = node.child_by_field_name("name") {
293                            captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
294                        }
295                    }
296                    "function" => {
297                        if let Some(n) = node.child_by_field_name("name") {
298                            captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
299                        }
300                    }
301                    _ => {}
302                }
303            }
304        }
305
306        // Assert
307        assert!(
308            captured_classes.contains(&"Animal".to_string()),
309            "expected Animal class, got {:?}",
310            captured_classes
311        );
312        assert!(
313            captured_functions.contains(&"eat".to_string()),
314            "expected eat function, got {:?}",
315            captured_functions
316        );
317    }
318
319    #[test]
320    fn test_java_extract_inheritance() {
321        // Arrange
322        let src = "class Dog extends Animal implements ICanRun, ICanSwim {}";
323        let tree = parse_java(src);
324        let root = tree.root_node();
325
326        // Act -- find the class_declaration node and call extract_inheritance
327        let mut class_node: Option<tree_sitter::Node> = None;
328        let mut stack = vec![root];
329        while let Some(node) = stack.pop() {
330            if node.kind() == "class_declaration" {
331                class_node = Some(node);
332                break;
333            }
334            for i in 0..node.child_count() {
335                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
336                    stack.push(child);
337                }
338            }
339        }
340        let class = class_node.expect("class_declaration not found");
341        let bases = extract_inheritance(&class, src);
342
343        // Assert
344        assert!(
345            bases.iter().any(|b| b.contains("Animal")),
346            "expected extends Animal, got {:?}",
347            bases
348        );
349        assert!(
350            bases.iter().any(|b| b.contains("ICanRun")),
351            "expected implements ICanRun, got {:?}",
352            bases
353        );
354        assert!(
355            bases.iter().any(|b| b.contains("ICanSwim")),
356            "expected implements ICanSwim, got {:?}",
357            bases
358        );
359    }
360
361    #[test]
362    fn test_defuse_query_write_site() {
363        // Arrange
364        let src = "class C { void m() { int z = 5; } }\n";
365        let sites =
366            SemanticExtractor::extract_def_use_for_file(src, "java", "z", "test.java", None);
367        assert!(!sites.is_empty(), "defuse sites should not be empty");
368        let has_write = sites.iter().any(|s| matches!(s.kind, DefUseKind::Write));
369        assert!(has_write, "should contain a Write DefUseSite");
370    }
371}