Skip to main content

antlr4_runtime/
tree.rs

1use crate::errors::AntlrError;
2use crate::token::{CommonToken, Token};
3use std::collections::BTreeMap;
4
5#[derive(Clone, Debug, Eq, PartialEq)]
6pub enum ParseTree {
7    Rule(RuleNode),
8    Terminal(TerminalNode),
9    Error(ErrorNode),
10}
11
12impl ParseTree {
13    pub fn text(&self) -> String {
14        match self {
15            Self::Rule(rule) => rule.text(),
16            Self::Terminal(node) => node.text(),
17            Self::Error(node) => node.text(),
18        }
19    }
20
21    pub fn to_string_tree(&self, rule_names: &[String]) -> String {
22        match self {
23            Self::Rule(rule) => rule.to_string_tree(rule_names),
24            Self::Terminal(node) => escape_tree_text(&node.text()),
25            Self::Error(node) => escape_tree_text(&node.text()),
26        }
27    }
28
29    /// Finds the first rule node with `rule_index` in a depth-first walk.
30    pub fn first_rule(&self, rule_index: usize) -> Option<&Self> {
31        match self {
32            Self::Rule(rule) => {
33                if rule.context().rule_index() == rule_index {
34                    return Some(self);
35                }
36                rule.context()
37                    .children()
38                    .iter()
39                    .find_map(|child| child.first_rule(rule_index))
40            }
41            Self::Terminal(_) | Self::Error(_) => None,
42        }
43    }
44
45    /// Finds the stop token for the first rule node with `rule_index`.
46    pub fn first_rule_stop(&self, rule_index: usize) -> Option<&CommonToken> {
47        let Self::Rule(rule) = self else {
48            return None;
49        };
50        if rule.context().rule_index() == rule_index {
51            return rule.context().stop();
52        }
53        rule.context()
54            .children()
55            .iter()
56            .find_map(|child| child.first_rule_stop(rule_index))
57    }
58
59    /// Reads an integer return value from the first rule node with
60    /// `rule_index`, matching ANTLR's `$label.value` resolution for labeled
61    /// rule references in the runtime testsuite.
62    pub fn first_rule_int_return(&self, rule_index: usize, name: &str) -> Option<i64> {
63        let Self::Rule(rule) = self else {
64            return None;
65        };
66        if rule.context().rule_index() == rule_index {
67            return rule.context().int_return(name);
68        }
69        rule.context()
70            .children()
71            .iter()
72            .find_map(|child| child.first_rule_int_return(rule_index, name))
73    }
74
75    /// Finds the first recovery error token in a depth-first walk.
76    pub fn first_error_token(&self) -> Option<&CommonToken> {
77        match self {
78            Self::Rule(rule) => rule
79                .context()
80                .children()
81                .iter()
82                .find_map(Self::first_error_token),
83            Self::Terminal(_) => None,
84            Self::Error(node) => Some(node.symbol()),
85        }
86    }
87
88    /// Returns the first rule invocation stack for `rule_index`, ordered from
89    /// the selected rule outward to the root rule.
90    pub fn rule_invocation_stack(
91        &self,
92        rule_index: usize,
93        rule_names: &[String],
94    ) -> Option<Vec<String>> {
95        let mut stack = Vec::new();
96        if self.find_rule_path(rule_index, rule_names, &mut stack) {
97            stack.reverse();
98            return Some(stack);
99        }
100        None
101    }
102
103    fn find_rule_path(
104        &self,
105        rule_index: usize,
106        rule_names: &[String],
107        stack: &mut Vec<String>,
108    ) -> bool {
109        let Self::Rule(rule) = self else {
110            return false;
111        };
112        let current_index = rule.context().rule_index();
113        stack.push(
114            rule_names
115                .get(current_index)
116                .map_or("<unknown>", String::as_str)
117                .to_owned(),
118        );
119        if current_index == rule_index
120            || rule
121                .context()
122                .children()
123                .iter()
124                .any(|child| child.find_rule_path(rule_index, rule_names, stack))
125        {
126            return true;
127        }
128        stack.pop();
129        false
130    }
131}
132
133fn escape_tree_text(text: &str) -> String {
134    let mut escaped = String::with_capacity(text.len());
135    for ch in text.chars() {
136        match ch {
137            '\n' => escaped.push_str("\\n"),
138            '\r' => escaped.push_str("\\r"),
139            '\t' => escaped.push_str("\\t"),
140            _ => escaped.push(ch),
141        }
142    }
143    escaped
144}
145
146#[derive(Clone, Debug, Eq, PartialEq)]
147pub struct RuleNode {
148    context: ParserRuleContext,
149}
150
151impl RuleNode {
152    pub const fn new(context: ParserRuleContext) -> Self {
153        Self { context }
154    }
155
156    pub const fn context(&self) -> &ParserRuleContext {
157        &self.context
158    }
159
160    pub const fn context_mut(&mut self) -> &mut ParserRuleContext {
161        &mut self.context
162    }
163
164    pub fn text(&self) -> String {
165        self.context.text()
166    }
167
168    pub fn to_string_tree(&self, rule_names: &[String]) -> String {
169        self.context.to_string_tree(rule_names)
170    }
171}
172
173#[derive(Clone, Debug, Eq, PartialEq)]
174pub struct ParserRuleContext {
175    rule_index: usize,
176    invoking_state: isize,
177    alt_number: usize,
178    start: Option<CommonToken>,
179    stop: Option<CommonToken>,
180    int_returns: Option<Box<IntReturns>>,
181    children: Vec<ParseTree>,
182    exception: Option<AntlrError>,
183}
184
185#[derive(Clone, Debug, Default, Eq, PartialEq)]
186struct IntReturns(BTreeMap<String, i64>);
187
188impl ParserRuleContext {
189    pub const fn new(rule_index: usize, invoking_state: isize) -> Self {
190        Self {
191            rule_index,
192            invoking_state,
193            alt_number: 0,
194            start: None,
195            stop: None,
196            int_returns: None,
197            children: Vec::new(),
198            exception: None,
199        }
200    }
201
202    pub const fn rule_index(&self) -> usize {
203        self.rule_index
204    }
205
206    pub const fn invoking_state(&self) -> isize {
207        self.invoking_state
208    }
209
210    pub const fn alt_number(&self) -> usize {
211        self.alt_number
212    }
213
214    pub const fn set_alt_number(&mut self, alt_number: usize) {
215        self.alt_number = alt_number;
216    }
217
218    pub const fn start(&self) -> Option<&CommonToken> {
219        self.start.as_ref()
220    }
221
222    pub const fn stop(&self) -> Option<&CommonToken> {
223        self.stop.as_ref()
224    }
225
226    pub fn set_start(&mut self, token: CommonToken) {
227        self.start = Some(token);
228    }
229
230    pub fn set_stop(&mut self, token: CommonToken) {
231        self.stop = Some(token);
232    }
233
234    /// Stores a generated integer return value on this rule context.
235    pub fn set_int_return(&mut self, name: impl Into<String>, value: i64) {
236        self.int_returns
237            .get_or_insert_with(Box::default)
238            .0
239            .insert(name.into(), value);
240    }
241
242    /// Reads a generated integer return value from this rule context.
243    pub fn int_return(&self, name: &str) -> Option<i64> {
244        self.int_returns
245            .as_ref()
246            .and_then(|values| values.0.get(name).copied())
247    }
248
249    pub const fn exception(&self) -> Option<&AntlrError> {
250        self.exception.as_ref()
251    }
252
253    pub fn set_exception(&mut self, error: AntlrError) {
254        self.exception = Some(error);
255    }
256
257    pub fn children(&self) -> &[ParseTree] {
258        &self.children
259    }
260
261    pub fn add_child(&mut self, child: ParseTree) {
262        self.children.push(child);
263    }
264
265    pub fn text(&self) -> String {
266        self.children.iter().map(ParseTree::text).collect()
267    }
268
269    pub fn to_string_tree(&self, rule_names: &[String]) -> String {
270        let name = rule_names
271            .get(self.rule_index)
272            .map_or("<unknown>", String::as_str);
273        let display_name = if self.alt_number == 0 {
274            name.to_owned()
275        } else {
276            format!("{name}:{}", self.alt_number)
277        };
278        if self.children.is_empty() {
279            return display_name;
280        }
281        let children = self
282            .children
283            .iter()
284            .map(|child| child.to_string_tree(rule_names))
285            .collect::<Vec<_>>()
286            .join(" ");
287        format!("({display_name} {children})")
288    }
289}
290
291#[derive(Clone, Debug, Eq, PartialEq)]
292pub struct TerminalNode {
293    token: CommonToken,
294}
295
296impl TerminalNode {
297    pub const fn new(token: CommonToken) -> Self {
298        Self { token }
299    }
300
301    pub const fn symbol(&self) -> &CommonToken {
302        &self.token
303    }
304
305    pub fn text(&self) -> String {
306        self.token.text().unwrap_or("").to_owned()
307    }
308}
309
310#[derive(Clone, Debug, Eq, PartialEq)]
311pub struct ErrorNode {
312    token: CommonToken,
313}
314
315impl ErrorNode {
316    pub const fn new(token: CommonToken) -> Self {
317        Self { token }
318    }
319
320    pub const fn symbol(&self) -> &CommonToken {
321        &self.token
322    }
323
324    pub fn text(&self) -> String {
325        self.token.text().unwrap_or("").to_owned()
326    }
327}
328
329pub trait ParseTreeListener {
330    fn enter_every_rule(&mut self, _ctx: &ParserRuleContext) -> Result<(), AntlrError> {
331        Ok(())
332    }
333
334    fn exit_every_rule(&mut self, _ctx: &ParserRuleContext) -> Result<(), AntlrError> {
335        Ok(())
336    }
337
338    fn visit_terminal(&mut self, _node: &TerminalNode) -> Result<(), AntlrError> {
339        Ok(())
340    }
341
342    fn visit_error_node(&mut self, _node: &ErrorNode) -> Result<(), AntlrError> {
343        Ok(())
344    }
345}
346
347#[derive(Debug, Default)]
348pub struct ParseTreeWalker;
349
350impl ParseTreeWalker {
351    /// Walks a parse tree depth-first, invoking listener callbacks in ANTLR's
352    /// enter/child/exit order for rule nodes.
353    pub fn walk<L: ParseTreeListener>(
354        listener: &mut L,
355        tree: &ParseTree,
356    ) -> Result<(), AntlrError> {
357        match tree {
358            ParseTree::Rule(rule) => {
359                listener.enter_every_rule(rule.context())?;
360                for child in rule.context().children() {
361                    Self::walk(listener, child)?;
362                }
363                listener.exit_every_rule(rule.context())
364            }
365            ParseTree::Terminal(node) => listener.visit_terminal(node),
366            ParseTree::Error(node) => listener.visit_error_node(node),
367        }
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::token::CommonToken;
375
376    #[test]
377    fn renders_rule_tree() {
378        let mut ctx = ParserRuleContext::new(0, -1);
379        ctx.add_child(ParseTree::Terminal(TerminalNode::new(
380            CommonToken::new(1).with_text("x"),
381        )));
382        let tree = ParseTree::Rule(RuleNode::new(ctx));
383        assert_eq!(tree.to_string_tree(&["expr".to_owned()]), "(expr x)");
384    }
385
386    #[test]
387    fn finds_first_rule_depth_first() {
388        let mut nested = ParserRuleContext::new(1, -1);
389        nested.add_child(ParseTree::Terminal(TerminalNode::new(
390            CommonToken::new(1).with_text("x"),
391        )));
392
393        let mut root = ParserRuleContext::new(0, -1);
394        root.add_child(ParseTree::Rule(RuleNode::new(nested)));
395        let tree = ParseTree::Rule(RuleNode::new(root));
396
397        let rule = tree.first_rule(1).expect("nested rule should be found");
398        assert_eq!(
399            rule.to_string_tree(&["root".to_owned(), "child".to_owned()]),
400            "(child x)"
401        );
402        assert!(tree.first_rule(2).is_none());
403    }
404
405    #[test]
406    fn reports_rule_invocation_stack_from_leaf_to_root() {
407        let mut nested = ParserRuleContext::new(1, -1);
408        nested.add_child(ParseTree::Terminal(TerminalNode::new(
409            CommonToken::new(1).with_text("x"),
410        )));
411
412        let mut root = ParserRuleContext::new(0, -1);
413        root.add_child(ParseTree::Rule(RuleNode::new(nested)));
414        let tree = ParseTree::Rule(RuleNode::new(root));
415
416        assert_eq!(
417            tree.rule_invocation_stack(1, &["s".to_owned(), "a".to_owned()]),
418            Some(vec!["a".to_owned(), "s".to_owned()])
419        );
420    }
421}