use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use syn::{FnArg, ItemFn, Pat, PatType, parse_macro_input, punctuated::Punctuated, token::Comma};
fn args_to_split(inputs: &Punctuated<FnArg, Comma>) -> (&PatType, Vec<&FnArg>) {
assert!(
!inputs.is_empty(),
"Piperize function cannot take no arguments"
);
assert!(
!inputs.iter().any(|i| match i {
FnArg::Receiver(_) => true,
FnArg::Typed(_) => false,
}),
"Function arguments cannot be \"self\""
);
let Some(FnArg::Typed(first_arg)) = inputs.first() else {
panic!(
"Invalid function arguments\n example of a valid function signature: fn foo(a: i32) -> i32"
)
};
let rest_args = inputs.iter().skip(1).collect();
(first_arg, rest_args)
}
#[proc_macro_attribute]
pub fn piperize(_: TokenStream, item: TokenStream) -> TokenStream {
let input_fn = parse_macro_input!(item as ItemFn);
let fn_name = &input_fn.sig.ident;
let output = &input_fn.sig.output;
let visibility = &input_fn.vis;
let trait_name = to_piperize_trait_name(fn_name);
let arg_split = args_to_split(&input_fn.sig.inputs);
let first_arg = arg_split.0;
let first_arg_type = &first_arg.ty;
if let syn::Type::ImplTrait(_) = first_arg_type.as_ref() {
todo!(
"Cannot have impl Trait as first argument's type yet\n instead use a generic type parameter to constrain your function signature:\n---\n fn foo(a: impl SomeTrait) ==> fn foo<T: SomeTrait>(a: T)\n---"
)
}
let rest_args = arg_split.1;
let mut rest = Punctuated::<&FnArg, Comma>::new();
let mut rest_inputs = Punctuated::<&Pat, Comma>::new();
for rest_arg in rest_args {
rest.push(rest_arg);
match rest_arg {
FnArg::Receiver(_) => unreachable!("shouldn't pass self as second arg"),
FnArg::Typed(pat_type) => {
rest_inputs.push(&pat_type.pat);
}
}
}
let generics = &input_fn.sig.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let asyncness = input_fn.sig.asyncness;
let call = if asyncness.is_some() {
quote! { #fn_name(self, #rest_inputs).await}
} else {
quote! { #fn_name(self, #rest_inputs)}
};
let expanded = quote! {
#input_fn
#visibility trait #trait_name #generics #where_clause {
#asyncness fn #fn_name(self, #rest) #output;
}
impl #impl_generics #trait_name #ty_generics for #first_arg_type #where_clause {
#asyncness fn #fn_name(self, #rest) #output {
#call
}
}
};
expanded.into()
}
fn to_piperize_trait_name(name: &syn::Ident) -> syn::Ident {
let mut name_str = name.to_string();
to_camel_case(&mut name_str);
name_str.push_str("__PiperizeTrait");
syn::Ident::new(&name_str, Span::mixed_site())
}
fn to_camel_case(s: &mut String) {
assert!(s.is_ascii(), "identifier must be valid ascii");
let bytes = unsafe { s.as_bytes_mut() };
let mut write_index = 0;
let mut capitalize_next = true;
for read_index in 0..bytes.len() {
let b = bytes[read_index];
match b {
b'a'..=b'z' => {
if capitalize_next {
bytes[write_index] = b - b'a' + b'A';
} else {
bytes[write_index] = b;
}
write_index += 1;
capitalize_next = false;
}
b'A'..=b'Z' | b'0'..=b'9' => {
bytes[write_index] = b;
write_index += 1;
capitalize_next = false;
}
_ => {
capitalize_next = true;
}
}
}
s.truncate(write_index);
}