#![forbid(unsafe_code)]
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{FnArg, ItemTrait, ReturnType, TraitItem, TraitItemFn, Type, parse_macro_input};
struct ArgLowering {
owned_ty: TokenStream2,
call_expr: TokenStream2,
extra_binding: Option<TokenStream2>,
}
#[proc_macro_attribute]
pub fn wagon(attrs: TokenStream, item: TokenStream) -> TokenStream {
let item_clone = item.clone();
let attrs2: TokenStream2 = attrs.into();
let identity_opt_out = attrs2.to_string().trim() == "identity";
if identity_opt_out {
return item_clone;
}
let parsed = parse_macro_input!(item as ItemTrait);
let Some(mode) = classify_trait(&parsed) else {
return item_clone;
};
match expand_trait(&parsed, mode) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum TraitMode {
Sync,
Async,
}
fn classify_trait(item: &ItemTrait) -> Option<TraitMode> {
let has_async_trait_attr = item.attrs.iter().any(|a| a.path().is_ident("async_trait"));
let mut all_methods_async = true;
let mut any_method_async = false;
for trait_item in &item.items {
let TraitItem::Fn(m) = trait_item else {
continue;
};
if m.sig.asyncness.is_some() {
any_method_async = true;
} else {
all_methods_async = false;
}
for input in &m.sig.inputs {
let FnArg::Typed(pat_type) = input else {
continue;
};
let pat = quote! { __dummy };
lower_arg_type(&pat_type.ty, &pat)?;
}
if let ReturnType::Type(_, ty) = &m.sig.output
&& contains_reference(ty)
{
return None;
}
}
if has_async_trait_attr || any_method_async {
if !all_methods_async {
return None;
}
Some(TraitMode::Async)
} else {
Some(TraitMode::Sync)
}
}
fn lower_arg_type(ty: &Type, name: &TokenStream2) -> Option<ArgLowering> {
if let Type::Reference(r) = ty {
let inner = &*r.elem;
if is_str_path(inner) {
return Some(ArgLowering {
owned_ty: quote! { ::std::string::String },
call_expr: quote! { &#name },
extra_binding: None,
});
}
if let Type::Slice(slice) = inner {
if let Type::Reference(inner_ref) = &*slice.elem
&& is_str_path(&inner_ref.elem)
{
let borrowed_ident =
format_ident!("__caravan_{}_borrowed", name.to_string().replace(' ', ""));
return Some(ArgLowering {
owned_ty: quote! { ::std::vec::Vec<::std::string::String> },
call_expr: quote! { &#borrowed_ident },
extra_binding: Some(quote! {
let #borrowed_ident: ::std::vec::Vec<&str> =
#name.iter().map(::std::string::String::as_str).collect();
}),
});
}
let elem_ty = &slice.elem;
if !contains_reference(elem_ty) {
return Some(ArgLowering {
owned_ty: quote! { ::std::vec::Vec<#elem_ty> },
call_expr: quote! { &#name },
extra_binding: None,
});
}
return None;
}
return None;
}
if contains_reference(ty) {
return None;
}
Some(ArgLowering {
owned_ty: quote! { #ty },
call_expr: quote! { #name },
extra_binding: None,
})
}
fn is_str_path(ty: &Type) -> bool {
if let Type::Path(p) = ty
&& p.qself.is_none()
&& let Some(last) = p.path.segments.last()
{
return last.ident == "str";
}
false
}
fn contains_reference(ty: &Type) -> bool {
match ty {
Type::Reference(_) => true,
Type::Slice(_) => true,
Type::Array(arr) => contains_reference(&arr.elem),
Type::Tuple(t) => t.elems.iter().any(contains_reference),
Type::Path(path) => {
for segment in &path.path.segments {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Type(inner) = arg
&& contains_reference(inner)
{
return true;
}
}
}
}
false
}
Type::Paren(p) => contains_reference(&p.elem),
Type::Group(g) => contains_reference(&g.elem),
_ => false,
}
}
fn expand_trait(item: &ItemTrait, mode: TraitMode) -> syn::Result<TokenStream2> {
let trait_ident = &item.ident;
let vis = &item.vis;
let interface_str = trait_ident.to_string();
let client_struct = format_ident!("{}HttpClient", trait_ident);
let router_fn = format_ident!("build_{}_router", to_snake_case(&interface_str));
let mut client_methods: Vec<TokenStream2> = Vec::new();
let mut handler_bindings: Vec<TokenStream2> = Vec::new();
let mut router_chain: Vec<TokenStream2> = Vec::new();
for trait_item in &item.items {
let TraitItem::Fn(m) = trait_item else {
continue;
};
client_methods.push(emit_client_method(m, &interface_str, mode)?);
let (binding, method_str) = emit_server_handler(m, trait_ident, mode)?;
handler_bindings.push(binding);
let handler_ident = format_ident!("__caravan_handler_{}", method_str);
router_chain.push(quote! { .add_method(#method_str, #handler_ident) });
}
let async_trait_attr = match mode {
TraitMode::Sync => quote! {},
TraitMode::Async => quote! { #[::caravan_rpc::__macro_support::async_trait::async_trait] },
};
let out = quote! {
#item
#vis struct #client_struct {
base_url: ::std::string::String,
}
impl #client_struct {
#vis fn new(base_url: impl ::std::convert::Into<::std::string::String>) -> Self {
Self { base_url: base_url.into() }
}
}
#async_trait_attr
impl #trait_ident for #client_struct {
#(#client_methods)*
}
#vis fn #router_fn(
impl_arc: ::std::sync::Arc<dyn #trait_ident>,
) -> ::caravan_rpc::__macro_support::axum::Router {
#(#handler_bindings)*
::caravan_rpc::server::RpcRouter::new(#interface_str)
#(#router_chain)*
.into_axum_router(::caravan_rpc::peers::shared_secret())
}
::caravan_rpc::__macro_support::inventory::submit! {
::caravan_rpc::HttpAdapterFactory {
interface_name: #interface_str,
type_id_fn: || ::std::any::TypeId::of::<dyn #trait_ident>(),
construct: |__url: ::std::string::String|
-> ::std::boxed::Box<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync> {
let __adapter: ::std::sync::Arc<dyn #trait_ident> =
::std::sync::Arc::new(#client_struct::new(__url));
::std::boxed::Box::new(__adapter)
},
}
}
::caravan_rpc::__macro_support::inventory::submit! {
::caravan_rpc::HttpServerFactory {
interface_name: #interface_str,
build_router_from_registry: || {
let __impl = ::caravan_rpc::try_client::<dyn #trait_ident>()
.ok_or("no provide() call for this trait before run_or_serve")?;
Ok(#router_fn(__impl))
},
}
}
};
Ok(out)
}
fn emit_client_method(
m: &TraitItemFn,
interface: &str,
mode: TraitMode,
) -> syn::Result<TokenStream2> {
let sig = &m.sig;
let method_str = sig.ident.to_string();
let mut arg_serializations: Vec<TokenStream2> = Vec::new();
for input in &sig.inputs {
if let FnArg::Typed(pat_type) = input {
let pat = &pat_type.pat;
arg_serializations.push(quote! {
::caravan_rpc::__macro_support::serde_json::to_value(&#pat).expect("caravan-rpc: arg serialize")
});
}
}
let dispatch_call = match mode {
TraitMode::Sync => quote! {
::caravan_rpc::dispatch::dispatch_sync(
&self.base_url, #interface, #method_str, __args
).expect("caravan-rpc: dispatch_sync")
},
TraitMode::Async => quote! {
::caravan_rpc::dispatch::dispatch_async(
&self.base_url, #interface, #method_str, __args
).await.expect("caravan-rpc: dispatch_async")
},
};
let body = quote! {
let __args: ::std::vec::Vec<::caravan_rpc::__macro_support::serde_json::Value> = vec![ #(#arg_serializations),* ];
let __v = #dispatch_call;
::caravan_rpc::__macro_support::serde_json::from_value(__v).expect("caravan-rpc: deserialize return")
};
let block: syn::Block = syn::parse2(quote! { { #body } })?;
let mut m = m.clone();
m.default = Some(block);
m.semi_token = None;
Ok(quote! { #m })
}
fn emit_server_handler(
m: &TraitItemFn,
trait_ident: &syn::Ident,
mode: TraitMode,
) -> syn::Result<(TokenStream2, String)> {
let sig = &m.sig;
let method_ident = &sig.ident;
let method_str = method_ident.to_string();
let handler_ident = format_ident!("__caravan_handler_{}", method_str);
let mut decode_blocks: Vec<TokenStream2> = Vec::new();
let mut call_args: Vec<TokenStream2> = Vec::new();
let mut idx: usize = 0;
for input in &sig.inputs {
if let FnArg::Typed(pat_type) = input {
let pat = &pat_type.pat;
let pat_tokens = quote! { #pat };
let arg_name = pat_tokens.to_string();
let lowering =
lower_arg_type(&pat_type.ty, &pat_tokens).expect("is_sync_owned_trait gates this");
let owned_ty = &lowering.owned_ty;
let idx_lit = idx;
let extra = lowering.extra_binding.unwrap_or_default();
decode_blocks.push(quote! {
let #pat: #owned_ty = match __env.args.get(#idx_lit) {
::std::option::Option::Some(__val) => {
match ::caravan_rpc::__macro_support::serde_json::from_value(__val.clone()) {
::std::result::Result::Ok(__t) => __t,
::std::result::Result::Err(__e) => {
return ::caravan_rpc::codec::Response::err(
format!("BadArg({})", #arg_name),
__e.to_string(),
);
}
}
}
::std::option::Option::None => {
return ::caravan_rpc::codec::Response::err(
format!("MissingArg({})", #arg_name),
format!("expected args[{}]", #idx_lit),
);
}
};
#extra
});
call_args.push(lowering.call_expr);
idx += 1;
}
}
let impl_call = match mode {
TraitMode::Sync => quote! {
<dyn #trait_ident>::#method_ident(&*__impl_arc #(, #call_args)*)
},
TraitMode::Async => quote! {
<dyn #trait_ident>::#method_ident(&*__impl_arc #(, #call_args)*).await
},
};
let body = quote! {
let #handler_ident: ::caravan_rpc::server::MethodHandler = {
let __impl_arc = impl_arc.clone();
::std::sync::Arc::new(move |__body: ::caravan_rpc::__macro_support::axum::body::Bytes| {
let __impl_arc = __impl_arc.clone();
::std::boxed::Box::pin(async move {
let __env: ::caravan_rpc::codec::Request = match ::caravan_rpc::__macro_support::serde_json::from_slice(&__body) {
::std::result::Result::Ok(__e) => __e,
::std::result::Result::Err(__e) => {
return ::caravan_rpc::codec::Response::err(
"BadJSON",
__e.to_string(),
);
}
};
#(#decode_blocks)*
let __result = #impl_call;
match ::caravan_rpc::__macro_support::serde_json::to_value(&__result) {
::std::result::Result::Ok(__v) => ::caravan_rpc::codec::Response::ok(__v),
::std::result::Result::Err(__e) => ::caravan_rpc::codec::Response::err(
"EncodeError",
__e.to_string(),
),
}
})
})
};
};
Ok((body, method_str))
}
fn to_snake_case(s: &str) -> String {
let mut out = String::with_capacity(s.len() + 4);
for (i, ch) in s.chars().enumerate() {
if ch.is_uppercase() {
if i > 0 {
out.push('_');
}
for low in ch.to_lowercase() {
out.push(low);
}
} else {
out.push(ch);
}
}
out
}