binmod-mdk-macros 0.1.4

Binmod MDK for Rust
Documentation
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn, ItemForeignMod, ForeignItem, Expr, FnArg, Pat, ReturnType, Ident};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::Error;

struct ModuleFnMacroArgs {
    name: Option<Expr>,
}

impl Parse for ModuleFnMacroArgs {
    fn parse(input: ParseStream) -> Result<Self, Error> {
        let mut name = None;
        let args = Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated(input)?;
        
        for arg in args {
            if let syn::Meta::NameValue(nv) = arg {
                if nv.path.is_ident("name") {
                    name = Some(nv.value);
                }
            }
        }
        
        Ok(ModuleFnMacroArgs { name })
    }
}

struct HostFnMacroArgs {
    namespace: Option<Expr>,
}

impl Parse for HostFnMacroArgs {
    fn parse(input: ParseStream) -> Result<Self, Error> {
        let mut namespace = None;
        let args = Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated(input)?;
        
        for arg in args {
            if let syn::Meta::NameValue(nv) = arg {
                if nv.path.is_ident("namespace") {
                    namespace = Some(nv.value);
                }
            }
        }
        
        Ok(HostFnMacroArgs { namespace })
    }
}

/// This macro is used to define a module function for the ABI.
/// It takes an optional name argument to specify the exported function name.
/// 
/// # Examples
/// 
/// ```rust
/// #[mod_fn(name = "my_function")]
/// fn my_function(arg1: String, arg2: i32) -> FnResult<String, String> {
///    Ok(format!("{} {}", arg1, arg2))
/// }
/// ```
/// 
/// This will generate a function wrapped in a compatible interface for usage
/// with the module runtime.
#[proc_macro_attribute]
pub fn mod_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
    let input_fn = parse_macro_input!(item as ItemFn);
    let original_fn_name = &input_fn.sig.ident;
    
    let macro_args = match syn::parse::<ModuleFnMacroArgs>(attr) {
        Ok(args) => args,
        Err(e) => return e.to_compile_error().into(),
    };
    
    let generated_fn_name = match macro_args.name {
        Some(Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. })) => {
            Ident::new(&lit_str.value(), lit_str.span())
        },
        _ => original_fn_name.clone(),
    };

    let mut arg_idents = Vec::new();
    let mut fn_args = Vec::new();

    for (i, arg) in input_fn.sig.inputs.iter().enumerate() {
        if let FnArg::Typed(pat_type) = arg {
            if let Pat::Ident(pat_ident) = &*pat_type.pat {
                let arg_name = &pat_ident.ident;
                let arg_type = &pat_type.ty;

                arg_idents.push(arg_name.clone());
                
                fn_args.push(quote! {
                    let #arg_name = match input.get_arg::<#arg_type>(#i, stringify!(#arg_name)) {
                        Ok(val) => val,
                        Err(e) => return Err(::binmod_mdk::ModuleFnErr {
                            message: e.to_string(),
                            error_type: "ArgumentError".into(),
                        }),
                    };
                });
            }
        }
    }

    let fn_return_type = match &input_fn.sig.output {
        ReturnType::Default => quote! { ::binmod_mdk::ModuleFnResult::Data(
            ::binmod_mdk::ModuleFnReturn::empty()
        ) },
        _ => quote! { ::binmod_mdk::ModuleFnResult::Data(
            ::binmod_mdk::ModuleFnReturn::new_serialized(result).unwrap()
        ) },
    };

    let arg_idents_tokens = arg_idents
        .iter()
        .map(|ident| quote! { #ident });

    let expanded = quote! {
        #input_fn
        
        #[unsafe(no_mangle)]
        pub unsafe extern "C" fn #generated_fn_name(input_ptr: u32, input_len: u32) -> u64 {
            let input: ::binmod_mdk::ModuleFnInput = match ::binmod_mdk::deserialize_from_ptr(input_ptr, input_len) {
                Ok(input) => input,
                Err(e) => {
                    return match ::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
                        ::binmod_mdk::ModuleFnErr {
                            message: e.to_string(),
                            error_type: "DeserializationError".into(),
                        }
                    )) {
                        Ok(ptr) => ptr,
                        Err(e) => ::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
                            ::binmod_mdk::ModuleFnErr {
                                message: e.to_string(),
                                error_type: "SerializationError".into(),
                            }
                        )).unwrap_or(0),
                    }
                }
            };
                
            let result = std::panic::catch_unwind(|| -> ::binmod_mdk::FnResult<_> {
                #(#fn_args)*
                #original_fn_name(#(#arg_idents_tokens),*)
            });
            
            let response = match result {
                Ok(Ok(result)) => #fn_return_type,
                Ok(Err(e)) => ::binmod_mdk::ModuleFnResult::Error(e),
                Err(_) => ::binmod_mdk::ModuleFnResult::Error(
                    ::binmod_mdk::ModuleFnErr {
                        message: "Panic occurred".into(),
                        error_type: "PanicError".into(),
                    }
                ),
            };

            match ::binmod_mdk::serialize_to_ptr(response) {
                Ok(ptr) => ptr,
                Err(e) => {
                    ::binmod_mdk::serialize_to_ptr(::binmod_mdk::ModuleFnResult::<()>::Error(
                        ::binmod_mdk::ModuleFnErr {
                            message: e.to_string(),
                            error_type: "SerializationError".into(),
                        }
                    )).unwrap_or(0)
                }
            }
        }
    };
    
    TokenStream::from(expanded)
}

/// This macro is used to define the expected host functions that are accessible
/// to the module. It takes an optional namespace argument to specify the
/// namespace of the host functions. If not provided, it defaults to "env".
/// 
/// # Examples
/// 
/// ```rust
/// #[host_fns(namespace = "env")]
/// unsafe extern "host" {
///    fn host_log(message: String);
///    fn host_add(a: i32, b: i32) -> i32;
/// }
/// ```
/// 
/// This allows calling the host functions in the module code like this:
/// 
/// ```rust
/// #[mod_fn(name = "my_func")]
/// pub fn my_func() -> FnResult<()> {
///     unsafe { host_log("Hello from the plugin!".to_string()) }
///     let result = unsafe { host_add(1, 2) };
///     println!("Result from host_add: {}", result);
///     Ok(())
/// }
#[proc_macro_attribute]
pub fn host_fns(attr: TokenStream, item: TokenStream) -> TokenStream {
    let macro_args = match syn::parse::<HostFnMacroArgs>(attr) {
        Ok(args) => args,
        Err(e) => return e.to_compile_error().into(),
    };

    let namespace = match macro_args.namespace {
        Some(Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. })) => lit_str,
        _ => syn::LitStr::new("env", proc_macro2::Span::call_site()),
    };


    let item = parse_macro_input!(item as ItemForeignMod);
    let functions = item.items;

    if item.abi.name.is_none() || item.abi.name.unwrap().value() != "host" {
        panic!("Host functions must be in a foreign module with the `host` ABI");
    }

    let mut generated = quote! {};

    for function in functions {
        if let ForeignItem::Fn(func) = function {
            let func_name = &func.sig.ident;
            let raw_func_name = Ident::new(&format!("{}_raw", func_name), func_name.span());
            let link_name_lit = syn::LitStr::new(&func_name.to_string(), func_name.span());

            let params = func
                .sig
                .inputs
                .iter()
                .cloned()
                .collect::<Vec<_>>();

            let param_names = params
                .iter()
                .map(|param| {
                    if let FnArg::Typed(pat_type) = param {
                        if let Pat::Ident(pat_ident) = &*pat_type.pat {
                            &pat_ident.ident
                        } else {
                            panic!("Expected identifier in function argument");
                        }
                    } else {
                        panic!("Expected typed argument in function signature");
                    }
                })
                .collect::<Vec<_>>();

            let inner_return_type = match &func.sig.output {
                ReturnType::Default => quote! { () },
                ReturnType::Type(_, ty) => quote! { #ty },
            };

            let wrapper = quote! {
                #[allow(unused_unsafe)]
                pub unsafe fn #func_name(#(#params),*) -> ::binmod_mdk::FnResult<#inner_return_type> {
                    let mut input = ::binmod_mdk::ModuleFnInput::new();

                    #(
                        match input.add_arg(#param_names) {
                            Ok(_) => {},
                            Err(e) => {
                                return Err(::binmod_mdk::ModuleFnErr {
                                    message: e.to_string(),
                                    error_type: "ArgumentError".into(),
                                });
                            }
                        }
                    )*

                    let input_ptr = match ::binmod_mdk::serialize_to_ptr(input) {
                        Ok(ptr) => ptr,
                        Err(e) => {
                            return Err(::binmod_mdk::ModuleFnErr {
                                message: e.to_string(),
                                error_type: "SerializationError".into(),
                            });
                        }
                    };

                    let result = unsafe { #raw_func_name(input_ptr) };
                    let (result_ptr, result_len) = ::binmod_mdk::unpack_ptr(result);

                    let result: ::binmod_mdk::ModuleFnResult<#inner_return_type> = match ::binmod_mdk::deserialize_from_ptr(
                        result_ptr as u32,
                        result_len as u32,
                    ) {
                        Ok(res) => res,
                        Err(e) => {
                            unsafe {
                                host_dealloc(result_ptr as *mut u8, result_len as usize);
                            }

                            return Err(::binmod_mdk::ModuleFnErr {
                                message: e.to_string(),
                                error_type: "DeserializationError".into(),
                            });
                        }
                    };

                    unsafe {
                        host_dealloc(result_ptr as *mut u8, result_len as usize);
                    }

                    match result {
                        ::binmod_mdk::ModuleFnResult::Data(data) => {
                            match data.value {
                                Some(value) => Ok(value),
                                None => Ok(Default::default()),
                            }
                        },
                        ::binmod_mdk::ModuleFnResult::Error(err) => Err(err),
                    }
                }
            };

            generated.extend(wrapper);
            generated.extend(quote! {
                #[link(wasm_import_module = #namespace)]
                unsafe extern "C" {
                    #[link_name = #link_name_lit]
                    pub fn #raw_func_name(input_ptr: u64) -> u64;
                }
            });
        }
    }

    generated.extend(quote! {
        #[link(wasm_import_module = #namespace)]
        unsafe extern "C" {
            pub fn host_alloc(len: usize) -> *mut u8;
            pub fn host_dealloc(ptr: *mut u8, len: usize);
        }
    });

    generated.into()
}