use proc_macro::TokenStream;
use quote::{format_ident, quote};
use std::collections::HashSet;
use syn::{
Block, Expr, Ident, ItemFn, Pat, Path, Result, Type,
parse::Parse,
parse::ParseStream,
parse_macro_input, parse_quote,
visit::{self, Visit},
visit_mut::{self, VisitMut},
};
struct StatefulArgs<F: Fn(Ident) -> Expr> {
state_type: Type,
state_expr_builder: F,
}
struct StatefulAttr {
state_type: Type,
}
impl Parse for StatefulAttr {
fn parse(input: ParseStream) -> Result<Self> {
let state_type = input.parse()?;
Ok(StatefulAttr { state_type })
}
}
struct StatefulExprAttr {
state_type: Type,
state_expr: Expr,
}
impl Parse for StatefulExprAttr {
fn parse(input: ParseStream) -> Result<Self> {
let state_type: Type = input.parse()?;
input.parse::<syn::Token![,]>()?;
let state_expr: Expr = input.parse()?;
Ok(StatefulExprAttr {
state_type,
state_expr,
})
}
}
struct PathReplacer {
new_path: Path,
}
impl PathReplacer {
fn new(new_path: Path) -> Self {
PathReplacer { new_path }
}
}
impl VisitMut for PathReplacer {
fn visit_expr_mut(&mut self, expr: &mut Expr) {
if let Expr::Path(expr_path) = expr {
if expr_path.path.segments.len() == 1 {
let ident = &expr_path.path.segments.first().unwrap().ident;
if ident == "state" {
expr_path.path = self.new_path.clone();
} else {
panic!("Invalid identifier '{}' in state expression. Only 'state' is allowed.", ident);
}
}
}
visit_mut::visit_expr_mut(self, expr);
}
}
fn fresh_name(used_idents: &HashSet<String>, prefix: &str) -> Ident {
if !used_idents.contains(prefix) {
return format_ident!("{}", prefix);
}
let conflicting_names: Vec<String> = used_idents
.iter()
.filter(|name| name.starts_with(prefix))
.cloned()
.collect();
let max_suffix_len = conflicting_names
.iter()
.map(|name| name.len())
.max()
.unwrap_or(0)
.saturating_sub(prefix.len()) .saturating_add(1);
let suffix = "_".repeat(max_suffix_len);
format_ident!("{}{}", prefix, suffix)
}
struct IdentVisitor {
identifiers: HashSet<String>,
}
impl IdentVisitor {
fn new() -> Self {
IdentVisitor {
identifiers: HashSet::new(),
}
}
fn add_ident(&mut self, ident: &Ident) {
self.identifiers.insert(ident.to_string());
}
}
impl<'ast> Visit<'ast> for IdentVisitor {
fn visit_pat(&mut self, pat: &'ast Pat) {
if let Pat::Ident(pat_ident) = pat {
self.add_ident(&pat_ident.ident);
}
visit::visit_pat(self, pat);
}
fn visit_expr(&mut self, expr: &'ast Expr) {
if let Expr::Path(expr_path) = expr {
if let Some(segment) = expr_path.path.segments.first() {
self.add_ident(&segment.ident);
}
}
visit::visit_expr(self, expr);
}
fn visit_path(&mut self, path: &'ast Path) {
for segment in &path.segments {
self.add_ident(&segment.ident);
}
visit::visit_path(self, path);
}
}
fn make_stateful<F: Fn(Ident) -> Expr>(
item: TokenStream,
StatefulArgs {
state_type,
state_expr_builder,
}: StatefulArgs<F>,
) -> TokenStream {
let mut input_fn = parse_macro_input!(item as ItemFn);
let fn_block = input_fn.block;
let mut visitor = IdentVisitor::new();
visitor.visit_block(&fn_block);
for param in &input_fn.sig.inputs {
if let syn::FnArg::Typed(pat_type) = param {
visitor.visit_pat(&pat_type.pat);
}
}
let state_ident = fresh_name(&visitor.identifiers, "state");
let state_expr = state_expr_builder(state_ident.clone());
let mut new_inputs = Vec::new();
new_inputs.push(syn::parse_quote! { #state_ident: #state_type });
new_inputs.extend(input_fn.sig.inputs.iter().cloned());
input_fn.sig.inputs = syn::punctuated::Punctuated::from_iter(new_inputs);
let fn_body = fn_block.stmts;
let new_block: Block = syn::parse_quote! {
{
::state_macro::with_state! { #state_expr;
#(#fn_body)*
}
}
};
input_fn.block = Box::new(new_block);
TokenStream::from(quote! { #input_fn })
}
pub fn stateful(attr: TokenStream, item: TokenStream) -> TokenStream {
let StatefulAttr { state_type } = parse_macro_input!(attr as StatefulAttr);
let args = StatefulArgs {
state_type,
state_expr_builder: |state_ident| parse_quote!(#state_ident),
};
make_stateful(item, args)
}
pub fn stateful_cloned(attr: TokenStream, item: TokenStream) -> TokenStream {
let StatefulAttr { state_type } = parse_macro_input!(attr as StatefulAttr);
let args = StatefulArgs {
state_type,
state_expr_builder: |state_ident| parse_quote!(#state_ident.clone()),
};
make_stateful(item, args)
}
pub fn stateful_expr(attr: TokenStream, item: TokenStream) -> TokenStream {
let StatefulExprAttr {
state_type,
state_expr,
} = parse_macro_input!(attr as StatefulExprAttr);
let args = StatefulArgs {
state_type,
state_expr_builder: |state_ident| {
let new_path: Path = parse_quote!(#state_ident);
let mut expr_copy = state_expr.clone();
let mut replacer = PathReplacer::new(new_path);
replacer.visit_expr_mut(&mut expr_copy);
expr_copy
},
};
make_stateful(item, args)
}