pub const ELEMENT_QUERY: &str = r"
(function_declaration
name: (identifier) @function_name) @function
(class_declaration
name: (identifier) @class_name) @class
(object_declaration
name: (identifier) @object_name) @class
";
pub const CALL_QUERY: &str = r"
(call_expression
(identifier) @call)
";
pub const REFERENCE_QUERY: &str = r"
(identifier) @type_ref
";
pub const IMPORT_QUERY: &str = r"
(import) @import_path
";
pub const DEFUSE_QUERY: &str = r"
(property_declaration
name: (simple_identifier) @write.property)
(simple_identifier) @read.usage
";
use tree_sitter::Node;
use crate::languages::get_node_text;
#[must_use]
pub fn extract_inheritance(node: &Node, source: &str) -> Vec<String> {
let mut inherits = Vec::new();
let Some(delegation) = (0..node.child_count())
.filter_map(|i| node.child(u32::try_from(i).ok()?))
.find(|n| n.kind() == "delegation_specifiers")
else {
return inherits;
};
for spec in (0..delegation.child_count())
.filter_map(|j| delegation.child(u32::try_from(j).ok()?))
.filter(|n| n.kind() == "delegation_specifier")
{
for spec_child in (0..spec.child_count()).filter_map(|k| spec.child(u32::try_from(k).ok()?))
{
match spec_child.kind() {
"constructor_invocation" => {
if let Some(type_node) = spec_child.child(0)
&& let Some(text) = get_node_text(&type_node, source)
{
inherits.push(format!("extends {text}"));
}
}
"type" | "user_type" => {
if let Some(text) = get_node_text(&spec_child, source) {
inherits.push(format!("implements {text}"));
}
}
_ => {}
}
}
}
inherits
}
#[must_use]
pub fn extract_function_name(node: &Node, source: &str, _lang: &str) -> Option<String> {
if node.kind() != "function_declaration" {
return None;
}
node.child_by_field_name("name")
.and_then(|n| get_node_text(&n, source))
}
#[must_use]
pub fn find_receiver_type(node: &Node, source: &str) -> Option<String> {
if node.kind() != "function_declaration" {
return None;
}
let mut current = *node;
while let Some(parent) = current.parent() {
match parent.kind() {
"class_declaration" | "object_declaration" => {
return parent
.child_by_field_name("name")
.and_then(|n| get_node_text(&n, source));
}
_ => {
current = parent;
}
}
}
None
}
#[must_use]
pub fn find_method_for_receiver(
node: &Node,
source: &str,
_depth: Option<usize>,
) -> Option<String> {
if node.kind() != "function_declaration" {
return None;
}
let mut current = *node;
let mut in_type_body = false;
while let Some(parent) = current.parent() {
match parent.kind() {
"class_declaration" | "object_declaration" => {
in_type_body = true;
break;
}
_ => {
current = parent;
}
}
}
if !in_type_body {
return None;
}
node.child_by_field_name("name")
.and_then(|n| get_node_text(&n, source))
}
#[cfg(all(test, feature = "lang-kotlin"))]
mod tests {
use super::*;
use tree_sitter::{Parser, StreamingIterator};
fn find_node<'a>(root: tree_sitter::Node<'a>, kind: &str) -> Option<tree_sitter::Node<'a>> {
if root.kind() == kind {
return Some(root);
}
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
if let Some(n) = find_node(child, kind) {
return Some(n);
}
}
None
}
fn parse_kotlin(src: &str) -> tree_sitter::Tree {
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_kotlin_ng::LANGUAGE.into())
.expect("Error loading Kotlin language");
parser.parse(src, None).expect("Failed to parse Kotlin")
}
#[test]
fn test_element_query_free_function() {
let src = "fun greet(name: String): String { return \"Hello, $name\" }";
let tree = parse_kotlin(src);
let root = tree.root_node();
let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
.expect("ELEMENT_QUERY must be valid");
let mut cursor = tree_sitter::QueryCursor::new();
let mut matches = cursor.matches(&query, root, src.as_bytes());
let mut captured_functions: Vec<String> = Vec::new();
while let Some(mat) = matches.next() {
for capture in mat.captures {
let name = query.capture_names()[capture.index as usize];
let node = capture.node;
if name == "function" {
if let Some(n) = node.child_by_field_name("name") {
captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
}
}
}
}
assert!(
captured_functions.contains(&"greet".to_string()),
"expected greet function, got {:?}",
captured_functions
);
}
#[test]
fn test_element_query_method_in_class() {
let src = "class Animal { fun eat() {} }";
let tree = parse_kotlin(src);
let root = tree.root_node();
let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
.expect("ELEMENT_QUERY must be valid");
let mut cursor = tree_sitter::QueryCursor::new();
let mut matches = cursor.matches(&query, root, src.as_bytes());
let mut captured_classes: Vec<String> = Vec::new();
let mut captured_functions: Vec<String> = Vec::new();
while let Some(mat) = matches.next() {
for capture in mat.captures {
let name = query.capture_names()[capture.index as usize];
let node = capture.node;
match name {
"class" => {
if let Some(n) = node.child_by_field_name("name") {
captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
}
}
"function" => {
if let Some(n) = node.child_by_field_name("name") {
captured_functions.push(src[n.start_byte()..n.end_byte()].to_string());
}
}
_ => {}
}
}
}
assert!(
captured_classes.contains(&"Animal".to_string()),
"expected Animal class, got {:?}",
captured_classes
);
assert!(
captured_functions.contains(&"eat".to_string()),
"expected eat function, got {:?}",
captured_functions
);
}
#[test]
fn test_call_query() {
let src = "fun main() { println(\"hello\") }";
let tree = parse_kotlin(src);
let root = tree.root_node();
let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), CALL_QUERY)
.expect("CALL_QUERY must be valid");
let mut cursor = tree_sitter::QueryCursor::new();
let mut matches = cursor.matches(&query, root, src.as_bytes());
let mut captured_calls: Vec<String> = Vec::new();
while let Some(mat) = matches.next() {
for capture in mat.captures {
let name = query.capture_names()[capture.index as usize];
if name == "call" {
let node = capture.node;
captured_calls.push(src[node.start_byte()..node.end_byte()].to_string());
}
}
}
assert!(
captured_calls.contains(&"println".to_string()),
"expected println call, got {:?}",
captured_calls
);
}
#[test]
fn test_element_query_class_declarations() {
let src = "class Dog {} object Singleton {}";
let tree = parse_kotlin(src);
let root = tree.root_node();
let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), ELEMENT_QUERY)
.expect("ELEMENT_QUERY must be valid");
let mut cursor = tree_sitter::QueryCursor::new();
let mut matches = cursor.matches(&query, root, src.as_bytes());
let mut captured_classes: Vec<String> = Vec::new();
while let Some(mat) = matches.next() {
for capture in mat.captures {
let name = query.capture_names()[capture.index as usize];
let node = capture.node;
if name == "class" {
if let Some(n) = node.child_by_field_name("name") {
captured_classes.push(src[n.start_byte()..n.end_byte()].to_string());
}
}
}
}
assert!(
captured_classes.contains(&"Dog".to_string()),
"expected Dog class, got {:?}",
captured_classes
);
assert!(
captured_classes.contains(&"Singleton".to_string()),
"expected Singleton object, got {:?}",
captured_classes
);
}
#[test]
fn test_import_query() {
let src = "import java.util.List\nimport kotlin.io.println";
let tree = parse_kotlin(src);
let root = tree.root_node();
let query = tree_sitter::Query::new(&tree_sitter_kotlin_ng::LANGUAGE.into(), IMPORT_QUERY)
.expect("IMPORT_QUERY must be valid");
let mut cursor = tree_sitter::QueryCursor::new();
let matches = cursor.matches(&query, root, src.as_bytes());
let import_count = matches.count();
assert!(
import_count >= 2,
"expected at least 2 imports, got {}",
import_count
);
}
#[test]
fn test_extract_inheritance_single_superclass() {
let src = "class Dog : Animal() {}";
let tree = parse_kotlin(src);
let root = tree.root_node();
let mut class_node: Option<tree_sitter::Node> = None;
let mut stack = vec![root];
while let Some(node) = stack.pop() {
if node.kind() == "class_declaration" {
class_node = Some(node);
break;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
stack.push(child);
}
}
}
let class = class_node.expect("class_declaration not found");
let bases = extract_inheritance(&class, src);
assert!(
bases.iter().any(|b| b.contains("Animal")),
"expected extends Animal, got {:?}",
bases
);
}
#[test]
fn test_extract_inheritance_multiple_interfaces() {
let src = "class Dog : Runnable, Comparable<Dog> {}";
let tree = parse_kotlin(src);
let root = tree.root_node();
let mut class_node: Option<tree_sitter::Node> = None;
let mut stack = vec![root];
while let Some(node) = stack.pop() {
if node.kind() == "class_declaration" {
class_node = Some(node);
break;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
stack.push(child);
}
}
}
let class = class_node.expect("class_declaration not found");
let bases = extract_inheritance(&class, src);
assert!(
bases.iter().any(|b| b.contains("Runnable")),
"expected implements Runnable, got {:?}",
bases
);
assert!(
bases.iter().any(|b| b.contains("Comparable")),
"expected implements Comparable, got {:?}",
bases
);
}
#[test]
fn test_extract_inheritance_mixed() {
let src = "class Dog : Animal(), Runnable, Comparable<Dog> {}";
let tree = parse_kotlin(src);
let root = tree.root_node();
let mut class_node: Option<tree_sitter::Node> = None;
let mut stack = vec![root];
while let Some(node) = stack.pop() {
if node.kind() == "class_declaration" {
class_node = Some(node);
break;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(u32::try_from(i).unwrap_or(u32::MAX)) {
stack.push(child);
}
}
}
let class = class_node.expect("class_declaration not found");
let bases = extract_inheritance(&class, src);
assert!(
bases.iter().any(|b| b.contains("Animal")),
"expected extends Animal, got {:?}",
bases
);
assert!(
bases.iter().any(|b| b.contains("Runnable")),
"expected implements Runnable, got {:?}",
bases
);
assert!(
bases.iter().any(|b| b.contains("Comparable")),
"expected implements Comparable, got {:?}",
bases
);
}
#[test]
fn test_extract_function_name_free_function() {
let src = "fun greet() {}";
let tree = parse_kotlin(src);
let root = tree.root_node();
let node = find_node(root, "function_declaration").expect("function_declaration not found");
let result = extract_function_name(&node, src, "kotlin");
assert_eq!(result, Some("greet".to_string()));
}
#[test]
fn test_extract_function_name_method_in_class() {
let src = "class Foo { fun bar() {} }";
let tree = parse_kotlin(src);
let root = tree.root_node();
let class_node = find_node(root, "class_declaration").expect("class_declaration not found");
let node =
find_node(class_node, "function_declaration").expect("function_declaration not found");
let result = extract_function_name(&node, src, "kotlin");
assert_eq!(result, Some("bar".to_string()));
}
#[test]
fn test_find_receiver_type_top_level_returns_none() {
let src = "fun greet() {}";
let tree = parse_kotlin(src);
let root = tree.root_node();
let node = find_node(root, "function_declaration").expect("function_declaration not found");
let result = find_receiver_type(&node, src);
assert_eq!(result, None);
}
#[test]
fn test_find_receiver_type_method_in_class() {
let src = "class Foo { fun bar() {} }";
let tree = parse_kotlin(src);
let root = tree.root_node();
let class_node = find_node(root, "class_declaration").expect("class_declaration not found");
let node =
find_node(class_node, "function_declaration").expect("function_declaration not found");
let result = find_receiver_type(&node, src);
assert_eq!(result, Some("Foo".to_string()));
}
#[test]
fn test_find_receiver_type_extension_function_returns_none() {
let src = "fun String.greet() {}";
let tree = parse_kotlin(src);
let root = tree.root_node();
let node = find_node(root, "function_declaration").expect("function_declaration not found");
let result = find_receiver_type(&node, src);
assert_eq!(result, None);
}
#[test]
fn test_find_method_for_receiver_top_level_returns_none() {
let src = "fun greet() {}";
let tree = parse_kotlin(src);
let root = tree.root_node();
let node = find_node(root, "function_declaration").expect("function_declaration not found");
let result = find_method_for_receiver(&node, src, None);
assert_eq!(result, None);
}
#[test]
fn test_find_method_for_receiver_method_in_class() {
let src = "class Foo { fun bar() {} }";
let tree = parse_kotlin(src);
let root = tree.root_node();
let class_node = find_node(root, "class_declaration").expect("class_declaration not found");
let node =
find_node(class_node, "function_declaration").expect("function_declaration not found");
let result = find_method_for_receiver(&node, src, None);
assert_eq!(result, Some("bar".to_string()));
}
}