Skip to main content

memlink_msdk_macros/
lib.rs

1//! Procedural macros for memlink SDK.
2//!
3//! Provides the `#[memlink_export]` attribute macro for exporting Rust functions
4//! as memlink module methods with automatic serialization and FFI bindings.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote, ToTokens};
8use syn::parse::{Parse, ParseStream};
9use syn::{
10    parse_macro_input, FnArg, Ident, ItemFn, ItemStruct, Pat, Signature, Token,
11};
12
13const fn fnv1a_hash_bytes(bytes: &[u8]) -> u32 {
14    const FNV_OFFSET: u32 = 2166136261;
15    const FNV_PRIME: u32 = 16777619;
16
17    let mut hash = FNV_OFFSET;
18    let mut i = 0;
19    while i < bytes.len() {
20        hash ^= bytes[i] as u32;
21        hash = hash.wrapping_mul(FNV_PRIME);
22        i += 1;
23    }
24    hash
25}
26
27const fn fnv1a_hash_str(s: &str) -> u32 {
28    fnv1a_hash_bytes(s.as_bytes())
29}
30
31struct ExportAttrs {
32    name: Option<String>,
33}
34
35impl Parse for ExportAttrs {
36    fn parse(input: ParseStream) -> syn::Result<Self> {
37        let mut name = None;
38
39        while !input.is_empty() {
40            let ident: Ident = input.parse()?;
41            input.parse::<Token![=]>()?;
42            let value: syn::LitStr = input.parse()?;
43
44            if ident == "name" {
45                name = Some(value.value());
46            }
47
48            if !input.is_empty() {
49                input.parse::<Token![,]>()?;
50            }
51        }
52
53        Ok(ExportAttrs { name })
54    }
55}
56
57#[proc_macro_attribute]
58pub fn memlink_export(args: TokenStream, input: TokenStream) -> TokenStream {
59    let attrs = parse_macro_input!(args as ExportAttrs);
60    let mut func = parse_macro_input!(input as ItemFn);
61
62    let method_name = attrs.name.unwrap_or_else(|| func.sig.ident.to_string());
63    let method_hash = fnv1a_hash_str(&method_name);
64
65    let expanded = generate_export_code(&mut func, &method_name, method_hash);
66
67    TokenStream::from(expanded)
68}
69
70fn generate_export_code(func: &mut ItemFn, _method_name: &str, method_hash: u32) -> proc_macro2::TokenStream {
71    let func_name = &func.sig.ident;
72    let _func_vis = &func.vis;
73    let sig = &func.sig;
74
75    let is_async = sig.asyncness.is_some();
76
77    let (_context_param, other_params) = extract_params(sig);
78
79    let args_struct = if !other_params.is_empty() {
80        generate_args_struct(func_name, other_params.clone())
81    } else {
82        quote! {}
83    };
84
85    let wrapper_name = format_ident!("__{}_wrapper", func_name);
86    let wrapper = generate_wrapper(func_name, &wrapper_name, other_params, is_async);
87
88    let ffi_name = format_ident!("__{}_ffi", func_name);
89    let ffi_func = generate_ffi_export(&wrapper_name, &ffi_name, method_hash, is_async);
90
91    let register_func = generate_registration(func_name, method_hash, is_async);
92
93    quote! {
94        #func
95        #args_struct
96        #wrapper
97        #ffi_func
98        #register_func
99    }
100}
101
102fn extract_params(sig: &Signature) -> (Option<&FnArg>, Vec<&FnArg>) {
103    let params = sig.inputs.iter();
104    let mut context_param = None;
105    let mut other_params = Vec::new();
106
107    for param in params {
108        match param {
109            FnArg::Typed(pat_type) => {
110                let type_str = pat_type.ty.to_token_stream().to_string();
111                if type_str.contains("CallContext") {
112                    context_param = Some(param);
113                } else {
114                    other_params.push(param);
115                }
116            }
117            FnArg::Receiver(_) => {
118                other_params.push(param);
119            }
120        }
121    }
122
123    (context_param, other_params)
124}
125
126fn generate_args_struct(func_name: &Ident, params: Vec<&FnArg>) -> proc_macro2::TokenStream {
127    let args_struct_name = format_ident!("__{}Args", func_name);
128
129    let fields: Vec<_> = params.iter().map(|param| {
130        if let FnArg::Typed(pat_type) = param {
131            let pat = &pat_type.pat;
132            let ty = &pat_type.ty;
133            if let Pat::Ident(ident) = pat.as_ref() {
134                let field_name = &ident.ident;
135                quote! { pub #field_name: #ty }
136            } else {
137                quote! {}
138            }
139        } else {
140            quote! {}
141        }
142    }).collect();
143
144    quote! {
145        #[derive(::serde::Serialize, ::serde::Deserialize)]
146        struct #args_struct_name {
147            #(#fields,)*
148        }
149    }
150}
151
152fn generate_wrapper(
153    func_name: &Ident,
154    wrapper_name: &Ident,
155    params: Vec<&FnArg>,
156    is_async: bool,
157) -> proc_macro2::TokenStream {
158    let args_struct_name = format_ident!("__{}Args", func_name);
159
160    let field_names: Vec<_> = params.iter().filter_map(|param| {
161        if let FnArg::Typed(pat_type) = param {
162            let pat = &pat_type.pat;
163            if let Pat::Ident(ident) = pat.as_ref() {
164                Some(&ident.ident)
165            } else {
166                None
167            }
168        } else {
169            None
170        }
171    }).collect();
172
173    let call_args = if field_names.is_empty() {
174        quote! { ctx }
175    } else {
176        let args_unpack = field_names.iter().map(|name| {
177            quote! { args.#name }
178        });
179        quote! { ctx, #(#args_unpack),* }
180    };
181
182    if is_async {
183        quote! {
184            async fn #wrapper_name(
185                ctx: &memlink_msdk::CallContext<'_>,
186                args_bytes: &[u8],
187            ) -> memlink_msdk::Result<Vec<u8>> {
188                let args: #args_struct_name = memlink_msdk::serialize::default_serializer()
189                    .deserialize(args_bytes)
190                    .map_err(|e| memlink_msdk::ModuleError::Serialize(e.to_string()))?;
191
192                let result = #func_name(#call_args).await?;
193
194                memlink_msdk::serialize::default_serializer()
195                    .serialize(&result)
196                    .map_err(|e| memlink_msdk::ModuleError::Serialize(e.to_string()))
197            }
198        }
199    } else {
200        quote! {
201            fn #wrapper_name(
202                ctx: &memlink_msdk::CallContext<'_>,
203                args_bytes: &[u8],
204            ) -> memlink_msdk::Result<Vec<u8>> {
205                let args: #args_struct_name = memlink_msdk::serialize::default_serializer()
206                    .deserialize(args_bytes)
207                    .map_err(|e| memlink_msdk::ModuleError::Serialize(e.to_string()))?;
208
209                let result = #func_name(#call_args)?;
210
211                memlink_msdk::serialize::default_serializer()
212                    .serialize(&result)
213                    .map_err(|e| memlink_msdk::ModuleError::Serialize(e.to_string()))
214            }
215        }
216    }
217}
218
219fn generate_ffi_export(
220    wrapper_name: &Ident,
221    ffi_name: &Ident,
222    _method_hash: u32,
223    is_async: bool,
224) -> proc_macro2::TokenStream {
225    if is_async {
226        quote! {
227            #[no_mangle]
228            pub unsafe extern "C" fn #ffi_name(
229                ctx_ptr: *const memlink_msdk::CallContext<'static>,
230                args_ptr: *const u8,
231                args_len: usize,
232                out_ptr: *mut u8,
233                out_cap: usize,
234            ) -> i32 {
235                use memlink_msdk::panic::catch_module_panic;
236                use memlink_msdk::request::Response;
237
238                const CALL_SUCCESS: i32 = 0;
239                const CALL_FAILURE: i32 = -1;
240                const CALL_BUFFER_TOO_SMALL: i32 = -2;
241
242                if args_len > 0 && args_ptr.is_null() {
243                    return CALL_FAILURE;
244                }
245                if out_cap > 0 && out_ptr.is_null() {
246                    return CALL_FAILURE;
247                }
248
249                let result = catch_module_panic(|| {
250                    let ctx = unsafe { &*ctx_ptr };
251                    let args = if args_len > 0 {
252                        unsafe { std::slice::from_raw_parts(args_ptr, args_len) }.to_vec()
253                    } else {
254                        vec![]
255                    };
256
257                    let rt = tokio::runtime::Handle::current();
258                    let result = rt.block_on(#wrapper_name(ctx, &args));
259
260                    let response = match result {
261                        Ok(data) => Response::success(0, data),
262                        Err(_) => return CALL_FAILURE,
263                    };
264
265                    let response_bytes = match response.into_bytes() {
266                        Ok(bytes) => bytes,
267                        Err(_) => return CALL_FAILURE,
268                    };
269
270                    if response_bytes.len() > out_cap {
271                        return CALL_BUFFER_TOO_SMALL;
272                    }
273
274                    std::ptr::copy_nonoverlapping(
275                        response_bytes.as_ptr(),
276                        out_ptr,
277                        response_bytes.len(),
278                    );
279
280                    CALL_SUCCESS
281                });
282
283                match result {
284                    Ok(code) => code,
285                    Err(_) => CALL_FAILURE,
286                }
287            }
288        }
289    } else {
290        quote! {
291            #[no_mangle]
292            pub unsafe extern "C" fn #ffi_name(
293                ctx_ptr: *const memlink_msdk::CallContext<'static>,
294                args_ptr: *const u8,
295                args_len: usize,
296                out_ptr: *mut u8,
297                out_cap: usize,
298            ) -> i32 {
299                use memlink_msdk::panic::catch_module_panic;
300                use memlink_msdk::request::Response;
301
302                const CALL_SUCCESS: i32 = 0;
303                const CALL_FAILURE: i32 = -1;
304                const CALL_BUFFER_TOO_SMALL: i32 = -2;
305
306                if args_len > 0 && args_ptr.is_null() {
307                    return CALL_FAILURE;
308                }
309                if out_cap > 0 && out_ptr.is_null() {
310                    return CALL_FAILURE;
311                }
312
313                let result = catch_module_panic(|| {
314                    let ctx = unsafe { &*ctx_ptr };
315                    let args = if args_len > 0 {
316                        unsafe { std::slice::from_raw_parts(args_ptr, args_len) }.to_vec()
317                    } else {
318                        vec![]
319                    };
320
321                    let result = #wrapper_name(ctx, &args);
322
323                    let response = match result {
324                        Ok(data) => Response::success(0, data),
325                        Err(_) => return CALL_FAILURE,
326                    };
327
328                    let response_bytes = match response.into_bytes() {
329                        Ok(bytes) => bytes,
330                        Err(_) => return CALL_FAILURE,
331                    };
332
333                    if response_bytes.len() > out_cap {
334                        return CALL_BUFFER_TOO_SMALL;
335                    }
336
337                    std::ptr::copy_nonoverlapping(
338                        response_bytes.as_ptr(),
339                        out_ptr,
340                        response_bytes.len(),
341                    );
342
343                    CALL_SUCCESS
344                });
345
346                match result {
347                    Ok(code) => code,
348                    Err(_) => CALL_FAILURE,
349                }
350            }
351        }
352    }
353}
354
355fn generate_registration(
356    func_name: &Ident,
357    method_hash: u32,
358    _is_async: bool,
359) -> proc_macro2::TokenStream {
360    let register_func_name = format_ident!("__{}_register", func_name);
361
362    quote! {
363        #[used]
364        static #register_func_name: unsafe extern "C" fn() = {
365            unsafe extern "C" fn register() {
366            }
367            register
368        };
369
370        const _: () = {
371            const _HASH: u32 = #method_hash;
372        };
373    }
374}
375
376#[proc_macro_attribute]
377pub fn memlink_module(_args: TokenStream, input: TokenStream) -> TokenStream {
378    let item = parse_macro_input!(input as ItemStruct);
379    let _struct_name = &item.ident;
380
381    let expanded = quote! {
382        #item
383
384        #[no_mangle]
385        pub unsafe extern "C" fn memlink_init(
386            config_ptr: *const u8,
387            config_len: usize,
388            arena_ptr: *mut u8,
389            arena_capacity: usize,
390        ) -> i32 {
391            use memlink_msdk::exports::{init_arena, INIT_SUCCESS, INIT_FAILURE};
392
393            if !arena_ptr.is_null() && arena_capacity > 0 {
394                init_arena(arena_ptr, arena_capacity);
395            }
396
397            __register_all_methods();
398
399            INIT_SUCCESS
400        }
401
402        fn __register_all_methods() {
403        }
404    };
405
406    TokenStream::from(expanded)
407}