kcl_lib/walk/
ast_visitor.rs

1use anyhow::Result;
2
3use crate::walk::Node;
4
5/// Walk-specific trait adding the ability to traverse the KCL AST.
6///
7/// This trait is implemented on [Node] to handle the fairly tricky bit of
8/// recursing into the AST in a single place, as well as helpers for traversing
9/// the tree. for callers to use.
10pub trait Visitable<'tree> {
11    /// Return a `Vec<Node>` for all *direct* children of this AST node. This
12    /// should only contain direct descendants.
13    fn children(&self) -> Vec<Node<'tree>>;
14
15    /// Return `self` as a [Node]. Generally speaking, the [Visitable] trait
16    /// is only going to be implemented on [Node], so this is purely used by
17    /// helpers that are generic over a [Visitable] and want to deref back
18    /// into a [Node].
19    fn node(&self) -> Node<'tree>;
20
21    /// Call the provided [Visitor] in order to Visit `self`. This will
22    /// only be called on `self` -- the [Visitor] is responsible for
23    /// recursing into any children, if desired.
24    fn visit<VisitorT>(&self, visitor: VisitorT) -> Result<bool, VisitorT::Error>
25    where
26        VisitorT: Visitor<'tree>,
27    {
28        visitor.visit_node(self.node())
29    }
30}
31
32/// Trait used to enable visiting members of KCL AST.
33///
34/// Implementing this trait enables the implementer to be invoked over
35/// members of KCL AST by using the [Visitable::visit] function on
36/// a [Node].
37pub trait Visitor<'tree> {
38    /// Error type returned by the [Self::visit] function.
39    type Error;
40
41    /// Visit a KCL AST [Node].
42    ///
43    /// In general, implementers likely wish to check to see if a Node is what
44    /// they're looking for, and either descend into that [Node]'s children (by
45    /// calling [Visitable::children] on [Node] to get children nodes,
46    /// calling [Visitable::visit] on each node of interest), or perform
47    /// some action.
48    fn visit_node(&self, node: Node<'tree>) -> Result<bool, Self::Error>;
49}
50
51impl<'a, FnT, ErrorT> Visitor<'a> for FnT
52where
53    FnT: Fn(Node<'a>) -> Result<bool, ErrorT>,
54{
55    type Error = ErrorT;
56
57    fn visit_node(&self, n: Node<'a>) -> Result<bool, ErrorT> {
58        self(n)
59    }
60}
61
62impl<'tree> Visitable<'tree> for Node<'tree> {
63    fn node(&self) -> Node<'tree> {
64        *self
65    }
66
67    fn children(&self) -> Vec<Node<'tree>> {
68        match self {
69            Node::Program(n) => n.body.iter().map(|node| node.into()).collect(),
70            Node::ExpressionStatement(n) => {
71                vec![(&n.expression).into()]
72            }
73            Node::BinaryExpression(n) => {
74                vec![(&n.left).into(), (&n.right).into()]
75            }
76            Node::FunctionExpression(n) => {
77                let mut children = n.params.iter().map(|v| v.into()).collect::<Vec<Node>>();
78                children.push((&n.body).into());
79                children
80            }
81            Node::CallExpressionKw(n) => {
82                let mut children: Vec<Node<'_>> =
83                    Vec::with_capacity(1 + if n.unlabeled.is_some() { 1 } else { 0 } + n.arguments.len());
84                children.push((&n.callee).into());
85                children.extend(n.unlabeled.iter().map(Node::from));
86
87                // TODO: this is wrong but it's what the old walk code was doing.
88                // We likely need a real LabeledArg AST node, but I don't
89                // want to tango with it since it's a lot deeper than
90                // adding it to the enum.
91                children.extend(n.arguments.iter().map(|v| Node::from(&v.arg)));
92                children
93            }
94            Node::PipeExpression(n) => n.body.iter().map(|v| v.into()).collect(),
95            Node::ArrayExpression(n) => n.elements.iter().map(|v| v.into()).collect(),
96            Node::ArrayRangeExpression(n) => {
97                vec![(&n.start_element).into(), (&n.end_element).into()]
98            }
99            Node::ObjectExpression(n) => n.properties.iter().map(|v| v.into()).collect(),
100            Node::MemberExpression(n) => {
101                vec![(&n.object).into(), (&n.property).into()]
102            }
103            Node::IfExpression(n) => {
104                let mut children = n.else_ifs.iter().map(|v| v.into()).collect::<Vec<Node>>();
105                children.insert(0, n.cond.as_ref().into());
106                children.push(n.final_else.as_ref().into());
107                children
108            }
109            Node::VariableDeclaration(n) => vec![(&n.declaration).into()],
110            Node::TypeDeclaration(n) => vec![(&n.name).into()],
111            Node::ReturnStatement(n) => {
112                vec![(&n.argument).into()]
113            }
114            Node::VariableDeclarator(n) => {
115                vec![(&n.id).into(), (&n.init).into()]
116            }
117            Node::UnaryExpression(n) => {
118                vec![(&n.argument).into()]
119            }
120            Node::Parameter(n) => {
121                vec![(&n.identifier).into()]
122            }
123            Node::ObjectProperty(n) => {
124                vec![(&n.value).into()]
125            }
126            Node::ElseIf(n) => {
127                vec![(&n.cond).into(), n.then_val.as_ref().into()]
128            }
129            Node::LabelledExpression(e) => {
130                vec![(&e.expr).into(), (&e.label).into()]
131            }
132            Node::AscribedExpression(e) => {
133                vec![(&e.expr).into()]
134            }
135            Node::Name(n) => Some((&n.name).into())
136                .into_iter()
137                .chain(n.path.iter().map(|n| n.into()))
138                .collect(),
139            Node::PipeSubstitution(_)
140            | Node::TagDeclarator(_)
141            | Node::Identifier(_)
142            | Node::ImportStatement(_)
143            | Node::KclNone(_)
144            | Node::Literal(_) => vec![],
145        }
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use std::sync::Mutex;
152
153    use super::*;
154
155    macro_rules! kcl {
156        ( $kcl:expr ) => {{
157            $crate::parsing::top_level_parse($kcl).unwrap()
158        }};
159    }
160
161    #[test]
162    fn count_crows() {
163        let program = kcl!(
164            "\
165const crow1 = 1
166const crow2 = 2
167
168fn crow3() {
169    const crow4 = 3
170    crow5()
171}
172"
173        );
174
175        #[derive(Debug, Default)]
176        struct CountCrows {
177            n: Box<Mutex<usize>>,
178        }
179
180        impl<'tree> Visitor<'tree> for &CountCrows {
181            type Error = ();
182
183            fn visit_node(&self, node: Node<'tree>) -> Result<bool, Self::Error> {
184                if let Node::VariableDeclarator(vd) = node {
185                    if vd.id.name.starts_with("crow") {
186                        *self.n.lock().unwrap() += 1;
187                    }
188                }
189
190                for child in node.children().iter() {
191                    if !child.visit(*self)? {
192                        return Ok(false);
193                    }
194                }
195
196                Ok(true)
197            }
198        }
199
200        let prog: Node = (&program).into();
201        let count_crows: CountCrows = Default::default();
202        Visitable::visit(&prog, &count_crows).unwrap();
203        assert_eq!(*count_crows.n.lock().unwrap(), 4);
204    }
205}