extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{quote, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::parse_macro_input;
use syn::spanned::Spanned;
#[proc_macro_attribute]
pub fn rline_bindgen(metadata: TokenStream, item: TokenStream) -> TokenStream {
let serialization_format = if !metadata.is_empty() {
parse_macro_input!(metadata as SerializationFormat)
} else {
SerializationFormat::Bincode
};
let mut ast: syn::ItemFn = syn::parse(item).unwrap();
let func_ident = ast.sig.ident;
let run_ident_name: String = format!("run_{}", func_ident);
let run_ident = Ident::new(run_ident_name.as_str(), Span::call_site());
ast.sig.ident = run_ident.clone();
if let Err(e) = check_param(&ast) {
return e;
};
if let Err(e) = check_return(&ast) {
return e;
};
let stdin_import = Ident::new(
format!("stdin_{}", run_ident_name).as_str(),
Span::call_site(),
);
let stdout_import = Ident::new(
format!("stdout_{}", run_ident_name).as_str(),
Span::call_site(),
);
let read_import = Ident::new(
format!("Read_{}", run_ident_name).as_str(),
Span::call_site(),
);
let write_import = Ident::new(
format!("Write_{}", run_ident_name).as_str(),
Span::call_site(),
);
let (ser_method, de_method) = match serialization_format {
SerializationFormat::Bincode => (
Ident::new("to_bytes_bincode_result", Span::call_site()),
Ident::new("from_bytes_bincode", Span::call_site()),
),
SerializationFormat::Json => (
Ident::new("to_bytes_json_result", Span::call_site()),
Ident::new("from_bytes_json", Span::call_site()),
),
};
let wrapper = quote! {
use std::io::{stdin as #stdin_import, stdout as #stdout_import, Read as #read_import, Write as #write_import};
#[no_mangle]
pub unsafe extern "C" fn #func_ident() {
let mut buf = Vec::new();
#stdin_import()
.read_to_end(&mut buf)
.expect("Wasm input read error");
let result = match Row::#de_method(&buf) {
Ok(row) => #run_ident(row), Err(e) => Err(format!("Wasm deserialization error : {}", e)),
};
#stdout_import().write_all(&[3]).unwrap(); let output_table = Row::#ser_method(result).expect("Wasm result serialization error");
#stdout_import()
.write_all(&output_table)
.expect("Wasm result write error");
#stdout_import().flush().expect("Wasm result I/O error");
}
};
let user_function_definition = ast.to_token_stream().to_string();
let wrapper_and_function = wrapper.to_string() + &user_function_definition;
wrapper_and_function.parse().unwrap()
}
fn check_return(ast: &syn::ItemFn) -> Result<(), TokenStream> {
let is_row_type = |seg: &syn::PathSegment| seg.ident.to_string().as_str() == "Row";
let is_string_type = |seg: &syn::PathSegment| seg.ident.to_string().as_str() == "String";
let invalid_type_error: fn(Span) -> TokenStream = |span: Span| {
syn::Error::new(span, "Unsupported return type, must be Result<Row, String>")
.to_compile_error()
.into()
};
if let syn::ReturnType::Type(_, ref rt) = ast.sig.output {
match &**rt {
syn::Type::Path(type_path) => {
let seg = &type_path.path.segments.first().unwrap();
let seg_type = seg.ident.to_string();
if seg_type == "Result" {
if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
if let syn::GenericArgument::Type(syn::Type::Path(arg_type_path)) =
args.args.first().unwrap()
{
let arg_seg = arg_type_path.path.segments.first().unwrap();
if !is_row_type(arg_seg) {
return Err(invalid_type_error(arg_seg.ident.span()));
}
} else {
return Err(invalid_type_error(seg.ident.span()));
}
if let syn::GenericArgument::Type(syn::Type::Path(arg_type_path)) =
args.args.last().unwrap()
{
let arg_seg = arg_type_path.path.segments.first().unwrap();
if !is_string_type(arg_seg) {
return Err(invalid_type_error(arg_seg.ident.span()));
}
} else {
return Err(invalid_type_error(seg.ident.span()));
}
}
} else {
return Err(invalid_type_error(seg.ident.span()));
}
}
_ => return Err(invalid_type_error(ast.sig.output.span())),
}
}
Ok(())
}
fn check_param(ast: &syn::ItemFn) -> Result<(), TokenStream> {
let invalid_type_error: fn(Span) -> TokenStream = |span: Span| {
syn::Error::new(span, "Unsupported parameter type, must be Row")
.to_compile_error()
.into()
};
let params_iter = ast.sig.inputs.iter();
if params_iter.len() != 1 {
return Err(syn::Error::new(
ast.sig.inputs.span(),
"Wrong number of parameters, must be one",
)
.to_compile_error()
.into());
};
for param in params_iter {
if let syn::FnArg::Typed(param_type) = param {
match &*param_type.ty {
syn::Type::Path(type_path) => {
let seg = &type_path.path.segments.first().unwrap();
if seg.ident.to_string().as_str() != "Row" {
return Err(invalid_type_error(seg.ident.span()));
}
}
_ => {
return Err(invalid_type_error(param_type.ty.span()));
}
}
}
}
Ok(())
}
enum SerializationFormat {
Bincode,
Json,
}
impl Parse for SerializationFormat {
fn parse(input: ParseStream) -> syn::Result<Self> {
let ident = Ident::parse(input)?;
match ident.to_string().to_uppercase().as_str() {
"BINCODE" => Ok(SerializationFormat::Bincode),
"JSON" => Ok(SerializationFormat::Json),
_ => Err(syn::Error::new(
ident.span(),
"Unsupported argument, must be one of bincode, json",
)),
}
}
}