use crate::error::{Result, SpliceError};
use crate::expand::tree_walker::find_parent_symbol_node;
use crate::symbol::{parser_for_language, Language};
use std::path::Path;
pub mod tree_walker;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExpansionLevel {
None = 0,
Body = 1,
ContainingBlock = 2,
}
impl ExpansionLevel {
pub fn as_u8(self) -> u8 {
self as u8
}
pub fn from_u8(value: u8) -> Option<Self> {
match value {
0 => Some(ExpansionLevel::None),
1 => Some(ExpansionLevel::Body),
2 => Some(ExpansionLevel::ContainingBlock),
_ => None,
}
}
}
pub trait SymbolExpander {
fn expand_to_body(&self, node: tree_sitter::Node, source: &[u8]) -> Option<(usize, usize)>;
fn is_symbol_kind(&self, node_kind: &str) -> bool;
fn is_block_kind(&self, node_kind: &str) -> bool;
}
#[derive(Debug, Clone, Copy)]
pub struct RustExpander;
impl SymbolExpander for RustExpander {
fn expand_to_body(&self, node: tree_sitter::Node, source: &[u8]) -> Option<(usize, usize)> {
find_parent_symbol_node(node, source, |kind| self.is_symbol_kind(kind))
.map(|n| (n.start_byte() as usize, n.end_byte() as usize))
}
fn is_symbol_kind(&self, node_kind: &str) -> bool {
matches!(
node_kind,
"function_item"
| "struct_item"
| "enum_item"
| "trait_item"
| "impl_item"
| "mod_item"
| "const_item"
| "static_item"
| "type_item"
)
}
fn is_block_kind(&self, node_kind: &str) -> bool {
matches!(node_kind, "impl_item" | "mod_item" | "source_file")
}
}
#[derive(Debug, Clone, Copy)]
pub struct PythonExpander;
impl SymbolExpander for PythonExpander {
fn expand_to_body(&self, node: tree_sitter::Node, source: &[u8]) -> Option<(usize, usize)> {
find_parent_symbol_node(node, source, |kind| self.is_symbol_kind(kind))
.map(|n| (n.start_byte() as usize, n.end_byte() as usize))
}
fn is_symbol_kind(&self, node_kind: &str) -> bool {
matches!(node_kind, "function_definition" | "class_definition")
}
fn is_block_kind(&self, node_kind: &str) -> bool {
matches!(node_kind, "module" | "source_file")
}
}
#[derive(Debug, Clone, Copy)]
pub struct CppExpander;
impl SymbolExpander for CppExpander {
fn expand_to_body(&self, node: tree_sitter::Node, source: &[u8]) -> Option<(usize, usize)> {
find_parent_symbol_node(node, source, |kind| self.is_symbol_kind(kind))
.map(|n| (n.start_byte() as usize, n.end_byte() as usize))
}
fn is_symbol_kind(&self, node_kind: &str) -> bool {
matches!(
node_kind,
"function_definition"
| "class_specifier"
| "struct_specifier"
| "enum_specifier"
| "union_specifier"
| "namespace_definition"
)
}
fn is_block_kind(&self, node_kind: &str) -> bool {
matches!(node_kind, "namespace_definition" | "translation_unit")
}
}
#[derive(Debug, Clone, Copy)]
pub struct JavaExpander;
impl SymbolExpander for JavaExpander {
fn expand_to_body(&self, node: tree_sitter::Node, source: &[u8]) -> Option<(usize, usize)> {
find_parent_symbol_node(node, source, |kind| self.is_symbol_kind(kind))
.map(|n| (n.start_byte() as usize, n.end_byte() as usize))
}
fn is_symbol_kind(&self, node_kind: &str) -> bool {
matches!(
node_kind,
"class_declaration"
| "interface_declaration"
| "method_declaration"
| "constructor_declaration"
| "field_declaration"
| "enum_declaration"
)
}
fn is_block_kind(&self, node_kind: &str) -> bool {
matches!(node_kind, "class_declaration" | "interface_declaration")
}
}
#[derive(Debug, Clone, Copy)]
pub struct JavaScriptExpander;
impl SymbolExpander for JavaScriptExpander {
fn expand_to_body(&self, node: tree_sitter::Node, source: &[u8]) -> Option<(usize, usize)> {
find_parent_symbol_node(node, source, |kind| self.is_symbol_kind(kind))
.map(|n| (n.start_byte() as usize, n.end_byte() as usize))
}
fn is_symbol_kind(&self, node_kind: &str) -> bool {
matches!(
node_kind,
"function_declaration"
| "class_declaration"
| "method_definition"
| "generator_function_declaration"
| "arrow_function"
)
}
fn is_block_kind(&self, node_kind: &str) -> bool {
matches!(node_kind, "class_declaration" | "program")
}
}
#[derive(Debug, Clone, Copy)]
pub struct TypeScriptExpander;
impl SymbolExpander for TypeScriptExpander {
fn expand_to_body(&self, node: tree_sitter::Node, source: &[u8]) -> Option<(usize, usize)> {
find_parent_symbol_node(node, source, |kind| self.is_symbol_kind(kind))
.map(|n| (n.start_byte() as usize, n.end_byte() as usize))
}
fn is_symbol_kind(&self, node_kind: &str) -> bool {
matches!(
node_kind,
"function_declaration"
| "class_declaration"
| "interface_declaration"
| "type_alias_declaration"
| "method_definition"
| "generator_function_declaration"
| "arrow_function"
| "enum_declaration"
)
}
fn is_block_kind(&self, node_kind: &str) -> bool {
matches!(
node_kind,
"class_declaration" | "interface_declaration" | "module"
)
}
}
fn get_expander(language: Language) -> Box<dyn SymbolExpander> {
match language {
Language::Rust => Box::new(RustExpander),
Language::Python => Box::new(PythonExpander),
Language::C | Language::Cpp => Box::new(CppExpander),
Language::Java => Box::new(JavaExpander),
Language::JavaScript => Box::new(JavaScriptExpander),
Language::TypeScript => Box::new(TypeScriptExpander),
}
}
pub fn expand_symbol(
path: &Path,
byte_offset: usize,
language: Language,
level: ExpansionLevel,
) -> Result<(usize, usize)> {
expand_symbol_impl(path, byte_offset, language, level)
}
pub fn expand_symbol_with_level(
path: &Path,
byte_offset: usize,
language: Language,
level: usize,
) -> Result<(usize, usize)> {
let expansion_level = match level {
0 => ExpansionLevel::None,
1 => ExpansionLevel::Body,
2 => ExpansionLevel::ContainingBlock,
_ => ExpansionLevel::Body, };
expand_symbol_impl(path, byte_offset, language, expansion_level)
}
pub fn expand_to_body_with_docs(
path: &Path,
byte_offset: usize,
language: Language,
) -> Result<(usize, usize)> {
let source = std::fs::read(path).map_err(|e| SpliceError::Io {
path: path.to_path_buf(),
source: e,
})?;
let mut parser = parser_for_language(language)?;
let tree = parser
.parse(&source, None)
.ok_or_else(|| SpliceError::Parse {
file: path.to_path_buf(),
message: "Parse failed - no tree returned".to_string(),
})?;
let root_node = tree.root_node();
let node = root_node
.descendant_for_byte_range(byte_offset, byte_offset)
.ok_or_else(|| SpliceError::InvalidSpan {
file: path.to_path_buf(),
start: byte_offset,
end: byte_offset,
file_size: source.len(),
})?;
let expander = get_expander(language);
let (body_start, body_end) = expander.expand_to_body(node, &source).ok_or_else(|| {
SpliceError::Other(format!(
"Could not expand symbol at offset {} in {}",
byte_offset,
path.display()
))
})?;
let body_node = root_node
.descendant_for_byte_range(body_start, body_end)
.ok_or_else(|| {
SpliceError::Other(format!(
"Could not find expanded body node in {}",
path.display()
))
})?;
let doc_start = tree_walker::extract_leading_docs(&body_node, &source);
Ok((doc_start, body_end))
}
fn expand_symbol_impl(
path: &Path,
byte_offset: usize,
language: Language,
level: ExpansionLevel,
) -> Result<(usize, usize)> {
let source = std::fs::read(path).map_err(|e| SpliceError::Io {
path: path.to_path_buf(),
source: e,
})?;
let mut parser = parser_for_language(language)?;
let tree = parser
.parse(&source, None)
.ok_or_else(|| SpliceError::Parse {
file: path.to_path_buf(),
message: "Parse failed - no tree returned".to_string(),
})?;
let root_node = tree.root_node();
let node = root_node
.descendant_for_byte_range(byte_offset, byte_offset)
.ok_or_else(|| SpliceError::InvalidSpan {
file: path.to_path_buf(),
start: byte_offset,
end: byte_offset,
file_size: source.len(),
})?;
let expander = get_expander(language);
match level {
ExpansionLevel::None => {
Ok((node.start_byte() as usize, node.end_byte() as usize))
}
ExpansionLevel::Body => {
expander.expand_to_body(node, &source).ok_or_else(|| {
SpliceError::Other(format!(
"Could not expand symbol at offset {} in {}",
byte_offset,
path.display()
))
})
}
ExpansionLevel::ContainingBlock => {
let (body_start, body_end) =
expander.expand_to_body(node, &source).ok_or_else(|| {
SpliceError::Other(format!(
"Could not expand symbol at offset {} in {}",
byte_offset,
path.display()
))
})?;
tree_walker::find_containing_block(&root_node, body_start, body_end, &source)
.ok_or_else(|| {
SpliceError::Other(format!(
"Could not expand to containing block in {}",
path.display()
))
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expansion_level_conversions() {
assert_eq!(ExpansionLevel::None.as_u8(), 0);
assert_eq!(ExpansionLevel::Body.as_u8(), 1);
assert_eq!(ExpansionLevel::ContainingBlock.as_u8(), 2);
assert_eq!(ExpansionLevel::from_u8(0), Some(ExpansionLevel::None));
assert_eq!(ExpansionLevel::from_u8(1), Some(ExpansionLevel::Body));
assert_eq!(
ExpansionLevel::from_u8(2),
Some(ExpansionLevel::ContainingBlock)
);
assert_eq!(ExpansionLevel::from_u8(3), None);
}
#[test]
fn test_rust_expander_symbol_kinds() {
let expander = RustExpander;
assert!(expander.is_symbol_kind("function_item"));
assert!(expander.is_symbol_kind("struct_item"));
assert!(expander.is_symbol_kind("enum_item"));
assert!(expander.is_symbol_kind("trait_item"));
assert!(expander.is_symbol_kind("impl_item"));
assert!(expander.is_symbol_kind("mod_item"));
assert!(!expander.is_symbol_kind("identifier"));
assert!(!expander.is_symbol_kind("string_literal"));
}
#[test]
fn test_python_expander_symbol_kinds() {
let expander = PythonExpander;
assert!(expander.is_symbol_kind("function_definition"));
assert!(expander.is_symbol_kind("class_definition"));
assert!(!expander.is_symbol_kind("identifier"));
assert!(!expander.is_symbol_kind("string"));
}
#[test]
fn test_cpp_expander_symbol_kinds() {
let expander = CppExpander;
assert!(expander.is_symbol_kind("function_definition"));
assert!(expander.is_symbol_kind("class_specifier"));
assert!(expander.is_symbol_kind("struct_specifier"));
assert!(expander.is_symbol_kind("enum_specifier"));
assert!(!expander.is_symbol_kind("identifier"));
assert!(!expander.is_symbol_kind("string_literal"));
}
#[test]
fn test_java_expander_symbol_kinds() {
let expander = JavaExpander;
assert!(expander.is_symbol_kind("class_declaration"));
assert!(expander.is_symbol_kind("interface_declaration"));
assert!(expander.is_symbol_kind("method_declaration"));
assert!(!expander.is_symbol_kind("identifier"));
assert!(!expander.is_symbol_kind("string_literal"));
}
#[test]
fn test_javascript_expander_symbol_kinds() {
let expander = JavaScriptExpander;
assert!(expander.is_symbol_kind("function_declaration"));
assert!(expander.is_symbol_kind("class_declaration"));
assert!(expander.is_symbol_kind("method_definition"));
assert!(!expander.is_symbol_kind("identifier"));
assert!(!expander.is_symbol_kind("string"));
}
#[test]
fn test_typescript_expander_symbol_kinds() {
let expander = TypeScriptExpander;
assert!(expander.is_symbol_kind("function_declaration"));
assert!(expander.is_symbol_kind("class_declaration"));
assert!(expander.is_symbol_kind("interface_declaration"));
assert!(expander.is_symbol_kind("type_alias_declaration"));
assert!(!expander.is_symbol_kind("identifier"));
assert!(!expander.is_symbol_kind("string"));
}
}