Skip to main content

code_analyze_core/languages/
java.rs

1// SPDX-FileCopyrightText: 2026 code-analyze-mcp contributors
2// SPDX-License-Identifier: Apache-2.0
3/// Tree-sitter query for extracting Java elements (methods and classes).
4pub const ELEMENT_QUERY: &str = r"
5(method_declaration
6  name: (identifier) @method_name) @function
7(class_declaration
8  name: (identifier) @class_name) @class
9(interface_declaration
10  name: (identifier) @interface_name) @class
11(enum_declaration
12  name: (identifier) @enum_name) @class
13";
14
15/// Tree-sitter query for extracting function calls.
16pub const CALL_QUERY: &str = r"
17(method_invocation
18  name: (identifier) @call)
19";
20
21/// Tree-sitter query for extracting type references.
22pub const REFERENCE_QUERY: &str = r"
23(type_identifier) @type_ref
24";
25
26/// Tree-sitter query for extracting Java imports.
27pub const IMPORT_QUERY: &str = r"
28(import_declaration) @import_path
29";
30
31use tree_sitter::Node;
32
33/// Extract inheritance information from a Java class node.
34#[must_use]
35pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
36    let mut inherits = Vec::new();
37
38    // Extract superclass (extends)
39    if let Some(superclass) = node.child_by_field_name("superclass") {
40        for i in 0..superclass.named_child_count() {
41            if let Some(child) = superclass.named_child(u32::try_from(i).unwrap_or(u32::MAX))
42                && child.kind() == "type_identifier"
43            {
44                let text = &source[child.start_byte()..child.end_byte()];
45                inherits.push(format!("extends {text}"));
46            }
47        }
48    }
49
50    // Extract interfaces (implements)
51    if let Some(interfaces) = node.child_by_field_name("interfaces") {
52        for i in 0..interfaces.named_child_count() {
53            if let Some(type_list) = interfaces.named_child(u32::try_from(i).unwrap_or(u32::MAX)) {
54                for j in 0..type_list.named_child_count() {
55                    if let Some(type_node) =
56                        type_list.named_child(u32::try_from(j).unwrap_or(u32::MAX))
57                        && type_node.kind() == "type_identifier"
58                    {
59                        let text = &source[type_node.start_byte()..type_node.end_byte()];
60                        inherits.push(format!("implements {text}"));
61                    }
62                }
63            }
64        }
65    }
66
67    inherits
68}
69
70#[cfg(all(test, feature = "lang-java"))]
71mod tests {
72    use super::*;
73    use tree_sitter::{Parser, StreamingIterator};
74
75    fn parse_java(src: &str) -> tree_sitter::Tree {
76        let mut parser = Parser::new();
77        parser
78            .set_language(&tree_sitter_java::LANGUAGE.into())
79            .expect("Error loading Java language");
80        parser.parse(src, None).expect("Failed to parse Java")
81    }
82
83    #[test]
84    fn test_java_element_query_happy_path() {
85        // Arrange
86        let src = "class Animal { void eat() {} }";
87        let tree = parse_java(src);
88        let root = tree.root_node();
89
90        // Act -- verify ELEMENT_QUERY compiles and matches class + method
91        let query = tree_sitter::Query::new(&tree_sitter_java::LANGUAGE.into(), ELEMENT_QUERY)
92            .expect("ELEMENT_QUERY must be valid");
93        let mut cursor = tree_sitter::QueryCursor::new();
94        let mut matches = cursor.matches(&query, root, src.as_bytes());
95
96        let mut captured_classes: Vec<String> = Vec::new();
97        let mut captured_functions: Vec<String> = Vec::new();
98        while let Some(mat) = matches.next() {
99            for capture in mat.captures {
100                let name = query.capture_names()[capture.index as usize];
101                let node = capture.node;
102                match name {
103                    "class" => {
104                        if let Some(n) = node.child_by_field_name("name") {
105                            captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
106                        }
107                    }
108                    "function" => {
109                        if let Some(n) = node.child_by_field_name("name") {
110                            captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
111                        }
112                    }
113                    _ => {}
114                }
115            }
116        }
117
118        // Assert
119        assert!(
120            captured_classes.contains(&"Animal".to_string()),
121            "expected Animal class, got {:?}",
122            captured_classes
123        );
124        assert!(
125            captured_functions.contains(&"eat".to_string()),
126            "expected eat function, got {:?}",
127            captured_functions
128        );
129    }
130
131    #[test]
132    fn test_java_extract_inheritance() {
133        // Arrange
134        let src = "class Dog extends Animal implements ICanRun, ICanSwim {}";
135        let tree = parse_java(src);
136        let root = tree.root_node();
137
138        // Act -- find the class_declaration node and call extract_inheritance
139        let mut class_node: Option<tree_sitter::Node> = None;
140        let mut stack = vec![root];
141        while let Some(node) = stack.pop() {
142            if node.kind() == "class_declaration" {
143                class_node = Some(node);
144                break;
145            }
146            for i in 0..node.child_count() {
147                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
148                    stack.push(child);
149                }
150            }
151        }
152        let class = class_node.expect("class_declaration not found");
153        let bases = extract_inheritance(&class, src);
154
155        // Assert
156        assert!(
157            bases.iter().any(|b| b.contains("Animal")),
158            "expected extends Animal, got {:?}",
159            bases
160        );
161        assert!(
162            bases.iter().any(|b| b.contains("ICanRun")),
163            "expected implements ICanRun, got {:?}",
164            bases
165        );
166        assert!(
167            bases.iter().any(|b| b.contains("ICanSwim")),
168            "expected implements ICanSwim, got {:?}",
169            bases
170        );
171    }
172}