use anyhow::Result;
use crate::walk::Node;
pub trait Visitable<'tree> {
fn children(&self) -> Vec<Node<'tree>>;
fn node(&self) -> Node<'tree>;
fn visit<VisitorT>(&self, visitor: VisitorT) -> Result<bool, VisitorT::Error>
where
VisitorT: Visitor<'tree>,
{
visitor.visit_node(self.node())
}
}
pub trait Visitor<'tree> {
type Error;
fn visit_node(&self, node: Node<'tree>) -> Result<bool, Self::Error>;
}
impl<'a, FnT, ErrorT> Visitor<'a> for FnT
where
FnT: Fn(Node<'a>) -> Result<bool, ErrorT>,
{
type Error = ErrorT;
fn visit_node(&self, n: Node<'a>) -> Result<bool, ErrorT> {
self(n)
}
}
impl<'tree> Visitable<'tree> for Node<'tree> {
fn node(&self) -> Node<'tree> {
*self
}
fn children(&self) -> Vec<Node<'tree>> {
match self {
Node::Program(n) => n.body.iter().map(|node| node.into()).collect(),
Node::ExpressionStatement(n) => {
vec![(&n.expression).into()]
}
Node::BinaryExpression(n) => {
vec![(&n.left).into(), (&n.right).into()]
}
Node::FunctionExpression(n) => {
let mut children = n.params.iter().map(|v| v.into()).collect::<Vec<Node>>();
children.push((&n.body).into());
children
}
Node::CallExpressionKw(n) => {
let mut children: Vec<Node<'_>> =
Vec::with_capacity(1 + if n.unlabeled.is_some() { 1 } else { 0 } + n.arguments.len());
children.push((&n.callee).into());
children.extend(n.unlabeled.iter().map(Node::from));
children.extend(n.arguments.iter().map(|v| Node::from(&v.arg)));
children
}
Node::PipeExpression(n) => n.body.iter().map(|v| v.into()).collect(),
Node::ArrayExpression(n) => n.elements.iter().map(|v| v.into()).collect(),
Node::ArrayRangeExpression(n) => {
vec![(&n.start_element).into(), (&n.end_element).into()]
}
Node::ObjectExpression(n) => n.properties.iter().map(|v| v.into()).collect(),
Node::MemberExpression(n) => {
vec![(&n.object).into(), (&n.property).into()]
}
Node::IfExpression(n) => {
let mut children = n.else_ifs.iter().map(|v| v.into()).collect::<Vec<Node>>();
children.insert(0, n.cond.as_ref().into());
children.push(n.final_else.as_ref().into());
children
}
Node::VariableDeclaration(n) => vec![(&n.declaration).into()],
Node::TypeDeclaration(n) => vec![(&n.name).into()],
Node::ReturnStatement(n) => {
vec![(&n.argument).into()]
}
Node::VariableDeclarator(n) => {
vec![(&n.id).into(), (&n.init).into()]
}
Node::UnaryExpression(n) => {
vec![(&n.argument).into()]
}
Node::Parameter(n) => {
vec![(&n.identifier).into()]
}
Node::ObjectProperty(n) => {
vec![(&n.value).into()]
}
Node::ElseIf(n) => {
vec![(&n.cond).into(), n.then_val.as_ref().into()]
}
Node::LabelledExpression(e) => {
vec![(&e.expr).into(), (&e.label).into()]
}
Node::AscribedExpression(e) => {
vec![(&e.expr).into()]
}
Node::Name(n) => Some((&n.name).into())
.into_iter()
.chain(n.path.iter().map(|n| n.into()))
.collect(),
Node::SketchBlock(n) => {
let mut children: Vec<Node<'_>> = Vec::with_capacity(n.arguments.len() + 1);
children.extend(n.arguments.iter().map(|a| Node::from(&a.arg)));
children.push((&n.body).into());
children
}
Node::SketchVar(n) => {
if let Some(initial) = &n.initial {
vec![initial.as_ref().into()]
} else {
vec![]
}
}
Node::Block(n) => n.items.iter().map(|node| node.into()).collect(),
Node::PipeSubstitution(_)
| Node::TagDeclarator(_)
| Node::Identifier(_)
| Node::ImportStatement(_)
| Node::KclNone(_)
| Node::NumericLiteral(_)
| Node::Literal(_) => vec![],
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
macro_rules! kcl {
( $kcl:expr_2021 ) => {{ $crate::parsing::top_level_parse($kcl).unwrap() }};
}
#[test]
fn count_crows() {
let program = kcl!(
"\
const crow1 = 1
const crow2 = 2
fn crow3() {
const crow4 = 3
crow5()
}
"
);
#[derive(Debug, Default)]
struct CountCrows {
n: Box<Mutex<usize>>,
}
impl<'tree> Visitor<'tree> for &CountCrows {
type Error = ();
fn visit_node(&self, node: Node<'tree>) -> Result<bool, Self::Error> {
if let Node::VariableDeclarator(vd) = node
&& vd.id.name.starts_with("crow")
{
*self.n.lock().unwrap() += 1;
}
for child in node.children().iter() {
if !child.visit(*self)? {
return Ok(false);
}
}
Ok(true)
}
}
let prog: Node = (&program).into();
let count_crows: CountCrows = Default::default();
Visitable::visit(&prog, &count_crows).unwrap();
assert_eq!(*count_crows.n.lock().unwrap(), 4);
}
}