use crate::error::{Result, SpliceError};
use ropey::Rope;
use std::path::Path;
#[derive(Debug, Clone, PartialEq)]
pub struct JavaSymbol {
pub name: String,
pub kind: JavaSymbolKind,
pub byte_start: usize,
pub byte_end: usize,
pub line_start: usize,
pub line_end: usize,
pub col_start: usize,
pub col_end: usize,
pub parameters: Vec<String>,
pub container_path: String,
pub fully_qualified: String,
pub is_public: bool,
pub is_static: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JavaSymbolKind {
Class,
Interface,
Enum,
Method,
Constructor,
Field,
}
impl JavaSymbolKind {
pub fn as_str(&self) -> &'static str {
match self {
JavaSymbolKind::Class => "class",
JavaSymbolKind::Interface => "interface",
JavaSymbolKind::Enum => "enum",
JavaSymbolKind::Method => "method",
JavaSymbolKind::Constructor => "constructor",
JavaSymbolKind::Field => "field",
}
}
}
pub fn extract_java_symbols(path: &Path, source: &[u8]) -> Result<Vec<JavaSymbol>> {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_java::language())
.map_err(|e| SpliceError::Parse {
file: path.to_path_buf(),
message: format!("Failed to set Java language: {:?}", e),
})?;
let tree = parser
.parse(source, None)
.ok_or_else(|| SpliceError::Parse {
file: path.to_path_buf(),
message: "Parse failed - no tree returned".to_string(),
})?;
let rope = Rope::from_str(std::str::from_utf8(source)?);
let mut symbols = Vec::new();
extract_symbols(tree.root_node(), source, &rope, &mut symbols, "");
Ok(symbols)
}
fn extract_symbols(
node: tree_sitter::Node,
source: &[u8],
rope: &Rope,
symbols: &mut Vec<JavaSymbol>,
container_path: &str,
) {
let kind = node.kind();
let is_public = has_modifier(node, "public");
let is_static = has_modifier(node, "static");
let symbol_kind = match kind {
"class_declaration" => Some(JavaSymbolKind::Class),
"interface_declaration" => Some(JavaSymbolKind::Interface),
"enum_declaration" => Some(JavaSymbolKind::Enum),
"method_declaration" => Some(JavaSymbolKind::Method),
"constructor_declaration" => Some(JavaSymbolKind::Constructor),
"field_declaration" => Some(JavaSymbolKind::Field),
_ => None,
};
if let Some(kind) = symbol_kind {
if let Some(symbol) = extract_symbol(
node,
source,
rope,
kind,
container_path,
is_public,
is_static,
) {
let name = symbol.name.clone();
symbols.push(symbol);
if matches!(
kind,
JavaSymbolKind::Class | JavaSymbolKind::Interface | JavaSymbolKind::Enum
) {
let new_container = if container_path.is_empty() {
name.clone()
} else {
format!("{}.{}", container_path, name)
};
if let Some(body) = node.child_by_field_name("body") {
extract_symbols(body, source, rope, symbols, &new_container);
}
return;
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if matches!(
kind,
"class_declaration" | "interface_declaration" | "enum_declaration"
) && matches!(child.kind(), "class_body" | "interface_body" | "enum_body")
{
continue;
}
if kind == "field_declaration" && child.kind() == "variable_declarator" {
continue;
}
extract_symbols(child, source, rope, symbols, container_path);
}
}
fn has_modifier(node: tree_sitter::Node, modifier: &str) -> bool {
for child in node.children(&mut node.walk()) {
if child.kind() == "modifiers" {
for modifier_node in child.children(&mut child.walk()) {
if modifier_node.kind() == modifier {
return true;
}
}
}
}
false
}
fn extract_symbol(
node: tree_sitter::Node,
source: &[u8],
rope: &Rope,
kind: JavaSymbolKind,
container_path: &str,
is_public: bool,
is_static: bool,
) -> Option<JavaSymbol> {
let name = extract_name(node, source)?;
let byte_start = node.start_byte();
let byte_end = node.end_byte();
let start_char = rope.byte_to_char(byte_start);
let end_char = rope.byte_to_char(byte_end);
let line_start = rope.char_to_line(start_char);
let line_end = rope.char_to_line(end_char);
let line_start_byte = rope.line_to_byte(line_start);
let line_end_byte = rope.line_to_byte(line_end);
let col_start = byte_start - line_start_byte;
let col_end = byte_end - line_end_byte;
let parameters = extract_parameters(node, source);
let fully_qualified = if container_path.is_empty() {
name.clone()
} else {
format!("{}.{}", container_path, name)
};
Some(JavaSymbol {
name,
kind,
byte_start,
byte_end,
line_start: line_start + 1,
line_end: line_end + 1,
col_start,
col_end,
parameters,
container_path: container_path.to_string(),
fully_qualified,
is_public,
is_static,
})
}
fn extract_name(node: tree_sitter::Node, source: &[u8]) -> Option<String> {
let kind = node.kind();
match kind {
"class_declaration" | "interface_declaration" | "enum_declaration" => node
.child_by_field_name("name")
.and_then(|n| n.utf8_text(source).ok().map(|s| s.to_string())),
"method_declaration" | "constructor_declaration" => node
.child_by_field_name("name")
.and_then(|n| n.utf8_text(source).ok().map(|s| s.to_string())),
"field_declaration" => {
for child in node.children(&mut node.walk()) {
if child.kind() == "variable_declarator" {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source) {
return Some(name.to_string());
}
}
}
}
None
}
_ => None,
}
}
fn extract_parameters(node: tree_sitter::Node, source: &[u8]) -> Vec<String> {
let mut parameters = Vec::new();
if let Some(params) = node.child_by_field_name("parameters") {
for param in params.children(&mut params.walk()) {
if param.kind() == "formal_parameter" {
if let Some(name_node) = param.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source) {
parameters.push(name.to_string());
}
}
}
}
}
parameters
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_simple_class() {
let source = b"class MyClass {}\n";
let path = Path::new("test.java");
let result = extract_java_symbols(path, source);
assert!(result.is_ok());
let symbols = result.unwrap();
assert_eq!(symbols.len(), 1);
assert_eq!(symbols[0].name, "MyClass");
assert_eq!(symbols[0].kind.as_str(), "class");
}
#[test]
fn test_extract_class_with_method() {
let source = b"class MyClass { void method() {} }\n";
let path = Path::new("test.java");
let result = extract_java_symbols(path, source);
assert!(result.is_ok());
let symbols = result.unwrap();
assert_eq!(symbols.len(), 2);
assert_eq!(symbols[0].name, "MyClass");
assert_eq!(symbols[0].kind.as_str(), "class");
assert_eq!(symbols[1].name, "method");
assert_eq!(symbols[1].kind.as_str(), "method");
}
#[test]
fn test_extract_class_with_field() {
let source = b"class MyClass { private int field; }\n";
let path = Path::new("test.java");
let result = extract_java_symbols(path, source);
assert!(result.is_ok());
let symbols = result.unwrap();
assert_eq!(symbols.len(), 2);
assert_eq!(symbols[0].name, "MyClass");
assert_eq!(symbols[1].name, "field");
assert_eq!(symbols[1].kind.as_str(), "field");
}
#[test]
fn test_extract_interface() {
let source = b"interface MyInterface { void method(); }\n";
let path = Path::new("test.java");
let result = extract_java_symbols(path, source);
assert!(result.is_ok());
let symbols = result.unwrap();
assert_eq!(symbols.len(), 2);
assert_eq!(symbols[0].name, "MyInterface");
assert_eq!(symbols[0].kind.as_str(), "interface");
assert_eq!(symbols[1].name, "method");
assert_eq!(symbols[1].kind.as_str(), "method");
}
#[test]
fn test_extract_enum() {
let source = b"enum Color { RED, GREEN, BLUE }\n";
let path = Path::new("test.java");
let result = extract_java_symbols(path, source);
assert!(result.is_ok());
let symbols = result.unwrap();
assert_eq!(symbols.len(), 1);
assert_eq!(symbols[0].name, "Color");
assert_eq!(symbols[0].kind.as_str(), "enum");
}
#[test]
fn test_extract_class_with_constructor() {
let source = b"class Foo { Foo() {} }\n";
let path = Path::new("test.java");
let result = extract_java_symbols(path, source);
assert!(result.is_ok());
let symbols = result.unwrap();
assert_eq!(symbols.len(), 2);
assert_eq!(symbols[0].name, "Foo");
assert_eq!(symbols[0].kind.as_str(), "class");
assert_eq!(symbols[1].name, "Foo");
assert_eq!(symbols[1].kind.as_str(), "constructor");
}
#[test]
fn test_extract_method_with_parameters() {
let source = b"class MyClass { void add(int a, int b) {} }\n";
let path = Path::new("test.java");
let result = extract_java_symbols(path, source);
assert!(result.is_ok());
let symbols = result.unwrap();
assert_eq!(symbols.len(), 2);
assert_eq!(symbols[1].parameters, vec!["a", "b"]);
}
#[test]
fn test_extract_public_class() {
let source = b"public class MyClass {}\n";
let path = Path::new("test.java");
let result = extract_java_symbols(path, source);
assert!(result.is_ok());
let symbols = result.unwrap();
assert_eq!(symbols.len(), 1);
assert_eq!(symbols[0].name, "MyClass");
assert!(symbols[0].is_public);
}
#[test]
fn test_extract_static_method() {
let source = b"class MyClass { static void method() {} }\n";
let path = Path::new("test.java");
let result = extract_java_symbols(path, source);
assert!(result.is_ok());
let symbols = result.unwrap();
assert_eq!(symbols.len(), 2);
assert_eq!(symbols[1].name, "method");
assert!(symbols[1].is_static);
}
}