//! Kotlin parser using tree-sitter
//!
//! Extracts classes, interfaces, objects, functions, imports, and call relationships from Kotlin source code.
use crate::models::{Class, Function};
use crate::parsers::ParseResult;
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::path::Path;
use streaming_iterator::StreamingIterator;
use tree_sitter::{Node, Parser, Query, QueryCursor};
/// Parse a Kotlin file and extract all code entities
pub fn parse(path: &Path) -> Result<ParseResult> {
let source = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read file: {}", path.display()))?;
parse_source(&source, path)
}
/// Parse Kotlin source code directly (useful for testing)
pub fn parse_source(source: &str, path: &Path) -> Result<ParseResult> {
let mut parser = Parser::new();
let language = tree_sitter_kotlin::LANGUAGE;
parser
.set_language(&language.into())
.context("Failed to set Kotlin language")?;
let tree = parser
.parse(source, None)
.context("Failed to parse Kotlin source")?;
let root = tree.root_node();
let source_bytes = source.as_bytes();
let mut result = ParseResult::default();
extract_top_level_functions(&root, source_bytes, path, &mut result)?;
extract_classes_and_objects(&root, source_bytes, path, &mut result)?;
extract_imports(&root, source_bytes, &mut result)?;
extract_calls(&root, source_bytes, path, &mut result)?;
Ok(result)
}
/// Extract top-level function definitions
fn extract_top_level_functions(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
let query_str = r#"
(source_file
(function_declaration
(simple_identifier) @func_name
(function_value_parameters) @params
) @func
)
"#;
let language = tree_sitter_kotlin::LANGUAGE;
let query = Query::new(&language.into(), query_str).context("Failed to create function query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
let mut func_node = None;
let mut name = String::new();
let mut params_node = None;
for capture in m.captures.iter() {
let capture_name = query.capture_names()[capture.index as usize];
match capture_name {
"func" => func_node = Some(capture.node),
"func_name" => {
name = capture.node.utf8_text(source).unwrap_or("").to_string();
}
"params" => params_node = Some(capture.node),
_ => {}
}
}
if let Some(node) = func_node {
let is_async = is_suspend_function(&node, source);
let parameters = extract_parameters(params_node, source);
let return_type = extract_return_type(&node, source);
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}:{}", path.display(), name, line_start);
result.functions.push(Function {
name: name.clone(),
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async,
complexity: Some(calculate_complexity(&node, source)),
});
}
}
Ok(())
}
/// Check if function has suspend modifier
fn is_suspend_function(node: &Node, source: &[u8]) -> bool {
for child in node.children(&mut node.walk()) {
if child.kind() == "modifiers" {
for modifier in child.children(&mut child.walk()) {
if let Ok(text) = modifier.utf8_text(source) {
if text == "suspend" {
return true;
}
}
}
}
}
false
}
/// Extract return type from a function
fn extract_return_type(func_node: &Node, source: &[u8]) -> Option<String> {
for child in func_node.children(&mut func_node.walk()) {
if child.kind() == ":" {
// Next sibling should be the return type
continue;
}
if child.kind() == "user_type" || child.kind() == "nullable_type" {
return child.utf8_text(source).ok().map(|s| s.to_string());
}
}
None
}
/// Extract classes, interfaces, and objects from the AST
fn extract_classes_and_objects(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
extract_types_recursive(root, source, path, result, None);
Ok(())
}
/// Recursively extract type definitions
fn extract_types_recursive(
node: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
parent_type: Option<&str>,
) {
for child in node.children(&mut node.walk()) {
match child.kind() {
"class_declaration" => {
if let Some(class) = parse_class_node(&child, source, path, parent_type) {
let class_name = class.name.clone();
extract_class_methods(&child, source, path, result, &class_name);
result.classes.push(class);
// Handle nested classes
if let Some(body) = find_class_body(&child) {
extract_types_recursive(&body, source, path, result, Some(&class_name));
}
}
}
"interface_declaration" => {
if let Some(iface) = parse_interface_node(&child, source, path, parent_type) {
let iface_name = iface.name.clone();
extract_class_methods(&child, source, path, result, &iface_name);
result.classes.push(iface);
}
}
"object_declaration" => {
if let Some(obj) = parse_object_node(&child, source, path, parent_type) {
let obj_name = obj.name.clone();
extract_class_methods(&child, source, path, result, &obj_name);
result.classes.push(obj);
}
}
"enum_class_body" | "class_body" => {
extract_types_recursive(&child, source, path, result, parent_type);
}
_ => {
extract_types_recursive(&child, source, path, result, parent_type);
}
}
}
}
/// Find the class body node
fn find_class_body(class_node: &Node) -> Option<Node> {
for child in class_node.children(&mut class_node.walk()) {
if child.kind() == "class_body" || child.kind() == "enum_class_body" {
return Some(child);
}
}
None
}
/// Parse a class declaration into a Class struct
fn parse_class_node(node: &Node, source: &[u8], path: &Path, parent: Option<&str>) -> Option<Class> {
let name = extract_class_name(node, source)?;
let full_name = if let Some(parent_name) = parent {
format!("{}.{}", parent_name, name)
} else {
name.clone()
};
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
// Check if it's an enum, data class, sealed class, etc.
let kind = get_class_kind(node, source);
let qualified_name = if kind.is_empty() {
format!("{}::{}:{}", path.display(), full_name, line_start)
} else {
format!("{}::{}::{}:{}", path.display(), kind, full_name, line_start)
};
let bases = extract_delegation_specifiers(node, source);
let methods = extract_method_names(node, source);
Some(Class {
name: full_name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
methods,
bases,
})
}
/// Extract the class name from a class declaration
fn extract_class_name(node: &Node, source: &[u8]) -> Option<String> {
for child in node.children(&mut node.walk()) {
if child.kind() == "type_identifier" || child.kind() == "simple_identifier" {
return child.utf8_text(source).ok().map(|s| s.to_string());
}
}
None
}
/// Get the kind of class (enum, data, sealed, etc.)
fn get_class_kind(node: &Node, source: &[u8]) -> String {
for child in node.children(&mut node.walk()) {
if child.kind() == "modifiers" {
for modifier in child.children(&mut child.walk()) {
if let Ok(text) = modifier.utf8_text(source) {
match text {
"enum" => return "enum".to_string(),
"data" => return "data".to_string(),
"sealed" => return "sealed".to_string(),
"abstract" => return "abstract".to_string(),
"open" => return "open".to_string(),
_ => {}
}
}
}
}
}
String::new()
}
/// Parse an interface declaration
fn parse_interface_node(node: &Node, source: &[u8], path: &Path, parent: Option<&str>) -> Option<Class> {
let name = extract_class_name(node, source)?;
let full_name = if let Some(parent_name) = parent {
format!("{}.{}", parent_name, name)
} else {
name.clone()
};
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::interface::{}:{}", path.display(), full_name, line_start);
let bases = extract_delegation_specifiers(node, source);
let methods = extract_method_names(node, source);
Some(Class {
name: full_name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
methods,
bases,
})
}
/// Parse an object declaration (singleton)
fn parse_object_node(node: &Node, source: &[u8], path: &Path, parent: Option<&str>) -> Option<Class> {
let name = extract_class_name(node, source)?;
let full_name = if let Some(parent_name) = parent {
format!("{}.{}", parent_name, name)
} else {
name.clone()
};
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::object::{}:{}", path.display(), full_name, line_start);
let bases = extract_delegation_specifiers(node, source);
let methods = extract_method_names(node, source);
Some(Class {
name: full_name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
methods,
bases,
})
}
/// Extract base types from delegation specifiers
fn extract_delegation_specifiers(node: &Node, source: &[u8]) -> Vec<String> {
let mut bases = Vec::new();
for child in node.children(&mut node.walk()) {
if child.kind() == "delegation_specifiers" {
for spec in child.children(&mut child.walk()) {
if spec.kind() == "delegation_specifier" {
// Get the type from the specifier
for type_child in spec.children(&mut spec.walk()) {
if type_child.kind() == "user_type" || type_child.kind() == "constructor_invocation" {
if let Ok(text) = type_child.utf8_text(source) {
// Clean up constructor invocation
let base = text.split('(').next().unwrap_or(text).to_string();
if !base.is_empty() {
bases.push(base);
}
}
break;
}
}
}
}
}
}
bases
}
/// Extract method names from a class body
fn extract_method_names(type_node: &Node, source: &[u8]) -> Vec<String> {
let mut methods = Vec::new();
fn find_methods(node: &Node, source: &[u8], methods: &mut Vec<String>) {
for child in node.children(&mut node.walk()) {
if child.kind() == "function_declaration" {
for grandchild in child.children(&mut child.walk()) {
if grandchild.kind() == "simple_identifier" {
if let Ok(name) = grandchild.utf8_text(source) {
methods.push(name.to_string());
}
break;
}
}
} else if child.kind() == "class_body" || child.kind() == "enum_class_body" {
find_methods(&child, source, methods);
}
}
}
find_methods(type_node, source, &mut methods);
methods
}
/// Extract methods from a class body as Function entities
fn extract_class_methods(
class_node: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
class_name: &str,
) {
fn find_and_extract_methods(
node: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
class_name: &str,
) {
for child in node.children(&mut node.walk()) {
if child.kind() == "function_declaration" {
if let Some(func) = parse_method_node(&child, source, path, class_name) {
result.functions.push(func);
}
} else if child.kind() == "class_body" || child.kind() == "enum_class_body" {
find_and_extract_methods(&child, source, path, result, class_name);
}
}
}
find_and_extract_methods(class_node, source, path, result, class_name);
}
/// Parse a method declaration into a Function struct
fn parse_method_node(node: &Node, source: &[u8], path: &Path, class_name: &str) -> Option<Function> {
let mut name = String::new();
let mut params_node = None;
for child in node.children(&mut node.walk()) {
if child.kind() == "simple_identifier" && name.is_empty() {
name = child.utf8_text(source).ok()?.to_string();
} else if child.kind() == "function_value_parameters" {
params_node = Some(child);
}
}
if name.is_empty() {
return None;
}
let parameters = extract_parameters(params_node, source);
let return_type = extract_return_type(node, source);
let is_async = is_suspend_function(node, source);
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}.{}:{}", path.display(), class_name, name, line_start);
Some(Function {
name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async,
complexity: Some(calculate_complexity(node, source)),
})
}
/// Extract parameter names from a parameter list
fn extract_parameters(params_node: Option<Node>, source: &[u8]) -> Vec<String> {
let Some(node) = params_node else {
return vec![];
};
let mut params = Vec::new();
for child in node.children(&mut node.walk()) {
if child.kind() == "parameter" || child.kind() == "function_value_parameter" {
for grandchild in child.children(&mut child.walk()) {
if grandchild.kind() == "simple_identifier" {
if let Ok(text) = grandchild.utf8_text(source) {
params.push(text.to_string());
}
break;
}
}
}
}
params
}
/// Extract import statements from the AST
fn extract_imports(root: &Node, source: &[u8], result: &mut ParseResult) -> Result<()> {
let query_str = r#"
(import_header
(identifier) @import_path
)
"#;
let language = tree_sitter_kotlin::LANGUAGE;
let query = Query::new(&language.into(), query_str).context("Failed to create import query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
for capture in m.captures.iter() {
if let Ok(text) = capture.node.utf8_text(source) {
result.imports.push(text.to_string());
}
}
}
Ok(())
}
/// Extract function calls from the AST
fn extract_calls(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
let mut scope_map: HashMap<(u32, u32), String> = HashMap::new();
for func in &result.functions {
scope_map.insert(
(func.line_start, func.line_end),
func.qualified_name.clone(),
);
}
extract_calls_recursive(root, source, path, &scope_map, result);
Ok(())
}
/// Recursively extract calls from the AST
fn extract_calls_recursive(
node: &Node,
source: &[u8],
path: &Path,
scope_map: &HashMap<(u32, u32), String>,
result: &mut ParseResult,
) {
if node.kind() == "call_expression" {
let call_line = node.start_position().row as u32 + 1;
let caller = find_containing_scope(call_line, scope_map);
let callee = extract_call_target(node, source);
if let (Some(caller), Some(callee)) = (caller, callee) {
result.calls.push((caller, callee));
}
}
for child in node.children(&mut node.walk()) {
extract_calls_recursive(&child, source, path, scope_map, result);
}
}
/// Find which scope contains a given line
fn find_containing_scope(line: u32, scope_map: &HashMap<(u32, u32), String>) -> Option<String> {
let mut best_match: Option<(&(u32, u32), &String)> = None;
for (range, name) in scope_map {
if line >= range.0 && line <= range.1 {
match best_match {
None => best_match = Some((range, name)),
Some((best_range, _)) => {
if (range.1 - range.0) < (best_range.1 - best_range.0) {
best_match = Some((range, name));
}
}
}
}
}
best_match.map(|(_, name)| name.clone())
}
/// Extract the target of a function call
fn extract_call_target(node: &Node, source: &[u8]) -> Option<String> {
// The first child is usually the callee
for child in node.children(&mut node.walk()) {
match child.kind() {
"simple_identifier" => {
return child.utf8_text(source).ok().map(|s| s.to_string());
}
"navigation_expression" => {
return child.utf8_text(source).ok().map(|s| s.to_string());
}
"call_suffix" => continue,
_ => {
if let Ok(text) = child.utf8_text(source) {
if !text.is_empty() && !text.starts_with('(') {
return Some(text.to_string());
}
}
}
}
}
None
}
/// Calculate cyclomatic complexity of a function
fn calculate_complexity(node: &Node, _source: &[u8]) -> u32 {
let mut complexity = 1;
fn count_branches(node: &Node, complexity: &mut u32) {
match node.kind() {
"if_expression" | "while_statement" | "for_statement" | "do_while_statement" => {
*complexity += 1;
}
"when_entry" => {
*complexity += 1;
}
"catch_block" => {
*complexity += 1;
}
"conjunction_expression" | "disjunction_expression" => {
*complexity += 1;
}
"elvis_expression" => {
*complexity += 1;
}
"lambda_literal" => {
*complexity += 1;
}
_ => {}
}
for child in node.children(&mut node.walk()) {
count_branches(&child, complexity);
}
}
count_branches(node, &mut complexity);
complexity
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_parse_simple_function() {
let source = r#"
fun hello(name: String): String {
return "Hello, $name!"
}
"#;
let path = PathBuf::from("test.kt");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.functions.len(), 1);
let func = &result.functions[0];
assert_eq!(func.name, "hello");
}
#[test]
fn test_parse_suspend_function() {
let source = r#"
suspend fun fetchData(url: String): String {
return ""
}
"#;
let path = PathBuf::from("test.kt");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.functions.len(), 1);
let func = &result.functions[0];
assert!(func.is_async);
}
#[test]
fn test_parse_class() {
let source = r#"
class Person(val name: String, val age: Int) {
fun greet() {
println("Hello, $name")
}
}
"#;
let path = PathBuf::from("test.kt");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.classes.len(), 1);
let class = &result.classes[0];
assert_eq!(class.name, "Person");
}
#[test]
fn test_parse_data_class() {
let source = r#"
data class User(val id: Int, val name: String)
"#;
let path = PathBuf::from("test.kt");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.classes.len(), 1);
assert!(result.classes[0].qualified_name.contains("data"));
}
#[test]
fn test_parse_object() {
let source = r#"
object Singleton {
fun getInstance(): Singleton = this
}
"#;
let path = PathBuf::from("test.kt");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.classes.len(), 1);
assert!(result.classes[0].qualified_name.contains("object"));
}
#[test]
fn test_parse_imports() {
let source = r#"
import kotlin.collections.List
import kotlinx.coroutines.launch
fun main() {}
"#;
let path = PathBuf::from("test.kt");
let result = parse_source(source, &path).unwrap();
assert!(result.imports.len() >= 2);
}
}