1use crate::errors::AntlrError;
2use crate::token::{CommonToken, Token};
3use std::collections::BTreeMap;
4
5#[derive(Clone, Debug, Eq, PartialEq)]
6pub enum ParseTree {
7 Rule(RuleNode),
8 Terminal(TerminalNode),
9 Error(ErrorNode),
10}
11
12impl ParseTree {
13 pub fn text(&self) -> String {
14 match self {
15 Self::Rule(rule) => rule.text(),
16 Self::Terminal(node) => node.text(),
17 Self::Error(node) => node.text(),
18 }
19 }
20
21 pub fn to_string_tree(&self, rule_names: &[String]) -> String {
22 match self {
23 Self::Rule(rule) => rule.to_string_tree(rule_names),
24 Self::Terminal(node) => escape_tree_text(&node.text()),
25 Self::Error(node) => escape_tree_text(&node.text()),
26 }
27 }
28
29 pub fn first_rule(&self, rule_index: usize) -> Option<&Self> {
31 match self {
32 Self::Rule(rule) => {
33 if rule.context().rule_index() == rule_index {
34 return Some(self);
35 }
36 rule.context()
37 .children()
38 .iter()
39 .find_map(|child| child.first_rule(rule_index))
40 }
41 Self::Terminal(_) | Self::Error(_) => None,
42 }
43 }
44
45 pub fn first_rule_stop(&self, rule_index: usize) -> Option<&CommonToken> {
47 let Self::Rule(rule) = self else {
48 return None;
49 };
50 if rule.context().rule_index() == rule_index {
51 return rule.context().stop();
52 }
53 rule.context()
54 .children()
55 .iter()
56 .find_map(|child| child.first_rule_stop(rule_index))
57 }
58
59 pub fn first_rule_int_return(&self, rule_index: usize, name: &str) -> Option<i64> {
63 let Self::Rule(rule) = self else {
64 return None;
65 };
66 if rule.context().rule_index() == rule_index {
67 return rule.context().int_return(name);
68 }
69 rule.context()
70 .children()
71 .iter()
72 .find_map(|child| child.first_rule_int_return(rule_index, name))
73 }
74
75 pub fn first_error_token(&self) -> Option<&CommonToken> {
77 match self {
78 Self::Rule(rule) => rule
79 .context()
80 .children()
81 .iter()
82 .find_map(Self::first_error_token),
83 Self::Terminal(_) => None,
84 Self::Error(node) => Some(node.symbol()),
85 }
86 }
87
88 pub fn rule_invocation_stack(
91 &self,
92 rule_index: usize,
93 rule_names: &[String],
94 ) -> Option<Vec<String>> {
95 let mut stack = Vec::new();
96 if self.find_rule_path(rule_index, rule_names, &mut stack) {
97 stack.reverse();
98 return Some(stack);
99 }
100 None
101 }
102
103 fn find_rule_path(
104 &self,
105 rule_index: usize,
106 rule_names: &[String],
107 stack: &mut Vec<String>,
108 ) -> bool {
109 let Self::Rule(rule) = self else {
110 return false;
111 };
112 let current_index = rule.context().rule_index();
113 stack.push(
114 rule_names
115 .get(current_index)
116 .map_or("<unknown>", String::as_str)
117 .to_owned(),
118 );
119 if current_index == rule_index
120 || rule
121 .context()
122 .children()
123 .iter()
124 .any(|child| child.find_rule_path(rule_index, rule_names, stack))
125 {
126 return true;
127 }
128 stack.pop();
129 false
130 }
131}
132
133fn escape_tree_text(text: &str) -> String {
134 let mut escaped = String::with_capacity(text.len());
135 for ch in text.chars() {
136 match ch {
137 '\n' => escaped.push_str("\\n"),
138 '\r' => escaped.push_str("\\r"),
139 '\t' => escaped.push_str("\\t"),
140 _ => escaped.push(ch),
141 }
142 }
143 escaped
144}
145
146#[derive(Clone, Debug, Eq, PartialEq)]
147pub struct RuleNode {
148 context: ParserRuleContext,
149}
150
151impl RuleNode {
152 pub const fn new(context: ParserRuleContext) -> Self {
153 Self { context }
154 }
155
156 pub const fn context(&self) -> &ParserRuleContext {
157 &self.context
158 }
159
160 pub const fn context_mut(&mut self) -> &mut ParserRuleContext {
161 &mut self.context
162 }
163
164 pub fn text(&self) -> String {
165 self.context.text()
166 }
167
168 pub fn to_string_tree(&self, rule_names: &[String]) -> String {
169 self.context.to_string_tree(rule_names)
170 }
171}
172
173#[derive(Clone, Debug, Eq, PartialEq)]
174pub struct ParserRuleContext {
175 rule_index: usize,
176 invoking_state: isize,
177 alt_number: usize,
178 start: Option<CommonToken>,
179 stop: Option<CommonToken>,
180 int_returns: Option<Box<IntReturns>>,
181 children: Vec<ParseTree>,
182 exception: Option<AntlrError>,
183}
184
185#[derive(Clone, Debug, Default, Eq, PartialEq)]
186struct IntReturns(BTreeMap<String, i64>);
187
188impl ParserRuleContext {
189 pub const fn new(rule_index: usize, invoking_state: isize) -> Self {
190 Self {
191 rule_index,
192 invoking_state,
193 alt_number: 0,
194 start: None,
195 stop: None,
196 int_returns: None,
197 children: Vec::new(),
198 exception: None,
199 }
200 }
201
202 pub const fn rule_index(&self) -> usize {
203 self.rule_index
204 }
205
206 pub const fn invoking_state(&self) -> isize {
207 self.invoking_state
208 }
209
210 pub const fn alt_number(&self) -> usize {
211 self.alt_number
212 }
213
214 pub const fn set_alt_number(&mut self, alt_number: usize) {
215 self.alt_number = alt_number;
216 }
217
218 pub const fn start(&self) -> Option<&CommonToken> {
219 self.start.as_ref()
220 }
221
222 pub const fn stop(&self) -> Option<&CommonToken> {
223 self.stop.as_ref()
224 }
225
226 pub fn set_start(&mut self, token: CommonToken) {
227 self.start = Some(token);
228 }
229
230 pub fn set_stop(&mut self, token: CommonToken) {
231 self.stop = Some(token);
232 }
233
234 pub fn set_int_return(&mut self, name: impl Into<String>, value: i64) {
236 self.int_returns
237 .get_or_insert_with(Box::default)
238 .0
239 .insert(name.into(), value);
240 }
241
242 pub fn int_return(&self, name: &str) -> Option<i64> {
244 self.int_returns
245 .as_ref()
246 .and_then(|values| values.0.get(name).copied())
247 }
248
249 pub const fn exception(&self) -> Option<&AntlrError> {
250 self.exception.as_ref()
251 }
252
253 pub fn set_exception(&mut self, error: AntlrError) {
254 self.exception = Some(error);
255 }
256
257 pub fn children(&self) -> &[ParseTree] {
258 &self.children
259 }
260
261 pub fn add_child(&mut self, child: ParseTree) {
262 self.children.push(child);
263 }
264
265 pub fn text(&self) -> String {
266 self.children.iter().map(ParseTree::text).collect()
267 }
268
269 pub fn to_string_tree(&self, rule_names: &[String]) -> String {
270 let name = rule_names
271 .get(self.rule_index)
272 .map_or("<unknown>", String::as_str);
273 let display_name = if self.alt_number == 0 {
274 name.to_owned()
275 } else {
276 format!("{name}:{}", self.alt_number)
277 };
278 if self.children.is_empty() {
279 return display_name;
280 }
281 let children = self
282 .children
283 .iter()
284 .map(|child| child.to_string_tree(rule_names))
285 .collect::<Vec<_>>()
286 .join(" ");
287 format!("({display_name} {children})")
288 }
289}
290
291#[derive(Clone, Debug, Eq, PartialEq)]
292pub struct TerminalNode {
293 token: CommonToken,
294}
295
296impl TerminalNode {
297 pub const fn new(token: CommonToken) -> Self {
298 Self { token }
299 }
300
301 pub const fn symbol(&self) -> &CommonToken {
302 &self.token
303 }
304
305 pub fn text(&self) -> String {
306 self.token.text().unwrap_or("").to_owned()
307 }
308}
309
310#[derive(Clone, Debug, Eq, PartialEq)]
311pub struct ErrorNode {
312 token: CommonToken,
313}
314
315impl ErrorNode {
316 pub const fn new(token: CommonToken) -> Self {
317 Self { token }
318 }
319
320 pub const fn symbol(&self) -> &CommonToken {
321 &self.token
322 }
323
324 pub fn text(&self) -> String {
325 self.token.text().unwrap_or("").to_owned()
326 }
327}
328
329pub trait ParseTreeListener {
330 fn enter_every_rule(&mut self, _ctx: &ParserRuleContext) -> Result<(), AntlrError> {
331 Ok(())
332 }
333
334 fn exit_every_rule(&mut self, _ctx: &ParserRuleContext) -> Result<(), AntlrError> {
335 Ok(())
336 }
337
338 fn visit_terminal(&mut self, _node: &TerminalNode) -> Result<(), AntlrError> {
339 Ok(())
340 }
341
342 fn visit_error_node(&mut self, _node: &ErrorNode) -> Result<(), AntlrError> {
343 Ok(())
344 }
345}
346
347#[derive(Debug, Default)]
348pub struct ParseTreeWalker;
349
350impl ParseTreeWalker {
351 pub fn walk<L: ParseTreeListener>(
354 listener: &mut L,
355 tree: &ParseTree,
356 ) -> Result<(), AntlrError> {
357 match tree {
358 ParseTree::Rule(rule) => {
359 listener.enter_every_rule(rule.context())?;
360 for child in rule.context().children() {
361 Self::walk(listener, child)?;
362 }
363 listener.exit_every_rule(rule.context())
364 }
365 ParseTree::Terminal(node) => listener.visit_terminal(node),
366 ParseTree::Error(node) => listener.visit_error_node(node),
367 }
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use crate::token::CommonToken;
375
376 #[test]
377 fn renders_rule_tree() {
378 let mut ctx = ParserRuleContext::new(0, -1);
379 ctx.add_child(ParseTree::Terminal(TerminalNode::new(
380 CommonToken::new(1).with_text("x"),
381 )));
382 let tree = ParseTree::Rule(RuleNode::new(ctx));
383 assert_eq!(tree.to_string_tree(&["expr".to_owned()]), "(expr x)");
384 }
385
386 #[test]
387 fn finds_first_rule_depth_first() {
388 let mut nested = ParserRuleContext::new(1, -1);
389 nested.add_child(ParseTree::Terminal(TerminalNode::new(
390 CommonToken::new(1).with_text("x"),
391 )));
392
393 let mut root = ParserRuleContext::new(0, -1);
394 root.add_child(ParseTree::Rule(RuleNode::new(nested)));
395 let tree = ParseTree::Rule(RuleNode::new(root));
396
397 let rule = tree.first_rule(1).expect("nested rule should be found");
398 assert_eq!(
399 rule.to_string_tree(&["root".to_owned(), "child".to_owned()]),
400 "(child x)"
401 );
402 assert!(tree.first_rule(2).is_none());
403 }
404
405 #[test]
406 fn reports_rule_invocation_stack_from_leaf_to_root() {
407 let mut nested = ParserRuleContext::new(1, -1);
408 nested.add_child(ParseTree::Terminal(TerminalNode::new(
409 CommonToken::new(1).with_text("x"),
410 )));
411
412 let mut root = ParserRuleContext::new(0, -1);
413 root.add_child(ParseTree::Rule(RuleNode::new(nested)));
414 let tree = ParseTree::Rule(RuleNode::new(root));
415
416 assert_eq!(
417 tree.rule_invocation_stack(1, &["s".to_owned(), "a".to_owned()]),
418 Some(vec!["a".to_owned(), "s".to_owned()])
419 );
420 }
421}