#![deny(unsafe_code)]
#![warn(
missing_docs,
trivial_casts,
trivial_numeric_casts,
unused_import_braces,
unused_qualifications
)]
extern crate proc_macro;
use proc_macro::TokenStream;
use quote::quote;
use syn::{self, parse_macro_input, Error, Ident, ItemFn, ReturnType};
#[proc_macro_attribute]
pub fn node(_args: TokenStream, item: TokenStream) -> TokenStream {
let mut input_function = parse_macro_input!(item as ItemFn);
if input_function.sig.asyncness.is_none() {
let message = "a function with attribute '#[ockam::node]' must be declared as 'async'";
let token = input_function.sig.fn_token;
return Error::new_spanned(token, message).to_compile_error().into();
}
if input_function.sig.inputs.len() != 1 {
let message = "a function with '#[ockam::node]' must have exactly one argument";
let token = input_function.sig.fn_token;
return Error::new_spanned(token, message).to_compile_error().into();
}
let ctx_ident: &Ident;
let function_arg = &input_function.sig.inputs.first().unwrap();
if let syn::FnArg::Typed(syn::PatType {
attrs: _,
pat,
colon_token: _,
ty,
}) = function_arg
{
if let syn::Pat::Ident(syn::PatIdent {
attrs: _,
by_ref: _,
mutability: _,
ident,
subpat: _,
}) = &**pat
{
ctx_ident = ident;
} else {
let message = format!("Expected an identifier, found `{}`", quote! {#pat});
return Error::new_spanned(pat, message).to_compile_error().into();
};
if let syn::Type::Path(syn::TypePath { qself: _, path }) = &**ty {
let ident = path.segments.last();
if ident.is_none() {
let message = "Input argument should be of type `ockam::Context`";
return Error::new_spanned(path, message).to_compile_error().into();
} else {
let type_ident = quote! {#ident}.to_string();
if type_ident != "Context" {
let path_ident = quote! {#path}.to_string().replace(' ', "");
let message = format!("Expected `ockam::Context` found `{}`", path_ident);
return Error::new_spanned(path, message).to_compile_error().into();
}
}
}
if input_function.block.stmts.is_empty() {
let fn_ident = input_function.sig.ident;
let message = "Function body Cannot be Empty.";
return Error::new_spanned(fn_ident, message)
.to_compile_error()
.into();
}
let mut ctx_used = false;
for st in &input_function.block.stmts {
let stmt_str = quote! {#st}.to_string().replace(' ', "");
if stmt_str.contains(&ctx_ident.to_string()) {
ctx_used = true;
}
}
if !ctx_used {
let message = format!(
"Unused `{}`. Passed `ockam::Context` should be used.",
ctx_ident,
);
return Error::new_spanned(ctx_ident, message)
.to_compile_error()
.into();
}
} else {
let message = "Input argument should be of type `ockam::Context`";
return Error::new_spanned(function_arg, message)
.to_compile_error()
.into();
};
let output_fn_ident = Ident::new("trampoline", input_function.sig.ident.span());
input_function.sig.ident = output_fn_ident.clone();
let returns_unit = input_function.sig.output == ReturnType::Default;
let input_function_call = if returns_unit {
quote! {
#output_fn_ident(#ctx_ident).await;
}
} else {
quote! {
#output_fn_ident(#ctx_ident).await.unwrap();
}
};
#[cfg(not(feature = "no_main"))]
let output_function = quote! {
#[inline(always)]
#input_function
fn main() -> ockam::Result<()> {
let (#ctx_ident, mut executor) = ockam::start_node();
executor.execute(async move {
#input_function_call
})
}
};
#[cfg(feature = "no_main")]
let output_function = quote! {
#[inline(always)]
#input_function
fn ockam_async_main() -> ockam::Result<()> {
let (#ctx_ident, mut executor) = ockam::start_node();
executor.execute(async move {
#input_function_call
})
}
ockam_async_main().unwrap();
};
TokenStream::from(output_function)
}