Skip to main content

aptu_coder_core/languages/
kotlin.rs

1// SPDX-FileCopyrightText: 2026 aptu-coder contributors
2// SPDX-License-Identifier: Apache-2.0
3
4/// Tree-sitter query for extracting Kotlin elements (functions and classes).
5pub const ELEMENT_QUERY: &str = r"
6(function_declaration
7  name: (identifier) @function_name) @function
8(class_declaration
9  name: (identifier) @class_name) @class
10(object_declaration
11  name: (identifier) @object_name) @class
12";
13
14/// Tree-sitter query for extracting function calls.
15pub const CALL_QUERY: &str = r"
16(call_expression
17  (identifier) @call)
18";
19
20/// Tree-sitter query for extracting type references.
21pub const REFERENCE_QUERY: &str = r"
22(identifier) @type_ref
23";
24
25/// Tree-sitter query for extracting Kotlin imports.
26pub const IMPORT_QUERY: &str = r"
27(import) @import_path
28";
29
30/// Tree-sitter query for extracting definition and use sites.
31/// Write-site patterns capture the identifier node within property declarations:
32/// - (single): captures `val x = ...` or `var x = ...` via variable_declaration
33/// - (multi): captures `val (a, b) = ...` via multi_variable_declaration
34pub const DEFUSE_QUERY: &str = r"
35; write site: val/var x = ... (single variable declaration)
36(property_declaration
37  (variable_declaration
38    (identifier) @write.property))
39; write site: val (a, b) = ... (destructuring declaration)
40(property_declaration
41  (multi_variable_declaration
42    (variable_declaration
43      (identifier) @write.property)))
44; read site: any identifier reference -- intentionally broad, consistent with Python/Go/Rust patterns
45(identifier) @read.usage
46";
47
48use tree_sitter::Node;
49
50use crate::languages::get_node_text;
51
52/// Extract inheritance information from a Kotlin class node.
53#[must_use]
54pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
55    let mut inherits = Vec::new();
56
57    // Find the delegation_specifiers child of the class node.
58    // Grammar: optional(seq(':', $.delegation_specifiers))
59    let Some(delegation) = (0..node.child_count())
60        .filter_map(|i| node.child(u32::try_from(i).ok()?))
61        .find(|n| n.kind() == "delegation_specifiers")
62    else {
63        return inherits;
64    };
65
66    // Each delegation_specifier holds either a constructor_invocation (superclass)
67    // or a user_type (interface).
68    for spec in (0..delegation.child_count())
69        .filter_map(|j| delegation.child(u32::try_from(j).ok()?))
70        .filter(|n| n.kind() == "delegation_specifier")
71    {
72        for spec_child in (0..spec.child_count()).filter_map(|k| spec.child(u32::try_from(k).ok()?))
73        {
74            match spec_child.kind() {
75                "constructor_invocation" => {
76                    // Superclass: constructor_invocation = type + value_arguments.
77                    // The first child is the type node.
78                    if let Some(type_node) = spec_child.child(0)
79                        && let Some(text) = get_node_text(&type_node, source)
80                    {
81                        inherits.push(format!("extends {text}"));
82                    }
83                }
84                "type" | "user_type" => {
85                    // Interface: direct type without constructor call.
86                    if let Some(text) = get_node_text(&spec_child, source) {
87                        inherits.push(format!("implements {text}"));
88                    }
89                }
90                _ => {}
91            }
92        }
93    }
94
95    inherits
96}
97
98/// Extract the function name from a Kotlin `function_declaration` node.
99#[must_use]
100pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
101    if node.kind() != "function_declaration" {
102        return None;
103    }
104    node.child_by_field_name("name")
105        .and_then(|n| get_node_text(&n, source))
106}
107
108/// Find the receiver type (enclosing class or object) for a Kotlin function.
109///
110/// Returns `None` for top-level functions (including extension functions) and
111/// functions whose only enclosing type is a `companion_object`.
112#[must_use]
113pub fn find_receiver_type(node: &Node, source: &str) -> Option<String> {
114    if node.kind() != "function_declaration" {
115        return None;
116    }
117    let mut current = *node;
118    while let Some(parent) = current.parent() {
119        match parent.kind() {
120            "class_declaration" | "object_declaration" => {
121                return parent
122                    .child_by_field_name("name")
123                    .and_then(|n| get_node_text(&n, source));
124            }
125            _ => {
126                current = parent;
127            }
128        }
129    }
130    None
131}
132
133/// Find the method name when a function lives inside a named type body.
134///
135/// Returns `None` for top-level functions and functions inside `companion_object`
136/// that have no enclosing `class_declaration` or `object_declaration`.
137#[must_use]
138pub fn find_method_for_receiver(
139    node: &Node,
140    source: &str,
141    _depth: Option<usize>,
142) -> Option<String> {
143    if node.kind() != "function_declaration" {
144        return None;
145    }
146    let mut current = *node;
147    let mut in_type_body = false;
148    while let Some(parent) = current.parent() {
149        match parent.kind() {
150            "class_declaration" | "object_declaration" => {
151                in_type_body = true;
152                break;
153            }
154            _ => {
155                current = parent;
156            }
157        }
158    }
159    if !in_type_body {
160        return None;
161    }
162    node.child_by_field_name("name")
163        .and_then(|n| get_node_text(&n, source))
164}
165
166#[cfg(all(test, feature = "lang-kotlin"))]
167mod tests {
168    use super::*;
169    use tree_sitter::{Parser, StreamingIterator};
170
171    fn find_node<'a>(root: tree_sitter::Node<'a>, kind: &str) -> Option<tree_sitter::Node<'a>> {
172        if root.kind() == kind {
173            return Some(root);
174        }
175        let mut cursor = root.walk();
176        for child in root.children(&mut cursor) {
177            if let Some(n) = find_node(child, kind) {
178                return Some(n);
179            }
180        }
181        None
182    }
183
184    fn parse_kotlin(src: &str) -> tree_sitter::Tree {
185        let mut parser = Parser::new();
186        parser
187            .set_language(&tree_sitter_kotlin_ng::LANGUAGE.into())
188            .expect("Error loading Kotlin language");
189        parser.parse(src, None).expect("Failed to parse Kotlin")
190    }
191
192    #[test]
193    fn test_element_query_free_function() {
194        // Arrange: free function at top level
195        let src = "fun greet(name: String): String { return \"Hello, $name\" }";
196        let tree = parse_kotlin(src);
197        let root = tree.root_node();
198
199        // Act -- verify ELEMENT_QUERY compiles and matches function
200        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
201            .expect("ELEMENT_QUERY must be valid");
202        let mut cursor = tree_sitter::QueryCursor::new();
203        let mut matches = cursor.matches(&query, root, src.as_bytes());
204
205        let mut captured_functions: Vec<String> = Vec::new();
206        while let Some(mat) = matches.next() {
207            for capture in mat.captures {
208                let name = query.capture_names()[capture.index as usize];
209                let node = capture.node;
210                if name == "function" {
211                    if let Some(n) = node.child_by_field_name("name") {
212                        captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
213                    }
214                }
215            }
216        }
217
218        // Assert
219        assert!(
220            captured_functions.contains(&"greet".to_string()),
221            "expected greet function, got {:?}",
222            captured_functions
223        );
224    }
225
226    #[test]
227    fn test_element_query_method_in_class() {
228        // Arrange: method inside a class
229        let src = "class Animal { fun eat() {} }";
230        let tree = parse_kotlin(src);
231        let root = tree.root_node();
232
233        // Act -- verify ELEMENT_QUERY compiles and matches class + method
234        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
235            .expect("ELEMENT_QUERY must be valid");
236        let mut cursor = tree_sitter::QueryCursor::new();
237        let mut matches = cursor.matches(&query, root, src.as_bytes());
238
239        let mut captured_classes: Vec<String> = Vec::new();
240        let mut captured_functions: Vec<String> = Vec::new();
241        while let Some(mat) = matches.next() {
242            for capture in mat.captures {
243                let name = query.capture_names()[capture.index as usize];
244                let node = capture.node;
245                match name {
246                    "class" => {
247                        if let Some(n) = node.child_by_field_name("name") {
248                            captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
249                        }
250                    }
251                    "function" => {
252                        if let Some(n) = node.child_by_field_name("name") {
253                            captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
254                        }
255                    }
256                    _ => {}
257                }
258            }
259        }
260
261        // Assert
262        assert!(
263            captured_classes.contains(&"Animal".to_string()),
264            "expected Animal class, got {:?}",
265            captured_classes
266        );
267        assert!(
268            captured_functions.contains(&"eat".to_string()),
269            "expected eat function, got {:?}",
270            captured_functions
271        );
272    }
273
274    #[test]
275    fn test_call_query() {
276        // Arrange: function call
277        let src = "fun main() { println(\"hello\") }";
278        let tree = parse_kotlin(src);
279        let root = tree.root_node();
280
281        // Act -- verify CALL_QUERY compiles and matches call
282        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), CALL_QUERY)
283            .expect("CALL_QUERY must be valid");
284        let mut cursor = tree_sitter::QueryCursor::new();
285        let mut matches = cursor.matches(&query, root, src.as_bytes());
286
287        let mut captured_calls: Vec<String> = Vec::new();
288        while let Some(mat) = matches.next() {
289            for capture in mat.captures {
290                let name = query.capture_names()[capture.index as usize];
291                if name == "call" {
292                    let node = capture.node;
293                    captured_calls.push(src[node.start_byte()..node.end_byte()].to_string());
294                }
295            }
296        }
297
298        // Assert
299        assert!(
300            captured_calls.contains(&"println".to_string()),
301            "expected println call, got {:?}",
302            captured_calls
303        );
304    }
305
306    #[test]
307    fn test_element_query_class_declarations() {
308        // Arrange: various class types (data class is just a class with data modifier)
309        let src = "class Dog {} object Singleton {}";
310        let tree = parse_kotlin(src);
311        let root = tree.root_node();
312
313        // Act -- verify ELEMENT_QUERY matches all declaration types
314        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
315            .expect("ELEMENT_QUERY must be valid");
316        let mut cursor = tree_sitter::QueryCursor::new();
317        let mut matches = cursor.matches(&query, root, src.as_bytes());
318
319        let mut captured_classes: Vec<String> = Vec::new();
320        while let Some(mat) = matches.next() {
321            for capture in mat.captures {
322                let name = query.capture_names()[capture.index as usize];
323                let node = capture.node;
324                if name == "class" {
325                    if let Some(n) = node.child_by_field_name("name") {
326                        captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
327                    }
328                }
329            }
330        }
331
332        // Assert
333        assert!(
334            captured_classes.contains(&"Dog".to_string()),
335            "expected Dog class, got {:?}",
336            captured_classes
337        );
338        assert!(
339            captured_classes.contains(&"Singleton".to_string()),
340            "expected Singleton object, got {:?}",
341            captured_classes
342        );
343    }
344
345    #[test]
346    fn test_import_query() {
347        // Arrange: import statements
348        let src = "import java.util.List\nimport kotlin.io.println";
349        let tree = parse_kotlin(src);
350        let root = tree.root_node();
351
352        // Act -- verify IMPORT_QUERY compiles and matches imports
353        let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), IMPORT_QUERY)
354            .expect("IMPORT_QUERY must be valid");
355        let mut cursor = tree_sitter::QueryCursor::new();
356        let matches = cursor.matches(&query, root, src.as_bytes());
357
358        let import_count = matches.count();
359
360        // Assert
361        assert!(
362            import_count >= 2,
363            "expected at least 2 imports, got {}",
364            import_count
365        );
366    }
367
368    #[test]
369    fn test_extract_inheritance_single_superclass() {
370        // Arrange: class with single superclass (constructor invocation with parens)
371        let src = "class Dog : Animal() {}";
372        let tree = parse_kotlin(src);
373        let root = tree.root_node();
374
375        // Act -- find the class_declaration node and call extract_inheritance
376        let mut class_node: Option<tree_sitter::Node> = None;
377        let mut stack = vec![root];
378        while let Some(node) = stack.pop() {
379            if node.kind() == "class_declaration" {
380                class_node = Some(node);
381                break;
382            }
383            for i in 0..node.child_count() {
384                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
385                    stack.push(child);
386                }
387            }
388        }
389        let class = class_node.expect("class_declaration not found");
390        let bases = extract_inheritance(&class, src);
391
392        // Assert
393        assert!(
394            bases.iter().any(|b| b.contains("Animal")),
395            "expected extends Animal, got {:?}",
396            bases
397        );
398    }
399
400    #[test]
401    fn test_extract_inheritance_multiple_interfaces() {
402        // Arrange: class with multiple interfaces (no parens)
403        let src = "class Dog : Runnable, Comparable<Dog> {}";
404        let tree = parse_kotlin(src);
405        let root = tree.root_node();
406
407        // Act -- find the class_declaration node and call extract_inheritance
408        let mut class_node: Option<tree_sitter::Node> = None;
409        let mut stack = vec![root];
410        while let Some(node) = stack.pop() {
411            if node.kind() == "class_declaration" {
412                class_node = Some(node);
413                break;
414            }
415            for i in 0..node.child_count() {
416                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
417                    stack.push(child);
418                }
419            }
420        }
421        let class = class_node.expect("class_declaration not found");
422        let bases = extract_inheritance(&class, src);
423
424        // Assert
425        assert!(
426            bases.iter().any(|b| b.contains("Runnable")),
427            "expected implements Runnable, got {:?}",
428            bases
429        );
430        assert!(
431            bases.iter().any(|b| b.contains("Comparable")),
432            "expected implements Comparable, got {:?}",
433            bases
434        );
435    }
436
437    #[test]
438    fn test_extract_inheritance_mixed() {
439        // Arrange: class with superclass and interfaces
440        let src = "class Dog : Animal(), Runnable, Comparable<Dog> {}";
441        let tree = parse_kotlin(src);
442        let root = tree.root_node();
443
444        // Act -- find the class_declaration node and call extract_inheritance
445        let mut class_node: Option<tree_sitter::Node> = None;
446        let mut stack = vec![root];
447        while let Some(node) = stack.pop() {
448            if node.kind() == "class_declaration" {
449                class_node = Some(node);
450                break;
451            }
452            for i in 0..node.child_count() {
453                if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
454                    stack.push(child);
455                }
456            }
457        }
458        let class = class_node.expect("class_declaration not found");
459        let bases = extract_inheritance(&class, src);
460
461        // Assert
462        assert!(
463            bases.iter().any(|b| b.contains("Animal")),
464            "expected extends Animal, got {:?}",
465            bases
466        );
467        assert!(
468            bases.iter().any(|b| b.contains("Runnable")),
469            "expected implements Runnable, got {:?}",
470            bases
471        );
472        assert!(
473            bases.iter().any(|b| b.contains("Comparable")),
474            "expected implements Comparable, got {:?}",
475            bases
476        );
477    }
478
479    #[test]
480    fn test_extract_function_name_free_function() {
481        let src = "fun greet() {}";
482        let tree = parse_kotlin(src);
483        let root = tree.root_node();
484        let node = find_node(root, "function_declaration").expect("function_declaration not found");
485        let result = extract_function_name(&node, src, "kotlin");
486        assert_eq!(result, Some("greet".to_string()));
487    }
488
489    #[test]
490    fn test_extract_function_name_method_in_class() {
491        let src = "class Foo { fun bar() {} }";
492        let tree = parse_kotlin(src);
493        let root = tree.root_node();
494        // find the inner function_declaration (bar), not the class
495        let class_node = find_node(root, "class_declaration").expect("class_declaration not found");
496        let node =
497            find_node(class_node, "function_declaration").expect("function_declaration not found");
498        let result = extract_function_name(&node, src, "kotlin");
499        assert_eq!(result, Some("bar".to_string()));
500    }
501
502    #[test]
503    fn test_find_receiver_type_top_level_returns_none() {
504        let src = "fun greet() {}";
505        let tree = parse_kotlin(src);
506        let root = tree.root_node();
507        let node = find_node(root, "function_declaration").expect("function_declaration not found");
508        let result = find_receiver_type(&node, src);
509        assert_eq!(result, None);
510    }
511
512    #[test]
513    fn test_find_receiver_type_method_in_class() {
514        let src = "class Foo { fun bar() {} }";
515        let tree = parse_kotlin(src);
516        let root = tree.root_node();
517        let class_node = find_node(root, "class_declaration").expect("class_declaration not found");
518        let node =
519            find_node(class_node, "function_declaration").expect("function_declaration not found");
520        let result = find_receiver_type(&node, src);
521        assert_eq!(result, Some("Foo".to_string()));
522    }
523
524    #[test]
525    fn test_find_receiver_type_extension_function_returns_none() {
526        let src = "fun String.greet() {}";
527        let tree = parse_kotlin(src);
528        let root = tree.root_node();
529        let node = find_node(root, "function_declaration").expect("function_declaration not found");
530        let result = find_receiver_type(&node, src);
531        assert_eq!(result, None);
532    }
533
534    #[test]
535    fn test_find_method_for_receiver_top_level_returns_none() {
536        let src = "fun greet() {}";
537        let tree = parse_kotlin(src);
538        let root = tree.root_node();
539        let node = find_node(root, "function_declaration").expect("function_declaration not found");
540        let result = find_method_for_receiver(&node, src, None);
541        assert_eq!(result, None);
542    }
543
544    #[test]
545    fn test_find_method_for_receiver_method_in_class() {
546        let src = "class Foo { fun bar() {} }";
547        let tree = parse_kotlin(src);
548        let root = tree.root_node();
549        let class_node = find_node(root, "class_declaration").expect("class_declaration not found");
550        let node =
551            find_node(class_node, "function_declaration").expect("function_declaration not found");
552        let result = find_method_for_receiver(&node, src, None);
553        assert_eq!(result, Some("bar".to_string()));
554    }
555
556    #[test]
557    fn test_defuse_kotlin_val_declaration() {
558        // Arrange: val declaration with write and read
559        let source = r#"
560fun main() {
561    val x = 42
562    val y = x + 1
563}
564"#;
565        // Act
566        let sites = crate::parser::SemanticExtractor::extract_def_use_for_file(
567            source,
568            "kotlin",
569            "x",
570            "src/main.kt",
571            None,
572        );
573
574        // Assert
575        assert!(
576            !sites.is_empty(),
577            "expected at least one def-use site for 'x'"
578        );
579        let has_write = sites
580            .iter()
581            .any(|s| s.kind == crate::types::DefUseKind::Write);
582        let has_read = sites
583            .iter()
584            .any(|s| s.kind == crate::types::DefUseKind::Read);
585        assert!(has_write, "expected a write site for 'x'");
586        assert!(has_read, "expected a read site for 'x'");
587    }
588
589    #[test]
590    fn test_defuse_kotlin_multi_variable_declaration() {
591        // Arrange: destructuring assignment with multiple write sites
592        let source = r#"
593fun main() {
594    val (a, b) = Pair(1, 2)
595}
596"#;
597        // Act
598        let sites_a = crate::parser::SemanticExtractor::extract_def_use_for_file(
599            source,
600            "kotlin",
601            "a",
602            "src/main.kt",
603            None,
604        );
605
606        // Assert
607        assert!(
608            !sites_a.is_empty(),
609            "expected at least one def-use site for 'a'"
610        );
611        let has_write_a = sites_a
612            .iter()
613            .any(|s| s.kind == crate::types::DefUseKind::Write);
614        assert!(has_write_a, "expected a write site for 'a'");
615    }
616}