libhotpatch_macros/
lib.rs1use proc_macro2::{Span, TokenStream};
2use quote::{ToTokens, quote};
3use syn::{
4 Abi, FnArg, ImplItemFn, LitByteStr, LitStr, Pat, PatWild, Token, parse_macro_input,
5 parse_quote, token::Extern,
6};
7
8use crate::{args::Args, hotpatch_fn::HotpatchFn};
9
10mod args;
11mod hotpatch_fn;
12
13#[proc_macro_attribute]
14pub fn hotpatch(
15 args: proc_macro::TokenStream,
16 input: proc_macro::TokenStream,
17) -> proc_macro::TokenStream {
18 let Args { is_checked } = parse_macro_input!(args as Args);
19 let hotpatch_fn = parse_macro_input!(input as HotpatchFn);
20
21 if is_checked {
22 hotpatch_checked(hotpatch_fn)
23 } else {
24 hotpatch_unchecked(hotpatch_fn)
25 }
26 .into()
27}
28
29fn hotpatch_checked(HotpatchFn { inner, outer }: HotpatchFn) -> TokenStream {
30 let ImplItemFn {
31 attrs,
32 vis,
33 defaultness,
34 sig,
35 ..
36 } = outer;
37
38 let sig_str = quote!(sig).to_string();
39 let sig_lit = LitByteStr::new(sig_str.as_bytes(), Span::call_site());
40
41 let outer_fn = &sig.ident;
42 let inner_fn = &inner.sig.ident;
43
44 let args = inner.sig.inputs.iter().map(|input| match input {
45 FnArg::Receiver(_) => parse_quote!(self),
46 FnArg::Typed(typed) => fn_input_pat_to_ts(&typed.pat),
47 });
48
49 let tuple_args_outer = args.clone();
50 let tuple_args_inner = args.clone();
51
52 quote! {
53 #(#attrs)*
54 #vis #defaultness #sig {
55 extern "C-unwind" fn checked_call(ptr: *const u8, len: usize) -> libhotpatch::BoxedSlice<u8> {
56 #inner
57 let (#(#tuple_args_inner,)*) = unsafe {
58 libhotpatch::rmp_serde::from_slice(::std::slice::from_raw_parts(ptr, len))
59 .expect("checked hot-patch input deserialization failed")
60 };
61 let output = unsafe {
62 libhotpatch::rmp_serde::to_vec_named(&#inner_fn(#(#args,)*))
63 .expect("checked hot-patch output serialization failed")
64 };
65 libhotpatch::BoxedSlice::new(&output)
66 }
67 fn type_of() -> (u128, &'static str) {
68 let name = ::std::any::type_name_of_val(&#outer_fn);
69 let mut hasher = libhotpatch::Xxh3::new();
70 ::std::hash::Hash::hash(#sig_lit, &mut hasher);
71 ::std::hash::Hash::hash(name.as_bytes(), &mut hasher);
72 (hasher.digest128(), name)
73 }
74 #[libhotpatch::distributed_slice(libhotpatch::HOTPATCH_FN)]
75 #[linkme(crate = libhotpatch::linkme)]
76 static HOTPATCH_FN: (
77 ::std::sync::atomic::AtomicPtr<()>,
78 libhotpatch::LibraryHandle,
79 fn() -> (u128, &'static str),
80 ) = (
81 ::std::sync::atomic::AtomicPtr::new(checked_call as *mut ()),
82 libhotpatch::LibraryHandle::null(),
83 type_of,
84 );
85 libhotpatch::Watcher::get().map(libhotpatch::Watcher::poll);
86 let library_handle = HOTPATCH_FN.1.clone();
87 let serialized = libhotpatch::rmp_serde::to_vec_named(&(#(#tuple_args_outer,)*))
88 .expect("checked hot-patch input serialization failed");
89 let serialized_output = unsafe {
90 ::std::mem::transmute::<_, extern "C-unwind" fn(_, _) -> libhotpatch::BoxedSlice<u8>>(
91 HOTPATCH_FN.0.load(::std::sync::atomic::Ordering::Relaxed))
92 (serialized.as_ptr(), serialized.len())
93 };
94 libhotpatch::rmp_serde::from_slice(&serialized_output)
95 .expect("checked hot-patch output deserialization failed")
96 }
97 }
98}
99
100fn hotpatch_unchecked(HotpatchFn { mut inner, outer }: HotpatchFn) -> TokenStream {
101 let ImplItemFn {
102 attrs,
103 vis,
104 defaultness,
105 sig,
106 ..
107 } = outer;
108
109 inner
110 .attrs
111 .push(parse_quote!(#[allow(improper_ctypes_definitions)]));
112
113 let abi = inner
114 .sig
115 .abi
116 .get_or_insert_with(|| Abi {
117 extern_token: Extern(Span::call_site()),
118 name: Some(LitStr::new("C-unwind", Span::call_site())),
119 })
120 .clone();
121
122 let sig_str = quote!(sig).to_string();
123 let sig_lit = LitByteStr::new(sig_str.as_bytes(), Span::call_site());
124
125 let outer_fn = &sig.ident;
126 let inner_fn = &inner.sig.ident;
127
128 let args = inner.sig.inputs.iter().map(|input| match input {
129 FnArg::Receiver(_) => parse_quote!(self),
130 FnArg::Typed(typed) => fn_input_pat_to_ts(&typed.pat),
131 });
132
133 let wild = inner.sig.inputs.iter().map(|_| PatWild {
134 attrs: vec![],
135 underscore_token: Token),
136 });
137
138 quote! {
139 #(#attrs)*
140 #vis #defaultness #sig {
141 #inner
142 fn type_of() -> (u128, &'static str) {
143 let name = ::std::any::type_name_of_val(&#outer_fn);
144 let mut hasher = libhotpatch::Xxh3::new();
145 ::std::hash::Hash::hash(#sig_lit, &mut hasher);
146 ::std::hash::Hash::hash(name.as_bytes(), &mut hasher);
147 (hasher.digest128(), name)
148 }
149 #[libhotpatch::distributed_slice(libhotpatch::HOTPATCH_FN)]
150 #[linkme(crate = libhotpatch::linkme)]
151 static HOTPATCH_FN: (
152 ::std::sync::atomic::AtomicPtr<()>,
153 libhotpatch::LibraryHandle,
154 fn() -> (u128, &'static str),
155 ) = (
156 ::std::sync::atomic::AtomicPtr::new(#inner_fn as *mut ()),
157 libhotpatch::LibraryHandle::null(),
158 type_of,
159 );
160 libhotpatch::Watcher::get().map(libhotpatch::Watcher::poll);
161 let library_handle = HOTPATCH_FN.1.clone();
162 unsafe {
163 ::std::mem::transmute::<_, #abi fn(#(#wild,)*) -> _>(
164 HOTPATCH_FN.0.load(::std::sync::atomic::Ordering::Relaxed))
165 (#(#args,)*)
166 }
167 }
168 }
169}
170
171fn fn_input_pat_to_ts(pat: &Pat) -> TokenStream {
172 match pat {
173 Pat::Ident(pat_ident) => pat_ident.ident.clone().to_token_stream(),
174 Pat::Paren(pat_paren) => fn_input_pat_to_ts(&pat_paren.pat),
175 Pat::Reference(pat_ref) => fn_input_pat_to_ts(&pat_ref.pat),
176 Pat::Tuple(pat_tuple) => {
177 let elems = &pat_tuple.elems;
178 parse_quote!((#elems))
179 }
180 Pat::Struct(pat_struct) => {
181 let path = &pat_struct.path;
182 let members = pat_struct.fields.iter().map(|field_pat| &field_pat.member);
183
184 parse_quote!(#path { #(#members,)* })
185 }
186 Pat::TupleStruct(pat_tstruct) => {
187 let path = &pat_tstruct.path;
188 let elems = pat_tstruct.elems.iter().map(fn_input_pat_to_ts);
189
190 parse_quote!(#path(#(#elems,)*))
191 }
192 _ => panic!("unsupported type pattern in function input position"),
193 }
194}