use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use proc_macro_crate::{crate_name, FoundCrate};
use quote::{format_ident, quote};
use syn::{
parse::Parse, parse::ParseStream, punctuated::Punctuated, FnArg, ItemFn, Meta, Pat, Token,
};
struct InstrumentArgs {
metas: Punctuated<Meta, Token![,]>,
}
impl Parse for InstrumentArgs {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
Ok(Self {
metas: Punctuated::parse_terminated(input)?,
})
}
}
#[proc_macro_attribute]
pub fn instrument_ipc(attr: TokenStream, item: TokenStream) -> TokenStream {
expand_instrument_ipc(attr.into(), item.into()).into()
}
#[proc_macro_attribute]
pub fn auditaur_command(attr: TokenStream, item: TokenStream) -> TokenStream {
expand_auditaur_command(attr.into(), item.into()).into()
}
fn expand_instrument_ipc(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
let args = match syn::parse2::<InstrumentArgs>(attr) {
Ok(args) => args,
Err(error) => return error.to_compile_error(),
};
let auditaur_crate = dependency_crate_path("tauri-plugin-auditaur", "tauri_plugin_auditaur");
let tracing_crate = dependency_crate_path("tracing", "tracing");
let traceparent_field = quote! {
traceparent = #auditaur_crate::ipc_traceparent(auditaur_trace_context.as_ref())
};
let instrument_args = match merge_instrument_args(
args,
quote!(auditaur_trace_context),
quote!(#traceparent_field),
) {
Ok(args) => args,
Err(error) => return error.to_compile_error(),
};
quote! {
#[#tracing_crate::instrument(#(#instrument_args),*)]
#item
}
}
fn expand_auditaur_command(attr: TokenStream2, item: TokenStream2) -> TokenStream2 {
let args = match syn::parse2::<InstrumentArgs>(attr) {
Ok(args) => args,
Err(error) => return error.to_compile_error(),
};
let mut function = match syn::parse2::<ItemFn>(item) {
Ok(function) => function,
Err(error) => return error.to_compile_error(),
};
let auditaur_crate = dependency_crate_path("tauri-plugin-auditaur", "tauri_plugin_auditaur");
let tauri_crate = dependency_crate_path("tauri", "tauri");
let tracing_crate = dependency_crate_path("tracing", "tracing");
if function_has_argument(&function, "auditaur_trace_context") {
return syn::Error::new_spanned(
&function.sig.ident,
"`#[tauri_plugin_auditaur::auditaur_command]` reserves the `auditaur_trace_context` argument; remove it or use `#[tauri::command]` with `#[tauri_plugin_auditaur::instrument_ipc]` instead",
)
.to_compile_error();
}
let request_ident = unique_argument_ident(&function, "auditaur_request");
let request_arg: FnArg = syn::parse_quote! {
#request_ident: #tauri_crate::ipc::Request<'_>
};
function.sig.inputs.push(request_arg);
let context_arg: FnArg = syn::parse_quote! {
auditaur_trace_context: Option<#auditaur_crate::IpcTraceContext>
};
function.sig.inputs.push(context_arg);
let traceparent_field = quote! {
traceparent = #auditaur_crate::ipc_traceparent_from_request_or_context(
&#request_ident,
auditaur_trace_context.as_ref()
)
};
let injected_skip_args = quote!(#request_ident, auditaur_trace_context);
let instrument_args = match merge_instrument_args(
args,
quote!(#injected_skip_args),
quote!(#traceparent_field),
) {
Ok(args) => args,
Err(error) => return error.to_compile_error(),
};
quote! {
#[#tauri_crate::command]
#[#tracing_crate::instrument(#(#instrument_args),*)]
#function
}
}
fn merge_instrument_args(
args: InstrumentArgs,
injected_skip_arg: TokenStream2,
traceparent_field: TokenStream2,
) -> syn::Result<Vec<TokenStream2>> {
let mut instrument_args = Vec::new();
let mut skip_args = None;
let mut fields_args = None;
let mut skip_all = false;
for meta in args.metas {
if meta.path().is_ident("skip") {
let Meta::List(list) = meta else {
return Err(syn::Error::new_spanned(meta, "expected skip(...)"));
};
let existing = list.tokens;
skip_args = Some(if existing.is_empty() {
quote!(#injected_skip_arg)
} else if token_stream_mentions(&existing, &injected_skip_arg.to_string()) {
quote!(#existing)
} else {
quote!(#existing, #injected_skip_arg)
});
} else if meta.path().is_ident("skip_all") {
skip_all = true;
instrument_args.push(quote!(#meta));
} else if meta.path().is_ident("fields") {
let Meta::List(list) = meta else {
return Err(syn::Error::new_spanned(meta, "expected fields(...)"));
};
let existing = list.tokens;
fields_args = Some(if existing.is_empty() {
traceparent_field.clone()
} else if token_stream_mentions(&existing, "traceparent") {
quote!(#existing)
} else {
quote!(#existing, #traceparent_field)
});
} else {
instrument_args.push(quote!(#meta));
}
}
let skip_args = skip_args.unwrap_or_else(|| quote!(#injected_skip_arg));
let fields_args = fields_args.unwrap_or_else(|| quote!(#traceparent_field));
if !skip_all {
instrument_args.push(quote!(skip(#skip_args)));
}
instrument_args.push(quote!(fields(#fields_args)));
Ok(instrument_args)
}
fn unique_argument_ident(function: &ItemFn, base: &str) -> syn::Ident {
let mut candidate = format_ident!("{base}");
let mut suffix = 2;
while function_has_argument(function, &candidate.to_string()) {
candidate = format_ident!("{base}_{suffix}");
suffix += 1;
}
candidate
}
fn function_has_argument(function: &ItemFn, name: &str) -> bool {
function.sig.inputs.iter().any(|arg| {
matches!(
arg,
FnArg::Typed(pat_type)
if matches!(pat_type.pat.as_ref(), Pat::Ident(ident) if ident.ident == name)
)
})
}
fn token_stream_mentions(tokens: &TokenStream2, needle: &str) -> bool {
tokens
.to_string()
.split_whitespace()
.any(|part| part == needle)
}
fn dependency_crate_path(package_name: &str, fallback_name: &str) -> TokenStream2 {
match crate_name(package_name) {
Ok(FoundCrate::Itself) => quote!(crate),
Ok(FoundCrate::Name(name)) => {
let ident = format_ident!("{}", name);
quote!(::#ident)
}
Err(_) => {
let ident = format_ident!("{}", fallback_name);
quote!(::#ident)
}
}
}
#[cfg(test)]
mod tests {
use super::{expand_auditaur_command, expand_instrument_ipc};
use quote::quote;
#[test]
fn adds_traceparent_field_and_context_skip() {
let expanded = expand_instrument_ipc(
quote!(),
quote! {
fn load_user(id: String, auditaur_trace_context: Option<IpcTraceContext>) {}
},
)
.to_string();
assert!(expanded.contains("skip (auditaur_trace_context)"));
assert!(expanded.contains("fields (traceparent ="));
assert!(expanded.contains("ipc_traceparent"));
}
#[test]
fn preserves_tracing_options_and_merges_skip_fields() {
let expanded = expand_instrument_ipc(
quote!(err, skip(app), fields(command = "emit_backend_event")),
quote! {
fn emit_backend_event(
app: tauri::AppHandle,
auditaur_trace_context: Option<IpcTraceContext>,
) {}
},
)
.to_string();
assert!(expanded.contains("err"));
assert!(expanded.contains("skip (app , auditaur_trace_context)"));
assert!(expanded.contains("command = \"emit_backend_event\""));
assert!(expanded.contains("traceparent ="));
}
#[test]
fn respects_skip_all_without_adding_skip() {
let expanded = expand_instrument_ipc(
quote!(skip_all),
quote! {
fn load_user(id: String, auditaur_trace_context: Option<IpcTraceContext>) {}
},
)
.to_string();
assert!(expanded.contains("skip_all"));
assert!(!expanded.contains("skip (auditaur_trace_context)"));
assert!(expanded.contains("fields (traceparent ="));
}
#[test]
fn avoids_duplicate_injected_arguments() {
let expanded = expand_instrument_ipc(
quote!(skip(auditaur_trace_context), fields(traceparent = "custom")),
quote! {
fn load_user(id: String, auditaur_trace_context: Option<IpcTraceContext>) {}
},
)
.to_string();
assert!(expanded.contains("skip (auditaur_trace_context)"));
assert!(!expanded.contains("auditaur_trace_context , auditaur_trace_context"));
assert!(expanded.contains("traceparent"));
assert!(expanded.contains("custom"));
assert!(!expanded.contains("ipc_traceparent"));
}
#[test]
fn auditaur_command_wraps_tauri_command_and_injects_request_traceparent() {
let expanded = expand_auditaur_command(
quote!(err, skip(app), fields(command = "load_user")),
quote! {
async fn load_user(app: tauri::AppHandle, id: String) -> Result<String, String> {
Ok(id)
}
},
)
.to_string();
assert!(expanded.contains("tauri :: command"));
assert!(expanded.contains("tracing :: instrument"));
assert!(expanded.contains("auditaur_request :"));
assert!(expanded.contains("tauri :: ipc :: Request < '_ >"));
assert!(expanded.contains("auditaur_trace_context : Option"));
assert!(expanded.contains("skip (app , auditaur_request , auditaur_trace_context)"));
assert!(expanded.contains("ipc_traceparent_from_request_or_context"));
assert!(expanded.contains("& auditaur_request"));
assert!(expanded.contains("auditaur_trace_context . as_ref"));
assert!(expanded.contains("command = \"load_user\""));
assert!(expanded.contains("err"));
}
#[test]
fn auditaur_command_avoids_request_argument_name_collision() {
let expanded = expand_auditaur_command(
quote!(),
quote! {
fn load_user(auditaur_request: String) -> String {
auditaur_request
}
},
)
.to_string();
assert!(expanded.contains("auditaur_request_2 :"));
assert!(expanded.contains("skip (auditaur_request_2 , auditaur_trace_context)"));
assert!(expanded.contains("ipc_traceparent_from_request_or_context"));
assert!(expanded.contains("& auditaur_request_2"));
}
#[test]
fn auditaur_command_rejects_reserved_trace_context_argument() {
let expanded = expand_auditaur_command(
quote!(),
quote! {
fn load_user(auditaur_trace_context: Option<IpcTraceContext>) {}
},
)
.to_string();
assert!(expanded.contains("reserves the `auditaur_trace_context` argument"));
}
}