use crate::DataType;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
pub fn construct_type_check(
provider_name: &str,
probe_name: &str,
use_statements: &[syn::ItemUse],
types: &[DataType],
) -> TokenStream {
if types.is_empty() {
return quote! {
let _: () = ($args_lambda)();
};
}
let type_check_params = types
.iter()
.map(|typ| match typ {
DataType::Serializable(ty) => {
match ty {
syn::Type::Reference(reference) => {
if let Some(elem) = shared_slice_elem_type(reference) {
quote! { _: impl AsRef<[#elem]> }
} else {
let elem = &*reference.elem;
quote! { _: impl ::std::borrow::Borrow<#elem> }
}
}
syn::Type::Slice(slice) => {
let elem = &*slice.elem;
quote! { _: impl AsRef<[#elem]> }
}
syn::Type::Array(array) => {
let elem = &*array.elem;
quote! { _: impl AsRef<[#elem]> }
}
syn::Type::Path(_) => {
quote! { _: impl ::std::borrow::Borrow<#ty> }
}
_ => {
quote! { _: #ty }
}
}
}
DataType::Native(dtrace_parser::DataType::String) => quote! { _: impl AsRef<str> },
_ => {
let arg = typ.to_rust_type();
quote! { _: impl ::std::borrow::Borrow<#arg> }
}
})
.collect::<Vec<_>>();
let type_check_args = (0..types.len())
.map(|i| {
let index = syn::Index::from(i);
quote! { args.#index }
})
.collect::<Vec<_>>();
let type_check_fn = format_ident!("__usdt_private_{}_{}_type_check", provider_name, probe_name);
quote! {
#[allow(unused_imports)]
#(#use_statements)*
#[allow(non_snake_case)]
fn #type_check_fn(#(#type_check_params),*) {}
let _ = || { #type_check_fn(#(#type_check_args),*); };
}
}
fn shared_slice_elem_type(reference: &syn::TypeReference) -> Option<&syn::Type> {
if let syn::Type::Slice(slice) = &*reference.elem {
Some(&*slice.elem)
} else {
None
}
}
pub fn construct_probe_args(types: &[DataType]) -> (TokenStream, TokenStream) {
#[cfg(target_arch = "x86_64")]
let abi_regs = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"];
#[cfg(target_arch = "aarch64")]
let abi_regs = ["x0", "x1", "x2", "x3", "x4", "x5"];
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
compile_error!("USDT only supports x86_64 and ARM64 architectures");
assert!(
types.len() <= abi_regs.len(),
"Up to 6 probe arguments are currently supported"
);
let (unpacked_args, in_regs): (Vec<_>, Vec<_>) = types
.iter()
.zip(&abi_regs)
.enumerate()
.map(|(i, (typ, reg))| {
let arg = format_ident!("arg_{}", i);
let index = syn::Index::from(i);
let input = quote! { args.#index };
let (value, at_use) = asm_type_convert(typ, input);
let destructured_arg = quote! {
let #arg = #value;
};
let register_arg = quote! { in(#reg) (#arg #at_use) };
(destructured_arg, register_arg)
})
.unzip();
let arg_lambda = call_argument_closure(types);
let unpacked_args = quote! {
#arg_lambda
#(#unpacked_args)*
};
let in_regs = quote! { #(#in_regs,)* };
(unpacked_args, in_regs)
}
pub fn call_argument_closure(types: &[DataType]) -> TokenStream {
match types.len() {
0 => quote! {},
1 => quote! { let args = (($args_lambda)(),); },
_ => quote! { let args = ($args_lambda)(); },
}
}
fn asm_type_convert(typ: &DataType, input: TokenStream) -> (TokenStream, TokenStream) {
match typ {
DataType::Serializable(_) => (
quote! {
[
match ::usdt::to_json(&#input) {
Ok(json) => format!("{{\"ok\":{}}}", json),
Err(e) => format!("{{\"err\":\"{}\"}}", e.to_string()),
}.as_bytes(),
&[0_u8]
].concat()
},
quote! { .as_ptr() as usize },
),
DataType::Native(dtrace_parser::DataType::String) => (
quote! {
[(#input.as_ref() as &str).as_bytes(), &[0_u8]].concat()
},
quote! { .as_ptr() as usize },
),
DataType::Native(_) => {
let ty = typ.to_rust_type();
(
quote! { (*<_ as ::std::borrow::Borrow<#ty>>::borrow(&#input) as usize) },
quote! {},
)
}
DataType::UniqueId => (quote! { #input.as_u64() as usize }, quote! {}),
}
}
pub(crate) fn build_probe_macro(
config: &crate::CompileProvidersConfig,
probe_name: &str,
types: &[DataType],
impl_block: TokenStream,
) -> TokenStream {
let module = config.module_ident();
let macro_name = config.probe_ident(probe_name);
let no_args_match = if types.is_empty() {
quote! { () => { crate::#module::#macro_name!(|| ()) }; }
} else {
quote! {}
};
quote! {
#[allow(unused_macros)]
macro_rules! #macro_name {
#no_args_match
($tree:tt) => {
compile_error!("USDT probe macros should be invoked with a closure returning the arguments");
};
($args_lambda:expr) => {
{
#impl_block
}
};
}
#[allow(unused_imports)]
pub(crate) use #macro_name;
}
}
#[cfg(test)]
mod tests {
use super::*;
use dtrace_parser::BitWidth;
use dtrace_parser::DataType as DType;
use dtrace_parser::Integer;
use dtrace_parser::Sign;
#[test]
fn test_construct_type_check_empty() {
let expected = quote! {
let _ : () = ($args_lambda)();
};
let block = construct_type_check("", "", &[], &[]);
assert_eq!(block.to_string(), expected.to_string());
}
#[test]
fn test_construct_type_check_native() {
let provider = "provider";
let probe = "probe";
let types = &[
DataType::Native(DType::Integer(Integer {
sign: Sign::Unsigned,
width: BitWidth::Bit8,
})),
DataType::Native(DType::Integer(Integer {
sign: Sign::Signed,
width: BitWidth::Bit64,
})),
];
let expected = quote! {
#[allow(unused_imports)]
#[allow(non_snake_case)]
fn __usdt_private_provider_probe_type_check(
_: impl ::std::borrow::Borrow<u8>,
_: impl ::std::borrow::Borrow<i64>
) { }
let _ = || {
__usdt_private_provider_probe_type_check(args.0, args.1);
};
};
let block = construct_type_check(provider, probe, &[], types);
assert_eq!(block.to_string(), expected.to_string());
}
#[test]
fn test_construct_type_check_with_string() {
let provider = "provider";
let probe = "probe";
let types = &[DataType::Native(dtrace_parser::DataType::String)];
let use_statements = vec![];
let expected = quote! {
#[allow(unused_imports)]
#[allow(non_snake_case)]
fn __usdt_private_provider_probe_type_check(_: impl AsRef<str>) { }
let _ = || {
__usdt_private_provider_probe_type_check(args.0);
};
};
let block = construct_type_check(provider, probe, &use_statements, types);
assert_eq!(block.to_string(), expected.to_string());
}
#[test]
fn test_construct_type_check_with_shared_slice() {
let provider = "provider";
let probe = "probe";
let types = &[DataType::Serializable(syn::parse_str("&[u8]").unwrap())];
let use_statements = vec![];
let expected = quote! {
#[allow(unused_imports)]
#[allow(non_snake_case)]
fn __usdt_private_provider_probe_type_check(_: impl AsRef<[u8]>) { }
let _ = || {
__usdt_private_provider_probe_type_check(args.0);
};
};
let block = construct_type_check(provider, probe, &use_statements, types);
assert_eq!(block.to_string(), expected.to_string());
}
#[test]
fn test_construct_type_check_with_custom_type() {
let provider = "provider";
let probe = "probe";
let types = &[DataType::Serializable(syn::parse_str("MyType").unwrap())];
let use_statements = vec![syn::parse2(quote! { use my_module::MyType; }).unwrap()];
let expected = quote! {
#[allow(unused_imports)]
use my_module::MyType;
#[allow(non_snake_case)]
fn __usdt_private_provider_probe_type_check(_: impl ::std::borrow::Borrow<MyType>) { }
let _ = || {
__usdt_private_provider_probe_type_check(args.0);
};
};
let block = construct_type_check(provider, probe, &use_statements, types);
assert_eq!(block.to_string(), expected.to_string());
}
#[test]
fn test_construct_probe_args() {
let types = &[
DataType::Native(DType::Pointer(Integer {
sign: Sign::Unsigned,
width: BitWidth::Bit8,
})),
DataType::Native(dtrace_parser::DataType::String),
];
#[cfg(target_arch = "x86_64")]
let registers = ["rdi", "rsi"];
#[cfg(target_arch = "aarch64")]
let registers = ["x0", "x1"];
let (args, regs) = construct_probe_args(types);
let expected = quote! {
let args = ($args_lambda)();
let arg_0 = (*<_ as ::std::borrow::Borrow<*const u8>>::borrow(&args.0) as usize);
let arg_1 = [(args.1.as_ref() as &str).as_bytes(), &[0_u8]].concat();
};
assert_eq!(args.to_string(), expected.to_string());
for (i, (expected, actual)) in registers
.iter()
.zip(regs.to_string().split(','))
.enumerate()
{
let reg = actual.replace(' ', "");
let expected = format!("in(\"{}\")(arg_{}", expected, i);
assert!(
reg.starts_with(&expected),
"reg: {}; expected {}",
reg,
expected,
);
}
}
#[test]
fn test_asm_type_convert() {
use std::str::FromStr;
let (out, post) = asm_type_convert(
&DataType::Native(DType::Integer(Integer {
sign: Sign::Unsigned,
width: BitWidth::Bit8,
})),
TokenStream::from_str("foo").unwrap(),
);
assert_eq!(
out.to_string(),
quote! {(*<_ as ::std::borrow::Borrow<u8>>::borrow(&foo) as usize)}.to_string()
);
assert_eq!(post.to_string(), quote! {}.to_string());
let (out, post) = asm_type_convert(
&DataType::Native(dtrace_parser::DataType::String),
TokenStream::from_str("foo").unwrap(),
);
assert_eq!(
out.to_string(),
quote! { [(foo.as_ref() as &str).as_bytes(), &[0_u8]].concat() }.to_string()
);
assert_eq!(post.to_string(), quote! { .as_ptr() as usize }.to_string());
}
}