rline_macro 1.0.0

A Rust procedural macro for generating WebAssembly stubs with customizable serialization formats.
Documentation
extern crate proc_macro;

use proc_macro::TokenStream;

use proc_macro2::{Ident, Span};
use quote::{quote, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::parse_macro_input;
use syn::spanned::Spanned;

/// A macro to generate WebAssembly stub.
/// You can choose a serialization format among "bincode", "json".
/// Default is bincode.
///
/// # Examples
///
/// ```
/// use rline_api::row::Row;
/// use rline_macro::rline_bindgen;
///
/// /// Use the default format for serialization.
/// #[rline_bindgen]
/// pub fn identity(row: Row) -> Result<Row, String> {
///     Ok(row)
/// }
///
/// /// Explicit tell to use bincode format for serialization.
/// #[rline_bindgen(bincode)]
/// pub fn identity_bincode(row: Row) -> Result<Row, String> {
///     Ok(row)
/// }
///
/// /// Explicit tell to use json format for serialization.
/// #[rline_bindgen(json)]
/// pub fn identity_json(row: Row) -> Result<Row, String> {
///     Ok(row)
/// }
/// ```
#[proc_macro_attribute]
pub fn rline_bindgen(metadata: TokenStream, item: TokenStream) -> TokenStream {
    let serialization_format = if !metadata.is_empty() {
        parse_macro_input!(metadata as SerializationFormat)
    } else {
        SerializationFormat::Bincode
    };

    let mut ast: syn::ItemFn = syn::parse(item).unwrap();

    // Rename the wrapped function prefixing its name by "run_".
    // Thus, the wrapper function is able to by named using the wrapped functon name.
    // For example, the function "fun" is renamed "run_fun" and is wrapped by a function "fun".
    let func_ident = ast.sig.ident;
    let run_ident_name: String = format!("run_{}", func_ident);
    let run_ident = Ident::new(run_ident_name.as_str(), Span::call_site());
    ast.sig.ident = run_ident.clone();

    if let Err(e) = check_param(&ast) {
        return e;
    };
    if let Err(e) = check_return(&ast) {
        return e;
    };

    let stdin_import = Ident::new(
        format!("stdin_{}", run_ident_name).as_str(),
        Span::call_site(),
    );
    let stdout_import = Ident::new(
        format!("stdout_{}", run_ident_name).as_str(),
        Span::call_site(),
    );
    let read_import = Ident::new(
        format!("Read_{}", run_ident_name).as_str(),
        Span::call_site(),
    );
    let write_import = Ident::new(
        format!("Write_{}", run_ident_name).as_str(),
        Span::call_site(),
    );

    let (ser_method, de_method) = match serialization_format {
        SerializationFormat::Bincode => (
            Ident::new("to_bytes_bincode_result", Span::call_site()),
            Ident::new("from_bytes_bincode", Span::call_site()),
        ),
        SerializationFormat::Json => (
            Ident::new("to_bytes_json_result", Span::call_site()),
            Ident::new("from_bytes_json", Span::call_site()),
        ),
    };

    let wrapper = quote! {
        use std::io::{stdin as #stdin_import, stdout as #stdout_import, Read as #read_import, Write as #write_import};

        #[no_mangle]
        pub unsafe extern "C" fn #func_ident() {
            // Read
            let mut buf = Vec::new();
            #stdin_import()
                .read_to_end(&mut buf)
                .expect("Wasm input read error");
            // Deserialize
            let result = match Row::#de_method(&buf) {
                Ok(row) => #run_ident(row), // Compute
                Err(e) => Err(format!("Wasm deserialization error : {}", e)),
            };
            // Write
            #stdout_import().write_all(&[3]).unwrap(); // End of text
            let output_table = Row::#ser_method(result).expect("Wasm result serialization error");
            #stdout_import()
                .write_all(&output_table)
                .expect("Wasm result write error");
            #stdout_import().flush().expect("Wasm result I/O error");
        }
    };

    let user_function_definition = ast.to_token_stream().to_string();
    let wrapper_and_function = wrapper.to_string() + &user_function_definition;
    wrapper_and_function.parse().unwrap()
}

fn check_return(ast: &syn::ItemFn) -> Result<(), TokenStream> {
    let is_row_type = |seg: &syn::PathSegment| seg.ident.to_string().as_str() == "Row";
    let is_string_type = |seg: &syn::PathSegment| seg.ident.to_string().as_str() == "String";
    let invalid_type_error: fn(Span) -> TokenStream = |span: Span| {
        syn::Error::new(span, "Unsupported return type, must be Result<Row, String>")
            .to_compile_error()
            .into()
    };

    if let syn::ReturnType::Type(_, ref rt) = ast.sig.output {
        match &**rt {
            syn::Type::Path(type_path) => {
                let seg = &type_path.path.segments.first().unwrap();
                let seg_type = seg.ident.to_string();
                if seg_type == "Result" {
                    if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
                        if let syn::GenericArgument::Type(syn::Type::Path(arg_type_path)) =
                            args.args.first().unwrap()
                        {
                            let arg_seg = arg_type_path.path.segments.first().unwrap();
                            if !is_row_type(arg_seg) {
                                return Err(invalid_type_error(arg_seg.ident.span()));
                            }
                        } else {
                            return Err(invalid_type_error(seg.ident.span()));
                        }

                        if let syn::GenericArgument::Type(syn::Type::Path(arg_type_path)) =
                            args.args.last().unwrap()
                        {
                            let arg_seg = arg_type_path.path.segments.first().unwrap();
                            if !is_string_type(arg_seg) {
                                return Err(invalid_type_error(arg_seg.ident.span()));
                            }
                        } else {
                            return Err(invalid_type_error(seg.ident.span()));
                        }
                    }
                } else {
                    return Err(invalid_type_error(seg.ident.span()));
                }
            }
            _ => return Err(invalid_type_error(ast.sig.output.span())),
        }
    }

    Ok(())
}

fn check_param(ast: &syn::ItemFn) -> Result<(), TokenStream> {
    let invalid_type_error: fn(Span) -> TokenStream = |span: Span| {
        syn::Error::new(span, "Unsupported parameter type, must be Row")
            .to_compile_error()
            .into()
    };

    let params_iter = ast.sig.inputs.iter();

    if params_iter.len() != 1 {
        return Err(syn::Error::new(
            ast.sig.inputs.span(),
            "Wrong number of parameters, must be one",
        )
        .to_compile_error()
        .into());
    };

    for param in params_iter {
        if let syn::FnArg::Typed(param_type) = param {
            match &*param_type.ty {
                syn::Type::Path(type_path) => {
                    let seg = &type_path.path.segments.first().unwrap();
                    if seg.ident.to_string().as_str() != "Row" {
                        return Err(invalid_type_error(seg.ident.span()));
                    }
                }
                _ => {
                    return Err(invalid_type_error(param_type.ty.span()));
                }
            }
        }
    }

    Ok(())
}

enum SerializationFormat {
    Bincode,
    Json,
}

impl Parse for SerializationFormat {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let ident = Ident::parse(input)?;
        match ident.to_string().to_uppercase().as_str() {
            "BINCODE" => Ok(SerializationFormat::Bincode),
            "JSON" => Ok(SerializationFormat::Json),
            _ => Err(syn::Error::new(
                ident.span(),
                "Unsupported argument, must be one of bincode, json",
            )),
        }
    }
}