state-macro 0.1.1

Syntax sugar for stateful functions
Documentation
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use std::collections::HashSet;
use syn::{
    Block, Expr, Ident, ItemFn, Pat, Path, Result, Type,
    parse::Parse,
    parse::ParseStream,
    parse_macro_input, parse_quote,
    visit::{self, Visit},
    visit_mut::{self, VisitMut},
};

struct StatefulArgs<F: Fn(Ident) -> Expr> {
    /// Type of the state argument
    state_type: Type,

    /// Make an expression from the state identifier.
    state_expr_builder: F,
}

/// Parses the stateful attribute input, which should be a Type
struct StatefulAttr {
    state_type: Type,
}

impl Parse for StatefulAttr {
    fn parse(input: ParseStream) -> Result<Self> {
        let state_type = input.parse()?;
        Ok(StatefulAttr { state_type })
    }
}

struct StatefulExprAttr {
    state_type: Type,
    state_expr: Expr,
}

impl Parse for StatefulExprAttr {
    fn parse(input: ParseStream) -> Result<Self> {
        let state_type: Type = input.parse()?;
        input.parse::<syn::Token![,]>()?;
        let state_expr: Expr = input.parse()?;
        Ok(StatefulExprAttr {
            state_type,
            state_expr,
        })
    }
}

/// A visitor that replaces all occurrences of 'state' paths with a new path
/// Panics if any path identifier is not 'state'
struct PathReplacer {
    new_path: Path,
}

impl PathReplacer {
    fn new(new_path: Path) -> Self {
        PathReplacer { new_path }
    }
}

impl VisitMut for PathReplacer {
    fn visit_expr_mut(&mut self, expr: &mut Expr) {
        if let Expr::Path(expr_path) = expr {
            // Check if this is a simple identifier path
            if expr_path.path.segments.len() == 1 {
                let ident = &expr_path.path.segments.first().unwrap().ident;
                if ident == "state" {
                    expr_path.path = self.new_path.clone();
                } else {
                    panic!("Invalid identifier '{}' in state expression. Only 'state' is allowed.", ident);
                }
            }
        }

        // Continue visiting the expression
        visit_mut::visit_expr_mut(self, expr);
    }
}

/// Generate a fresh identifier that doesn't collide with existing identifiers
///
/// * `used_idents` - Set of identifiers that are already in use
/// * `prefix` - Base prefix for the generated name
///
/// This function:
///
/// 1. Starts with the base prefix
/// 2. If that's available, returns it directly
/// 3. Otherwise, finds all identifiers that share the prefix
/// 4. Generates a suffix of underscores of length n+1, where n = max length of used_idents.
///
/// TODO: find the *shortest* prefix that makes this work.
fn fresh_name(used_idents: &HashSet<String>, prefix: &str) -> Ident {
    // First, check if the prefix itself is available
    if !used_idents.contains(prefix) {
        return format_ident!("{}", prefix);
    }

    // Filter the set to only include identifiers that start with our prefix
    let conflicting_names: Vec<String> = used_idents
        .iter()
        .filter(|name| name.starts_with(prefix))
        .cloned()
        .collect();

    // Generate a suffix of underscores whose length is the maximum length of any conflicting name plus 1
    // This ensures our new identifier is unique by making it longer than any existing identifier with the same prefix
    let max_suffix_len = conflicting_names
        .iter()
        .map(|name| name.len())
        .max()
        .unwrap_or(0)
        .saturating_sub(prefix.len()) // Get just the suffix part length
        .saturating_add(1); // Add 1 to make it longer than any existing suffix

    // Create a string of underscores with the calculated length
    let suffix = "_".repeat(max_suffix_len);

    // Create the final identifier with the prefix and suffix
    format_ident!("{}{}", prefix, suffix)
}

/// A visitor that collects all identifiers used in a syntax tree
struct IdentVisitor {
    identifiers: HashSet<String>,
}

impl IdentVisitor {
    fn new() -> Self {
        IdentVisitor {
            identifiers: HashSet::new(),
        }
    }

    /// Add an identifier to the set
    fn add_ident(&mut self, ident: &Ident) {
        self.identifiers.insert(ident.to_string());
    }
}

impl<'ast> Visit<'ast> for IdentVisitor {
    // Collect identifiers from patterns (like in let statements and function params)
    fn visit_pat(&mut self, pat: &'ast Pat) {
        // Check for pattern identifiers like in `let x = 1;`
        if let Pat::Ident(pat_ident) = pat {
            self.add_ident(&pat_ident.ident);
        }

        // Continue visiting the pattern
        visit::visit_pat(self, pat);
    }

    // Collect identifiers from expressions
    fn visit_expr(&mut self, expr: &'ast Expr) {
        // Check for path expressions like `x` or `foo::bar`
        if let Expr::Path(expr_path) = expr {
            if let Some(segment) = expr_path.path.segments.first() {
                self.add_ident(&segment.ident);
            }
        }

        // Continue visiting the expression
        visit::visit_expr(self, expr);
    }

    // Collect identifiers from paths
    fn visit_path(&mut self, path: &'ast Path) {
        // Add each segment of the path
        for segment in &path.segments {
            self.add_ident(&segment.ident);
        }

        // Continue visiting the path
        visit::visit_path(self, path);
    }
}

fn make_stateful<F: Fn(Ident) -> Expr>(
    item: TokenStream,
    StatefulArgs {
        state_type,
        state_expr_builder,
    }: StatefulArgs<F>,
) -> TokenStream {
    // Parse the function to transform
    let mut input_fn = parse_macro_input!(item as ItemFn);
    let fn_block = input_fn.block;

    // Collect all identifiers used in the function body
    let mut visitor = IdentVisitor::new();
    visitor.visit_block(&fn_block);

    // Also collect identifiers from function parameters
    for param in &input_fn.sig.inputs {
        if let syn::FnArg::Typed(pat_type) = param {
            visitor.visit_pat(&pat_type.pat);
        }
    }

    // Generate a fresh state parameter name that doesn't collide
    // Use the state expr unchanged; it's up to the caller to ensure safety.
    let state_ident = fresh_name(&visitor.identifiers, "state");

    // Generate the state expression from the state param
    let state_expr = state_expr_builder(state_ident.clone());

    // Add the state parameter to the function signature
    let mut new_inputs = Vec::new();
    new_inputs.push(syn::parse_quote! { #state_ident: #state_type });
    new_inputs.extend(input_fn.sig.inputs.iter().cloned());
    input_fn.sig.inputs = syn::punctuated::Punctuated::from_iter(new_inputs);

    // Create the new function block wrapped in with_state!
    let fn_body = fn_block.stmts;
    let new_block: Block = syn::parse_quote! {
        {
            ::state_macro::with_state! { #state_expr;
                #(#fn_body)*
            }
        }
    };

    // Update the function with the new block
    input_fn.block = Box::new(new_block);

    // Return the transformed function
    TokenStream::from(quote! { #input_fn })
}

/// Implements the #[stateful(StateType)] attribute macro
pub fn stateful(attr: TokenStream, item: TokenStream) -> TokenStream {
    // Parse the attribute input (the state type)
    let StatefulAttr { state_type } = parse_macro_input!(attr as StatefulAttr);

    let args = StatefulArgs {
        state_type,
        state_expr_builder: |state_ident| parse_quote!(#state_ident),
    };

    make_stateful(item, args)
}

/// Implements the #[stateful_cloned(StateType)] attribute macro
pub fn stateful_cloned(attr: TokenStream, item: TokenStream) -> TokenStream {
    // Parse the attribute input (the state type)
    let StatefulAttr { state_type } = parse_macro_input!(attr as StatefulAttr);

    let args = StatefulArgs {
        state_type,
        state_expr_builder: |state_ident| parse_quote!(#state_ident.clone()),
    };

    make_stateful(item, args)
}

/// Implements the #[stateful_expr(StateType, expr)] attribute macro
pub fn stateful_expr(attr: TokenStream, item: TokenStream) -> TokenStream {
    let StatefulExprAttr {
        state_type,
        state_expr,
    } = parse_macro_input!(attr as StatefulExprAttr);

    let args = StatefulArgs {
        state_type,
        state_expr_builder: |state_ident| {
            let new_path: Path = parse_quote!(#state_ident);
            let mut expr_copy = state_expr.clone();
            let mut replacer = PathReplacer::new(new_path);
            replacer.visit_expr_mut(&mut expr_copy);
            expr_copy
        },
    };

    make_stateful(item, args)
}