Skip to main content

libhotpatch_macros/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{ToTokens, quote};
3use syn::{
4    Abi, FnArg, ImplItemFn, LitByteStr, LitStr, Pat, PatWild, Token, parse_macro_input,
5    parse_quote, token::Extern,
6};
7
8use crate::{args::Args, hotpatch_fn::HotpatchFn};
9
10mod args;
11mod hotpatch_fn;
12
13#[proc_macro_attribute]
14pub fn hotpatch(
15    args: proc_macro::TokenStream,
16    input: proc_macro::TokenStream,
17) -> proc_macro::TokenStream {
18    let Args { is_checked } = parse_macro_input!(args as Args);
19    let hotpatch_fn = parse_macro_input!(input as HotpatchFn);
20
21    if is_checked {
22        hotpatch_checked(hotpatch_fn)
23    } else {
24        hotpatch_unchecked(hotpatch_fn)
25    }
26    .into()
27}
28
29fn hotpatch_checked(HotpatchFn { inner, outer }: HotpatchFn) -> TokenStream {
30    let ImplItemFn {
31        attrs,
32        vis,
33        defaultness,
34        sig,
35        ..
36    } = outer;
37
38    let sig_str = quote!(sig).to_string();
39    let sig_lit = LitByteStr::new(sig_str.as_bytes(), Span::call_site());
40
41    let outer_fn = &sig.ident;
42    let inner_fn = &inner.sig.ident;
43
44    let args = inner.sig.inputs.iter().map(|input| match input {
45        FnArg::Receiver(_) => parse_quote!(self),
46        FnArg::Typed(typed) => fn_input_pat_to_ts(&typed.pat),
47    });
48
49    let tuple_args_outer = args.clone();
50    let tuple_args_inner = args.clone();
51
52    quote! {
53        #(#attrs)*
54        #vis #defaultness #sig {
55            extern "C-unwind" fn checked_call(ptr: *const u8, len: usize) -> libhotpatch::BoxedSlice<u8> {
56                #inner
57                let (#(#tuple_args_inner,)*) = unsafe {
58                    libhotpatch::rmp_serde::from_slice(::std::slice::from_raw_parts(ptr, len))
59                        .expect("checked hot-patch input deserialization failed")
60                };
61                let output = unsafe {
62                    libhotpatch::rmp_serde::to_vec_named(&#inner_fn(#(#args,)*))
63                        .expect("checked hot-patch output serialization failed")
64                };
65                libhotpatch::BoxedSlice::new(&output)
66            }
67            fn type_of() -> (u128, &'static str) {
68                let name = ::std::any::type_name_of_val(&#outer_fn);
69                let mut hasher = libhotpatch::Xxh3::new();
70                ::std::hash::Hash::hash(#sig_lit, &mut hasher);
71                ::std::hash::Hash::hash(name.as_bytes(), &mut hasher);
72                (hasher.digest128(), name)
73            }
74            #[libhotpatch::distributed_slice(libhotpatch::HOTPATCH_FN)]
75            #[linkme(crate = libhotpatch::linkme)]
76            static HOTPATCH_FN: (
77                ::std::sync::atomic::AtomicPtr<()>,
78                libhotpatch::LibraryHandle,
79                fn() -> (u128, &'static str),
80            ) = (
81                ::std::sync::atomic::AtomicPtr::new(checked_call as *mut ()),
82                libhotpatch::LibraryHandle::null(),
83                type_of,
84            );
85            libhotpatch::Watcher::get().map(libhotpatch::Watcher::poll);
86            let library_handle = HOTPATCH_FN.1.clone();
87            let serialized = libhotpatch::rmp_serde::to_vec_named(&(#(#tuple_args_outer,)*))
88                .expect("checked hot-patch input serialization failed");
89            let serialized_output = unsafe {
90                ::std::mem::transmute::<_, extern "C-unwind" fn(_, _) -> libhotpatch::BoxedSlice<u8>>(
91                    HOTPATCH_FN.0.load(::std::sync::atomic::Ordering::Relaxed))
92                        (serialized.as_ptr(), serialized.len())
93            };
94            libhotpatch::rmp_serde::from_slice(&serialized_output)
95                .expect("checked hot-patch output deserialization failed")
96        }
97    }
98}
99
100fn hotpatch_unchecked(HotpatchFn { mut inner, outer }: HotpatchFn) -> TokenStream {
101    let ImplItemFn {
102        attrs,
103        vis,
104        defaultness,
105        sig,
106        ..
107    } = outer;
108
109    inner
110        .attrs
111        .push(parse_quote!(#[allow(improper_ctypes_definitions)]));
112
113    let abi = inner
114        .sig
115        .abi
116        .get_or_insert_with(|| Abi {
117            extern_token: Extern(Span::call_site()),
118            name: Some(LitStr::new("C-unwind", Span::call_site())),
119        })
120        .clone();
121
122    let sig_str = quote!(sig).to_string();
123    let sig_lit = LitByteStr::new(sig_str.as_bytes(), Span::call_site());
124
125    let outer_fn = &sig.ident;
126    let inner_fn = &inner.sig.ident;
127
128    let args = inner.sig.inputs.iter().map(|input| match input {
129        FnArg::Receiver(_) => parse_quote!(self),
130        FnArg::Typed(typed) => fn_input_pat_to_ts(&typed.pat),
131    });
132
133    let wild = inner.sig.inputs.iter().map(|_| PatWild {
134        attrs: vec![],
135        underscore_token: Token![_](Span::call_site()),
136    });
137
138    quote! {
139        #(#attrs)*
140        #vis #defaultness #sig {
141            #inner
142            fn type_of() -> (u128, &'static str) {
143                let name = ::std::any::type_name_of_val(&#outer_fn);
144                let mut hasher = libhotpatch::Xxh3::new();
145                ::std::hash::Hash::hash(#sig_lit, &mut hasher);
146                ::std::hash::Hash::hash(name.as_bytes(), &mut hasher);
147                (hasher.digest128(), name)
148            }
149            #[libhotpatch::distributed_slice(libhotpatch::HOTPATCH_FN)]
150            #[linkme(crate = libhotpatch::linkme)]
151            static HOTPATCH_FN: (
152                ::std::sync::atomic::AtomicPtr<()>,
153                libhotpatch::LibraryHandle,
154                fn() -> (u128, &'static str),
155            ) = (
156                ::std::sync::atomic::AtomicPtr::new(#inner_fn as *mut ()),
157                libhotpatch::LibraryHandle::null(),
158                type_of,
159            );
160            libhotpatch::Watcher::get().map(libhotpatch::Watcher::poll);
161            let library_handle = HOTPATCH_FN.1.clone();
162            unsafe {
163                ::std::mem::transmute::<_, #abi fn(#(#wild,)*) -> _>(
164                    HOTPATCH_FN.0.load(::std::sync::atomic::Ordering::Relaxed))
165                        (#(#args,)*)
166            }
167        }
168    }
169}
170
171fn fn_input_pat_to_ts(pat: &Pat) -> TokenStream {
172    match pat {
173        Pat::Ident(pat_ident) => pat_ident.ident.clone().to_token_stream(),
174        Pat::Paren(pat_paren) => fn_input_pat_to_ts(&pat_paren.pat),
175        Pat::Reference(pat_ref) => fn_input_pat_to_ts(&pat_ref.pat),
176        Pat::Tuple(pat_tuple) => {
177            let elems = &pat_tuple.elems;
178            parse_quote!((#elems))
179        }
180        Pat::Struct(pat_struct) => {
181            let path = &pat_struct.path;
182            let members = pat_struct.fields.iter().map(|field_pat| &field_pat.member);
183
184            parse_quote!(#path { #(#members,)* })
185        }
186        Pat::TupleStruct(pat_tstruct) => {
187            let path = &pat_tstruct.path;
188            let elems = pat_tstruct.elems.iter().map(fn_input_pat_to_ts);
189
190            parse_quote!(#path(#(#elems,)*))
191        }
192        _ => panic!("unsupported type pattern in function input position"),
193    }
194}