binmod_mdk_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, ItemFn, ItemForeignMod, ForeignItem, Expr, FnArg, Pat, ReturnType, Ident};
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::Error;
7
8struct ModuleFnMacroArgs {
9    name: Option<Expr>,
10}
11
12impl Parse for ModuleFnMacroArgs {
13    fn parse(input: ParseStream) -> Result<Self, Error> {
14        let mut name = None;
15        let args = Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated(input)?;
16        
17        for arg in args {
18            if let syn::Meta::NameValue(nv) = arg {
19                if nv.path.is_ident("name") {
20                    name = Some(nv.value);
21                }
22            }
23        }
24        
25        Ok(ModuleFnMacroArgs { name })
26    }
27}
28
29struct HostFnMacroArgs {
30    namespace: Option<Expr>,
31}
32
33impl Parse for HostFnMacroArgs {
34    fn parse(input: ParseStream) -> Result<Self, Error> {
35        let mut namespace = None;
36        let args = Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated(input)?;
37        
38        for arg in args {
39            if let syn::Meta::NameValue(nv) = arg {
40                if nv.path.is_ident("namespace") {
41                    namespace = Some(nv.value);
42                }
43            }
44        }
45        
46        Ok(HostFnMacroArgs { namespace })
47    }
48}
49
50/// This macro is used to define a module function for the ABI.
51/// It takes an optional name argument to specify the exported function name.
52/// 
53/// # Examples
54/// 
55/// ```rust
56/// #[mod_fn(name = "my_function")]
57/// fn my_function(arg1: String, arg2: i32) -> FnResult<String, String> {
58///    Ok(format!("{} {}", arg1, arg2))
59/// }
60/// ```
61/// 
62/// This will generate a function wrapped in a compatible interface for usage
63/// with the module runtime.
64#[proc_macro_attribute]
65pub fn mod_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
66    let input_fn = parse_macro_input!(item as ItemFn);
67    let original_fn_name = &input_fn.sig.ident;
68    
69    let macro_args = match syn::parse::<ModuleFnMacroArgs>(attr) {
70        Ok(args) => args,
71        Err(e) => return e.to_compile_error().into(),
72    };
73    
74    let generated_fn_name = match macro_args.name {
75        Some(Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. })) => {
76            Ident::new(&lit_str.value(), lit_str.span())
77        },
78        _ => original_fn_name.clone(),
79    };
80
81    let mut arg_idents = Vec::new();
82    let mut fn_args = Vec::new();
83
84    for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
85        if let FnArg::Typed(pat_type) = arg {
86            if let Pat::Ident(pat_ident) = &*pat_type.pat {
87                let arg_name = &pat_ident.ident;
88                let arg_type = &pat_type.ty;
89
90                arg_idents.push(arg_name.clone());
91                
92                fn_args.push(quote! {
93                    let #arg_name = match input.get_arg::<#arg_type>(#i, stringify!(#arg_name)) {
94                        Ok(val) => val,
95                        Err(e) => return Err(::binmod_mdk::ModuleFnErr {
96                            message: e.to_string(),
97                            error_type: "ArgumentError".into(),
98                        }),
99                    };
100                });
101            }
102        }
103    }
104
105    let fn_return_type = match &input_fn.sig.output {
106        ReturnType::Default => quote! { ::binmod_mdk::ModuleFnResult::Data(
107            ::binmod_mdk::ModuleFnReturn::empty()
108        ) },
109        _ => quote! { ::binmod_mdk::ModuleFnResult::Data(
110            ::binmod_mdk::ModuleFnReturn::new_serialized(result).unwrap()
111        ) },
112    };
113
114    let arg_idents_tokens = arg_idents
115        .iter()
116        .map(|ident| quote! { #ident });
117
118    let expanded = quote! {
119        #input_fn
120        
121        #[unsafe(no_mangle)]
122        pub unsafe extern "C" fn #generated_fn_name(input_ptr: u32, input_len: u32) -> u64 {
123            let input: ::binmod_mdk::ModuleFnInput = match ::binmod_mdk::deserialize_from_ptr(input_ptr, input_len) {
124                Ok(input) => input,
125                Err(e) => {
126                    return match ::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
127                        ::binmod_mdk::ModuleFnErr {
128                            message: e.to_string(),
129                            error_type: "DeserializationError".into(),
130                        }
131                    )) {
132                        Ok(ptr) => ptr,
133                        Err(e) => ::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
134                            ::binmod_mdk::ModuleFnErr {
135                                message: e.to_string(),
136                                error_type: "SerializationError".into(),
137                            }
138                        )).unwrap_or(0),
139                    }
140                }
141            };
142                
143            let result = std::panic::catch_unwind(|| -> ::binmod_mdk::FnResult<_> {
144                #(#fn_args)*
145                #original_fn_name(#(#arg_idents_tokens),*)
146            });
147            
148            let response = match result {
149                Ok(Ok(result)) => #fn_return_type,
150                Ok(Err(e)) => ::binmod_mdk::ModuleFnResult::Error(e),
151                Err(_) => ::binmod_mdk::ModuleFnResult::Error(
152                    ::binmod_mdk::ModuleFnErr {
153                        message: "Panic occurred".into(),
154                        error_type: "PanicError".into(),
155                    }
156                ),
157            };
158
159            match ::binmod_mdk::serialize_to_ptr(response) {
160                Ok(ptr) => ptr,
161                Err(e) => {
162                    ::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
163                        ::binmod_mdk::ModuleFnErr {
164                            message: e.to_string(),
165                            error_type: "SerializationError".into(),
166                        }
167                    )).unwrap_or(0)
168                }
169            }
170        }
171    };
172    
173    TokenStream::from(expanded)
174}
175
176/// This macro is used to define the expected host functions that are accessible
177/// to the module. It takes an optional namespace argument to specify the
178/// namespace of the host functions. If not provided, it defaults to "env".
179/// 
180/// # Examples
181/// 
182/// ```rust
183/// #[host_fns(namespace = "env")]
184/// unsafe extern "host" {
185///    fn host_log(message: String);
186///    fn host_add(a: i32, b: i32) -> i32;
187/// }
188/// ```
189/// 
190/// This allows calling the host functions in the module code like this:
191/// 
192/// ```rust
193/// #[mod_fn(name = "my_func")]
194/// pub fn my_func() -> FnResult<()> {
195///     unsafe { host_log("Hello from the plugin!".to_string()) }
196///     let result = unsafe { host_add(1, 2) };
197///     println!("Result from host_add: {}", result);
198///     Ok(())
199/// }
200#[proc_macro_attribute]
201pub fn host_fns(attr: TokenStream, item: TokenStream) -> TokenStream {
202    let macro_args = match syn::parse::<HostFnMacroArgs>(attr) {
203        Ok(args) => args,
204        Err(e) => return e.to_compile_error().into(),
205    };
206
207    let namespace = match macro_args.namespace {
208        Some(Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. })) => lit_str,
209        _ => syn::LitStr::new("env", proc_macro2::Span::call_site()),
210    };
211
212
213    let item = parse_macro_input!(item as ItemForeignMod);
214    let functions = item.items;
215
216    if item.abi.name.is_none() || item.abi.name.unwrap().value() != "host" {
217        panic!("Host functions must be in a foreign module with the `host` ABI");
218    }
219
220    let mut generated = quote! {};
221
222    for function in functions {
223        if let ForeignItem::Fn(func) = function {
224            let func_name = &func.sig.ident;
225            let raw_func_name = Ident::new(&format!("{}_raw", func_name), func_name.span());
226            let link_name_lit = syn::LitStr::new(&func_name.to_string(), func_name.span());
227
228            let params = func
229                .sig
230                .inputs
231                .iter()
232                .cloned()
233                .collect::<Vec<_>>();
234
235            let param_names = params
236                .iter()
237                .map(|param| {
238                    if let FnArg::Typed(pat_type) = param {
239                        if let Pat::Ident(pat_ident) = &*pat_type.pat {
240                            &pat_ident.ident
241                        } else {
242                            panic!("Expected identifier in function argument");
243                        }
244                    } else {
245                        panic!("Expected typed argument in function signature");
246                    }
247                })
248                .collect::<Vec<_>>();
249
250            let inner_return_type = match &func.sig.output {
251                ReturnType::Default => quote! { () },
252                ReturnType::Type(_, ty) => quote! { #ty },
253            };
254
255            let wrapper = quote! {
256                #[allow(unused_unsafe)]
257                pub unsafe fn #func_name(#(#params),*) -> ::binmod_mdk::FnResult<#inner_return_type> {
258                    let mut input = ::binmod_mdk::ModuleFnInput::new();
259
260                    #(
261                        match input.add_arg(#param_names) {
262                            Ok(_) => {},
263                            Err(e) => {
264                                return Err(::binmod_mdk::ModuleFnErr {
265                                    message: e.to_string(),
266                                    error_type: "ArgumentError".into(),
267                                });
268                            }
269                        }
270                    )*
271
272                    let input_ptr = match ::binmod_mdk::serialize_to_ptr(input) {
273                        Ok(ptr) => ptr,
274                        Err(e) => {
275                            return Err(::binmod_mdk::ModuleFnErr {
276                                message: e.to_string(),
277                                error_type: "SerializationError".into(),
278                            });
279                        }
280                    };
281
282                    let result = unsafe { #raw_func_name(input_ptr) };
283                    let (result_ptr, result_len) = ::binmod_mdk::unpack_ptr(result);
284
285                    let result: ::binmod_mdk::ModuleFnResult<#inner_return_type> = match ::binmod_mdk::deserialize_from_ptr(
286                        result_ptr as u32,
287                        result_len as u32,
288                    ) {
289                        Ok(res) => res,
290                        Err(e) => {
291                            unsafe {
292                                host_dealloc(result_ptr as *mut u8, result_len as usize);
293                            }
294
295                            return Err(::binmod_mdk::ModuleFnErr {
296                                message: e.to_string(),
297                                error_type: "DeserializationError".into(),
298                            });
299                        }
300                    };
301
302                    unsafe {
303                        host_dealloc(result_ptr as *mut u8, result_len as usize);
304                    }
305
306                    match result {
307                        ::binmod_mdk::ModuleFnResult::Data(data) => {
308                            match data.value {
309                                Some(value) => Ok(value),
310                                None => Ok(Default::default()),
311                            }
312                        },
313                        ::binmod_mdk::ModuleFnResult::Error(err) => Err(err),
314                    }
315                }
316            };
317
318            generated.extend(wrapper);
319            generated.extend(quote! {
320                #[link(wasm_import_module = #namespace)]
321                unsafe extern "C" {
322                    #[link_name = #link_name_lit]
323                    pub fn #raw_func_name(input_ptr: u64) -> u64;
324                }
325            });
326        }
327    }
328
329    generated.extend(quote! {
330        #[link(wasm_import_module = #namespace)]
331        unsafe extern "C" {
332            pub fn host_alloc(len: usize) -> *mut u8;
333            pub fn host_dealloc(ptr: *mut u8, len: usize);
334        }
335    });
336
337    generated.into()
338}