use crate::models::{Class, Function};
use crate::parsers::{ImportInfo, ParseResult};
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::path::Path;
use tree_sitter::{Node, Parser, Query, QueryCursor};
pub fn parse(path: &Path) -> Result<ParseResult> {
let source = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read file: {}", path.display()))?;
parse_source(&source, path)
}
pub fn parse_source(source: &str, path: &Path) -> Result<ParseResult> {
let mut parser = Parser::new();
let language = tree_sitter_java::LANGUAGE;
parser
.set_language(&language.into())
.context("Failed to set Java language")?;
let tree = parser
.parse(source, None)
.context("Failed to parse Java source")?;
let root = tree.root_node();
let source_bytes = source.as_bytes();
let mut result = ParseResult::default();
extract_classes_and_interfaces(&root, source_bytes, path, &mut result)?;
extract_imports(&root, source_bytes, &mut result)?;
extract_calls(&root, source_bytes, path, &mut result)?;
Ok(result)
}
fn extract_classes_and_interfaces(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
extract_classes_recursive(root, source, path, result, None);
Ok(())
}
fn extract_classes_recursive(
node: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
parent_class: Option<&str>,
) {
for child in node.children(&mut node.walk()) {
match child.kind() {
"class_declaration" => {
if let Some(class) = parse_class_node(&child, source, path, parent_class) {
let class_name = class.name.clone();
extract_class_methods(&child, source, path, result, &class_name);
result.classes.push(class);
if let Some(body) = child.child_by_field_name("body") {
extract_classes_recursive(&body, source, path, result, Some(&class_name));
}
}
}
"interface_declaration" => {
if let Some(iface) = parse_interface_node(&child, source, path, parent_class) {
let iface_name = iface.name.clone();
extract_interface_methods(&child, source, path, result, &iface_name);
result.classes.push(iface);
}
}
"enum_declaration" => {
if let Some(enum_class) = parse_enum_node(&child, source, path, parent_class) {
let enum_name = enum_class.name.clone();
extract_class_methods(&child, source, path, result, &enum_name);
result.classes.push(enum_class);
}
}
"record_declaration" => {
if let Some(record) = parse_record_node(&child, source, path, parent_class) {
let record_name = record.name.clone();
extract_class_methods(&child, source, path, result, &record_name);
result.classes.push(record);
}
}
_ => {
extract_classes_recursive(&child, source, path, result, parent_class);
}
}
}
}
fn parse_class_node(node: &Node, source: &[u8], path: &Path, parent: Option<&str>) -> Option<Class> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source).ok()?.to_string();
let full_name = if let Some(parent_name) = parent {
format!("{}.{}", parent_name, name)
} else {
name.clone()
};
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}:{}", path.display(), full_name, line_start);
let mut bases = Vec::new();
if let Some(superclass) = node.child_by_field_name("superclass") {
if let Ok(text) = superclass.utf8_text(source) {
let base = text.trim_start_matches("extends ").trim().to_string();
if !base.is_empty() {
bases.push(base);
}
}
}
if let Some(interfaces) = node.child_by_field_name("interfaces") {
for child in interfaces.children(&mut interfaces.walk()) {
if child.kind() == "type_identifier" || child.kind() == "generic_type" {
if let Ok(text) = child.utf8_text(source) {
bases.push(text.to_string());
}
}
}
}
let methods = extract_method_names(node, source);
Some(Class {
name: full_name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
methods,
bases,
})
}
fn parse_interface_node(node: &Node, source: &[u8], path: &Path, parent: Option<&str>) -> Option<Class> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source).ok()?.to_string();
let full_name = if let Some(parent_name) = parent {
format!("{}.{}", parent_name, name)
} else {
name.clone()
};
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::interface::{}:{}", path.display(), full_name, line_start);
let mut bases = Vec::new();
for child in node.children(&mut node.walk()) {
if child.kind() == "extends_interfaces" {
for grandchild in child.children(&mut child.walk()) {
if grandchild.kind() == "type_identifier" || grandchild.kind() == "generic_type" {
if let Ok(text) = grandchild.utf8_text(source) {
bases.push(text.to_string());
}
}
}
}
}
let methods = extract_method_names(node, source);
Some(Class {
name: full_name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
methods,
bases,
})
}
fn parse_enum_node(node: &Node, source: &[u8], path: &Path, parent: Option<&str>) -> Option<Class> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source).ok()?.to_string();
let full_name = if let Some(parent_name) = parent {
format!("{}.{}", parent_name, name)
} else {
name.clone()
};
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::enum::{}:{}", path.display(), full_name, line_start);
let methods = extract_method_names(node, source);
Some(Class {
name: full_name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
methods,
bases: vec![],
})
}
fn parse_record_node(node: &Node, source: &[u8], path: &Path, parent: Option<&str>) -> Option<Class> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source).ok()?.to_string();
let full_name = if let Some(parent_name) = parent {
format!("{}.{}", parent_name, name)
} else {
name.clone()
};
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::record::{}:{}", path.display(), full_name, line_start);
let methods = extract_method_names(node, source);
Some(Class {
name: full_name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
methods,
bases: vec![],
})
}
fn extract_method_names(class_node: &Node, source: &[u8]) -> Vec<String> {
let mut methods = Vec::new();
let body = class_node.child_by_field_name("body");
let body_node = body.as_ref().unwrap_or(class_node);
for child in body_node.children(&mut body_node.walk()) {
if child.kind() == "method_declaration" {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source) {
methods.push(name.to_string());
}
}
} else if child.kind() == "constructor_declaration" {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source) {
methods.push(format!("<init>:{}", name));
}
}
}
}
methods
}
fn extract_class_methods(
class_node: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
class_name: &str,
) {
let body = class_node.child_by_field_name("body");
let body_node = body.as_ref().unwrap_or(class_node);
for child in body_node.children(&mut body_node.walk()) {
if child.kind() == "method_declaration" {
if let Some(func) = parse_method_node(&child, source, path, class_name) {
result.functions.push(func);
}
} else if child.kind() == "constructor_declaration" {
if let Some(func) = parse_constructor_node(&child, source, path, class_name) {
result.functions.push(func);
}
}
}
}
fn extract_interface_methods(
iface_node: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
iface_name: &str,
) {
let body = iface_node.child_by_field_name("body");
let body_node = body.as_ref().unwrap_or(iface_node);
for child in body_node.children(&mut body_node.walk()) {
if child.kind() == "method_declaration" {
if let Some(func) = parse_method_node(&child, source, path, iface_name) {
result.functions.push(func);
}
}
}
}
fn parse_method_node(node: &Node, source: &[u8], path: &Path, class_name: &str) -> Option<Function> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source).ok()?.to_string();
let params_node = node.child_by_field_name("parameters");
let parameters = extract_parameters(params_node, source);
let return_type = node
.child_by_field_name("type")
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string());
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}.{}:{}", path.display(), class_name, name, line_start);
Some(Function {
name,
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type,
is_async: false,
complexity: Some(calculate_complexity(node, source)),
})
}
fn parse_constructor_node(node: &Node, source: &[u8], path: &Path, class_name: &str) -> Option<Function> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source).ok()?.to_string();
let params_node = node.child_by_field_name("parameters");
let parameters = extract_parameters(params_node, source);
let line_start = node.start_position().row as u32 + 1;
let line_end = node.end_position().row as u32 + 1;
let qualified_name = format!("{}::{}.<init>:{}", path.display(), class_name, line_start);
Some(Function {
name: format!("<init>:{}", name),
qualified_name,
file_path: path.to_path_buf(),
line_start,
line_end,
parameters,
return_type: Some(class_name.to_string()),
is_async: false,
complexity: Some(calculate_complexity(node, source)),
})
}
fn extract_parameters(params_node: Option<Node>, source: &[u8]) -> Vec<String> {
let Some(node) = params_node else {
return vec![];
};
let mut params = Vec::new();
for child in node.children(&mut node.walk()) {
if child.kind() == "formal_parameter" || child.kind() == "spread_parameter" {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(text) = name_node.utf8_text(source) {
params.push(text.to_string());
}
}
}
}
params
}
fn extract_imports(root: &Node, source: &[u8], result: &mut ParseResult) -> Result<()> {
let query_str = r#"
(import_declaration
(scoped_identifier) @import_path
)
(import_declaration
(identifier) @import_path
)
"#;
let language = tree_sitter_java::LANGUAGE;
let query = Query::new(&language.into(), query_str).context("Failed to create import query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source);
while let Some(m) = matches.next() {
for capture in m.captures.iter() {
if let Ok(text) = capture.node.utf8_text(source) {
result.imports.push(ImportInfo::runtime(text.to_string()));
}
}
}
Ok(())
}
fn extract_calls(
root: &Node,
source: &[u8],
path: &Path,
result: &mut ParseResult,
) -> Result<()> {
let mut scope_map: HashMap<(u32, u32), String> = HashMap::new();
for func in &result.functions {
scope_map.insert(
(func.line_start, func.line_end),
func.qualified_name.clone(),
);
}
extract_calls_recursive(root, source, path, &scope_map, result);
Ok(())
}
fn extract_calls_recursive(
node: &Node,
source: &[u8],
path: &Path,
scope_map: &HashMap<(u32, u32), String>,
result: &mut ParseResult,
) {
if node.kind() == "method_invocation" {
let call_line = node.start_position().row as u32 + 1;
let caller = find_containing_scope(call_line, scope_map)
.unwrap_or_else(|| path.display().to_string());
if let Some(name_node) = node.child_by_field_name("name") {
if let Ok(callee) = name_node.utf8_text(source) {
let full_callee = if let Some(obj_node) = node.child_by_field_name("object") {
if let Ok(obj) = obj_node.utf8_text(source) {
format!("{}.{}", obj, callee)
} else {
callee.to_string()
}
} else {
callee.to_string()
};
result.calls.push((caller, full_callee));
}
}
}
if node.kind() == "object_creation_expression" {
let call_line = node.start_position().row as u32 + 1;
let caller = find_containing_scope(call_line, scope_map)
.unwrap_or_else(|| path.display().to_string());
if let Some(type_node) = node.child_by_field_name("type") {
if let Ok(callee) = type_node.utf8_text(source) {
result.calls.push((caller, format!("new {}", callee)));
}
}
}
for child in node.children(&mut node.walk()) {
extract_calls_recursive(&child, source, path, scope_map, result);
}
}
fn find_containing_scope(line: u32, scope_map: &HashMap<(u32, u32), String>) -> Option<String> {
let mut best_match: Option<(&(u32, u32), &String)> = None;
for (range, name) in scope_map {
if line >= range.0 && line <= range.1 {
match best_match {
None => best_match = Some((range, name)),
Some((best_range, _)) => {
if (range.1 - range.0) < (best_range.1 - best_range.0) {
best_match = Some((range, name));
}
}
}
}
}
best_match.map(|(_, name)| name.clone())
}
fn calculate_complexity(node: &Node, _source: &[u8]) -> u32 {
let mut complexity = 1;
fn count_branches(node: &Node, complexity: &mut u32) {
match node.kind() {
"if_statement" | "while_statement" | "for_statement" | "enhanced_for_statement" | "do_statement" => {
*complexity += 1;
}
"catch_clause" => {
*complexity += 1;
}
"switch_expression_arm" | "switch_block_statement_group" => {
*complexity += 1;
}
"ternary_expression" => {
*complexity += 1;
}
"binary_expression" => {
for child in node.children(&mut node.walk()) {
if child.kind() == "&&" || child.kind() == "||" {
*complexity += 1;
}
}
}
"lambda_expression" => {
*complexity += 1;
}
_ => {}
}
for child in node.children(&mut node.walk()) {
count_branches(&child, complexity);
}
}
count_branches(node, &mut complexity);
complexity
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_parse_simple_class() {
let source = r#"
public class HelloWorld {
public static void main(String[] args) {
System.out.println("Hello, World!");
}
}
"#;
let path = PathBuf::from("HelloWorld.java");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.classes.len(), 1);
let class = &result.classes[0];
assert_eq!(class.name, "HelloWorld");
assert!(class.methods.contains(&"main".to_string()));
}
#[test]
fn test_parse_class_with_inheritance() {
let source = r#"
public class Child extends Parent implements Runnable, Serializable {
@Override
public void run() {}
}
"#;
let path = PathBuf::from("Child.java");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.classes.len(), 1);
let class = &result.classes[0];
assert_eq!(class.name, "Child");
assert!(class.bases.iter().any(|b| b.contains("Parent")));
}
#[test]
fn test_parse_interface() {
let source = r#"
public interface MyInterface {
void doSomething();
default void doDefault() {}
}
"#;
let path = PathBuf::from("MyInterface.java");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.classes.len(), 1);
let iface = &result.classes[0];
assert_eq!(iface.name, "MyInterface");
}
#[test]
fn test_parse_imports() {
let source = r#"
import java.util.List;
import java.util.Map;
import static java.lang.Math.PI;
public class Test {}
"#;
let path = PathBuf::from("Test.java");
let result = parse_source(source, &path).unwrap();
assert!(result.imports.iter().any(|i| i.path.contains("java.util.List")));
assert!(result.imports.iter().any(|i| i.path.contains("java.util.Map")));
}
#[test]
fn test_parse_methods() {
let source = r#"
public class Calculator {
public int add(int a, int b) {
return a + b;
}
public int subtract(int a, int b) {
return a - b;
}
}
"#;
let path = PathBuf::from("Calculator.java");
let result = parse_source(source, &path).unwrap();
assert_eq!(result.functions.len(), 2);
assert!(result.functions.iter().any(|f| f.name == "add"));
assert!(result.functions.iter().any(|f| f.name == "subtract"));
}
#[test]
fn test_method_count_excludes_nested_lambdas() {
let source = r#"
public class StreamProcessor {
private List<String> items;
public StreamProcessor() {
this.items = new ArrayList<>();
}
public List<String> process() {
// These lambdas should NOT be counted as methods
return items.stream()
.filter(item -> item != null)
.map(item -> item.toUpperCase())
.collect(Collectors.toList());
}
public void registerCallback(Consumer<String> callback) {
// Lambda passed to method - not a class method
items.forEach(item -> callback.accept(item));
}
}
"#;
let path = PathBuf::from("StreamProcessor.java");
let result = parse_source(source, &path).unwrap();
let class = &result.classes[0];
assert_eq!(class.name, "StreamProcessor");
assert_eq!(
class.methods.len(),
3,
"Expected 3 methods, got {:?}",
class.methods
);
}
#[test]
fn test_method_count_excludes_anonymous_classes() {
let source = r#"
public class EventHandler {
public void setup() {
// Anonymous class - its methods should NOT count as EventHandler methods
button.addListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
handleClick();
}
});
}
private void handleClick() {
System.out.println("clicked");
}
}
"#;
let path = PathBuf::from("EventHandler.java");
let result = parse_source(source, &path).unwrap();
let main_class = result.classes.iter()
.find(|c| c.name == "EventHandler")
.expect("Should find EventHandler class");
assert_eq!(
main_class.methods.len(),
2,
"Expected 2 methods (setup, handleClick), got {:?}",
main_class.methods
);
}
}