1use std::any::Any;
3use std::borrow::Borrow;
4
5use std::fmt::{Debug, Formatter};
6use std::iter::from_fn;
7use std::marker::PhantomData;
8use std::ops::Deref;
9use std::rc::Rc;
10
11use crate::atn::INVALID_ALT;
12use crate::char_stream::InputData;
13use crate::int_stream::EOF;
14use crate::interval_set::Interval;
15use crate::parser::ParserNodeType;
16use crate::parser_rule_context::{ParserRuleContext, RuleContextExt};
17use crate::recognizer::Recognizer;
18use crate::rule_context::{CustomRuleContext, RuleContext};
19use crate::token::Token;
20use crate::token_factory::TokenFactory;
21use crate::{interval_set, trees, CoerceTo};
22use better_any::{Tid, TidAble};
23use std::mem;
24
25#[allow(missing_docs)]
27pub trait Tree<'input>: RuleContext<'input> {
28    fn get_parent(&self) -> Option<Rc<<Self::Ctx as ParserNodeType<'input>>::Type>> { None }
29    fn has_parent(&self) -> bool { false }
30    fn get_payload(&self) -> Box<dyn Any> { unimplemented!() }
31    fn get_child(&self, _i: usize) -> Option<Rc<<Self::Ctx as ParserNodeType<'input>>::Type>> {
32        None
33    }
34    fn get_child_count(&self) -> usize { 0 }
35    fn get_children<'a>(
36        &'a self,
37    ) -> Box<dyn Iterator<Item = Rc<<Self::Ctx as ParserNodeType<'input>>::Type>> + 'a>
38    where
39        'input: 'a,
40    {
41        let mut index = 0;
42        let iter = from_fn(move || {
43            if index < self.get_child_count() {
44                index += 1;
45                self.get_child(index - 1)
46            } else {
47                None
48            }
49        });
50
51        Box::new(iter)
52    }
53    }
55
56pub trait ParseTree<'input>: Tree<'input> {
58    fn get_source_interval(&self) -> Interval { interval_set::INVALID }
63
64    fn get_text(&self) -> String { String::new() }
72
73    fn to_string_tree(
77        &self,
78        r: &dyn Recognizer<'input, TF = Self::TF, Node = Self::Ctx>,
79    ) -> String {
80        trees::string_tree(self, r.get_rule_names())
81    }
82}
83
84#[doc(hidden)]
107#[derive(Debug)]
108pub struct NoError;
109
110#[doc(hidden)]
111#[derive(Debug)]
112pub struct IsError;
113
114pub struct LeafNode<'input, Node: ParserNodeType<'input>, T: 'static> {
116    pub symbol: <Node::TF as TokenFactory<'input>>::Tok,
118    iserror: PhantomData<T>,
119}
120better_any::tid! { impl <'input, Node, T:'static> TidAble<'input> for LeafNode<'input, Node, T> where Node:ParserNodeType<'input> }
121
122impl<'input, Node: ParserNodeType<'input>, T: 'static> CustomRuleContext<'input>
123    for LeafNode<'input, Node, T>
124{
125    type TF = Node::TF;
126    type Ctx = Node;
127
128    fn get_rule_index(&self) -> usize { usize::max_value() }
129
130    fn get_node_text(&self, rule_names: &[&str]) -> String {
131        self.symbol.borrow().get_text().to_display()
132    }
133}
134
135impl<'input, Node: ParserNodeType<'input>, T: 'static> ParserRuleContext<'input>
136    for LeafNode<'input, Node, T>
137{
138}
139
140impl<'input, Node: ParserNodeType<'input>, T: 'static> Tree<'input> for LeafNode<'input, Node, T> {}
141
142impl<'input, Node: ParserNodeType<'input>, T: 'static> RuleContext<'input>
143    for LeafNode<'input, Node, T>
144{
145}
146
147impl<'input, Node: ParserNodeType<'input>, T: 'static> ParseTree<'input>
154    for LeafNode<'input, Node, T>
155{
156    fn get_source_interval(&self) -> Interval {
157        let i = self.symbol.borrow().get_token_index();
158        Interval { a: i, b: i }
159    }
160
161    fn get_text(&self) -> String { self.symbol.borrow().get_text().to_display() }
162}
163
164impl<'input, Node: ParserNodeType<'input>, T: 'static> Debug for LeafNode<'input, Node, T> {
165    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
166        if self.symbol.borrow().get_token_type() == EOF {
167            f.write_str("<EOF>")
168        } else {
169            let a = self.symbol.borrow().get_text().to_display();
170            f.write_str(&a)
171        }
172    }
173}
174
175impl<'input, Node: ParserNodeType<'input>, T: 'static> LeafNode<'input, Node, T> {
176    pub fn new(symbol: <Node::TF as TokenFactory<'input>>::Tok) -> Self {
178        Self {
179            symbol,
180            iserror: Default::default(),
181        }
182    }
183}
184
185pub type TerminalNode<'input, NodeType> = LeafNode<'input, NodeType, NoError>;
187
188impl<'input, Node: ParserNodeType<'input>, Listener: ParseTreeListener<'input, Node> + ?Sized>
189    Listenable<Listener> for TerminalNode<'input, Node>
190{
191    fn enter(&self, listener: &mut Listener) { listener.visit_terminal(self) }
192
193    fn exit(&self, _listener: &mut Listener) {
194        }
196}
197
198impl<'input, Node: ParserNodeType<'input>, Visitor: ParseTreeVisitor<'input, Node> + ?Sized>
199    Visitable<Visitor> for TerminalNode<'input, Node>
200{
201    fn accept(&self, visitor: &mut Visitor) { visitor.visit_terminal(self) }
202}
203
204pub type ErrorNode<'input, NodeType> = LeafNode<'input, NodeType, IsError>;
207
208impl<'input, Node: ParserNodeType<'input>, Listener: ParseTreeListener<'input, Node> + ?Sized>
209    Listenable<Listener> for ErrorNode<'input, Node>
210{
211    fn enter(&self, listener: &mut Listener) { listener.visit_error_node(self) }
212
213    fn exit(&self, _listener: &mut Listener) {
214        }
216}
217
218impl<'input, Node: ParserNodeType<'input>, Visitor: ParseTreeVisitor<'input, Node> + ?Sized>
219    Visitable<Visitor> for ErrorNode<'input, Node>
220{
221    fn accept(&self, visitor: &mut Visitor) { visitor.visit_error_node(self) }
222}
223
224pub trait ParseTreeVisitorCompat<'input>: VisitChildren<'input, Self::Node> {
225    type Node: ParserNodeType<'input>;
226    type Return: Default;
227
228    fn temp_result(&mut self) -> &mut Self::Return;
230
231    fn visit(&mut self, node: &<Self::Node as ParserNodeType<'input>>::Type) -> Self::Return {
232        self.visit_node(&node);
233        mem::take(self.temp_result())
234    }
235
236    fn visit_terminal(&mut self, _node: &TerminalNode<'input, Self::Node>) -> Self::Return {
238        Self::Return::default()
239    }
240    fn visit_error_node(&mut self, _node: &ErrorNode<'input, Self::Node>) -> Self::Return {
242        Self::Return::default()
243    }
244
245    fn visit_children(
246        &mut self,
247        node: &<Self::Node as ParserNodeType<'input>>::Type,
248    ) -> Self::Return {
249        let mut result = Self::Return::default();
250        for node in node.get_children() {
251            if !self.should_visit_next_child(&node, &result) {
252                break;
253            }
254
255            let child_result = self.visit(&node);
256            result = self.aggregate_results(result, child_result);
257        }
258        return result;
259    }
260
261    fn aggregate_results(&self, aggregate: Self::Return, next: Self::Return) -> Self::Return {
262        next
263    }
264
265    fn should_visit_next_child(
266        &self,
267        node: &<Self::Node as ParserNodeType<'input>>::Type,
268        current: &Self::Return,
269    ) -> bool {
270        true
271    }
272}
273
274impl<'input, Node, T> ParseTreeVisitor<'input, Node> for T
281where
282    Node: ParserNodeType<'input>,
283    Node::Type: VisitableDyn<Self>,
284    T: ParseTreeVisitorCompat<'input, Node = Node>,
285{
286    fn visit_terminal(&mut self, node: &TerminalNode<'input, Node>) {
287        let result = <Self as ParseTreeVisitorCompat>::visit_terminal(self, node);
288        *<Self as ParseTreeVisitorCompat>::temp_result(self) = result;
289    }
290
291    fn visit_error_node(&mut self, node: &ErrorNode<'input, Node>) {
292        let result = <Self as ParseTreeVisitorCompat>::visit_error_node(self, node);
293        *<Self as ParseTreeVisitorCompat>::temp_result(self) = result;
294    }
295
296    fn visit_children(&mut self, node: &Node::Type) {
297        let result = <Self as ParseTreeVisitorCompat>::visit_children(self, node);
298        *<Self as ParseTreeVisitorCompat>::temp_result(self) = result;
299    }
300}
301
302pub trait ParseTreeVisitor<'input, Node: ParserNodeType<'input>>:
304    VisitChildren<'input, Node>
305{
306    fn visit_terminal(&mut self, _node: &TerminalNode<'input, Node>) {}
311    fn visit_error_node(&mut self, _node: &ErrorNode<'input, Node>) {}
313    fn visit_children(&mut self, node: &Node::Type) {
315        node.get_children()
316            .for_each(|child| self.visit_node(&child))
317    }
318}
319
320pub trait VisitChildren<'input, Node: ParserNodeType<'input>> {
325    fn visit_node(&mut self, node: &Node::Type);
327}
328
329impl<'input, Node, T> VisitChildren<'input, Node> for T
330where
331    Node: ParserNodeType<'input>,
332    T: ParseTreeVisitor<'input, Node> + ?Sized,
333    Node::Type: VisitableDyn<T>,
335{
336    fn visit_node(&mut self, node: &Node::Type) { node.accept_dyn(self) }
343}
344
345pub trait Visitable<Vis: ?Sized> {
348    fn accept(&self, _visitor: &mut Vis) {
350        unreachable!("should have been properly implemented by generated context when reachable")
351    }
352}
353
354#[doc(hidden)]
356pub trait VisitableDyn<Vis: ?Sized> {
357    fn accept_dyn(&self, _visitor: &mut Vis) {
358        unreachable!("should have been properly implemented by generated context when reachable")
359    }
360}
361
362pub trait ParseTreeListener<'input, Node: ParserNodeType<'input>> {
364    fn visit_terminal(&mut self, _node: &TerminalNode<'input, Node>) {}
366    fn visit_error_node(&mut self, _node: &ErrorNode<'input, Node>) {}
368    fn enter_every_rule(&mut self, _ctx: &Node::Type) {}
370    fn exit_every_rule(&mut self, _ctx: &Node::Type) {}
372}
373
374pub trait Listenable<T: ?Sized> {
377    fn enter(&self, _listener: &mut T) {}
379    fn exit(&self, _listener: &mut T) {}
381}
382
383#[derive(Debug)]
391pub struct ParseTreeWalker<'input, 'a, Node, T = dyn ParseTreeListener<'input, Node> + 'a>(
392    PhantomData<fn(&'a T) -> &'input Node::Type>,
393)
394where
395    Node: ParserNodeType<'input>,
396    T: ParseTreeListener<'input, Node> + ?Sized;
397
398impl<'input, 'a, Node, T> ParseTreeWalker<'input, 'a, Node, T>
399where
400    Node: ParserNodeType<'input>,
401    T: ParseTreeListener<'input, Node> + 'a + ?Sized,
402    Node::Type: Listenable<T>,
403{
404    pub fn walk<Listener, Ctx>(mut listener: Box<Listener>, t: &Ctx) -> Box<Listener>
406    where
407        Listener: CoerceTo<T>,
410        Ctx: CoerceTo<Node::Type>,
411    {
412        Self::walk_inner(listener.as_mut().coerce_mut_to(), t.coerce_ref_to());
414
415        listener
418    }
419
420    fn walk_inner(listener: &mut T, t: &Node::Type) {
421        t.enter(listener);
422
423        for child in t.get_children() {
424            Self::walk_inner(listener, child.deref())
425        }
426
427        t.exit(listener);
428    }
429}