kcl_lib/walk/
ast_visitor.rs1use anyhow::Result;
2
3use crate::walk::Node;
4
5pub trait Visitable<'tree> {
11 fn children(&self) -> Vec<Node<'tree>>;
14
15 fn node(&self) -> Node<'tree>;
20
21 fn visit<VisitorT>(&self, visitor: VisitorT) -> Result<bool, VisitorT::Error>
25 where
26 VisitorT: Visitor<'tree>,
27 {
28 visitor.visit_node(self.node())
29 }
30}
31
32pub trait Visitor<'tree> {
38 type Error;
40
41 fn visit_node(&self, node: Node<'tree>) -> Result<bool, Self::Error>;
49}
50
51impl<'a, FnT, ErrorT> Visitor<'a> for FnT
52where
53 FnT: Fn(Node<'a>) -> Result<bool, ErrorT>,
54{
55 type Error = ErrorT;
56
57 fn visit_node(&self, n: Node<'a>) -> Result<bool, ErrorT> {
58 self(n)
59 }
60}
61
62impl<'tree> Visitable<'tree> for Node<'tree> {
63 fn node(&self) -> Node<'tree> {
64 *self
65 }
66
67 fn children(&self) -> Vec<Node<'tree>> {
68 match self {
69 Node::Program(n) => n.body.iter().map(|node| node.into()).collect(),
70 Node::ExpressionStatement(n) => {
71 vec![(&n.expression).into()]
72 }
73 Node::BinaryExpression(n) => {
74 vec![(&n.left).into(), (&n.right).into()]
75 }
76 Node::FunctionExpression(n) => {
77 let mut children = n.params.iter().map(|v| v.into()).collect::<Vec<Node>>();
78 children.push((&n.body).into());
79 children
80 }
81 Node::CallExpressionKw(n) => {
82 let mut children: Vec<Node<'_>> =
83 Vec::with_capacity(1 + if n.unlabeled.is_some() { 1 } else { 0 } + n.arguments.len());
84 children.push((&n.callee).into());
85 children.extend(n.unlabeled.iter().map(Node::from));
86
87 children.extend(n.arguments.iter().map(|v| Node::from(&v.arg)));
92 children
93 }
94 Node::PipeExpression(n) => n.body.iter().map(|v| v.into()).collect(),
95 Node::ArrayExpression(n) => n.elements.iter().map(|v| v.into()).collect(),
96 Node::ArrayRangeExpression(n) => {
97 vec![(&n.start_element).into(), (&n.end_element).into()]
98 }
99 Node::ObjectExpression(n) => n.properties.iter().map(|v| v.into()).collect(),
100 Node::MemberExpression(n) => {
101 vec![(&n.object).into(), (&n.property).into()]
102 }
103 Node::IfExpression(n) => {
104 let mut children = n.else_ifs.iter().map(|v| v.into()).collect::<Vec<Node>>();
105 children.insert(0, n.cond.as_ref().into());
106 children.push(n.final_else.as_ref().into());
107 children
108 }
109 Node::VariableDeclaration(n) => vec![(&n.declaration).into()],
110 Node::TypeDeclaration(n) => vec![(&n.name).into()],
111 Node::ReturnStatement(n) => {
112 vec![(&n.argument).into()]
113 }
114 Node::VariableDeclarator(n) => {
115 vec![(&n.id).into(), (&n.init).into()]
116 }
117 Node::UnaryExpression(n) => {
118 vec![(&n.argument).into()]
119 }
120 Node::Parameter(n) => {
121 vec![(&n.identifier).into()]
122 }
123 Node::ObjectProperty(n) => {
124 vec![(&n.value).into()]
125 }
126 Node::ElseIf(n) => {
127 vec![(&n.cond).into(), n.then_val.as_ref().into()]
128 }
129 Node::LabelledExpression(e) => {
130 vec![(&e.expr).into(), (&e.label).into()]
131 }
132 Node::AscribedExpression(e) => {
133 vec![(&e.expr).into()]
134 }
135 Node::Name(n) => Some((&n.name).into())
136 .into_iter()
137 .chain(n.path.iter().map(|n| n.into()))
138 .collect(),
139 Node::PipeSubstitution(_)
140 | Node::TagDeclarator(_)
141 | Node::Identifier(_)
142 | Node::ImportStatement(_)
143 | Node::KclNone(_)
144 | Node::Literal(_) => vec![],
145 }
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use std::sync::Mutex;
152
153 use super::*;
154
155 macro_rules! kcl {
156 ( $kcl:expr ) => {{
157 $crate::parsing::top_level_parse($kcl).unwrap()
158 }};
159 }
160
161 #[test]
162 fn count_crows() {
163 let program = kcl!(
164 "\
165const crow1 = 1
166const crow2 = 2
167
168fn crow3() {
169 const crow4 = 3
170 crow5()
171}
172"
173 );
174
175 #[derive(Debug, Default)]
176 struct CountCrows {
177 n: Box<Mutex<usize>>,
178 }
179
180 impl<'tree> Visitor<'tree> for &CountCrows {
181 type Error = ();
182
183 fn visit_node(&self, node: Node<'tree>) -> Result<bool, Self::Error> {
184 if let Node::VariableDeclarator(vd) = node {
185 if vd.id.name.starts_with("crow") {
186 *self.n.lock().unwrap() += 1;
187 }
188 }
189
190 for child in node.children().iter() {
191 if !child.visit(*self)? {
192 return Ok(false);
193 }
194 }
195
196 Ok(true)
197 }
198 }
199
200 let prog: Node = (&program).into();
201 let count_crows: CountCrows = Default::default();
202 Visitable::visit(&prog, &count_crows).unwrap();
203 assert_eq!(*count_crows.n.lock().unwrap(), 4);
204 }
205}