kcl_lib/walk/
ast_visitor.rs1use anyhow::Result;
2
3use crate::walk::Node;
4
5pub trait Visitable<'tree> {
11 fn children(&self) -> Vec<Node<'tree>>;
14
15 fn node(&self) -> Node<'tree>;
20
21 fn visit<VisitorT>(&self, visitor: VisitorT) -> Result<bool, VisitorT::Error>
25 where
26 VisitorT: Visitor<'tree>,
27 {
28 visitor.visit_node(self.node())
29 }
30}
31
32pub trait Visitor<'tree> {
38 type Error;
40
41 fn visit_node(&self, node: Node<'tree>) -> Result<bool, Self::Error>;
49}
50
51impl<'a, FnT, ErrorT> Visitor<'a> for FnT
52where
53 FnT: Fn(Node<'a>) -> Result<bool, ErrorT>,
54{
55 type Error = ErrorT;
56
57 fn visit_node(&self, n: Node<'a>) -> Result<bool, ErrorT> {
58 self(n)
59 }
60}
61
62impl<'tree> Visitable<'tree> for Node<'tree> {
63 fn node(&self) -> Node<'tree> {
64 *self
65 }
66
67 fn children(&self) -> Vec<Node<'tree>> {
68 match self {
69 Node::Program(n) => n.body.iter().map(|node| node.into()).collect(),
70 Node::ExpressionStatement(n) => {
71 vec![(&n.expression).into()]
72 }
73 Node::BinaryExpression(n) => {
74 vec![(&n.left).into(), (&n.right).into()]
75 }
76 Node::FunctionExpression(n) => {
77 let mut children = n.params.iter().map(|v| v.into()).collect::<Vec<Node>>();
78 children.push((&n.body).into());
79 children
80 }
81 Node::CallExpression(n) => {
82 let mut children = n.arguments.iter().map(|v| v.into()).collect::<Vec<Node>>();
83 children.insert(0, (&n.callee).into());
84 children
85 }
86 Node::CallExpressionKw(n) => {
87 let mut children = n.unlabeled.iter().map(|v| v.into()).collect::<Vec<Node>>();
88
89 children.extend(n.arguments.iter().map(|v| (&v.arg).into()).collect::<Vec<Node>>());
94 children
95 }
96 Node::PipeExpression(n) => n.body.iter().map(|v| v.into()).collect(),
97 Node::ArrayExpression(n) => n.elements.iter().map(|v| v.into()).collect(),
98 Node::ArrayRangeExpression(n) => {
99 vec![(&n.start_element).into(), (&n.end_element).into()]
100 }
101 Node::ObjectExpression(n) => n.properties.iter().map(|v| v.into()).collect(),
102 Node::MemberExpression(n) => {
103 vec![(&n.object).into(), (&n.property).into()]
104 }
105 Node::IfExpression(n) => {
106 let mut children = n.else_ifs.iter().map(|v| v.into()).collect::<Vec<Node>>();
107 children.insert(0, n.cond.as_ref().into());
108 children.push(n.final_else.as_ref().into());
109 children
110 }
111 Node::VariableDeclaration(n) => vec![(&n.declaration).into()],
112 Node::ReturnStatement(n) => {
113 vec![(&n.argument).into()]
114 }
115 Node::VariableDeclarator(n) => {
116 vec![(&n.id).into(), (&n.init).into()]
117 }
118 Node::UnaryExpression(n) => {
119 vec![(&n.argument).into()]
120 }
121 Node::Parameter(n) => {
122 vec![(&n.identifier).into()]
123 }
124 Node::ObjectProperty(n) => {
125 vec![(&n.value).into()]
126 }
127 Node::ElseIf(n) => {
128 vec![(&n.cond).into(), n.then_val.as_ref().into()]
129 }
130 Node::LabelledExpression(e) => {
131 vec![(&e.expr).into(), (&e.label).into()]
132 }
133 Node::Ascription(e) => {
134 vec![(&e.expr).into()]
135 }
136 Node::PipeSubstitution(_)
137 | Node::TagDeclarator(_)
138 | Node::Identifier(_)
139 | Node::ImportStatement(_)
140 | Node::KclNone(_)
141 | Node::Literal(_) => vec![],
142 }
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use std::sync::Mutex;
149
150 use super::*;
151
152 macro_rules! kcl {
153 ( $kcl:expr ) => {{
154 $crate::parsing::top_level_parse($kcl).unwrap()
155 }};
156 }
157
158 #[test]
159 fn count_crows() {
160 let program = kcl!(
161 "\
162const crow1 = 1
163const crow2 = 2
164
165fn crow3() {
166 const crow4 = 3
167 crow5()
168}
169"
170 );
171
172 #[derive(Debug, Default)]
173 struct CountCrows {
174 n: Box<Mutex<usize>>,
175 }
176
177 impl<'tree> Visitor<'tree> for &CountCrows {
178 type Error = ();
179
180 fn visit_node(&self, node: Node<'tree>) -> Result<bool, Self::Error> {
181 if let Node::VariableDeclarator(vd) = node {
182 if vd.id.name.starts_with("crow") {
183 *self.n.lock().unwrap() += 1;
184 }
185 }
186
187 for child in node.children().iter() {
188 if !child.visit(*self)? {
189 return Ok(false);
190 }
191 }
192
193 Ok(true)
194 }
195 }
196
197 let prog: Node = (&program).into();
198 let count_crows: CountCrows = Default::default();
199 Visitable::visit(&prog, &count_crows).unwrap();
200 assert_eq!(*count_crows.n.lock().unwrap(), 4);
201 }
202}