use crate::error::{Result, SpliceError};
use crate::ingest::imports::ImportKind;
use std::path::Path;
pub fn extract_java_imports(path: &Path, source: &[u8]) -> Result<Vec<super::ImportFact>> {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_java::language())
.map_err(|e| SpliceError::Parse {
file: path.to_path_buf(),
message: format!("Failed to set Java language: {:?}", e),
})?;
let tree = parser
.parse(source, None)
.ok_or_else(|| SpliceError::Parse {
file: path.to_path_buf(),
message: "Parse failed - no tree returned".to_string(),
})?;
let mut imports = Vec::new();
extract_import_statements(tree.root_node(), source, &mut imports);
Ok(imports)
}
fn extract_import_statements(
node: tree_sitter::Node,
source: &[u8],
imports: &mut Vec<super::ImportFact>,
) {
if node.kind() == "import_declaration" {
if let Some(import) = extract_import_declaration(node, source) {
imports.push(import);
}
return; }
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
extract_import_statements(child, source, imports);
}
}
fn extract_import_declaration(node: tree_sitter::Node, source: &[u8]) -> Option<super::ImportFact> {
let byte_start = node.start_byte();
let byte_end = node.end_byte();
let mut is_static = false;
let mut path = Vec::new();
let mut is_glob = false;
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "static" {
is_static = true;
} else if child.kind() == "scoped_identifier" || child.kind() == "identifier" {
extract_path_segments(child, source, &mut path);
} else if child.kind() == "asterisk" {
is_glob = true;
}
}
if path.is_empty() {
return None;
}
let import_kind = if is_static {
ImportKind::JavaStaticImport
} else {
ImportKind::JavaImport
};
Some(super::ImportFact {
file_path: std::path::PathBuf::new(),
import_kind,
path,
imported_names: Vec::new(), is_glob,
is_reexport: false,
byte_span: (byte_start, byte_end),
})
}
fn extract_path_segments(node: tree_sitter::Node, source: &[u8], path: &mut Vec<String>) {
let kind = node.kind();
if kind == "identifier" {
if let Ok(name) = node.utf8_text(source) {
path.push(name.to_string());
}
} else if kind == "scoped_identifier" {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" || child.kind() == "scoped_identifier" {
extract_path_segments(child, source, path);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_simple_import() -> std::result::Result<(), Box<dyn std::error::Error>> {
let source = b"import java.util.List;\n";
let path = Path::new("test.java");
let result = extract_java_imports(path, source);
assert!(result.is_ok());
let imports = result?;
assert_eq!(imports.len(), 1);
assert_eq!(imports[0].import_kind, ImportKind::JavaImport);
assert_eq!(imports[0].path, vec!["java", "util", "List"]);
Ok(())
}
#[test]
fn test_extract_static_import() -> std::result::Result<(), Box<dyn std::error::Error>> {
let source = b"import static java.lang.Math.PI;\n";
let path = Path::new("test.java");
let result = extract_java_imports(path, source);
assert!(result.is_ok());
let imports = result?;
assert_eq!(imports.len(), 1);
assert_eq!(imports[0].import_kind, ImportKind::JavaStaticImport);
assert_eq!(imports[0].path, vec!["java", "lang", "Math", "PI"]);
Ok(())
}
#[test]
fn test_extract_wildcard_import() -> std::result::Result<(), Box<dyn std::error::Error>> {
let source = b"import java.util.*;\n";
let path = Path::new("test.java");
let result = extract_java_imports(path, source);
assert!(result.is_ok());
let imports = result?;
assert_eq!(imports.len(), 1);
assert_eq!(imports[0].import_kind, ImportKind::JavaImport);
assert!(imports[0].is_glob);
Ok(())
}
#[test]
fn test_extract_static_wildcard_import() -> std::result::Result<(), Box<dyn std::error::Error>>
{
let source = b"import static java.lang.Math.*;\n";
let path = Path::new("test.java");
let result = extract_java_imports(path, source);
assert!(result.is_ok());
let imports = result?;
assert_eq!(imports.len(), 1);
assert_eq!(imports[0].import_kind, ImportKind::JavaStaticImport);
assert!(imports[0].is_glob);
Ok(())
}
#[test]
fn test_extract_multiple_imports() -> std::result::Result<(), Box<dyn std::error::Error>> {
let source = b"import java.util.List;\nimport java.util.ArrayList;\n";
let path = Path::new("test.java");
let result = extract_java_imports(path, source);
assert!(result.is_ok());
let imports = result?;
assert_eq!(imports.len(), 2);
assert_eq!(imports[0].path, vec!["java", "util", "List"]);
assert_eq!(imports[1].path, vec!["java", "util", "ArrayList"]);
Ok(())
}
#[test]
fn test_import_has_byte_span() -> std::result::Result<(), Box<dyn std::error::Error>> {
let source = b"import java.util.List;\n";
let path = Path::new("test.java");
let result = extract_java_imports(path, source);
assert!(result.is_ok());
let imports = result?;
assert_eq!(imports.len(), 1);
assert_eq!(imports[0].byte_span, (0, 22));
Ok(())
}
}