use crate::error::Error;
use crate::util::log::debug;
use crate::util::token_stream::{TokenStreamExt, TokenVecExt};
use proc_macro2::{Group, Ident, Literal, Punct, TokenStream, TokenTree};
use std::collections::VecDeque;
pub struct StreamWalker<'a, V: StreamVisitor> {
visitor: &'a mut V,
}
impl<'a, V: StreamVisitor> StreamWalker<'a, V> {
pub fn new(visitor: &'a mut V) -> Self {
Self { visitor }
}
pub fn walk(&mut self, stream: TokenStream) -> Result<TokenStream, Error> {
debug!("Walking a stream: \"{}\"", stream);
let mut stack = VecDeque::<(usize, Vec<TokenTree>)>::new();
stack.push_back((0, stream.into_vec()));
let mut ctx = VisitorCtx::new(stack);
loop {
match (
ctx.stack().len(),
ctx.current_group_len(),
ctx.is_current_group_exhausted(),
ctx.current_token(),
) {
(0, _, _, _) => {
unreachable!()
}
(1, _, Some(true), _) => {
return match self
.visitor
.exit_group_mut(&ctx, ctx.current_group().unwrap())?
{
StreamVisitorAction::Continue => {
Ok(ctx.current_group().unwrap().to_owned().into_token_stream())
}
StreamVisitorAction::Skip => {
ctx.pop_remove();
Ok(TokenStream::new())
}
StreamVisitorAction::Replace(new_stream) => {
debug!(
"Replacing a group: \"{}\", with \"{}\"",
ctx.current_group().unwrap().to_token_stream(),
new_stream,
);
ctx.replace_current_group(new_stream);
self.visitor.after_replace_mut(&ctx)?;
Ok(ctx.current_group().unwrap().to_token_stream())
}
};
}
(_, _, Some(true), _) => {
match self
.visitor
.exit_group_mut(&ctx, ctx.current_group().unwrap())?
{
StreamVisitorAction::Continue => {
ctx.pop_fold();
ctx.advance_current_group();
}
StreamVisitorAction::Skip => {
ctx.pop_remove();
}
StreamVisitorAction::Replace(new_stream) => {
debug!(
"Replacing a group: \"{}\", with \"{}\"",
ctx.current_group().unwrap().to_token_stream(),
new_stream,
);
ctx.replace_current_group(new_stream);
ctx.pop_fold();
ctx.advance_current_group();
self.visitor.after_replace_mut(&ctx)?;
}
};
}
(_, _, _, Some(TokenTree::Group(group))) => {
match self.visitor.visit_group_mut(&ctx, group)? {
StreamVisitorAction::Continue => {
ctx.push_group(group.stream().to_vec());
match self
.visitor
.enter_group_mut(&ctx, ctx.current_group().unwrap())?
{
StreamVisitorAction::Continue => {}
StreamVisitorAction::Skip => {
ctx.pop_remove();
}
StreamVisitorAction::Replace(new_stream) => {
debug!(
"Replacing a group: {}, with {}",
ctx.current_group().unwrap().to_token_stream(),
new_stream,
);
ctx.replace_current_group(new_stream);
self.visitor.after_replace_mut(&ctx)?;
}
}
}
StreamVisitorAction::Skip => {
ctx.remove_current_token();
}
StreamVisitorAction::Replace(new_stream) if !new_stream.is_empty() => {
debug!(
"Replacing a token: \"{}\", with \"{}\", in \"{}\"",
ctx.current_token().unwrap(),
new_stream,
ctx.current_group().unwrap().to_token_stream(),
);
ctx.replace_current_token(new_stream);
ctx.advance_current_group();
self.visitor.after_replace_mut(&ctx)?;
}
StreamVisitorAction::Replace(new_stream) if new_stream.is_empty() => {
debug!(
"Replacing a token: \"{}\", with \"{}\", in \"{}\"",
ctx.current_token().unwrap(),
new_stream,
ctx.current_group().unwrap().to_token_stream(),
);
ctx.replace_current_token(new_stream);
self.visitor.after_replace_mut(&ctx)?;
}
_ => unreachable!(),
}
}
(_, _, _, Some(token)) => {
let action = match token {
TokenTree::Ident(ident) => self.visitor.visit_ident_mut(&ctx, ident)?,
TokenTree::Punct(punct) => self.visitor.visit_punct_mut(&ctx, punct)?,
TokenTree::Literal(literal) => {
self.visitor.visit_literal_mut(&ctx, literal)?
}
_ => unreachable!(),
};
match action {
StreamVisitorAction::Continue => {
ctx.advance_current_group();
}
StreamVisitorAction::Skip => {
ctx.remove_current_token();
}
StreamVisitorAction::Replace(new_stream) if !new_stream.is_empty() => {
debug!(
"Replacing a token: \"{}\", with \"{}\", in \"{}\"",
ctx.current_token().unwrap(),
new_stream,
ctx.current_group().unwrap().to_token_stream(),
);
ctx.replace_current_token(new_stream);
ctx.advance_current_group();
self.visitor.after_replace_mut(&ctx)?;
}
StreamVisitorAction::Replace(new_stream) if new_stream.is_empty() => {
debug!(
"Replacing a token: \"{}\", with \"{}\", in \"{}\"",
ctx.current_token().unwrap(),
new_stream,
ctx.current_group().unwrap().to_token_stream(),
);
ctx.replace_current_token(new_stream);
self.visitor.after_replace_mut(&ctx)?;
}
_ => unreachable!(),
};
}
_ => unreachable!(),
};
}
}
}
pub struct VisitorCtx {
stack: VecDeque<(usize, Vec<TokenTree>)>,
}
impl VisitorCtx {
pub fn new(stack: VecDeque<(usize, Vec<TokenTree>)>) -> Self {
Self { stack }
}
fn stack(&self) -> &VecDeque<(usize, Vec<TokenTree>)> {
&self.stack
}
fn current_group(&self) -> Option<&[TokenTree]> {
self.stack.back().map(|(_, tokens)| tokens.as_slice())
}
fn current_group_len(&self) -> Option<usize> {
self.stack.back().map(|(_, tokens)| tokens.len())
}
fn is_current_group_exhausted(&self) -> Option<bool> {
self.stack.back().map(|(i, tokens)| *i >= tokens.len())
}
fn current_token(&self) -> Option<&TokenTree> {
self.stack.back().and_then(|(i, tokens)| tokens.get(*i))
}
fn replace_current_token(&mut self, stream: TokenStream) {
if let Some((i, tokens)) = self.stack.back_mut() {
let stream_vec = stream.into_vec();
let stream_len = stream_vec.len();
tokens.splice(*i..=*i, stream_vec);
*i = tokens.len().min(*i + (stream_len.saturating_sub(1)));
}
}
fn remove_current_token(&mut self) {
if let Some((i, tokens)) = self.stack.back_mut() {
if *i < tokens.len() {
tokens.remove(*i);
}
}
}
fn fold_group(group_tokens: &[TokenTree], parent_tokens: &mut [TokenTree], i: usize) {
let TokenTree::Group(original_group) = parent_tokens[i].clone() else {
panic!(
"Expected a group at index {}, found: {:?}",
i, parent_tokens[i]
);
};
let mut new_group = Group::new(
original_group.delimiter(),
group_tokens.iter().cloned().collect::<TokenStream>(),
);
new_group.set_span(original_group.span());
parent_tokens[i] = TokenTree::Group(new_group);
}
fn pop_fold(&mut self) {
let group = self.stack.pop_back();
if let Some((_, tokens)) = group {
if let Some(parent) = self.stack.back_mut() {
Self::fold_group(tokens.as_slice(), &mut parent.1, parent.0);
}
}
}
fn pop_remove(&mut self) {
let group = self.stack.pop_back();
if group.is_some() {
if let Some(parent) = self.stack.back_mut() {
parent.1.remove(parent.0);
}
}
}
fn advance_current_group(&mut self) {
if let Some((i, tokens)) = self.stack.back_mut() {
*i = tokens.len().min(*i + 1);
}
}
fn replace_current_group(&mut self, stream: TokenStream) {
if let Some((_, tokens)) = self.stack.back_mut() {
tokens.clear();
tokens.extend(stream.into_vec());
}
}
fn push_group(&mut self, group: Vec<TokenTree>) {
self.stack.push_back((0, group));
}
pub fn current_stream(&self) -> TokenStream {
match self.stack.len() {
0 => return TokenStream::new(),
1 => {
let (_, tokens) = self.stack[0].clone();
return tokens.into_iter().collect();
}
_ => {}
};
let mut iter = self.stack.iter().cloned().rev();
let back = iter.next().unwrap().1;
let folded = iter.fold(back, |acc, (i, mut token_vec)| {
Self::fold_group(&acc, &mut token_vec, i);
token_vec.into_iter().collect::<Vec<_>>()
});
folded.into_iter().collect()
}
}
#[allow(dead_code)]
pub enum StreamVisitorAction {
Continue,
Skip,
Replace(TokenStream),
}
#[allow(unused_variables)]
pub trait StreamVisitor {
fn visit_ident_mut(
&mut self,
ctx: &VisitorCtx,
ident: &Ident,
) -> Result<StreamVisitorAction, Error> {
Ok(StreamVisitorAction::Continue)
}
fn visit_punct_mut(
&mut self,
ctx: &VisitorCtx,
punct: &Punct,
) -> Result<StreamVisitorAction, Error> {
Ok(StreamVisitorAction::Continue)
}
fn visit_literal_mut(
&mut self,
ctx: &VisitorCtx,
literal: &Literal,
) -> Result<StreamVisitorAction, Error> {
Ok(StreamVisitorAction::Continue)
}
fn visit_group_mut(
&mut self,
ctx: &VisitorCtx,
group: &Group,
) -> Result<StreamVisitorAction, Error> {
Ok(StreamVisitorAction::Continue)
}
fn enter_group_mut(
&mut self,
ctx: &VisitorCtx,
group: &[TokenTree],
) -> Result<StreamVisitorAction, Error> {
Ok(StreamVisitorAction::Continue)
}
fn exit_group_mut(
&mut self,
ctx: &VisitorCtx,
group: &[TokenTree],
) -> Result<StreamVisitorAction, Error> {
Ok(StreamVisitorAction::Continue)
}
fn after_replace_mut(&mut self, ctx: &VisitorCtx) -> Result<(), Error> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::{StreamVisitor, StreamVisitorAction, StreamWalker};
use crate::error::Error;
use proc_macro2::{Delimiter, Group, Ident, Literal, Punct, TokenStream, TokenTree};
use quote::quote;
use rstest::rstest;
#[derive(Debug, PartialEq)]
enum LogEvent {
VisitIdent(String),
VisitPunct(char),
VisitLiteral(String),
VisitGroup(Delimiter),
EnterGroup,
ExitGroup,
AfterReplace,
}
struct TestVisitor {
log: Vec<LogEvent>,
on_ident: Option<Box<dyn FnMut(&Ident) -> Result<StreamVisitorAction, Error>>>,
on_punct: Option<Box<dyn FnMut(&Punct) -> Result<StreamVisitorAction, Error>>>,
on_literal: Option<Box<dyn FnMut(&Literal) -> Result<StreamVisitorAction, Error>>>,
on_visit_group: Option<Box<dyn FnMut(&Group) -> Result<StreamVisitorAction, Error>>>,
on_enter_group: Option<Box<dyn FnMut(&[TokenTree]) -> Result<StreamVisitorAction, Error>>>,
on_exit_group: Option<Box<dyn FnMut(&[TokenTree]) -> Result<StreamVisitorAction, Error>>>,
}
impl TestVisitor {
fn new() -> Self {
Self {
log: Vec::new(),
on_ident: None,
on_punct: None,
on_literal: None,
on_visit_group: None,
on_enter_group: None,
on_exit_group: None,
}
}
fn after_replace_count(&self) -> usize {
self.log
.iter()
.filter(|e| matches!(e, LogEvent::AfterReplace))
.count()
}
}
impl StreamVisitor for TestVisitor {
fn visit_ident_mut(
&mut self,
_ctx: &super::VisitorCtx,
ident: &Ident,
) -> Result<StreamVisitorAction, Error> {
self.log.push(LogEvent::VisitIdent(ident.to_string()));
if let Some(cb) = self.on_ident.as_mut() {
cb(ident)
} else {
Ok(StreamVisitorAction::Continue)
}
}
fn visit_punct_mut(
&mut self,
_ctx: &super::VisitorCtx,
punct: &Punct,
) -> Result<StreamVisitorAction, Error> {
self.log.push(LogEvent::VisitPunct(punct.as_char()));
if let Some(cb) = self.on_punct.as_mut() {
cb(punct)
} else {
Ok(StreamVisitorAction::Continue)
}
}
fn visit_literal_mut(
&mut self,
_ctx: &super::VisitorCtx,
lit: &Literal,
) -> Result<StreamVisitorAction, Error> {
self.log.push(LogEvent::VisitLiteral(lit.to_string()));
if let Some(cb) = self.on_literal.as_mut() {
cb(lit)
} else {
Ok(StreamVisitorAction::Continue)
}
}
fn visit_group_mut(
&mut self,
_ctx: &super::VisitorCtx,
group: &Group,
) -> Result<StreamVisitorAction, Error> {
self.log.push(LogEvent::VisitGroup(group.delimiter()));
if let Some(cb) = self.on_visit_group.as_mut() {
cb(group)
} else {
Ok(StreamVisitorAction::Continue)
}
}
fn enter_group_mut(
&mut self,
_ctx: &super::VisitorCtx,
group: &[TokenTree],
) -> Result<StreamVisitorAction, Error> {
let _ = group;
self.log.push(LogEvent::EnterGroup);
if let Some(cb) = self.on_enter_group.as_mut() {
cb(group)
} else {
Ok(StreamVisitorAction::Continue)
}
}
fn exit_group_mut(
&mut self,
_ctx: &super::VisitorCtx,
group: &[TokenTree],
) -> Result<StreamVisitorAction, Error> {
let _ = group;
self.log.push(LogEvent::ExitGroup);
if let Some(cb) = self.on_exit_group.as_mut() {
cb(group)
} else {
Ok(StreamVisitorAction::Continue)
}
}
fn after_replace_mut(&mut self, _ctx: &super::VisitorCtx) -> Result<(), Error> {
self.log.push(LogEvent::AfterReplace);
Ok(())
}
}
#[rstest]
fn visit_order_simple_group() {
let input: TokenStream = quote!((a));
let mut visitor = TestVisitor::new();
let mut walker = StreamWalker::new(&mut visitor);
let actual = walker.walk(input).unwrap();
assert_eq!(actual.to_string(), "(a)");
assert_eq!(
visitor.log,
vec![
LogEvent::VisitGroup(Delimiter::Parenthesis),
LogEvent::EnterGroup,
LogEvent::VisitIdent("a".into()),
LogEvent::ExitGroup,
LogEvent::ExitGroup,
]
);
}
#[rstest]
fn replace_ident_with_multiple_tokens() {
let input: TokenStream = quote!(a);
let mut visitor = TestVisitor::new();
visitor.on_ident = Some(Box::new(|id: &Ident| {
if id == "a" {
Ok(StreamVisitorAction::Replace(quote!(x y)))
} else {
Ok(StreamVisitorAction::Continue)
}
}));
let mut walker = StreamWalker::new(&mut visitor);
let actual = walker.walk(input).unwrap();
assert_eq!(actual.to_string(), "x y");
assert_eq!(visitor.after_replace_count(), 1);
}
#[rstest]
fn replace_ident_with_empty_stream_triggers_after_replace() {
let input: TokenStream = quote!(a);
let mut visitor = TestVisitor::new();
visitor.on_ident = Some(Box::new(|id: &Ident| {
if id == "a" {
Ok(StreamVisitorAction::Replace(quote!()))
} else {
Ok(StreamVisitorAction::Continue)
}
}));
let mut walker = StreamWalker::new(&mut visitor);
let actual = walker.walk(input).unwrap();
assert_eq!(actual.to_string(), "");
assert_eq!(visitor.after_replace_count(), 1);
}
#[rstest]
fn remove_ident_by_skipping() {
let input: TokenStream = quote!(a b);
let mut visitor = TestVisitor::new();
visitor.on_ident = Some(Box::new(|id: &Ident| {
if id == "a" {
Ok(StreamVisitorAction::Skip)
} else {
Ok(StreamVisitorAction::Continue)
}
}));
let mut walker = StreamWalker::new(&mut visitor);
let actual = walker.walk(input).unwrap();
assert_eq!(actual.to_string(), "b");
}
#[rstest]
fn replace_group_on_visit() {
let input: TokenStream = quote!((a));
let mut visitor = TestVisitor::new();
visitor.on_visit_group = Some(Box::new(|g: &Group| {
if g.delimiter() == Delimiter::Parenthesis {
Ok(StreamVisitorAction::Replace(quote!(x + y)))
} else {
Ok(StreamVisitorAction::Continue)
}
}));
let mut walker = StreamWalker::new(&mut visitor);
let actual = walker.walk(input).unwrap();
assert_eq!(actual.to_string(), "x + y");
assert_eq!(visitor.after_replace_count(), 1);
}
#[rstest]
fn skip_group_on_exit_removes_it() {
let input: TokenStream = quote!((a) c);
let mut visitor = TestVisitor::new();
visitor.on_exit_group = Some(Box::new(|group: &[TokenTree]| {
let ts: TokenStream = group.iter().cloned().collect();
if ts.to_string() == "a" {
Ok(StreamVisitorAction::Skip)
} else {
Ok(StreamVisitorAction::Continue)
}
}));
let mut walker = StreamWalker::new(&mut visitor);
let actual = walker.walk(input).unwrap();
assert_eq!(actual.to_string(), "c");
}
#[rstest]
fn error_is_propagated() {
let input: TokenStream = quote!(boom);
let mut visitor = TestVisitor::new();
visitor.on_ident = Some(Box::new(|id: &Ident| {
if id == "boom" {
Err(Error::make_internal_error("boom".into()))
} else {
Ok(StreamVisitorAction::Continue)
}
}));
let mut walker = StreamWalker::new(&mut visitor);
let err = walker.walk(input).expect_err("expected error");
assert!(matches!(err, Error::InternalError(_)));
}
}