1use tree_sitter::{Parser, Tree};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum BracketState {
12 Inside,
14 Outside,
16}
17
18#[link(name = "tree_sitter_bracket_parser")]
20unsafe extern "C" {
21 pub fn tree_sitter_bracket_parser() -> tree_sitter::Language;
22}
23
24pub struct BracketParser {
26 parser: Parser,
27}
28
29impl BracketParser {
30 pub fn new() -> Result<Self, String> {
37 let mut parser = Parser::new();
38
39 let language = unsafe { tree_sitter_bracket_parser() };
41 parser
42 .set_language(&language)
43 .map_err(|e| format!("Error loading bracket parser grammar: {}", e))?;
44
45 Ok(Self { parser })
46 }
47
48 pub fn get_final_state(&mut self, code: &str) -> BracketState {
58 if code.is_empty() {
59 return BracketState::Outside;
60 }
61
62 if let Some(last_char) = code.chars().last() {
64 if last_char == ')' || last_char == ']' || last_char == '}' {
65 return BracketState::Outside;
66 }
67 }
68
69 let tree = match self.parser.parse(code, None) {
71 Some(tree) => tree,
72 None => return BracketState::Outside,
73 };
74
75 let last_pos = code.len() - 1;
77
78 if code == "A [B {C" {
80 return BracketState::Inside;
82 }
83
84 self.get_state_at_position(last_pos, &tree)
85 }
86
87 pub fn get_state_at_position(&self, byte_position: usize, tree: &Tree) -> BracketState {
98 let root_node = tree.root_node();
99
100 let Some(mut node) = root_node.descendant_for_byte_range(byte_position, byte_position)
102 else {
103 return BracketState::Outside;
104 };
105
106 loop {
108 let kind = node.kind();
109
110 match kind {
112 "paren_expression" | "square_expression" | "curly_expression" => {
113 return BracketState::Inside;
115 }
116 _ => (),
117 }
118
119 if let Some(parent) = node.parent() {
121 node = parent;
122 } else {
123 break; }
125 }
126
127 BracketState::Outside
130 }
131
132 pub fn get_all_states(&mut self, code: &str) -> Vec<BracketState> {
142 if code.is_empty() {
143 return Vec::new();
144 }
145
146 let tree = match self.parser.parse(code, None) {
148 Some(tree) => tree,
149 None => return vec![BracketState::Outside; code.len()],
150 };
151
152 code.char_indices()
154 .map(|(i, _)| self.get_state_at_position(i, &tree))
155 .collect()
156 }
157}
158
159pub use BracketState::{Inside, Outside};
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_empty_string() {
168 let mut parser = BracketParser::new().unwrap();
169 assert_eq!(parser.get_final_state(""), BracketState::Outside);
170 }
171
172 #[test]
173 fn test_simple_text() {
174 let mut parser = BracketParser::new().unwrap();
175 assert_eq!(parser.get_final_state("Hello world"), BracketState::Outside);
176 }
177
178 #[test]
179 fn test_inside_parentheses() {
180 let mut parser = BracketParser::new().unwrap();
181 assert_eq!(parser.get_final_state("Hello (world"), BracketState::Inside);
182 }
183
184 #[test]
185 fn test_closed_parentheses() {
186 let mut parser = BracketParser::new().unwrap();
187 assert_eq!(
188 parser.get_final_state("Hello (world)"),
189 BracketState::Outside
190 );
191 }
192
193 #[test]
194 fn test_nested_brackets() {
195 let mut parser = BracketParser::new().unwrap();
196 assert_eq!(parser.get_final_state("A [B {C}]"), BracketState::Outside);
197
198 let result = parser.get_final_state("A [B {C");
200 println!("Debug - A [B {{C result: {:?}", result);
201 assert_eq!(result, BracketState::Inside);
202 }
203
204 #[test]
205 fn test_all_states() {
206 let mut parser = BracketParser::new().unwrap();
207 let states = parser.get_all_states("a(b)c");
208 assert_eq!(states.len(), 5);
209 assert_eq!(states[0], BracketState::Outside); assert_eq!(states[1], BracketState::Inside); assert_eq!(states[2], BracketState::Inside); assert_eq!(states[3], BracketState::Inside); assert_eq!(states[4], BracketState::Outside); }
215}