Skip to main content

aptu_coder_core/languages/
kotlin.rs

1// SPDX-FileCopyrightText: 2026 aptu-coder contributors
2// SPDX-License-Identifier: Apache-2.0
3
4/// Tree-sitter query for extracting Kotlin elements (functions and classes).
5pub const ELEMENT_QUERY: &str = r"
6(function_declaration
7  name: (identifier) @function_name) @function
8(class_declaration
9  name: (identifier) @class_name) @class
10(object_declaration
11  name: (identifier) @object_name) @class
12";
13
14/// Tree-sitter query for extracting function calls.
15pub const CALL_QUERY: &str = r"
16(call_expression
17  (identifier) @call)
18";
19
20/// Tree-sitter query for extracting type references.
21pub const REFERENCE_QUERY: &str = r"
22(identifier) @type_ref
23";
24
25/// Tree-sitter query for extracting Kotlin imports.
26pub const IMPORT_QUERY: &str = r"
27(import) @import_path
28";
29
30/// Tree-sitter query for extracting definition and use sites.
31pub const DEFUSE_QUERY: &str = r"
32(property_declaration
33  name: (simple_identifier) @write.property)
34(simple_identifier) @read.usage
35";
36
37use tree_sitter::Node;
38
39use crate::languages::get_node_text;
40
41/// Extract inheritance information from a Kotlin class node.
42#[must_use]
43pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
44    let mut inherits = Vec::new();
45
46    // Find the delegation_specifiers child of the class node.
47    // Grammar: optional(seq(':', $.delegation_specifiers))
48    let Some(delegation) = (0..node.child_count())
49        .filter_map(|i| node.child(u32::try_from(i).ok()?))
50        .find(|n| n.kind() == "delegation_specifiers")
51    else {
52        return inherits;
53    };
54
55    // Each delegation_specifier holds either a constructor_invocation (superclass)
56    // or a user_type (interface).
57    for spec in (0..delegation.child_count())
58        .filter_map(|j| delegation.child(u32::try_from(j).ok()?))
59        .filter(|n| n.kind() == "delegation_specifier")
60    {
61        for spec_child in (0..spec.child_count()).filter_map(|k| spec.child(u32::try_from(k).ok()?))
62        {
63            match spec_child.kind() {
64                "constructor_invocation" => {
65                    // Superclass: constructor_invocation = type + value_arguments.
66                    // The first child is the type node.
67                    if let Some(type_node) = spec_child.child(0)
68                        && let Some(text) = get_node_text(&type_node, source)
69                    {
70                        inherits.push(format!("extends {text}"));
71                    }
72                }
73                "type" | "user_type" => {
74                    // Interface: direct type without constructor call.
75                    if let Some(text) = get_node_text(&spec_child, source) {
76                        inherits.push(format!("implements {text}"));
77                    }
78                }
79                _ => {}
80            }
81        }
82    }
83
84    inherits
85}
86
87/// Extract the function name from a Kotlin `function_declaration` node.
88#[must_use]
89pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
90    if node.kind() != "function_declaration" {
91        return None;
92    }
93    node.child_by_field_name("name")
94        .and_then(|n| get_node_text(&n, source))
95}
96
97/// Find the receiver type (enclosing class or object) for a Kotlin function.
98///
99/// Returns `None` for top-level functions (including extension functions) and
100/// functions whose only enclosing type is a `companion_object`.
101#[must_use]
102pub fn find_receiver_type(node: &Node, source: &str) -> Option<String> {
103    if node.kind() != "function_declaration" {
104        return None;
105    }
106    let mut current = *node;
107    while let Some(parent) = current.parent() {
108        match parent.kind() {
109            "class_declaration" | "object_declaration" => {
110                return parent
111                    .child_by_field_name("name")
112                    .and_then(|n| get_node_text(&n, source));
113            }
114            _ => {
115                current = parent;
116            }
117        }
118    }
119    None
120}
121
122/// Find the method name when a function lives inside a named type body.
123///
124/// Returns `None` for top-level functions and functions inside `companion_object`
125/// that have no enclosing `class_declaration` or `object_declaration`.
126#[must_use]
127pub fn find_method_for_receiver(
128    node: &Node,
129    source: &str,
130    _depth: Option<usize>,
131) -> Option<String> {
132    if node.kind() != "function_declaration" {
133        return None;
134    }
135    let mut current = *node;
136    let mut in_type_body = false;
137    while let Some(parent) = current.parent() {
138        match parent.kind() {
139            "class_declaration" | "object_declaration" => {
140                in_type_body = true;
141                break;
142            }
143            _ => {
144                current = parent;
145            }
146        }
147    }
148    if !in_type_body {
149        return None;
150    }
151    node.child_by_field_name("name")
152        .and_then(|n| get_node_text(&n, source))
153}
154
155#[cfg(all(test, feature = "lang-kotlin"))]
156mod tests {
157    use super::*;
158    use tree_sitter::{Parser, StreamingIterator};
159
160    fn find_node<'a>(root: tree_sitter::Node<'a>, kind: &str) -> Option<tree_sitter::Node<'a>> {
161        if root.kind() == kind {
162            return Some(root);
163        }
164        let mut cursor = root.walk();
165        for child in root.children(&mut cursor) {
166            if let Some(n) = find_node(child, kind) {
167                return Some(n);
168            }
169        }
170        None
171    }
172
173    fn parse_kotlin(src: &str) -> tree_sitter::Tree {
174        let mut parser = Parser::new();
175        parser
176            .set_language(&tree_sitter_kotlin_ng::LANGUAGE.into())
177            .expect("Error loading Kotlin language");
178        parser.parse(src, None).expect("Failed to parse Kotlin")
179    }
180
181    #[test]
182    fn test_element_query_free_function() {
183        // Arrange: free function at top level
184        let src = "fun greet(name: String): String { return \"Hello, $name\" }";
185        let tree = parse_kotlin(src);
186        let root = tree.root_node();
187
188        // Act -- verify ELEMENT_QUERY compiles and matches function
189        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
190            .expect("ELEMENT_QUERY must be valid");
191        let mut cursor = tree_sitter::QueryCursor::new();
192        let mut matches = cursor.matches(&query, root, src.as_bytes());
193
194        let mut captured_functions: Vec<String> = Vec::new();
195        while let Some(mat) = matches.next() {
196            for capture in mat.captures {
197                let name = query.capture_names()[capture.index as usize];
198                let node = capture.node;
199                if name == "function" {
200                    if let Some(n) = node.child_by_field_name("name") {
201                        captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
202                    }
203                }
204            }
205        }
206
207        // Assert
208        assert!(
209            captured_functions.contains(&"greet".to_string()),
210            "expected greet function, got {:?}",
211            captured_functions
212        );
213    }
214
215    #[test]
216    fn test_element_query_method_in_class() {
217        // Arrange: method inside a class
218        let src = "class Animal { fun eat() {} }";
219        let tree = parse_kotlin(src);
220        let root = tree.root_node();
221
222        // Act -- verify ELEMENT_QUERY compiles and matches class + method
223        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
224            .expect("ELEMENT_QUERY must be valid");
225        let mut cursor = tree_sitter::QueryCursor::new();
226        let mut matches = cursor.matches(&query, root, src.as_bytes());
227
228        let mut captured_classes: Vec<String> = Vec::new();
229        let mut captured_functions: Vec<String> = Vec::new();
230        while let Some(mat) = matches.next() {
231            for capture in mat.captures {
232                let name = query.capture_names()[capture.index as usize];
233                let node = capture.node;
234                match name {
235                    "class" => {
236                        if let Some(n) = node.child_by_field_name("name") {
237                            captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
238                        }
239                    }
240                    "function" => {
241                        if let Some(n) = node.child_by_field_name("name") {
242                            captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
243                        }
244                    }
245                    _ => {}
246                }
247            }
248        }
249
250        // Assert
251        assert!(
252            captured_classes.contains(&"Animal".to_string()),
253            "expected Animal class, got {:?}",
254            captured_classes
255        );
256        assert!(
257            captured_functions.contains(&"eat".to_string()),
258            "expected eat function, got {:?}",
259            captured_functions
260        );
261    }
262
263    #[test]
264    fn test_call_query() {
265        // Arrange: function call
266        let src = "fun main() { println(\"hello\") }";
267        let tree = parse_kotlin(src);
268        let root = tree.root_node();
269
270        // Act -- verify CALL_QUERY compiles and matches call
271        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), CALL_QUERY)
272            .expect("CALL_QUERY must be valid");
273        let mut cursor = tree_sitter::QueryCursor::new();
274        let mut matches = cursor.matches(&query, root, src.as_bytes());
275
276        let mut captured_calls: Vec<String> = Vec::new();
277        while let Some(mat) = matches.next() {
278            for capture in mat.captures {
279                let name = query.capture_names()[capture.index as usize];
280                if name == "call" {
281                    let node = capture.node;
282                    captured_calls.push(src[node.start_byte()..node.end_byte()].to_string());
283                }
284            }
285        }
286
287        // Assert
288        assert!(
289            captured_calls.contains(&"println".to_string()),
290            "expected println call, got {:?}",
291            captured_calls
292        );
293    }
294
295    #[test]
296    fn test_element_query_class_declarations() {
297        // Arrange: various class types (data class is just a class with data modifier)
298        let src = "class Dog {} object Singleton {}";
299        let tree = parse_kotlin(src);
300        let root = tree.root_node();
301
302        // Act -- verify ELEMENT_QUERY matches all declaration types
303        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
304            .expect("ELEMENT_QUERY must be valid");
305        let mut cursor = tree_sitter::QueryCursor::new();
306        let mut matches = cursor.matches(&query, root, src.as_bytes());
307
308        let mut captured_classes: Vec<String> = Vec::new();
309        while let Some(mat) = matches.next() {
310            for capture in mat.captures {
311                let name = query.capture_names()[capture.index as usize];
312                let node = capture.node;
313                if name == "class" {
314                    if let Some(n) = node.child_by_field_name("name") {
315                        captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
316                    }
317                }
318            }
319        }
320
321        // Assert
322        assert!(
323            captured_classes.contains(&"Dog".to_string()),
324            "expected Dog class, got {:?}",
325            captured_classes
326        );
327        assert!(
328            captured_classes.contains(&"Singleton".to_string()),
329            "expected Singleton object, got {:?}",
330            captured_classes
331        );
332    }
333
334    #[test]
335    fn test_import_query() {
336        // Arrange: import statements
337        let src = "import java.util.List\nimport kotlin.io.println";
338        let tree = parse_kotlin(src);
339        let root = tree.root_node();
340
341        // Act -- verify IMPORT_QUERY compiles and matches imports
342        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), IMPORT_QUERY)
343            .expect("IMPORT_QUERY must be valid");
344        let mut cursor = tree_sitter::QueryCursor::new();
345        let matches = cursor.matches(&query, root, src.as_bytes());
346
347        let import_count = matches.count();
348
349        // Assert
350        assert!(
351            import_count >= 2,
352            "expected at least 2 imports, got {}",
353            import_count
354        );
355    }
356
357    #[test]
358    fn test_extract_inheritance_single_superclass() {
359        // Arrange: class with single superclass (constructor invocation with parens)
360        let src = "class Dog : Animal() {}";
361        let tree = parse_kotlin(src);
362        let root = tree.root_node();
363
364        // Act -- find the class_declaration node and call extract_inheritance
365        let mut class_node: Option<tree_sitter::Node> = None;
366        let mut stack = vec![root];
367        while let Some(node) = stack.pop() {
368            if node.kind() == "class_declaration" {
369                class_node = Some(node);
370                break;
371            }
372            for i in 0..node.child_count() {
373                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
374                    stack.push(child);
375                }
376            }
377        }
378        let class = class_node.expect("class_declaration not found");
379        let bases = extract_inheritance(&class, src);
380
381        // Assert
382        assert!(
383            bases.iter().any(|b| b.contains("Animal")),
384            "expected extends Animal, got {:?}",
385            bases
386        );
387    }
388
389    #[test]
390    fn test_extract_inheritance_multiple_interfaces() {
391        // Arrange: class with multiple interfaces (no parens)
392        let src = "class Dog : Runnable, Comparable<Dog> {}";
393        let tree = parse_kotlin(src);
394        let root = tree.root_node();
395
396        // Act -- find the class_declaration node and call extract_inheritance
397        let mut class_node: Option<tree_sitter::Node> = None;
398        let mut stack = vec![root];
399        while let Some(node) = stack.pop() {
400            if node.kind() == "class_declaration" {
401                class_node = Some(node);
402                break;
403            }
404            for i in 0..node.child_count() {
405                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
406                    stack.push(child);
407                }
408            }
409        }
410        let class = class_node.expect("class_declaration not found");
411        let bases = extract_inheritance(&class, src);
412
413        // Assert
414        assert!(
415            bases.iter().any(|b| b.contains("Runnable")),
416            "expected implements Runnable, got {:?}",
417            bases
418        );
419        assert!(
420            bases.iter().any(|b| b.contains("Comparable")),
421            "expected implements Comparable, got {:?}",
422            bases
423        );
424    }
425
426    #[test]
427    fn test_extract_inheritance_mixed() {
428        // Arrange: class with superclass and interfaces
429        let src = "class Dog : Animal(), Runnable, Comparable<Dog> {}";
430        let tree = parse_kotlin(src);
431        let root = tree.root_node();
432
433        // Act -- find the class_declaration node and call extract_inheritance
434        let mut class_node: Option<tree_sitter::Node> = None;
435        let mut stack = vec![root];
436        while let Some(node) = stack.pop() {
437            if node.kind() == "class_declaration" {
438                class_node = Some(node);
439                break;
440            }
441            for i in 0..node.child_count() {
442                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
443                    stack.push(child);
444                }
445            }
446        }
447        let class = class_node.expect("class_declaration not found");
448        let bases = extract_inheritance(&class, src);
449
450        // Assert
451        assert!(
452            bases.iter().any(|b| b.contains("Animal")),
453            "expected extends Animal, got {:?}",
454            bases
455        );
456        assert!(
457            bases.iter().any(|b| b.contains("Runnable")),
458            "expected implements Runnable, got {:?}",
459            bases
460        );
461        assert!(
462            bases.iter().any(|b| b.contains("Comparable")),
463            "expected implements Comparable, got {:?}",
464            bases
465        );
466    }
467
468    #[test]
469    fn test_extract_function_name_free_function() {
470        let src = "fun greet() {}";
471        let tree = parse_kotlin(src);
472        let root = tree.root_node();
473        let node = find_node(root, "function_declaration").expect("function_declaration not found");
474        let result = extract_function_name(&node, src, "kotlin");
475        assert_eq!(result, Some("greet".to_string()));
476    }
477
478    #[test]
479    fn test_extract_function_name_method_in_class() {
480        let src = "class Foo { fun bar() {} }";
481        let tree = parse_kotlin(src);
482        let root = tree.root_node();
483        // find the inner function_declaration (bar), not the class
484        let class_node = find_node(root, "class_declaration").expect("class_declaration not found");
485        let node =
486            find_node(class_node, "function_declaration").expect("function_declaration not found");
487        let result = extract_function_name(&node, src, "kotlin");
488        assert_eq!(result, Some("bar".to_string()));
489    }
490
491    #[test]
492    fn test_find_receiver_type_top_level_returns_none() {
493        let src = "fun greet() {}";
494        let tree = parse_kotlin(src);
495        let root = tree.root_node();
496        let node = find_node(root, "function_declaration").expect("function_declaration not found");
497        let result = find_receiver_type(&node, src);
498        assert_eq!(result, None);
499    }
500
501    #[test]
502    fn test_find_receiver_type_method_in_class() {
503        let src = "class Foo { fun bar() {} }";
504        let tree = parse_kotlin(src);
505        let root = tree.root_node();
506        let class_node = find_node(root, "class_declaration").expect("class_declaration not found");
507        let node =
508            find_node(class_node, "function_declaration").expect("function_declaration not found");
509        let result = find_receiver_type(&node, src);
510        assert_eq!(result, Some("Foo".to_string()));
511    }
512
513    #[test]
514    fn test_find_receiver_type_extension_function_returns_none() {
515        let src = "fun String.greet() {}";
516        let tree = parse_kotlin(src);
517        let root = tree.root_node();
518        let node = find_node(root, "function_declaration").expect("function_declaration not found");
519        let result = find_receiver_type(&node, src);
520        assert_eq!(result, None);
521    }
522
523    #[test]
524    fn test_find_method_for_receiver_top_level_returns_none() {
525        let src = "fun greet() {}";
526        let tree = parse_kotlin(src);
527        let root = tree.root_node();
528        let node = find_node(root, "function_declaration").expect("function_declaration not found");
529        let result = find_method_for_receiver(&node, src, None);
530        assert_eq!(result, None);
531    }
532
533    #[test]
534    fn test_find_method_for_receiver_method_in_class() {
535        let src = "class Foo { fun bar() {} }";
536        let tree = parse_kotlin(src);
537        let root = tree.root_node();
538        let class_node = find_node(root, "class_declaration").expect("class_declaration not found");
539        let node =
540            find_node(class_node, "function_declaration").expect("function_declaration not found");
541        let result = find_method_for_receiver(&node, src, None);
542        assert_eq!(result, Some("bar".to_string()));
543    }
544}