use std::{env::var, sync::OnceLock};
use heck::{ToLowerCamelCase, ToSnakeCase};
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
use quote::{format_ident, quote, quote_spanned};
use syn::{
ext::IdentExt,
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
spanned::Spanned,
Expr, ExprLit, FnArg, ItemFn, Lit, Meta, Pat, Token, Visibility,
};
use tauri_utils::acl::REMOVE_UNUSED_COMMANDS_ENV_VAR;
#[allow(clippy::large_enum_variant)]
enum WrapperAttributeKind {
Meta(Meta),
Async,
}
impl Parse for WrapperAttributeKind {
fn parse(input: ParseStream) -> syn::Result<Self> {
match input.parse::<Meta>() {
Ok(m) => Ok(Self::Meta(m)),
Err(e) => match input.parse::<Token![async]>() {
Ok(_) => Ok(Self::Async),
Err(_) => Err(e),
},
}
}
}
struct WrapperAttributes {
root: TokenStream2,
execution_context: ExecutionContext,
argument_case: ArgumentCase,
rename: RenamePolicy,
}
impl Parse for WrapperAttributes {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut wrapper_attributes = WrapperAttributes {
root: quote!(::tauri),
execution_context: ExecutionContext::Blocking,
argument_case: ArgumentCase::Camel,
rename: RenamePolicy::Keep,
};
let attrs = Punctuated::<WrapperAttributeKind, Token![,]>::parse_terminated(input)?;
for attr in attrs {
match attr {
WrapperAttributeKind::Meta(Meta::List(_)) => {
return Err(syn::Error::new(input.span(), "unexpected list input"));
}
WrapperAttributeKind::Meta(Meta::NameValue(v)) => {
if v.path.is_ident("rename_all") {
if let Expr::Lit(ExprLit {
lit: Lit::Str(s),
attrs: _,
}) = v.value
{
wrapper_attributes.argument_case = match s.value().as_str() {
"snake_case" => ArgumentCase::Snake,
"camelCase" => ArgumentCase::Camel,
_ => {
return Err(syn::Error::new(
s.span(),
"expected \"camelCase\" or \"snake_case\"",
))
}
};
}
} else if v.path.is_ident("rename") {
if let Expr::Lit(ExprLit {
lit: Lit::Str(s), ..
}) = v.value
{
let lit = s.value();
wrapper_attributes.rename = RenamePolicy::Rename(quote!(#lit));
} else {
return Err(syn::Error::new(
v.span(),
"expected string literal for rename",
));
}
} else if v.path.is_ident("root") {
if let Expr::Lit(ExprLit {
lit: Lit::Str(s),
attrs: _,
}) = v.value
{
let lit = s.value();
wrapper_attributes.root = if lit == "crate" {
quote!($crate)
} else {
let ident = Ident::new(&lit, Span::call_site());
quote!(#ident)
};
}
}
}
WrapperAttributeKind::Meta(Meta::Path(_)) => {
return Err(syn::Error::new(
input.span(),
"unexpected input, expected one of `rename_all`, `rename`, `root`, `async`",
));
}
WrapperAttributeKind::Async => {
wrapper_attributes.execution_context = ExecutionContext::Async;
}
}
}
Ok(wrapper_attributes)
}
}
enum ExecutionContext {
Async,
Blocking,
}
#[derive(Copy, Clone)]
enum ArgumentCase {
Snake,
Camel,
}
enum RenamePolicy {
Keep,
Rename(TokenStream2),
}
struct Invoke {
message: Ident,
resolver: Ident,
acl: Ident,
}
pub fn wrapper(attributes: TokenStream, item: TokenStream) -> TokenStream {
let mut attrs = parse_macro_input!(attributes as WrapperAttributes);
let function = parse_macro_input!(item as ItemFn);
let wrapper = super::format_command_wrapper(&function.sig.ident);
let visibility = &function.vis;
if function.sig.asyncness.is_some() {
attrs.execution_context = ExecutionContext::Async;
}
let maybe_macro_export = match &function.vis {
Visibility::Public(_) | Visibility::Restricted(_) => {
quote!(#[macro_export])
}
_ => TokenStream2::default(),
};
let invoke = Invoke {
message: format_ident!("__tauri_message__"),
resolver: format_ident!("__tauri_resolver__"),
acl: format_ident!("__tauri_acl__"),
};
let mut async_command_check = TokenStream2::new();
if function.sig.asyncness.is_some() {
let mut ref_argument_span = None;
for arg in &function.sig.inputs {
if let syn::FnArg::Typed(pat) = arg {
match &*pat.ty {
syn::Type::Reference(_) => {
ref_argument_span = Some(pat.span());
}
syn::Type::Path(path) => {
let last = path.path.segments.last().unwrap();
if let syn::PathArguments::AngleBracketed(args) = &last.arguments {
if args
.args
.iter()
.any(|arg| matches!(arg, syn::GenericArgument::Lifetime(_)))
{
ref_argument_span = Some(pat.span());
}
}
}
_ => {}
}
if let Some(span) = ref_argument_span {
if let syn::ReturnType::Type(_, return_type) = &function.sig.output {
let diagnostic = if is_rustc_at_least(1, 78) {
quote!(#[diagnostic::on_unimplemented(message = "async commands that contain references as inputs must return a `Result`")])
} else {
quote!()
};
async_command_check = quote_spanned! {return_type.span() =>
#[allow(unreachable_code, clippy::diverging_sub_expression, clippy::used_underscore_binding)]
const _: () = if false {
#diagnostic
trait AsyncCommandMustReturnResult {}
impl<A, B> AsyncCommandMustReturnResult for ::std::result::Result<A, B> {}
let _check: #return_type = unreachable!();
let _: &dyn AsyncCommandMustReturnResult = &_check;
};
};
} else {
return quote_spanned! {
span => compile_error!("async commands that contain references as inputs must return a `Result`");
}.into();
}
}
}
}
}
let plugin_name = var("CARGO_PKG_NAME")
.expect("missing `CARGO_PKG_NAME` environment variable")
.strip_prefix("tauri-plugin-")
.map(|name| quote!(::core::option::Option::Some(#name)))
.unwrap_or_else(|| quote!(::core::option::Option::None));
let body = match attrs.execution_context {
ExecutionContext::Async => body_async(&plugin_name, &function, &invoke, &attrs)
.unwrap_or_else(syn::Error::into_compile_error),
ExecutionContext::Blocking => body_blocking(&plugin_name, &function, &invoke, &attrs)
.unwrap_or_else(syn::Error::into_compile_error),
};
let Invoke {
message,
resolver,
acl,
} = invoke;
let root = attrs.root;
let kind = match attrs.execution_context {
ExecutionContext::Async if function.sig.asyncness.is_none() => "sync_threadpool",
ExecutionContext::Async => "async",
ExecutionContext::Blocking => "sync",
};
let loc = function.span().start();
let line = loc.line;
let col = loc.column;
let maybe_span = if cfg!(feature = "tracing") {
quote!({
let _span = tracing::debug_span!(
"ipc::request::handler",
cmd = #message.command(),
kind = #kind,
loc.line = #line,
loc.col = #col,
is_internal = false,
)
.entered();
})
} else {
quote!()
};
let maybe_allow_unused = if var(REMOVE_UNUSED_COMMANDS_ENV_VAR).is_ok() {
quote!(#[allow(unused)])
} else {
TokenStream2::default()
};
let command_name_macro_ident = format_ident!("__tauri_command_name_{}", function.sig.ident);
let command_name_value = if let RenamePolicy::Rename(ref rename) = attrs.rename {
quote!(#rename)
} else {
let ident = &function.sig.ident;
quote!(stringify!(#ident))
};
quote!(
#async_command_check
#maybe_allow_unused
#function
#maybe_allow_unused
#maybe_macro_export
#[doc(hidden)]
macro_rules! #command_name_macro_ident {
() => {
#command_name_value
};
}
#maybe_allow_unused
#maybe_macro_export
#[doc(hidden)]
macro_rules! #wrapper {
($path:path, $invoke:ident) => {
{
move || {
#[allow(unused_imports)]
use #root::ipc::private::*;
#[allow(unused_variables)]
let #root::ipc::Invoke { message: #message, resolver: #resolver, acl: #acl } = $invoke;
#maybe_span
#body
}
}()
};
}
#[allow(unused_imports)]
#visibility use {#wrapper, #command_name_macro_ident};
)
.into()
}
fn body_async(
plugin_name: &TokenStream2,
function: &ItemFn,
invoke: &Invoke,
attributes: &WrapperAttributes,
) -> syn::Result<TokenStream2> {
let Invoke {
message,
resolver,
acl,
} = invoke;
parse_args(plugin_name, function, message, acl, attributes).map(|args| {
#[cfg(feature = "tracing")]
quote! {
use tracing::Instrument;
let span = tracing::debug_span!("ipc::request::run");
#resolver.respond_async_serialized(async move {
let result = $path(#(#args?),*);
let kind = (&result).async_kind();
kind.future(result).await
}
.instrument(span));
return true;
}
#[cfg(not(feature = "tracing"))]
quote! {
#resolver.respond_async_serialized(async move {
let result = $path(#(#args?),*);
let kind = (&result).async_kind();
kind.future(result).await
});
return true;
}
})
}
fn body_blocking(
plugin_name: &TokenStream2,
function: &ItemFn,
invoke: &Invoke,
attributes: &WrapperAttributes,
) -> syn::Result<TokenStream2> {
let Invoke {
message,
resolver,
acl,
} = invoke;
let args = parse_args(plugin_name, function, message, acl, attributes)?;
let match_body = quote!({
Ok(arg) => arg,
Err(err) => { #resolver.invoke_error(err); return true },
});
let maybe_span = if cfg!(feature = "tracing") {
quote!(let _span = tracing::debug_span!("ipc::request::run").entered();)
} else {
quote!()
};
Ok(quote! {
#maybe_span
let result = $path(#(match #args #match_body),*);
let kind = (&result).blocking_kind();
kind.block(result, #resolver);
return true;
})
}
fn parse_args(
plugin_name: &TokenStream2,
function: &ItemFn,
message: &Ident,
acl: &Ident,
attributes: &WrapperAttributes,
) -> syn::Result<Vec<TokenStream2>> {
function
.sig
.inputs
.iter()
.map(|arg| {
parse_arg(
plugin_name,
&function.sig.ident,
arg,
message,
acl,
attributes,
)
})
.collect()
}
fn parse_arg(
plugin_name: &TokenStream2,
command: &Ident,
arg: &FnArg,
message: &Ident,
acl: &Ident,
attributes: &WrapperAttributes,
) -> syn::Result<TokenStream2> {
let mut arg = match arg {
FnArg::Typed(arg) => arg.pat.as_ref().clone(),
FnArg::Receiver(arg) => {
return Err(syn::Error::new(
arg.span(),
"unable to use self as a command function parameter",
))
}
};
let mut key = match &mut arg {
Pat::Ident(arg) => arg.ident.unraw().to_string(),
Pat::Wild(_) => "".into(), Pat::Struct(s) => super::path_to_command(&mut s.path).ident.to_string(),
Pat::TupleStruct(s) => super::path_to_command(&mut s.path).ident.to_string(),
err => {
return Err(syn::Error::new(
err.span(),
"only named, wildcard, struct, and tuple struct arguments allowed",
))
}
};
if key == "self" {
return Err(syn::Error::new(
key.span(),
"unable to use self as a command function parameter",
));
}
match attributes.argument_case {
ArgumentCase::Camel => {
key = key.to_lower_camel_case();
}
ArgumentCase::Snake => {
key = key.to_snake_case();
}
}
let root = &attributes.root;
let command_name = if let RenamePolicy::Rename(r) = &attributes.rename {
quote!(stringify!(#r))
} else {
quote!(stringify!(#command))
};
Ok(quote!(#root::ipc::CommandArg::from_command(
#root::ipc::CommandItem {
plugin: #plugin_name,
name: #command_name,
key: #key,
message: &#message,
acl: &#acl,
}
)))
}
fn is_rustc_at_least(major: u32, minor: u32) -> bool {
let version = rustc_version();
version.0 >= major && version.1 >= minor
}
fn rustc_version() -> &'static (u32, u32) {
static RUSTC_VERSION: OnceLock<(u32, u32)> = OnceLock::new();
RUSTC_VERSION.get_or_init(|| {
cross_command("rustc")
.arg("-V")
.output()
.ok()
.and_then(|o| {
let version = String::from_utf8_lossy(&o.stdout)
.trim()
.split(' ')
.nth(1)
.unwrap_or_default()
.split('.')
.take(2)
.flat_map(|p| p.parse::<u32>().ok())
.collect::<Vec<_>>();
version
.first()
.and_then(|major| version.get(1).map(|minor| (*major, *minor)))
})
.unwrap_or((1, 0))
})
}
fn cross_command(bin: &str) -> std::process::Command {
#[cfg(target_os = "windows")]
let cmd = {
let mut cmd = std::process::Command::new("cmd");
cmd.arg("/c").arg(bin);
cmd
};
#[cfg(not(target_os = "windows"))]
let cmd = std::process::Command::new(bin);
cmd
}