use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
fn find_arg_name<'a>(signature: &'a syn::Signature, type_name: &'a syn::Ident) -> &'a syn::Ident {
for arg in &signature.inputs {
let arg = match arg {
syn::FnArg::Typed(arg) => arg,
_ => continue,
};
let path = match arg.ty.as_ref() {
syn::Type::Path(path) => &path.path,
_ => continue,
};
if path.is_ident(type_name) {
match arg.pat.as_ref() {
syn::Pat::Ident(ident) => return &ident.ident,
_ => panic!("Argument is an unsupported pattern."),
}
}
}
panic!("Could not find argument with type {type_name}.");
}
fn find_dependent_generics(
signature: &syn::Signature,
bound: &syn::TraitBound,
) -> proc_macro2::TokenStream {
use std::collections::HashMap;
fn visit_path_idents(p: &syn::Path, f: &mut impl FnMut(&syn::Ident)) {
for segment in &p.segments {
f(&segment.ident);
match &segment.arguments {
syn::PathArguments::AngleBracketed(arguments) => {
for argument in &arguments.args {
match argument {
syn::GenericArgument::Type(t) => visit_type_idents(t, f),
syn::GenericArgument::AssocType(t) => visit_type_idents(&t.ty, f),
syn::GenericArgument::AssocConst(_c) => {
}
_ => (),
}
}
}
syn::PathArguments::Parenthesized(arguments) => {
for t in &arguments.inputs {
visit_type_idents(t, f);
}
if let syn::ReturnType::Type(_, t) = &arguments.output {
visit_type_idents(t, f);
}
}
_ => (),
}
}
}
fn visit_type_idents(t: &syn::Type, f: &mut impl FnMut(&syn::Ident)) {
match t {
syn::Type::Paren(t) => visit_type_idents(&t.elem, f),
syn::Type::Path(p) => {
if let Some(qs) = &p.qself {
visit_type_idents(&qs.ty, f);
}
visit_path_idents(&p.path, f)
}
syn::Type::Array(a) => visit_type_idents(&a.elem, f),
syn::Type::Group(g) => visit_type_idents(&g.elem, f),
syn::Type::Reference(r) => visit_type_idents(&r.elem, f),
syn::Type::Slice(s) => visit_type_idents(&s.elem, f),
syn::Type::Tuple(t) => {
for t in &t.elems {
visit_type_idents(t, f);
}
}
_ => (),
}
}
fn check_new_dependent(
relevant: &mut HashMap<syn::Ident, bool>,
bound: &syn::TraitBound,
) {
let mut new_idents = Vec::new();
visit_path_idents(&bound.path, &mut |ident| {
if let Some(r) = relevant.get_mut(ident) {
if !(*r) {
*r = true;
new_idents.push(ident.clone());
check_new_dependent( relevant, bound);
}
}
});
}
let mut relevant: HashMap<syn::Ident, bool> = signature
.generics
.type_params()
.map(|t| (t.ident.clone(), false))
.collect();
let mut bounds = HashMap::new();
for bound in signature
.generics
.type_params()
.flat_map(|t| &t.bounds)
.filter_map(|bounds| match bounds {
syn::TypeParamBound::Trait(bound) => Some(bound),
_ => None,
})
{
visit_path_idents(&bound.path, &mut |ident| {
bounds
.entry(ident.clone())
.or_insert_with(Vec::new)
.push(bound);
});
}
check_new_dependent( &mut relevant, bound);
let generics: Vec<_> = signature
.generics
.type_params()
.filter(|t| relevant[&t.ident])
.collect();
quote! {
#(#generics),*
}
}
fn find_fn_generic(
signature: &syn::Signature,
) -> Option<(&syn::TraitBound, &syn::Ident, Option<&syn::Ident>)> {
for param in signature.generics.type_params() {
for bound in ¶m.bounds {
let bound = match bound {
syn::TypeParamBound::Trait(bound) => bound,
_ => continue,
};
let segment = match bound.path.segments.iter().last() {
Some(segment) => segment,
None => continue,
};
if segment.ident == "Fn" || segment.ident == "FnMut" || segment.ident == "FnOnce" {
let arg_name = find_arg_name(signature, ¶m.ident);
return Some((bound, arg_name, Some(¶m.ident)));
}
}
}
for predicate in &signature.generics.where_clause.as_ref()?.predicates {
let predicate = match predicate {
syn::WherePredicate::Type(predicate) => predicate,
_ => continue,
};
for bound in &predicate.bounds {
let bound = match bound {
syn::TypeParamBound::Trait(bound) => bound,
_ => continue,
};
let segment = match bound.path.segments.iter().last() {
Some(segment) => segment,
None => continue,
};
if segment.ident == "Fn" || segment.ident == "FnMut" || segment.ident == "FnOnce" {
let ident = match &predicate.bounded_ty {
syn::Type::Path(path) => path
.path
.get_ident()
.expect("expected single-ident for this type"),
_ => panic!("expected generic type for this bound"),
};
let arg_name = find_arg_name(signature, ident);
return Some((bound, arg_name, Some(ident)));
}
}
}
for arg in &signature.inputs {
let arg = match arg {
syn::FnArg::Typed(arg) => arg,
_ => continue,
};
let impl_trait = match arg.ty.as_ref() {
syn::Type::ImplTrait(impl_trait) => impl_trait,
_ => continue,
};
for bound in &impl_trait.bounds {
let bound = match bound {
syn::TypeParamBound::Trait(bound) => bound,
_ => continue,
};
let segment = match bound.path.segments.iter().last() {
Some(segment) => segment,
None => continue,
};
if segment.ident == "Fn" || segment.ident == "FnMut" || segment.ident == "FnOnce" {
let arg_name = match arg.pat.as_ref() {
syn::Pat::Ident(ident) => &ident.ident,
_ => panic!("Argument is an unsupported pattern."),
};
return Some((bound, arg_name, None));
}
}
}
None
}
fn bound_to_dyn(bound: &syn::TraitBound) -> proc_macro2::TokenStream {
quote!(::std::sync::Arc<dyn #bound + Send + Sync>)
}
fn get_arity(bound: &syn::TraitBound) -> usize {
let segment = bound.path.segments.iter().last().unwrap();
let args = match &segment.arguments {
syn::PathArguments::Parenthesized(args) => args,
_ => panic!("Expected Fn trait arguments"),
};
args.inputs.len()
}
pub fn callback_helpers(item: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(item as syn::ImplItemFn);
let (fn_bound, cb_arg_name, type_ident) =
find_fn_generic(&input.sig).expect("Could not find function-like generic parameter.");
let mut fn_bound = fn_bound.clone();
fn_bound.path.segments.last_mut().unwrap().ident = syn::Ident::new("Fn", Span::call_site());
let dyn_type = bound_to_dyn(&fn_bound);
let fn_ident = &input.sig.ident;
let fn_name = format!("{}", fn_ident);
let (maker_name, setter_name) = match fn_name.strip_prefix("set_") {
Some(base) => (format!("{base}_cb"), format!("{fn_name}_cb")),
None => (format!("{fn_name}_cb"), format!("{fn_name}_with_cb")),
};
let maker_ident = syn::Ident::new(&maker_name, Span::call_site());
let maker_generics = find_dependent_generics(&input.sig, &fn_bound);
let maker_bounds = quote!();
let maker_doc = format!(
r#"Helper method to store a callback of the correct type for [`Self::{fn_ident}`].
This is mostly useful when using this view in a template."#
);
let maker_fn = quote! {
#[doc = #maker_doc]
pub fn #maker_ident
<F: #fn_bound + 'static + Send + Sync, #maker_generics>
( #cb_arg_name: F ) -> #dyn_type
where #maker_bounds
{
::std::sync::Arc::new(#cb_arg_name)
}
};
let setter_ident = syn::Ident::new(&setter_name, Span::call_site());
let return_type = &input.sig.output;
let args_signature: Vec<_> = input
.sig
.inputs
.iter()
.map(|arg| {
if let syn::FnArg::Typed(arg) = arg {
if let syn::Pat::Ident(ident) = arg.pat.as_ref() {
if &ident.ident == cb_arg_name {
return quote! { #cb_arg_name: #dyn_type };
}
}
}
quote! { #arg }
})
.collect();
let args_signature = quote! {
#(#args_signature),*
};
let n_args = get_arity(&fn_bound);
let cb_args: Vec<_> = (0..n_args).map(|i| quote::format_ident!("a{i}")).collect();
let cb_args = quote! {
#(#cb_args),*
};
let args_call: Vec<_> = input
.sig
.inputs
.iter()
.map(|arg| match arg {
syn::FnArg::Receiver(_) => {
quote! { self }
}
syn::FnArg::Typed(arg) => {
if let syn::Pat::Ident(ident) = arg.pat.as_ref() {
if &ident.ident == cb_arg_name {
return quote! {
move |#cb_args| { (*#cb_arg_name)(#cb_args) }
};
}
}
let pat = &arg.pat;
quote! { #pat }
}
})
.collect();
let args_call = quote! {
#(#args_call),*
};
let generics: Vec<_> = input
.sig
.generics
.params
.iter()
.filter(|param| {
if let syn::GenericParam::Type(type_param) = param {
Some(&type_param.ident) != type_ident
} else {
true
}
})
.collect();
let generics = quote! {
< #(#generics),* >
};
let where_clause: Vec<_> = input
.sig
.generics
.where_clause
.as_ref()
.map(|where_clause| {
where_clause
.predicates
.iter()
.filter(|predicate| {
if let syn::WherePredicate::Type(syn::PredicateType {
bounded_ty: syn::Type::Path(path),
..
}) = predicate
{
type_ident.map_or(false, |ident| !path.path.is_ident(ident))
} else {
false
}
})
.collect()
})
.unwrap_or_else(Vec::new);
let where_clause = quote! {
where #(#where_clause),*
};
let setter_doc = format!(
r#"Helper method to call [`Self::{fn_ident}`] with a variable from a config.
This is mostly useful when writing a cursive blueprint for this view."#
);
let setter_fn = quote! {
#[doc = #setter_doc]
pub fn #setter_ident #generics (#args_signature) #return_type #where_clause {
Self::#fn_ident(#args_call)
}
};
TokenStream::from(quote! {
#input
#maker_fn
#setter_fn
})
}