1use proc_macro::TokenStream;
8use quote::{ToTokens, format_ident, quote};
9use syn::parse::Parse;
10use syn::{ItemFn, LitStr, parse_macro_input};
11struct HookableProcArgs {
12 name: LitStr,
13 }
15
16impl Parse for HookableProcArgs {
17 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
18 let name = input.parse::<LitStr>()?;
19 Ok(HookableProcArgs {
21 name,
22 })
24 }
25}
26
27fn gen_args_name_list(f: &ItemFn) -> proc_macro2::TokenStream {
28 let mut args = Vec::new();
30 for arg in f.sig.inputs.iter() {
31 if let syn::FnArg::Typed(pat_type) = arg {
32 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
33 args.push(pat_ident.ident.clone());
34 } else {
35 panic!("Argument pattern is not supported");
36 }
37 }
38 }
39 quote! {
40 #(#args),*
41 }
42}
43
44fn get_hookable_lifetime(f: &ItemFn) -> Option<proc_macro2::TokenStream> {
45 if f.sig.generics.where_clause.is_some() {
46 panic!("Where clause is not supported");
47 }
48 match f.sig.generics.params.iter().count() {
49 0 => None,
50 1 => {
51 let g = f.sig.generics.params.iter().next().unwrap();
52 if let syn::GenericParam::Lifetime(lifetime) = g {
53 Some(quote! { #lifetime })
54 } else {
55 panic!(
56 "Hookable cannot be used with generic '{}'",
57 g.to_token_stream()
58 );
59 }
60 }
61 _ => panic!(
62 "Hookable cannot be used with more than one generics <{}>",
63 f.sig.generics.params.to_token_stream()
64 ),
65 }
66}
67#[proc_macro_attribute]
83pub fn hookable(args: TokenStream, input: TokenStream) -> TokenStream {
84 let args = parse_macro_input!(args as HookableProcArgs);
85 let input_fn = parse_macro_input!(input as ItemFn);
86
87 let input_fn_ident = input_fn.sig.ident.clone();
88
89 let _ = get_hookable_lifetime(&input_fn);
90 let generics = input_fn.sig.generics.clone();
91
92 let input_type = input_fn
93 .sig
94 .inputs
95 .iter()
96 .map(|arg| match arg {
97 syn::FnArg::Typed(pat_type) => {
98 if !matches!(&*pat_type.pat, syn::Pat::Ident(_)) {
99 panic!("Argument pattern is not supported");
100 }
101 pat_type.ty.clone()
102 }
103 syn::FnArg::Receiver(_) => panic!("Method receiver (self) is not supported"),
104 })
105 .collect::<Vec<_>>();
106 let input_type_with_static_lifetime = input_type
107 .iter()
108 .map(|ty| {
109 if let syn::Type::Reference(ref_ty) = &**ty {
110 let mut ref_ty = ref_ty.clone();
111 ref_ty.lifetime = Some(syn::Lifetime::new("'static", proc_macro2::Span::call_site()));
112 quote! { #ref_ty }
113 } else {
114 quote! { #ty }
115 }
116 })
117 .collect::<Vec<_>>();
118
119 let ret_type = match &input_fn.sig.output {
120 syn::ReturnType::Default => quote! { () },
121 syn::ReturnType::Type(_, ty) => quote! { #ty },
122 };
123
124 let func_type = quote! {
125 fn(#(#input_type),*) -> #ret_type
126 };
127
128 let hookable_name = args.name;
129
130 let args_name_list = gen_args_name_list(&input_fn);
131
132 let mut inner_fn = input_fn.clone();
133 inner_fn.sig.ident = format_ident!("__hookable_inner");
134 let fn_sig = &input_fn.sig;
135
136 let unpack_list: proc_macro2::TokenStream = (0..input_fn.sig.inputs.len())
137 .map(|i| {
138 let idx = syn::Index::from(i);
139 quote! { args.#idx, }
140 })
141 .collect();
142
143 let generated = quote! {
145 #fn_sig {
146 #inner_fn
147
148 use ::safe_hook::HookableFuncMetadata;
149 use ::core::sync::atomic::AtomicBool;
150 use ::std::sync::LazyLock;
151 use ::std::sync::atomic::Ordering;
152
153 type SelfFunc #generics = #func_type;
154
155 static FLAG: AtomicBool = AtomicBool::new(false);
156 static META: LazyLock<HookableFuncMetadata> = LazyLock::new(|| {
157 let metadata = unsafe {
158 HookableFuncMetadata::new(
159 #hookable_name.to_string(),
160 #input_fn_ident as *const (),
161 (
162 std::any::TypeId::of::<#ret_type>(),
163 std::any::TypeId::of::<(#(#input_type_with_static_lifetime),*)>(),
164 ),
165 &FLAG,
166 )
167 };
168 metadata
169 });
170 ::safe_hook::inventory::submit! {
171 ::safe_hook::HookableFuncRegistry::new(&META)
172 }
173 if !FLAG.load(Ordering::Acquire) {
174 return __hookable_inner(#args_name_list);
175 }
176 ::safe_hook::call_with_hook::<#ret_type, (#(#input_type),*)>(|args| __hookable_inner(#unpack_list), &META, (#args_name_list))
177 }
178 };
179 generated.into()
180}