use std::ops::ControlFlow;
use marker_api::{
ast::{EnumVariant, ItemField},
prelude::*,
};
#[non_exhaustive]
#[derive(Debug, Copy, Clone, Default)]
pub enum VisitorScope {
AllBodies,
#[default]
NoBodies,
}
pub trait Visitor<B> {
fn scope(&self) -> VisitorScope {
VisitorScope::NoBodies
}
fn visit_item<'ast>(&mut self, _cx: &'ast MarkerContext<'ast>, _item: ItemKind<'ast>) -> ControlFlow<B> {
ControlFlow::Continue(())
}
fn visit_field<'ast>(&mut self, _cx: &'ast MarkerContext<'ast>, _field: &'ast ItemField<'ast>) -> ControlFlow<B> {
ControlFlow::Continue(())
}
fn visit_variant<'ast>(
&mut self,
_cx: &'ast MarkerContext<'ast>,
_variant: &'ast EnumVariant<'ast>,
) -> ControlFlow<B> {
ControlFlow::Continue(())
}
fn visit_body<'ast>(&mut self, _cx: &'ast MarkerContext<'ast>, _body: &'ast ast::Body<'ast>) -> ControlFlow<B> {
ControlFlow::Continue(())
}
fn visit_stmt<'ast>(&mut self, _cx: &'ast MarkerContext<'ast>, _stmt: StmtKind<'ast>) -> ControlFlow<B> {
ControlFlow::Continue(())
}
fn visit_expr<'ast>(&mut self, _cx: &'ast MarkerContext<'ast>, _expr: ExprKind<'ast>) -> ControlFlow<B> {
ControlFlow::Continue(())
}
}
pub fn traverse_item<'ast, B>(
cx: &'ast MarkerContext<'ast>,
visitor: &mut dyn Visitor<B>,
kind: ItemKind<'ast>,
) -> ControlFlow<B> {
fn traverse_body_id<'ast, B>(
cx: &'ast MarkerContext<'ast>,
visitor: &mut dyn Visitor<B>,
id: Option<BodyId>,
) -> ControlFlow<B> {
if let VisitorScope::NoBodies = visitor.scope() {
return ControlFlow::Continue(());
}
if let Some(body_id) = id {
let body = cx.ast().body(body_id);
traverse_body(cx, visitor, body)?;
}
ControlFlow::Continue(())
}
visitor.visit_item(cx, kind)?;
match kind {
ItemKind::Mod(module) => {
for mod_item in module.items() {
traverse_item(cx, visitor, *mod_item)?;
}
},
ItemKind::Static(item) => {
traverse_body_id(cx, visitor, item.body_id())?;
},
ItemKind::Const(item) => {
traverse_body_id(cx, visitor, item.body_id())?;
},
ItemKind::Fn(item) => {
traverse_body_id(cx, visitor, item.body_id())?;
},
ItemKind::Struct(item) => {
for field in item.fields() {
visitor.visit_field(cx, field)?;
}
},
ItemKind::Union(item) => {
for field in item.fields() {
visitor.visit_field(cx, field)?;
}
},
ItemKind::Enum(item) => {
for variant in item.variants() {
visitor.visit_variant(cx, variant)?;
if let Some(const_expr) = variant.discriminant() {
traverse_expr(cx, visitor, const_expr.expr())?;
}
}
},
ItemKind::Trait(item) => {
for assoc_item in item.items() {
traverse_item(cx, visitor, assoc_item.as_item())?;
}
},
ItemKind::Impl(item) => {
for assoc_item in item.items() {
traverse_item(cx, visitor, assoc_item.as_item())?;
}
},
ItemKind::ExternBlock(item) => {
for ext_item in item.items() {
traverse_item(cx, visitor, ext_item.as_item())?;
}
},
ItemKind::ExternCrate(_) | ItemKind::Use(_) | ItemKind::Unstable(_) | ItemKind::TyAlias(_) => {
},
_ => unreachable!("all items are covered"),
}
ControlFlow::Continue(())
}
pub fn traverse_body<'ast, B>(
cx: &'ast MarkerContext<'ast>,
visitor: &mut dyn Visitor<B>,
body: &'ast ast::Body<'ast>,
) -> ControlFlow<B> {
visitor.visit_body(cx, body)?;
traverse_expr(cx, visitor, body.expr())?;
ControlFlow::Continue(())
}
pub fn traverse_stmt<'ast, B>(
cx: &'ast MarkerContext<'ast>,
visitor: &mut dyn Visitor<B>,
stmt: StmtKind<'ast>,
) -> ControlFlow<B> {
visitor.visit_stmt(cx, stmt)?;
match stmt {
StmtKind::Item(item) => {
traverse_item(cx, visitor, item.item())?;
},
StmtKind::Let(lt) => {
if let Some(init) = lt.init() {
traverse_expr(cx, visitor, init)?;
}
if let Some(els) = lt.els() {
traverse_expr(cx, visitor, els)?;
}
},
StmtKind::Expr(expr) => {
traverse_expr(cx, visitor, expr.expr())?;
},
_ => unreachable!("all statements are covered"),
}
ControlFlow::Continue(())
}
#[allow(clippy::too_many_lines)]
pub fn traverse_expr<'ast, B>(
cx: &'ast MarkerContext<'ast>,
visitor: &mut dyn Visitor<B>,
expr: ExprKind<'ast>,
) -> ControlFlow<B> {
visitor.visit_expr(cx, expr)?;
match expr {
ExprKind::Block(e) => {
for stmt in e.stmts() {
traverse_stmt(cx, visitor, *stmt)?;
}
if let Some(block_expr) = e.expr() {
traverse_expr(cx, visitor, block_expr)?;
}
},
ExprKind::Closure(e) => {
if let VisitorScope::AllBodies = visitor.scope() {
let body = cx.ast().body(e.body_id());
traverse_body(cx, visitor, body)?;
}
},
ExprKind::UnaryOp(e) => {
traverse_expr(cx, visitor, e.expr())?;
},
ExprKind::Ref(e) => {
traverse_expr(cx, visitor, e.expr())?;
},
ExprKind::BinaryOp(e) => {
traverse_expr(cx, visitor, e.left())?;
traverse_expr(cx, visitor, e.right())?;
},
ExprKind::Try(e) => {
traverse_expr(cx, visitor, e.expr())?;
},
ExprKind::Assign(e) => {
traverse_expr(cx, visitor, e.value())?;
},
ExprKind::As(e) => {
traverse_expr(cx, visitor, e.expr())?;
},
ExprKind::Call(e) => {
traverse_expr(cx, visitor, e.func())?;
for arg in e.args() {
traverse_expr(cx, visitor, *arg)?;
}
},
ExprKind::Method(e) => {
traverse_expr(cx, visitor, e.receiver())?;
for arg in e.args() {
traverse_expr(cx, visitor, *arg)?;
}
},
ExprKind::Array(e) => {
for el in e.elements() {
traverse_expr(cx, visitor, *el)?;
}
if let Some(len) = e.len() {
traverse_expr(cx, visitor, len.expr())?;
}
},
ExprKind::Tuple(e) => {
for el in e.elements() {
traverse_expr(cx, visitor, *el)?;
}
},
ExprKind::Ctor(e) => {
for field in e.fields() {
traverse_expr(cx, visitor, field.expr())?;
}
if let Some(base) = e.base() {
traverse_expr(cx, visitor, base)?;
}
},
ExprKind::Range(e) => {
if let Some(start) = e.start() {
traverse_expr(cx, visitor, start)?;
}
if let Some(end) = e.end() {
traverse_expr(cx, visitor, end)?;
}
},
ExprKind::Index(e) => {
traverse_expr(cx, visitor, e.operand())?;
traverse_expr(cx, visitor, e.index())?;
},
ExprKind::Field(e) => {
traverse_expr(cx, visitor, e.operand())?;
},
ExprKind::If(e) => {
traverse_expr(cx, visitor, e.condition())?;
traverse_expr(cx, visitor, e.then())?;
if let Some(els) = e.els() {
traverse_expr(cx, visitor, els)?;
}
},
ExprKind::Let(e) => {
traverse_expr(cx, visitor, e.scrutinee())?;
},
ExprKind::Match(e) => {
traverse_expr(cx, visitor, e.scrutinee())?;
for arm in e.arms() {
if let Some(guard) = arm.guard() {
traverse_expr(cx, visitor, guard)?;
}
traverse_expr(cx, visitor, arm.expr())?;
}
},
ExprKind::Break(e) => {
if let Some(val) = e.expr() {
traverse_expr(cx, visitor, val)?;
}
},
ExprKind::Return(e) => {
if let Some(val) = e.expr() {
traverse_expr(cx, visitor, val)?;
}
},
ExprKind::For(e) => {
traverse_expr(cx, visitor, e.iterable())?;
traverse_expr(cx, visitor, e.block())?;
},
ExprKind::Loop(e) => {
traverse_expr(cx, visitor, e.block())?;
},
ExprKind::While(e) => {
traverse_expr(cx, visitor, e.condition())?;
traverse_expr(cx, visitor, e.block())?;
},
ExprKind::Await(e) => {
traverse_expr(cx, visitor, e.expr())?;
},
ExprKind::IntLit(_)
| ExprKind::FloatLit(_)
| ExprKind::StrLit(_)
| ExprKind::CharLit(_)
| ExprKind::BoolLit(_)
| ExprKind::Unstable(_)
| ExprKind::Path(_)
| ExprKind::Continue(_) => {
},
_ => unreachable!("all expressions are covered"),
}
ControlFlow::Continue(())
}
pub trait Traversable<'ast, B>
where
Self: Sized + Copy,
{
fn traverse(self, cx: &'ast MarkerContext<'ast>, visitor: &mut dyn Visitor<B>) -> ControlFlow<B>;
fn for_each_expr<F: for<'a> FnMut(ExprKind<'a>) -> ControlFlow<B>>(
self,
cx: &'ast MarkerContext<'ast>,
f: F,
) -> Option<B> {
struct ExprVisitor<F> {
f: F,
}
impl<B, F: for<'a> FnMut(ExprKind<'a>) -> ControlFlow<B>> Visitor<B> for ExprVisitor<F> {
fn visit_expr<'v_ast>(
&mut self,
_cx: &'v_ast MarkerContext<'v_ast>,
expr: ExprKind<'v_ast>,
) -> ControlFlow<B> {
(self.f)(expr)
}
}
let mut visitor = ExprVisitor { f };
match self.traverse(cx, &mut visitor) {
ControlFlow::Continue(()) => None,
ControlFlow::Break(b) => Some(b),
}
}
}
macro_rules! impl_traversable_for {
($ty:ty, $func:ident) => {
impl<'ast, B> Traversable<'ast, B> for $ty {
fn traverse(self, cx: &'ast MarkerContext<'ast>, visitor: &mut dyn Visitor<B>) -> ControlFlow<B> {
$func(cx, visitor, self)
}
}
};
}
impl_traversable_for!(ExprKind<'ast>, traverse_expr);
impl_traversable_for!(StmtKind<'ast>, traverse_stmt);
impl_traversable_for!(ItemKind<'ast>, traverse_item);
impl_traversable_for!(&'ast ast::Body<'ast>, traverse_body);
pub trait BoolTraversable<'ast>: Traversable<'ast, bool> {
fn contains_return(&self, cx: &'ast MarkerContext<'ast>) -> bool {
self.for_each_expr(cx, |expr| {
if matches!(expr, ExprKind::Return(_) | ExprKind::Try(_)) {
ControlFlow::Break(true)
} else {
ControlFlow::Continue(())
}
})
.is_some()
}
}
impl<'ast, T: Traversable<'ast, bool>> BoolTraversable<'ast> for T {}