Skip to main content

aptu_coder_core/languages/
kotlin.rs

1// SPDX-FileCopyrightText: 2026 aptu-coder contributors
2// SPDX-License-Identifier: Apache-2.0
3
4/// Tree-sitter query for extracting Kotlin elements (functions and classes).
5pub const ELEMENT_QUERY: &str = r"
6(function_declaration
7  name: (identifier) @function_name) @function
8(class_declaration
9  name: (identifier) @class_name) @class
10(object_declaration
11  name: (identifier) @object_name) @class
12";
13
14/// Tree-sitter query for extracting function calls.
15pub const CALL_QUERY: &str = r"
16(call_expression
17  (identifier) @call)
18";
19
20/// Tree-sitter query for extracting type references.
21pub const REFERENCE_QUERY: &str = r"
22(identifier) @type_ref
23";
24
25/// Tree-sitter query for extracting Kotlin imports.
26pub const IMPORT_QUERY: &str = r"
27(import) @import_path
28";
29
30/// Tree-sitter query for extracting definition and use sites.
31pub const DEFUSE_QUERY: &str = r"
32(property_declaration
33  name: (simple_identifier) @write.property)
34(simple_identifier) @read.usage
35";
36
37use tree_sitter::Node;
38
39use crate::languages::get_node_text;
40
41/// Extract inheritance information from a Kotlin class node.
42#[must_use]
43pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
44    let mut inherits = Vec::new();
45
46    // Find the delegation_specifiers child of the class node.
47    // Grammar: optional(seq(':', $.delegation_specifiers))
48    let Some(delegation) = (0..node.child_count())
49        .filter_map(|i| node.child(u32::try_from(i).ok()?))
50        .find(|n| n.kind() == "delegation_specifiers")
51    else {
52        return inherits;
53    };
54
55    // Each delegation_specifier holds either a constructor_invocation (superclass)
56    // or a user_type (interface).
57    for spec in (0..delegation.child_count())
58        .filter_map(|j| delegation.child(u32::try_from(j).ok()?))
59        .filter(|n| n.kind() == "delegation_specifier")
60    {
61        for spec_child in (0..spec.child_count()).filter_map(|k| spec.child(u32::try_from(k).ok()?))
62        {
63            match spec_child.kind() {
64                "constructor_invocation" => {
65                    // Superclass: constructor_invocation = type + value_arguments.
66                    // The first child is the type node.
67                    if let Some(type_node) = spec_child.child(0)
68                        && let Some(text) = get_node_text(&type_node, source)
69                    {
70                        inherits.push(format!("extends {text}"));
71                    }
72                }
73                "type" | "user_type" => {
74                    // Interface: direct type without constructor call.
75                    if let Some(text) = get_node_text(&spec_child, source) {
76                        inherits.push(format!("implements {text}"));
77                    }
78                }
79                _ => {}
80            }
81        }
82    }
83
84    inherits
85}
86
87#[cfg(all(test, feature = "lang-kotlin"))]
88mod tests {
89    use super::*;
90    use tree_sitter::{Parser, StreamingIterator};
91
92    fn parse_kotlin(src: &str) -> tree_sitter::Tree {
93        let mut parser = Parser::new();
94        parser
95            .set_language(&tree_sitter_kotlin_ng::LANGUAGE.into())
96            .expect("Error loading Kotlin language");
97        parser.parse(src, None).expect("Failed to parse Kotlin")
98    }
99
100    #[test]
101    fn test_element_query_free_function() {
102        // Arrange: free function at top level
103        let src = "fun greet(name: String): String { return \"Hello, $name\" }";
104        let tree = parse_kotlin(src);
105        let root = tree.root_node();
106
107        // Act -- verify ELEMENT_QUERY compiles and matches function
108        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
109            .expect("ELEMENT_QUERY must be valid");
110        let mut cursor = tree_sitter::QueryCursor::new();
111        let mut matches = cursor.matches(&query, root, src.as_bytes());
112
113        let mut captured_functions: Vec<String> = Vec::new();
114        while let Some(mat) = matches.next() {
115            for capture in mat.captures {
116                let name = query.capture_names()[capture.index as usize];
117                let node = capture.node;
118                if name == "function" {
119                    if let Some(n) = node.child_by_field_name("name") {
120                        captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
121                    }
122                }
123            }
124        }
125
126        // Assert
127        assert!(
128            captured_functions.contains(&"greet".to_string()),
129            "expected greet function, got {:?}",
130            captured_functions
131        );
132    }
133
134    #[test]
135    fn test_element_query_method_in_class() {
136        // Arrange: method inside a class
137        let src = "class Animal { fun eat() {} }";
138        let tree = parse_kotlin(src);
139        let root = tree.root_node();
140
141        // Act -- verify ELEMENT_QUERY compiles and matches class + method
142        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
143            .expect("ELEMENT_QUERY must be valid");
144        let mut cursor = tree_sitter::QueryCursor::new();
145        let mut matches = cursor.matches(&query, root, src.as_bytes());
146
147        let mut captured_classes: Vec<String> = Vec::new();
148        let mut captured_functions: Vec<String> = Vec::new();
149        while let Some(mat) = matches.next() {
150            for capture in mat.captures {
151                let name = query.capture_names()[capture.index as usize];
152                let node = capture.node;
153                match name {
154                    "class" => {
155                        if let Some(n) = node.child_by_field_name("name") {
156                            captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
157                        }
158                    }
159                    "function" => {
160                        if let Some(n) = node.child_by_field_name("name") {
161                            captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
162                        }
163                    }
164                    _ => {}
165                }
166            }
167        }
168
169        // Assert
170        assert!(
171            captured_classes.contains(&"Animal".to_string()),
172            "expected Animal class, got {:?}",
173            captured_classes
174        );
175        assert!(
176            captured_functions.contains(&"eat".to_string()),
177            "expected eat function, got {:?}",
178            captured_functions
179        );
180    }
181
182    #[test]
183    fn test_call_query() {
184        // Arrange: function call
185        let src = "fun main() { println(\"hello\") }";
186        let tree = parse_kotlin(src);
187        let root = tree.root_node();
188
189        // Act -- verify CALL_QUERY compiles and matches call
190        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), CALL_QUERY)
191            .expect("CALL_QUERY must be valid");
192        let mut cursor = tree_sitter::QueryCursor::new();
193        let mut matches = cursor.matches(&query, root, src.as_bytes());
194
195        let mut captured_calls: Vec<String> = Vec::new();
196        while let Some(mat) = matches.next() {
197            for capture in mat.captures {
198                let name = query.capture_names()[capture.index as usize];
199                if name == "call" {
200                    let node = capture.node;
201                    captured_calls.push(src[node.start_byte()..node.end_byte()].to_string());
202                }
203            }
204        }
205
206        // Assert
207        assert!(
208            captured_calls.contains(&"println".to_string()),
209            "expected println call, got {:?}",
210            captured_calls
211        );
212    }
213
214    #[test]
215    fn test_element_query_class_declarations() {
216        // Arrange: various class types (data class is just a class with data modifier)
217        let src = "class Dog {} object Singleton {}";
218        let tree = parse_kotlin(src);
219        let root = tree.root_node();
220
221        // Act -- verify ELEMENT_QUERY matches all declaration types
222        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
223            .expect("ELEMENT_QUERY must be valid");
224        let mut cursor = tree_sitter::QueryCursor::new();
225        let mut matches = cursor.matches(&query, root, src.as_bytes());
226
227        let mut captured_classes: Vec<String> = Vec::new();
228        while let Some(mat) = matches.next() {
229            for capture in mat.captures {
230                let name = query.capture_names()[capture.index as usize];
231                let node = capture.node;
232                if name == "class" {
233                    if let Some(n) = node.child_by_field_name("name") {
234                        captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
235                    }
236                }
237            }
238        }
239
240        // Assert
241        assert!(
242            captured_classes.contains(&"Dog".to_string()),
243            "expected Dog class, got {:?}",
244            captured_classes
245        );
246        assert!(
247            captured_classes.contains(&"Singleton".to_string()),
248            "expected Singleton object, got {:?}",
249            captured_classes
250        );
251    }
252
253    #[test]
254    fn test_import_query() {
255        // Arrange: import statements
256        let src = "import java.util.List\nimport kotlin.io.println";
257        let tree = parse_kotlin(src);
258        let root = tree.root_node();
259
260        // Act -- verify IMPORT_QUERY compiles and matches imports
261        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), IMPORT_QUERY)
262            .expect("IMPORT_QUERY must be valid");
263        let mut cursor = tree_sitter::QueryCursor::new();
264        let matches = cursor.matches(&query, root, src.as_bytes());
265
266        let import_count = matches.count();
267
268        // Assert
269        assert!(
270            import_count >= 2,
271            "expected at least 2 imports, got {}",
272            import_count
273        );
274    }
275
276    #[test]
277    fn test_extract_inheritance_single_superclass() {
278        // Arrange: class with single superclass (constructor invocation with parens)
279        let src = "class Dog : Animal() {}";
280        let tree = parse_kotlin(src);
281        let root = tree.root_node();
282
283        // Act -- find the class_declaration node and call extract_inheritance
284        let mut class_node: Option<tree_sitter::Node> = None;
285        let mut stack = vec![root];
286        while let Some(node) = stack.pop() {
287            if node.kind() == "class_declaration" {
288                class_node = Some(node);
289                break;
290            }
291            for i in 0..node.child_count() {
292                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
293                    stack.push(child);
294                }
295            }
296        }
297        let class = class_node.expect("class_declaration not found");
298        let bases = extract_inheritance(&class, src);
299
300        // Assert
301        assert!(
302            bases.iter().any(|b| b.contains("Animal")),
303            "expected extends Animal, got {:?}",
304            bases
305        );
306    }
307
308    #[test]
309    fn test_extract_inheritance_multiple_interfaces() {
310        // Arrange: class with multiple interfaces (no parens)
311        let src = "class Dog : Runnable, Comparable<Dog> {}";
312        let tree = parse_kotlin(src);
313        let root = tree.root_node();
314
315        // Act -- find the class_declaration node and call extract_inheritance
316        let mut class_node: Option<tree_sitter::Node> = None;
317        let mut stack = vec![root];
318        while let Some(node) = stack.pop() {
319            if node.kind() == "class_declaration" {
320                class_node = Some(node);
321                break;
322            }
323            for i in 0..node.child_count() {
324                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
325                    stack.push(child);
326                }
327            }
328        }
329        let class = class_node.expect("class_declaration not found");
330        let bases = extract_inheritance(&class, src);
331
332        // Assert
333        assert!(
334            bases.iter().any(|b| b.contains("Runnable")),
335            "expected implements Runnable, got {:?}",
336            bases
337        );
338        assert!(
339            bases.iter().any(|b| b.contains("Comparable")),
340            "expected implements Comparable, got {:?}",
341            bases
342        );
343    }
344
345    #[test]
346    fn test_extract_inheritance_mixed() {
347        // Arrange: class with superclass and interfaces
348        let src = "class Dog : Animal(), Runnable, Comparable<Dog> {}";
349        let tree = parse_kotlin(src);
350        let root = tree.root_node();
351
352        // Act -- find the class_declaration node and call extract_inheritance
353        let mut class_node: Option<tree_sitter::Node> = None;
354        let mut stack = vec![root];
355        while let Some(node) = stack.pop() {
356            if node.kind() == "class_declaration" {
357                class_node = Some(node);
358                break;
359            }
360            for i in 0..node.child_count() {
361                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
362                    stack.push(child);
363                }
364            }
365        }
366        let class = class_node.expect("class_declaration not found");
367        let bases = extract_inheritance(&class, src);
368
369        // Assert
370        assert!(
371            bases.iter().any(|b| b.contains("Animal")),
372            "expected extends Animal, got {:?}",
373            bases
374        );
375        assert!(
376            bases.iter().any(|b| b.contains("Runnable")),
377            "expected implements Runnable, got {:?}",
378            bases
379        );
380        assert!(
381            bases.iter().any(|b| b.contains("Comparable")),
382            "expected implements Comparable, got {:?}",
383            bases
384        );
385    }
386}