cheers-macros 0.1.0-alpha.1

Procedural macros for Cheers.
use std::collections::BTreeSet;

use proc_macro2::TokenStream;
use quote::ToTokens;
use syn::{
    Attribute, GenericParam, Generics, Ident, Lifetime, Path, Signature, Token, Type, Visibility,
    WherePredicate, braced,
    parse::{Parse, ParseStream},
    punctuated::Punctuated,
    visit::{Visit, visit_path, visit_where_predicate},
};

#[derive(Debug, Clone)]
pub struct MaybeItemFn {
    pub outer_attrs: Vec<Attribute>,
    pub inner_attrs: Vec<Attribute>,
    pub vis: Visibility,
    pub sig: Signature,
    pub block: TokenStream,
}

impl Parse for MaybeItemFn {
    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
        let outer_attrs = input.call(Attribute::parse_outer)?;
        let vis: Visibility = input.parse()?;
        let sig: Signature = input.parse()?;
        let inner_attrs = input.call(Attribute::parse_inner)?;
        let block;
        let _ = braced!(block in input);
        let block: TokenStream = block.call(|buffer| buffer.parse())?;
        Ok(Self {
            outer_attrs,
            inner_attrs,
            vis,
            sig,
            block,
        })
    }
}

impl ToTokens for MaybeItemFn {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        self.outer_attrs
            .iter()
            .for_each(|attr| attr.to_tokens(tokens));
        self.vis.to_tokens(tokens);
        self.sig.to_tokens(tokens);
        self.inner_attrs
            .iter()
            .for_each(|attr| attr.to_tokens(tokens));
        syn::token::Brace::default().surround(tokens, |tokens| {
            self.block.to_tokens(tokens);
        });
    }
}

pub fn parse_named_type(
    input: ParseStream<'_>,
    missing_type_error: &'static str,
) -> syn::Result<(Ident, Type)> {
    let name = input.parse()?;
    input
        .parse::<Token![:]>()
        .map_err(|_| syn::Error::new_spanned(&name, missing_type_error))?;
    let ty = input.parse()?;

    Ok((name, ty))
}

fn collect_used_generic_names<'a>(types: impl IntoIterator<Item = &'a Type>) -> BTreeSet<String> {
    struct Visitor {
        used: BTreeSet<String>,
    }

    impl<'a> Visit<'a> for Visitor {
        fn visit_path(&mut self, path: &'a Path) {
            if let Some(ident) = path.get_ident() {
                self.used.insert(ident.to_string());
            }
            visit_path(self, path);
        }

        fn visit_lifetime(&mut self, lifetime: &'a Lifetime) {
            self.used.insert(lifetime.ident.to_string());
        }
    }

    let mut visitor = Visitor {
        used: BTreeSet::new(),
    };

    for ty in types {
        visitor.visit_type(ty);
    }

    visitor.used
}

fn collect_predicate_generic_names(predicate: &WherePredicate) -> BTreeSet<String> {
    struct Visitor {
        used: BTreeSet<String>,
    }

    impl<'a> Visit<'a> for Visitor {
        fn visit_path(&mut self, path: &'a Path) {
            if let Some(ident) = path.get_ident() {
                self.used.insert(ident.to_string());
            }
            visit_path(self, path);
        }

        fn visit_lifetime(&mut self, lifetime: &'a Lifetime) {
            self.used.insert(lifetime.ident.to_string());
        }

        fn visit_where_predicate(&mut self, predicate: &'a WherePredicate) {
            visit_where_predicate(self, predicate);
        }
    }

    let mut visitor = Visitor {
        used: BTreeSet::new(),
    };
    visitor.visit_where_predicate(predicate);
    visitor.used
}

fn generic_param_name(param: &GenericParam, remove_lifetimes: bool) -> Option<String> {
    match param {
        GenericParam::Lifetime(lifetime) => {
            (!remove_lifetimes).then(|| lifetime.lifetime.ident.to_string())
        }
        GenericParam::Type(ty) => Some(ty.ident.to_string()),
        GenericParam::Const(const_param) => Some(const_param.ident.to_string()),
    }
}

pub fn filter_generics<'a>(
    mut generics: Generics,
    types: impl IntoIterator<Item = &'a Type>,
    remove_lifetimes: bool,
) -> Generics {
    let used_names = collect_used_generic_names(types);

    let removed_names = generics
        .params
        .iter()
        .filter_map(|param| {
            let name = generic_param_name(param, false)?;
            let keep = generic_param_name(param, remove_lifetimes)
                .is_some_and(|name| used_names.contains(&name));

            (!keep).then_some(name)
        })
        .collect::<BTreeSet<_>>();

    let mut filtered_params = Punctuated::<GenericParam, Token![,]>::new();
    for param in generics.params.into_iter().filter(|param| {
        generic_param_name(param, remove_lifetimes).is_some_and(|name| used_names.contains(&name))
    }) {
        filtered_params.push(param);
    }
    generics.params = filtered_params;

    if let Some(mut where_clause) = generics.where_clause.take() {
        let mut filtered_predicates = Punctuated::new();
        for predicate in where_clause.predicates.into_iter().filter(|predicate| {
            collect_predicate_generic_names(predicate)
                .iter()
                .all(|name| !removed_names.contains(name))
        }) {
            filtered_predicates.push(predicate);
        }
        where_clause.predicates = filtered_predicates;

        if !where_clause.predicates.is_empty() {
            generics.where_clause = Some(where_clause);
        }
    }

    generics
}

pub fn to_pascal_case(s: &str) -> String {
    let mut result = String::new();
    let mut capitalize_next = true;

    for c in s.chars() {
        if c == '_' {
            capitalize_next = true;
        } else if capitalize_next {
            result.push(c.to_ascii_uppercase());
            capitalize_next = false;
        } else {
            result.push(c);
        }
    }

    result
}

#[cfg(test)]
mod test {
    use quote::quote;
    use syn::parse_quote;

    use super::filter_generics;

    #[test]
    fn filter_generics_removes_stale_where_predicates() {
        let item: syn::ItemStruct = parse_quote! {
            struct Example<T, U>
            where
                T: Clone,
                U: Clone,
                T: Into<U>,
            {
                value: T,
            }
        };
        let generics = item.generics;
        let ty: syn::Type = parse_quote!(T);

        let filtered = filter_generics(generics, [&ty], false);
        let where_clause = filtered.where_clause.as_ref();

        assert_eq!(quote!(#filtered).to_string(), quote!(<T>).to_string());
        assert_eq!(
            quote!(#where_clause).to_string(),
            quote!(where T: Clone).to_string()
        );
    }
}