1use proc_macro::TokenStream;
7use quote::{format_ident, quote, ToTokens};
8use syn::parse::{Parse, ParseStream};
9use syn::{
10 parse_macro_input, FnArg, Ident, ItemFn, ItemStruct, Pat, Signature, Token,
11};
12
13const fn fnv1a_hash_bytes(bytes: &[u8]) -> u32 {
14 const FNV_OFFSET: u32 = 2166136261;
15 const FNV_PRIME: u32 = 16777619;
16
17 let mut hash = FNV_OFFSET;
18 let mut i = 0;
19 while i < bytes.len() {
20 hash ^= bytes[i] as u32;
21 hash = hash.wrapping_mul(FNV_PRIME);
22 i += 1;
23 }
24 hash
25}
26
27const fn fnv1a_hash_str(s: &str) -> u32 {
28 fnv1a_hash_bytes(s.as_bytes())
29}
30
31struct ExportAttrs {
32 name: Option<String>,
33}
34
35impl Parse for ExportAttrs {
36 fn parse(input: ParseStream) -> syn::Result<Self> {
37 let mut name = None;
38
39 while !input.is_empty() {
40 let ident: Ident = input.parse()?;
41 input.parse::<Token![=]>()?;
42 let value: syn::LitStr = input.parse()?;
43
44 if ident == "name" {
45 name = Some(value.value());
46 }
47
48 if !input.is_empty() {
49 input.parse::<Token![,]>()?;
50 }
51 }
52
53 Ok(ExportAttrs { name })
54 }
55}
56
57#[proc_macro_attribute]
58pub fn memlink_export(args: TokenStream, input: TokenStream) -> TokenStream {
59 let attrs = parse_macro_input!(args as ExportAttrs);
60 let mut func = parse_macro_input!(input as ItemFn);
61
62 let method_name = attrs.name.unwrap_or_else(|| func.sig.ident.to_string());
63 let method_hash = fnv1a_hash_str(&method_name);
64
65 let expanded = generate_export_code(&mut func, &method_name, method_hash);
66
67 TokenStream::from(expanded)
68}
69
70fn generate_export_code(func: &mut ItemFn, _method_name: &str, method_hash: u32) -> proc_macro2::TokenStream {
71 let func_name = &func.sig.ident;
72 let _func_vis = &func.vis;
73 let sig = &func.sig;
74
75 let is_async = sig.asyncness.is_some();
76
77 let (_context_param, other_params) = extract_params(sig);
78
79 let args_struct = if !other_params.is_empty() {
80 generate_args_struct(func_name, other_params.clone())
81 } else {
82 quote! {}
83 };
84
85 let wrapper_name = format_ident!("__{}_wrapper", func_name);
86 let wrapper = generate_wrapper(func_name, &wrapper_name, other_params, is_async);
87
88 let ffi_name = format_ident!("__{}_ffi", func_name);
89 let ffi_func = generate_ffi_export(&wrapper_name, &ffi_name, method_hash, is_async);
90
91 let register_func = generate_registration(func_name, method_hash, is_async);
92
93 quote! {
94 #func
95 #args_struct
96 #wrapper
97 #ffi_func
98 #register_func
99 }
100}
101
102fn extract_params(sig: &Signature) -> (Option<&FnArg>, Vec<&FnArg>) {
103 let params = sig.inputs.iter();
104 let mut context_param = None;
105 let mut other_params = Vec::new();
106
107 for param in params {
108 match param {
109 FnArg::Typed(pat_type) => {
110 let type_str = pat_type.ty.to_token_stream().to_string();
111 if type_str.contains("CallContext") {
112 context_param = Some(param);
113 } else {
114 other_params.push(param);
115 }
116 }
117 FnArg::Receiver(_) => {
118 other_params.push(param);
119 }
120 }
121 }
122
123 (context_param, other_params)
124}
125
126fn generate_args_struct(func_name: &Ident, params: Vec<&FnArg>) -> proc_macro2::TokenStream {
127 let args_struct_name = format_ident!("__{}Args", func_name);
128
129 let fields: Vec<_> = params.iter().map(|param| {
130 if let FnArg::Typed(pat_type) = param {
131 let pat = &pat_type.pat;
132 let ty = &pat_type.ty;
133 if let Pat::Ident(ident) = pat.as_ref() {
134 let field_name = &ident.ident;
135 quote! { pub #field_name: #ty }
136 } else {
137 quote! {}
138 }
139 } else {
140 quote! {}
141 }
142 }).collect();
143
144 quote! {
145 #[derive(::serde::Serialize, ::serde::Deserialize)]
146 struct #args_struct_name {
147 #(#fields,)*
148 }
149 }
150}
151
152fn generate_wrapper(
153 func_name: &Ident,
154 wrapper_name: &Ident,
155 params: Vec<&FnArg>,
156 is_async: bool,
157) -> proc_macro2::TokenStream {
158 let args_struct_name = format_ident!("__{}Args", func_name);
159
160 let field_names: Vec<_> = params.iter().filter_map(|param| {
161 if let FnArg::Typed(pat_type) = param {
162 let pat = &pat_type.pat;
163 if let Pat::Ident(ident) = pat.as_ref() {
164 Some(&ident.ident)
165 } else {
166 None
167 }
168 } else {
169 None
170 }
171 }).collect();
172
173 let call_args = if field_names.is_empty() {
174 quote! { ctx }
175 } else {
176 let args_unpack = field_names.iter().map(|name| {
177 quote! { args.#name }
178 });
179 quote! { ctx, #(#args_unpack),* }
180 };
181
182 if is_async {
183 quote! {
184 async fn #wrapper_name(
185 ctx: &memlink_msdk::CallContext<'_>,
186 args_bytes: &[u8],
187 ) -> memlink_msdk::Result<Vec<u8>> {
188 let args: #args_struct_name = memlink_msdk::serialize::default_serializer()
189 .deserialize(args_bytes)
190 .map_err(|e| memlink_msdk::ModuleError::Serialize(e.to_string()))?;
191
192 let result = #func_name(#call_args).await?;
193
194 memlink_msdk::serialize::default_serializer()
195 .serialize(&result)
196 .map_err(|e| memlink_msdk::ModuleError::Serialize(e.to_string()))
197 }
198 }
199 } else {
200 quote! {
201 fn #wrapper_name(
202 ctx: &memlink_msdk::CallContext<'_>,
203 args_bytes: &[u8],
204 ) -> memlink_msdk::Result<Vec<u8>> {
205 let args: #args_struct_name = memlink_msdk::serialize::default_serializer()
206 .deserialize(args_bytes)
207 .map_err(|e| memlink_msdk::ModuleError::Serialize(e.to_string()))?;
208
209 let result = #func_name(#call_args)?;
210
211 memlink_msdk::serialize::default_serializer()
212 .serialize(&result)
213 .map_err(|e| memlink_msdk::ModuleError::Serialize(e.to_string()))
214 }
215 }
216 }
217}
218
219fn generate_ffi_export(
220 wrapper_name: &Ident,
221 ffi_name: &Ident,
222 _method_hash: u32,
223 is_async: bool,
224) -> proc_macro2::TokenStream {
225 if is_async {
226 quote! {
227 #[no_mangle]
228 pub unsafe extern "C" fn #ffi_name(
229 ctx_ptr: *const memlink_msdk::CallContext<'static>,
230 args_ptr: *const u8,
231 args_len: usize,
232 out_ptr: *mut u8,
233 out_cap: usize,
234 ) -> i32 {
235 use memlink_msdk::panic::catch_module_panic;
236 use memlink_msdk::request::Response;
237
238 const CALL_SUCCESS: i32 = 0;
239 const CALL_FAILURE: i32 = -1;
240 const CALL_BUFFER_TOO_SMALL: i32 = -2;
241
242 if args_len > 0 && args_ptr.is_null() {
243 return CALL_FAILURE;
244 }
245 if out_cap > 0 && out_ptr.is_null() {
246 return CALL_FAILURE;
247 }
248
249 let result = catch_module_panic(|| {
250 let ctx = unsafe { &*ctx_ptr };
251 let args = if args_len > 0 {
252 unsafe { std::slice::from_raw_parts(args_ptr, args_len) }.to_vec()
253 } else {
254 vec![]
255 };
256
257 let rt = tokio::runtime::Handle::current();
258 let result = rt.block_on(#wrapper_name(ctx, &args));
259
260 let response = match result {
261 Ok(data) => Response::success(0, data),
262 Err(_) => return CALL_FAILURE,
263 };
264
265 let response_bytes = match response.into_bytes() {
266 Ok(bytes) => bytes,
267 Err(_) => return CALL_FAILURE,
268 };
269
270 if response_bytes.len() > out_cap {
271 return CALL_BUFFER_TOO_SMALL;
272 }
273
274 std::ptr::copy_nonoverlapping(
275 response_bytes.as_ptr(),
276 out_ptr,
277 response_bytes.len(),
278 );
279
280 CALL_SUCCESS
281 });
282
283 match result {
284 Ok(code) => code,
285 Err(_) => CALL_FAILURE,
286 }
287 }
288 }
289 } else {
290 quote! {
291 #[no_mangle]
292 pub unsafe extern "C" fn #ffi_name(
293 ctx_ptr: *const memlink_msdk::CallContext<'static>,
294 args_ptr: *const u8,
295 args_len: usize,
296 out_ptr: *mut u8,
297 out_cap: usize,
298 ) -> i32 {
299 use memlink_msdk::panic::catch_module_panic;
300 use memlink_msdk::request::Response;
301
302 const CALL_SUCCESS: i32 = 0;
303 const CALL_FAILURE: i32 = -1;
304 const CALL_BUFFER_TOO_SMALL: i32 = -2;
305
306 if args_len > 0 && args_ptr.is_null() {
307 return CALL_FAILURE;
308 }
309 if out_cap > 0 && out_ptr.is_null() {
310 return CALL_FAILURE;
311 }
312
313 let result = catch_module_panic(|| {
314 let ctx = unsafe { &*ctx_ptr };
315 let args = if args_len > 0 {
316 unsafe { std::slice::from_raw_parts(args_ptr, args_len) }.to_vec()
317 } else {
318 vec![]
319 };
320
321 let result = #wrapper_name(ctx, &args);
322
323 let response = match result {
324 Ok(data) => Response::success(0, data),
325 Err(_) => return CALL_FAILURE,
326 };
327
328 let response_bytes = match response.into_bytes() {
329 Ok(bytes) => bytes,
330 Err(_) => return CALL_FAILURE,
331 };
332
333 if response_bytes.len() > out_cap {
334 return CALL_BUFFER_TOO_SMALL;
335 }
336
337 std::ptr::copy_nonoverlapping(
338 response_bytes.as_ptr(),
339 out_ptr,
340 response_bytes.len(),
341 );
342
343 CALL_SUCCESS
344 });
345
346 match result {
347 Ok(code) => code,
348 Err(_) => CALL_FAILURE,
349 }
350 }
351 }
352 }
353}
354
355fn generate_registration(
356 func_name: &Ident,
357 method_hash: u32,
358 _is_async: bool,
359) -> proc_macro2::TokenStream {
360 let register_func_name = format_ident!("__{}_register", func_name);
361
362 quote! {
363 #[used]
364 static #register_func_name: unsafe extern "C" fn() = {
365 unsafe extern "C" fn register() {
366 }
367 register
368 };
369
370 const _: () = {
371 const _HASH: u32 = #method_hash;
372 };
373 }
374}
375
376#[proc_macro_attribute]
377pub fn memlink_module(_args: TokenStream, input: TokenStream) -> TokenStream {
378 let item = parse_macro_input!(input as ItemStruct);
379 let _struct_name = &item.ident;
380
381 let expanded = quote! {
382 #item
383
384 #[no_mangle]
385 pub unsafe extern "C" fn memlink_init(
386 config_ptr: *const u8,
387 config_len: usize,
388 arena_ptr: *mut u8,
389 arena_capacity: usize,
390 ) -> i32 {
391 use memlink_msdk::exports::{init_arena, INIT_SUCCESS, INIT_FAILURE};
392
393 if !arena_ptr.is_null() && arena_capacity > 0 {
394 init_arena(arena_ptr, arena_capacity);
395 }
396
397 __register_all_methods();
398
399 INIT_SUCCESS
400 }
401
402 fn __register_all_methods() {
403 }
404 };
405
406 TokenStream::from(expanded)
407}