use heck::{ToLowerCamelCase, ToSnakeCase};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, quote_spanned};
use syn::{
ext::IdentExt,
parse::{Parse, ParseStream},
parse_macro_input,
spanned::Spanned,
FnArg, Ident, ItemFn, Lit, Meta, Pat, Token, Visibility,
};
struct WrapperAttributes {
execution_context: ExecutionContext,
argument_case: ArgumentCase,
}
impl Parse for WrapperAttributes {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut wrapper_attributes = WrapperAttributes {
execution_context: ExecutionContext::Blocking,
argument_case: ArgumentCase::Camel,
};
loop {
match input.parse::<Meta>() {
Ok(Meta::List(_)) => {}
Ok(Meta::NameValue(v)) => {
if v.path.is_ident("rename_all") {
if let Lit::Str(s) = v.lit {
wrapper_attributes.argument_case = match s.value().as_str() {
"snake_case" => ArgumentCase::Snake,
"camelCase" => ArgumentCase::Camel,
_ => {
return Err(syn::Error::new(
s.span(),
"expected \"camelCase\" or \"snake_case\"",
))
}
};
}
}
}
Ok(Meta::Path(p)) => {
if p.is_ident("async") {
wrapper_attributes.execution_context = ExecutionContext::Async;
} else {
return Err(syn::Error::new(p.span(), "expected `async`"));
}
}
Err(_e) => {
break;
}
}
let lookahead = input.lookahead1();
if lookahead.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
}
Ok(wrapper_attributes)
}
}
enum ExecutionContext {
Async,
Blocking,
}
#[derive(Copy, Clone)]
enum ArgumentCase {
Snake,
Camel,
}
struct Invoke {
message: Ident,
resolver: Ident,
}
pub fn wrapper(attributes: TokenStream, item: TokenStream) -> TokenStream {
let function = parse_macro_input!(item as ItemFn);
let wrapper = super::format_command_wrapper(&function.sig.ident);
let visibility = &function.vis;
let maybe_macro_export = match &function.vis {
Visibility::Public(_) => quote!(#[macro_export]),
_ => Default::default(),
};
let invoke = Invoke {
message: format_ident!("__tauri_message__"),
resolver: format_ident!("__tauri_resolver__"),
};
let mut async_command_check = TokenStream2::new();
if function.sig.asyncness.is_some() {
let mut ref_argument_span = None;
for arg in &function.sig.inputs {
if let syn::FnArg::Typed(pat) = arg {
match &*pat.ty {
syn::Type::Reference(_) => {
ref_argument_span = Some(pat.span());
}
syn::Type::Path(path) => {
let last = path.path.segments.last().unwrap();
if let syn::PathArguments::AngleBracketed(args) = &last.arguments {
if args
.args
.iter()
.any(|arg| matches!(arg, syn::GenericArgument::Lifetime(_)))
{
ref_argument_span = Some(pat.span());
}
}
}
_ => {}
}
if let Some(span) = ref_argument_span {
if let syn::ReturnType::Type(_, return_type) = &function.sig.output {
async_command_check = quote_spanned! {return_type.span() =>
#[allow(unreachable_code, clippy::diverging_sub_expression)]
const _: () = if false {
trait AsyncCommandMustReturnResult {}
impl<A, B> AsyncCommandMustReturnResult for ::std::result::Result<A, B> {}
let _check: #return_type = unreachable!();
let _: &dyn AsyncCommandMustReturnResult = &_check;
};
};
} else {
return quote_spanned! {
span => compile_error!("async commands that contain references as inputs must return a `Result`");
}.into();
}
}
}
}
}
let (body, attributes) = syn::parse::<WrapperAttributes>(attributes)
.map(|mut attrs| {
if function.sig.asyncness.is_some() {
attrs.execution_context = ExecutionContext::Async;
}
attrs
})
.and_then(|attrs| {
let body = match attrs.execution_context {
ExecutionContext::Async => body_async(&function, &invoke, attrs.argument_case),
ExecutionContext::Blocking => body_blocking(&function, &invoke, attrs.argument_case),
};
body.map(|b| (b, Some(attrs)))
})
.unwrap_or_else(|e| (syn::Error::into_compile_error(e), None));
let Invoke { message, resolver } = invoke;
let kind = match attributes.as_ref().map(|a| &a.execution_context) {
Some(ExecutionContext::Async) if function.sig.asyncness.is_none() => "sync_threadpool",
Some(ExecutionContext::Async) => "async",
Some(ExecutionContext::Blocking) => "sync",
_ => "sync",
};
let loc = function.span().start();
let line = loc.line;
let col = loc.column;
let maybe_span = if cfg!(feature = "tracing") {
quote!({
let _span = tracing::debug_span!(
"ipc::request::handler",
cmd = #message.command(),
kind = #kind,
loc.line = #line,
loc.col = #col,
is_internal = false,
)
.entered();
})
} else {
quote!()
};
quote!(
#async_command_check
#function
#maybe_macro_export
#[doc(hidden)]
macro_rules! #wrapper {
($path:path, $invoke:ident) => {{
#[allow(unused_imports)]
use ::tauri::command::private::*;
#[allow(unused_variables)]
let ::tauri::Invoke { message: #message, resolver: #resolver } = $invoke;
#maybe_span
#body
}};
}
#[allow(unused_imports)]
#visibility use #wrapper;
)
.into()
}
fn body_async(function: &ItemFn, invoke: &Invoke, case: ArgumentCase) -> syn::Result<TokenStream2> {
let Invoke { message, resolver } = invoke;
parse_args(function, message, case).map(|args| {
#[cfg(feature = "tracing")]
quote! {
use tracing::Instrument;
let span = tracing::debug_span!("ipc::request::run");
#resolver.respond_async_serialized(async move {
let result = $path(#(#args?),*);
let kind = (&result).async_kind();
kind.future(result).await
}
.instrument(span));
}
#[cfg(not(feature = "tracing"))]
quote! {
#resolver.respond_async_serialized(async move {
let result = $path(#(#args?),*);
let kind = (&result).async_kind();
kind.future(result).await
});
}
})
}
fn body_blocking(
function: &ItemFn,
invoke: &Invoke,
case: ArgumentCase,
) -> syn::Result<TokenStream2> {
let Invoke { message, resolver } = invoke;
let args = parse_args(function, message, case)?;
let match_body = quote!({
Ok(arg) => arg,
Err(err) => return #resolver.invoke_error(err),
});
let maybe_span = if cfg!(feature = "tracing") {
quote!(let _span = tracing::debug_span!("ipc::request::run").entered();)
} else {
quote!()
};
Ok(quote! {
#maybe_span
let result = $path(#(match #args #match_body),*);
let kind = (&result).blocking_kind();
kind.block(result, #resolver);
})
}
fn parse_args(
function: &ItemFn,
message: &Ident,
case: ArgumentCase,
) -> syn::Result<Vec<TokenStream2>> {
function
.sig
.inputs
.iter()
.map(|arg| parse_arg(&function.sig.ident, arg, message, case))
.collect()
}
fn parse_arg(
command: &Ident,
arg: &FnArg,
message: &Ident,
case: ArgumentCase,
) -> syn::Result<TokenStream2> {
let mut arg = match arg {
FnArg::Typed(arg) => arg.pat.as_ref().clone(),
FnArg::Receiver(arg) => {
return Err(syn::Error::new(
arg.span(),
"unable to use self as a command function parameter",
))
}
};
let mut key = match &mut arg {
Pat::Ident(arg) => arg.ident.unraw().to_string(),
Pat::Wild(_) => "".into(), Pat::Struct(s) => super::path_to_command(&mut s.path).ident.to_string(),
Pat::TupleStruct(s) => super::path_to_command(&mut s.path).ident.to_string(),
err => {
return Err(syn::Error::new(
err.span(),
"only named, wildcard, struct, and tuple struct arguments allowed",
))
}
};
if key == "self" {
return Err(syn::Error::new(
key.span(),
"unable to use self as a command function parameter",
));
}
match case {
ArgumentCase::Camel => {
key = key.to_lower_camel_case();
}
ArgumentCase::Snake => {
key = key.to_snake_case();
}
}
Ok(quote!(::tauri::command::CommandArg::from_command(
::tauri::command::CommandItem {
name: stringify!(#command),
key: #key,
message: &#message,
}
)))
}