use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use syn::{Ident, ItemFn, parse_macro_input};
pub fn inner_com_thread(attr: TokenStream, item: TokenStream) -> TokenStream {
let model_kind_str = if attr.is_empty() || attr.to_string().trim().is_empty() {
"STA"
} else {
let ident: syn::Ident = match syn::parse(attr.clone()) {
Ok(i) => i,
Err(_) => {
return syn::Error::new(
proc_macro2::Span::call_site(),
"invalid attribute syntax; expected STA or MTA without quotes",
)
.to_compile_error()
.into();
}
};
match ident.to_string().to_uppercase().as_str() {
"MTA" | "MULTI" | "MULTITHREADED" => "MTA",
"STA" | "APARTMENT" | "APARTMENTTHREADED" => "STA",
_ => {
return syn::Error::new_spanned(ident, "invalid COM model, expected STA or MTA")
.to_compile_error()
.into();
}
}
};
let func = parse_macro_input!(item as ItemFn);
let vis = &func.vis;
let sig = &func.sig;
let block = &func.block;
let inputs = &sig.inputs;
let output = &sig.output;
let arg_types: Vec<_> = inputs
.iter()
.filter_map(|arg| match arg {
syn::FnArg::Typed(pat) => Some(&*pat.ty),
_ => None,
})
.collect();
let is_async = sig.asyncness.is_some();
let mut assert_bounds = Vec::new();
for (idx, arg_type) in arg_types.iter().enumerate() {
let assert_fn_name = Ident::new(
&format!("_assert_param_{}_is_send_static", idx),
Span::call_site(),
);
assert_bounds.push(quote! {
const fn #assert_fn_name() {
const fn require<T: Send + 'static>() {}
const fn check() { require::<#arg_type>(); }
}
let _ = #assert_fn_name;
});
}
let ret_type_for_assert = match output {
syn::ReturnType::Default => quote! { () },
syn::ReturnType::Type(_, ty) => quote! { #ty },
};
let assert_ret_fn = Ident::new("_assert_return_is_send_static", Span::call_site());
assert_bounds.push(quote! {
const fn #assert_ret_fn() {
const fn require<T: Send + 'static>() {}
const fn check() { require::<#ret_type_for_assert>(); }
}
let _ = #assert_ret_fn;
});
let compile_time_checks = if assert_bounds.is_empty() {
quote! {}
} else {
quote! { #(#assert_bounds)* }
};
let runtime_model_token = if model_kind_str == "MTA" {
quote! { ::callcomapi::__runtime::ComModel::MTA }
} else {
quote! { ::callcomapi::__runtime::ComModel::STA }
};
let expanded = if is_async {
quote! {
#vis #sig {
#compile_time_checks
::callcomapi::__runtime::call_async(#runtime_model_token, move || {
::callcomapi::__runtime::block_on(async move { #block })
}).await
}
}
} else {
quote! {
#vis #sig {
#compile_time_checks
::callcomapi::__runtime::call_sync(#runtime_model_token, move || { (|| #block)() })
}
}
};
expanded.into()
}