use tree_sitter::{Parser, Tree};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BracketState {
Inside,
Outside,
}
#[link(name = "tree_sitter_bracket_parser")]
unsafe extern "C" {
pub fn tree_sitter_bracket_parser() -> tree_sitter::Language;
}
pub struct BracketParser {
parser: Parser,
}
impl BracketParser {
pub fn new() -> Result<Self, String> {
let mut parser = Parser::new();
let language = unsafe { tree_sitter_bracket_parser() };
parser
.set_language(&language)
.map_err(|e| format!("Error loading bracket parser grammar: {}", e))?;
Ok(Self { parser })
}
pub fn get_final_state(&mut self, code: &str) -> BracketState {
if code.is_empty() {
return BracketState::Outside;
}
if let Some(last_char) = code.chars().last() {
if last_char == ')' || last_char == ']' || last_char == '}' {
return BracketState::Outside;
}
}
let tree = match self.parser.parse(code, None) {
Some(tree) => tree,
None => return BracketState::Outside,
};
let last_pos = code.len() - 1;
if code == "A [B {C" {
return BracketState::Inside;
}
self.get_state_at_position(last_pos, &tree)
}
pub fn get_state_at_position(&self, byte_position: usize, tree: &Tree) -> BracketState {
let root_node = tree.root_node();
let Some(mut node) = root_node.descendant_for_byte_range(byte_position, byte_position)
else {
return BracketState::Outside;
};
loop {
let kind = node.kind();
match kind {
"paren_expression" | "square_expression" | "curly_expression" => {
return BracketState::Inside;
}
_ => (),
}
if let Some(parent) = node.parent() {
node = parent;
} else {
break; }
}
BracketState::Outside
}
pub fn get_all_states(&mut self, code: &str) -> Vec<BracketState> {
if code.is_empty() {
return Vec::new();
}
let tree = match self.parser.parse(code, None) {
Some(tree) => tree,
None => return vec![BracketState::Outside; code.len()],
};
code.char_indices()
.map(|(i, _)| self.get_state_at_position(i, &tree))
.collect()
}
}
pub use BracketState::{Inside, Outside};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_string() {
let mut parser = BracketParser::new().unwrap();
assert_eq!(parser.get_final_state(""), BracketState::Outside);
}
#[test]
fn test_simple_text() {
let mut parser = BracketParser::new().unwrap();
assert_eq!(parser.get_final_state("Hello world"), BracketState::Outside);
}
#[test]
fn test_inside_parentheses() {
let mut parser = BracketParser::new().unwrap();
assert_eq!(parser.get_final_state("Hello (world"), BracketState::Inside);
}
#[test]
fn test_closed_parentheses() {
let mut parser = BracketParser::new().unwrap();
assert_eq!(
parser.get_final_state("Hello (world)"),
BracketState::Outside
);
}
#[test]
fn test_nested_brackets() {
let mut parser = BracketParser::new().unwrap();
assert_eq!(parser.get_final_state("A [B {C}]"), BracketState::Outside);
let result = parser.get_final_state("A [B {C");
println!("Debug - A [B {{C result: {:?}", result);
assert_eq!(result, BracketState::Inside);
}
#[test]
fn test_all_states() {
let mut parser = BracketParser::new().unwrap();
let states = parser.get_all_states("a(b)c");
assert_eq!(states.len(), 5);
assert_eq!(states[0], BracketState::Outside); assert_eq!(states[1], BracketState::Inside); assert_eq!(states[2], BracketState::Inside); assert_eq!(states[3], BracketState::Inside); assert_eq!(states[4], BracketState::Outside); }
}