use proc_macro::TokenStream;
use quote::{ToTokens, quote};
use syn::spanned::Spanned;
use syn::{
FnArg, Ident, ItemFn, Pat, PatType, ReturnType, Signature, Type, TypePath, parse_macro_input,
};
fn extract_previous_info(sig: &Signature) -> syn::Result<(Pat, TypePath)> {
if sig.inputs.len() != 1 {
return Err(syn::Error::new(
sig.inputs.span(),
"Chain function must have exactly one parameter",
));
}
let arg = &sig.inputs[0];
match arg {
FnArg::Typed(PatType { pat, ty, .. }) => {
let param_pat = (**pat).clone();
match &**ty {
Type::Path(type_path) => Ok((param_pat, type_path.clone())),
_ => Err(syn::Error::new(
ty.span(),
"Parameter type must be a type path",
)),
}
}
FnArg::Receiver(_) => Err(syn::Error::new(
arg.span(),
"Chain function cannot have self parameter",
)),
}
}
fn extract_return_type(sig: &Signature) -> syn::Result<TypePath> {
match &sig.output {
ReturnType::Type(_, ty) => match &**ty {
Type::Path(type_path) => Ok(type_path.clone()),
_ => Err(syn::Error::new(
ty.span(),
"Return type must be a type path",
)),
},
ReturnType::Default => Err(syn::Error::new(
sig.span(),
"Chain function must have a return type",
)),
}
}
pub fn chain_attr(attr: TokenStream, item: TokenStream) -> TokenStream {
let (group_name, use_crate_prefix) = if attr.is_empty() {
(
Ident::new("ThisProgram", proc_macro2::Span::call_site()),
true,
)
} else {
(parse_macro_input!(attr as Ident), false)
};
let input_fn = parse_macro_input!(item as ItemFn);
#[cfg(feature = "async")]
let is_async_fn = input_fn.sig.asyncness.is_some();
#[cfg(not(feature = "async"))]
{
if input_fn.sig.asyncness.is_some() {
return syn::Error::new(
input_fn.sig.span(),
"Chain function cannot be async when async feature is disabled",
)
.to_compile_error()
.into();
}
}
let (prev_param, previous_type) = match extract_previous_info(&input_fn.sig) {
Ok(info) => info,
Err(e) => return e.to_compile_error().into(),
};
let return_type = match extract_return_type(&input_fn.sig) {
Ok(ty) => ty,
Err(e) => return e.to_compile_error().into(),
};
if return_type.path.segments.last().unwrap().ident != "NextProcess" {
return syn::Error::new(
return_type.span(),
"Return type must be 'mingling::marker::NextProcess'",
)
.to_compile_error()
.into();
}
let fn_body = &input_fn.block;
let mut fn_attrs = input_fn.attrs.clone();
fn_attrs.retain(|attr| !attr.path().is_ident("chain"));
let vis = &input_fn.vis;
let fn_name = &input_fn.sig.ident;
let pascal_case_name = just_fmt::pascal_case!(fn_name.to_string());
let struct_name = Ident::new(&pascal_case_name, fn_name.span());
#[cfg(feature = "async")]
let proc_fn = if is_async_fn {
quote! {
async fn proc(#prev_param: Self::Previous) ->
::mingling::ChainProcess<ThisProgram>
{
let _ = NextProcess;
#fn_name(#prev_param).await.into()
}
}
} else {
quote! {
async fn proc(#prev_param: Self::Previous) ->
::mingling::ChainProcess<ThisProgram>
{
let _ = NextProcess;
#fn_name(#prev_param).into()
}
}
};
#[cfg(feature = "async")]
let origin_proc_fn = if is_async_fn {
quote! {
#(#fn_attrs)*
#vis async fn #fn_name(#prev_param: #previous_type)
-> impl Into<::mingling::ChainProcess<#group_name>>
{
#fn_body
}
}
} else {
quote! {
#(#fn_attrs)*
#vis fn #fn_name(#prev_param: #previous_type)
-> impl Into<::mingling::ChainProcess<#group_name>>
{
#fn_body
}
}
};
#[cfg(not(feature = "async"))]
let proc_fn = quote! {
fn proc(#prev_param: Self::Previous) ->
::mingling::ChainProcess<ThisProgram>
{
let _ = NextProcess;
#fn_name(#prev_param).into()
}
};
#[cfg(not(feature = "async"))]
let origin_proc_fn = quote! {
#(#fn_attrs)*
#vis fn #fn_name(#prev_param: #previous_type)
-> impl Into<::mingling::ChainProcess<#group_name>>
{
#fn_body
}
};
let expanded = if use_crate_prefix {
quote! {
#(#fn_attrs)*
#[doc(hidden)]
#vis struct #struct_name;
::mingling::macros::register_chain!(#previous_type, #struct_name);
impl ::mingling::Chain<ThisProgram> for #struct_name {
type Previous = #previous_type;
#proc_fn
}
#origin_proc_fn
}
} else {
quote! {
#(#fn_attrs)*
#vis struct #struct_name;
::mingling::macros::register_chain!(#previous_type, #struct_name);
impl ::mingling::Chain<#group_name> for #struct_name {
type Previous = #previous_type;
#proc_fn
}
#origin_proc_fn
}
};
expanded.into()
}
pub fn build_chain_arm(struct_name: &Ident, previous_type: &TypePath) -> proc_macro2::TokenStream {
quote! {
#struct_name => #previous_type,
}
}
pub fn build_chain_exist_arm(previous_type: &TypePath) -> proc_macro2::TokenStream {
quote! {
Self::#previous_type => true,
}
}
pub fn register_chain(input: TokenStream) -> TokenStream {
let input_parsed = syn::parse_macro_input!(input with syn::punctuated::Punctuated<syn::Expr, syn::Token![,]>::parse_terminated);
if input_parsed.len() != 2 {
return syn::Error::new(
input_parsed.span(),
"Expected exactly two comma-separated arguments: `PreviousType, StructName`",
)
.to_compile_error()
.into();
}
let previous_type_expr = &input_parsed[0];
let struct_name_expr = &input_parsed[1];
let previous_type = match syn::parse2::<TypePath>(previous_type_expr.to_token_stream()) {
Ok(ty) => ty,
Err(e) => return e.to_compile_error().into(),
};
let struct_name = match syn::parse2::<syn::Ident>(struct_name_expr.to_token_stream()) {
Ok(ident) => ident,
Err(e) => return e.to_compile_error().into(),
};
let chain_entry = build_chain_arm(&struct_name, &previous_type);
let chain_exist_entry = build_chain_exist_arm(&previous_type);
let mut chains = crate::CHAINS.lock().unwrap();
let mut chain_exist = crate::CHAINS_EXIST.lock().unwrap();
let chain_entry_str = chain_entry.to_string();
let chain_exist_entry_str = chain_exist_entry.to_string();
chains.insert(chain_entry_str);
chain_exist.insert(chain_exist_entry_str);
quote! {}.into()
}