1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, ItemFn, ItemForeignMod, ForeignItem, Expr, FnArg, Pat, ReturnType, Ident};
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::Error;
7
8struct ModuleFnMacroArgs {
9 name: Option<Expr>,
10}
11
12impl Parse for ModuleFnMacroArgs {
13 fn parse(input: ParseStream) -> Result<Self, Error> {
14 let mut name = None;
15 let args = Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated(input)?;
16
17 for arg in args {
18 if let syn::Meta::NameValue(nv) = arg {
19 if nv.path.is_ident("name") {
20 name = Some(nv.value);
21 }
22 }
23 }
24
25 Ok(ModuleFnMacroArgs { name })
26 }
27}
28
29struct HostFnMacroArgs {
30 namespace: Option<Expr>,
31}
32
33impl Parse for HostFnMacroArgs {
34 fn parse(input: ParseStream) -> Result<Self, Error> {
35 let mut namespace = None;
36 let args = Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated(input)?;
37
38 for arg in args {
39 if let syn::Meta::NameValue(nv) = arg {
40 if nv.path.is_ident("namespace") {
41 namespace = Some(nv.value);
42 }
43 }
44 }
45
46 Ok(HostFnMacroArgs { namespace })
47 }
48}
49
50#[proc_macro_attribute]
65pub fn mod_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
66 let input_fn = parse_macro_input!(item as ItemFn);
67 let original_fn_name = &input_fn.sig.ident;
68
69 let macro_args = match syn::parse::<ModuleFnMacroArgs>(attr) {
70 Ok(args) => args,
71 Err(e) => return e.to_compile_error().into(),
72 };
73
74 let generated_fn_name = match macro_args.name {
75 Some(Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. })) => {
76 Ident::new(&lit_str.value(), lit_str.span())
77 },
78 _ => original_fn_name.clone(),
79 };
80
81 let mut arg_idents = Vec::new();
82 let mut fn_args = Vec::new();
83
84 for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
85 if let FnArg::Typed(pat_type) = arg {
86 if let Pat::Ident(pat_ident) = &*pat_type.pat {
87 let arg_name = &pat_ident.ident;
88 let arg_type = &pat_type.ty;
89
90 arg_idents.push(arg_name.clone());
91
92 fn_args.push(quote! {
93 let #arg_name = match input.get_arg::<#arg_type>(#i, stringify!(#arg_name)) {
94 Ok(val) => val,
95 Err(e) => return Err(::binmod_mdk::ModuleFnErr {
96 message: e.to_string(),
97 error_type: "ArgumentError".into(),
98 }),
99 };
100 });
101 }
102 }
103 }
104
105 let fn_return_type = match &input_fn.sig.output {
106 ReturnType::Default => quote! { ::binmod_mdk::ModuleFnResult::Data(
107 ::binmod_mdk::ModuleFnReturn::empty()
108 ) },
109 _ => quote! { ::binmod_mdk::ModuleFnResult::Data(
110 ::binmod_mdk::ModuleFnReturn::new_serialized(result).unwrap()
111 ) },
112 };
113
114 let arg_idents_tokens = arg_idents
115 .iter()
116 .map(|ident| quote! { #ident });
117
118 let expanded = quote! {
119 #input_fn
120
121 #[unsafe(no_mangle)]
122 pub unsafe extern "C" fn #generated_fn_name(input_ptr: u32, input_len: u32) -> u64 {
123 let input: ::binmod_mdk::ModuleFnInput = match ::binmod_mdk::deserialize_from_ptr(input_ptr, input_len) {
124 Ok(input) => input,
125 Err(e) => {
126 return match ::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
127 ::binmod_mdk::ModuleFnErr {
128 message: e.to_string(),
129 error_type: "DeserializationError".into(),
130 }
131 )) {
132 Ok(ptr) => ptr,
133 Err(e) => ::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
134 ::binmod_mdk::ModuleFnErr {
135 message: e.to_string(),
136 error_type: "SerializationError".into(),
137 }
138 )).unwrap_or(0),
139 }
140 }
141 };
142
143 let result = std::panic::catch_unwind(|| -> ::binmod_mdk::FnResult<_> {
144 #(#fn_args)*
145 #original_fn_name(#(#arg_idents_tokens),*)
146 });
147
148 let response = match result {
149 Ok(Ok(result)) => #fn_return_type,
150 Ok(Err(e)) => ::binmod_mdk::ModuleFnResult::Error(e),
151 Err(_) => ::binmod_mdk::ModuleFnResult::Error(
152 ::binmod_mdk::ModuleFnErr {
153 message: "Panic occurred".into(),
154 error_type: "PanicError".into(),
155 }
156 ),
157 };
158
159 match ::binmod_mdk::serialize_to_ptr(response) {
160 Ok(ptr) => ptr,
161 Err(e) => {
162 ::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
163 ::binmod_mdk::ModuleFnErr {
164 message: e.to_string(),
165 error_type: "SerializationError".into(),
166 }
167 )).unwrap_or(0)
168 }
169 }
170 }
171 };
172
173 TokenStream::from(expanded)
174}
175
176#[proc_macro_attribute]
201pub fn host_fns(attr: TokenStream, item: TokenStream) -> TokenStream {
202 let macro_args = match syn::parse::<HostFnMacroArgs>(attr) {
203 Ok(args) => args,
204 Err(e) => return e.to_compile_error().into(),
205 };
206
207 let namespace = match macro_args.namespace {
208 Some(Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. })) => lit_str,
209 _ => syn::LitStr::new("env", proc_macro2::Span::call_site()),
210 };
211
212
213 let item = parse_macro_input!(item as ItemForeignMod);
214 let functions = item.items;
215
216 if item.abi.name.is_none() || item.abi.name.unwrap().value() != "host" {
217 panic!("Host functions must be in a foreign module with the `host` ABI");
218 }
219
220 let mut generated = quote! {};
221
222 for function in functions {
223 if let ForeignItem::Fn(func) = function {
224 let func_name = &func.sig.ident;
225 let raw_func_name = Ident::new(&format!("{}_raw", func_name), func_name.span());
226 let link_name_lit = syn::LitStr::new(&func_name.to_string(), func_name.span());
227
228 let params = func
229 .sig
230 .inputs
231 .iter()
232 .cloned()
233 .collect::<Vec<_>>();
234
235 let param_names = params
236 .iter()
237 .map(|param| {
238 if let FnArg::Typed(pat_type) = param {
239 if let Pat::Ident(pat_ident) = &*pat_type.pat {
240 &pat_ident.ident
241 } else {
242 panic!("Expected identifier in function argument");
243 }
244 } else {
245 panic!("Expected typed argument in function signature");
246 }
247 })
248 .collect::<Vec<_>>();
249
250 let inner_return_type = match &func.sig.output {
251 ReturnType::Default => quote! { () },
252 ReturnType::Type(_, ty) => quote! { #ty },
253 };
254
255 let wrapper = quote! {
256 #[allow(unused_unsafe)]
257 pub unsafe fn #func_name(#(#params),*) -> ::binmod_mdk::FnResult<#inner_return_type> {
258 let mut input = ::binmod_mdk::ModuleFnInput::new();
259
260 #(
261 match input.add_arg(#param_names) {
262 Ok(_) => {},
263 Err(e) => {
264 return Err(::binmod_mdk::ModuleFnErr {
265 message: e.to_string(),
266 error_type: "ArgumentError".into(),
267 });
268 }
269 }
270 )*
271
272 let input_ptr = match ::binmod_mdk::serialize_to_ptr(input) {
273 Ok(ptr) => ptr,
274 Err(e) => {
275 return Err(::binmod_mdk::ModuleFnErr {
276 message: e.to_string(),
277 error_type: "SerializationError".into(),
278 });
279 }
280 };
281
282 let result = unsafe { #raw_func_name(input_ptr) };
283 let (result_ptr, result_len) = ::binmod_mdk::unpack_ptr(result);
284
285 let result: ::binmod_mdk::ModuleFnResult<#inner_return_type> = match ::binmod_mdk::deserialize_from_ptr(
286 result_ptr as u32,
287 result_len as u32,
288 ) {
289 Ok(res) => res,
290 Err(e) => {
291 unsafe {
292 host_dealloc(result_ptr as *mut u8, result_len as usize);
293 }
294
295 return Err(::binmod_mdk::ModuleFnErr {
296 message: e.to_string(),
297 error_type: "DeserializationError".into(),
298 });
299 }
300 };
301
302 unsafe {
303 host_dealloc(result_ptr as *mut u8, result_len as usize);
304 }
305
306 match result {
307 ::binmod_mdk::ModuleFnResult::Data(data) => {
308 match data.value {
309 Some(value) => Ok(value),
310 None => Ok(Default::default()),
311 }
312 },
313 ::binmod_mdk::ModuleFnResult::Error(err) => Err(err),
314 }
315 }
316 };
317
318 generated.extend(wrapper);
319 generated.extend(quote! {
320 #[link(wasm_import_module = #namespace)]
321 unsafe extern "C" {
322 #[link_name = #link_name_lit]
323 pub fn #raw_func_name(input_ptr: u64) -> u64;
324 }
325 });
326 }
327 }
328
329 generated.extend(quote! {
330 #[link(wasm_import_module = #namespace)]
331 unsafe extern "C" {
332 pub fn host_alloc(len: usize) -> *mut u8;
333 pub fn host_dealloc(ptr: *mut u8, len: usize);
334 }
335 });
336
337 generated.into()
338}