1#![forbid(unsafe_code)]
27
28use proc_macro::TokenStream;
29use proc_macro2::TokenStream as TokenStream2;
30use quote::{format_ident, quote};
31use syn::{FnArg, ItemTrait, ReturnType, TraitItem, TraitItemFn, Type, parse_macro_input};
32
33struct ArgLowering {
36 owned_ty: TokenStream2,
38 call_expr: TokenStream2,
42 extra_binding: Option<TokenStream2>,
45}
46
47#[proc_macro_attribute]
58pub fn wagon(attrs: TokenStream, item: TokenStream) -> TokenStream {
59 let item_clone = item.clone();
60
61 let attrs2: TokenStream2 = attrs.into();
63 let identity_opt_out = attrs2.to_string().trim() == "identity";
64
65 if identity_opt_out {
66 return item_clone;
67 }
68
69 let parsed = parse_macro_input!(item as ItemTrait);
70
71 let Some(mode) = classify_trait(&parsed) else {
72 return item_clone;
74 };
75
76 match expand_trait(&parsed, mode) {
77 Ok(ts) => ts.into(),
78 Err(e) => e.to_compile_error().into(),
79 }
80}
81
82#[derive(Clone, Copy, PartialEq, Eq)]
84enum TraitMode {
85 Sync,
86 Async,
87}
88
89fn classify_trait(item: &ItemTrait) -> Option<TraitMode> {
94 let has_async_trait_attr = item.attrs.iter().any(|a| a.path().is_ident("async_trait"));
95
96 let mut all_methods_async = true;
97 let mut any_method_async = false;
98
99 for trait_item in &item.items {
100 let TraitItem::Fn(m) = trait_item else {
101 continue;
102 };
103
104 if m.sig.asyncness.is_some() {
105 any_method_async = true;
106 } else {
107 all_methods_async = false;
108 }
109
110 for input in &m.sig.inputs {
112 let FnArg::Typed(pat_type) = input else {
113 continue;
114 };
115 let pat = quote! { __dummy };
116 lower_arg_type(&pat_type.ty, &pat)?;
117 }
118
119 if let ReturnType::Type(_, ty) = &m.sig.output
121 && contains_reference(ty)
122 {
123 return None;
124 }
125 }
126
127 if has_async_trait_attr || any_method_async {
131 if !all_methods_async {
132 return None;
134 }
135 Some(TraitMode::Async)
136 } else {
137 Some(TraitMode::Sync)
138 }
139}
140
141fn lower_arg_type(ty: &Type, name: &TokenStream2) -> Option<ArgLowering> {
151 if let Type::Reference(r) = ty {
152 let inner = &*r.elem;
153 if is_str_path(inner) {
155 return Some(ArgLowering {
156 owned_ty: quote! { ::std::string::String },
157 call_expr: quote! { &#name },
158 extra_binding: None,
159 });
160 }
161 if let Type::Slice(slice) = inner {
163 if let Type::Reference(inner_ref) = &*slice.elem
165 && is_str_path(&inner_ref.elem)
166 {
167 let borrowed_ident =
168 format_ident!("__caravan_{}_borrowed", name.to_string().replace(' ', ""));
169 return Some(ArgLowering {
170 owned_ty: quote! { ::std::vec::Vec<::std::string::String> },
171 call_expr: quote! { &#borrowed_ident },
172 extra_binding: Some(quote! {
173 let #borrowed_ident: ::std::vec::Vec<&str> =
174 #name.iter().map(::std::string::String::as_str).collect();
175 }),
176 });
177 }
178 let elem_ty = &slice.elem;
181 if !contains_reference(elem_ty) {
182 return Some(ArgLowering {
183 owned_ty: quote! { ::std::vec::Vec<#elem_ty> },
184 call_expr: quote! { &#name },
185 extra_binding: None,
186 });
187 }
188 return None;
189 }
190 return None;
191 }
192 if contains_reference(ty) {
195 return None;
198 }
199 Some(ArgLowering {
200 owned_ty: quote! { #ty },
201 call_expr: quote! { #name },
202 extra_binding: None,
203 })
204}
205
206fn is_str_path(ty: &Type) -> bool {
208 if let Type::Path(p) = ty
209 && p.qself.is_none()
210 && let Some(last) = p.path.segments.last()
211 {
212 return last.ident == "str";
213 }
214 false
215}
216
217fn contains_reference(ty: &Type) -> bool {
223 match ty {
224 Type::Reference(_) => true,
225 Type::Slice(_) => true,
226 Type::Array(arr) => contains_reference(&arr.elem),
227 Type::Tuple(t) => t.elems.iter().any(contains_reference),
228 Type::Path(path) => {
229 for segment in &path.path.segments {
230 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
231 for arg in &args.args {
232 if let syn::GenericArgument::Type(inner) = arg
233 && contains_reference(inner)
234 {
235 return true;
236 }
237 }
238 }
239 }
240 false
241 }
242 Type::Paren(p) => contains_reference(&p.elem),
243 Type::Group(g) => contains_reference(&g.elem),
244 _ => false,
245 }
246}
247
248fn expand_trait(item: &ItemTrait, mode: TraitMode) -> syn::Result<TokenStream2> {
253 let trait_ident = &item.ident;
254 let vis = &item.vis;
255 let interface_str = trait_ident.to_string();
256 let client_struct = format_ident!("{}HttpClient", trait_ident);
257 let router_fn = format_ident!("build_{}_router", to_snake_case(&interface_str));
258
259 let mut client_methods: Vec<TokenStream2> = Vec::new();
260 let mut handler_bindings: Vec<TokenStream2> = Vec::new();
261 let mut router_chain: Vec<TokenStream2> = Vec::new();
262
263 for trait_item in &item.items {
264 let TraitItem::Fn(m) = trait_item else {
265 continue;
266 };
267 client_methods.push(emit_client_method(m, &interface_str, mode)?);
268 let (binding, method_str) = emit_server_handler(m, trait_ident, mode)?;
269 handler_bindings.push(binding);
270 let handler_ident = format_ident!("__caravan_handler_{}", method_str);
271 router_chain.push(quote! { .add_method(#method_str, #handler_ident) });
272 }
273
274 let async_trait_attr = match mode {
279 TraitMode::Sync => quote! {},
280 TraitMode::Async => quote! { #[::caravan_rpc::__macro_support::async_trait::async_trait] },
281 };
282
283 let out = quote! {
284 #item
286
287 #vis struct #client_struct {
289 base_url: ::std::string::String,
290 }
291
292 impl #client_struct {
293 #vis fn new(base_url: impl ::std::convert::Into<::std::string::String>) -> Self {
294 Self { base_url: base_url.into() }
295 }
296 }
297
298 #async_trait_attr
299 impl #trait_ident for #client_struct {
300 #(#client_methods)*
301 }
302
303 #vis fn #router_fn(
307 impl_arc: ::std::sync::Arc<dyn #trait_ident>,
308 ) -> ::caravan_rpc::__macro_support::axum::Router {
309 #(#handler_bindings)*
310 ::caravan_rpc::server::RpcRouter::new(#interface_str)
311 #(#router_chain)*
312 .into_axum_router(::caravan_rpc::peers::shared_secret())
313 }
314
315 ::caravan_rpc::__macro_support::inventory::submit! {
319 ::caravan_rpc::HttpAdapterFactory {
320 interface_name: #interface_str,
321 type_id_fn: || ::std::any::TypeId::of::<dyn #trait_ident>(),
322 construct: |__url: ::std::string::String|
323 -> ::std::boxed::Box<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync> {
324 let __adapter: ::std::sync::Arc<dyn #trait_ident> =
325 ::std::sync::Arc::new(#client_struct::new(__url));
326 ::std::boxed::Box::new(__adapter)
327 },
328 }
329 }
330
331 ::caravan_rpc::__macro_support::inventory::submit! {
337 ::caravan_rpc::HttpServerFactory {
338 interface_name: #interface_str,
339 build_router_from_registry: || {
340 let __impl = ::caravan_rpc::try_client::<dyn #trait_ident>()
341 .ok_or("no provide() call for this trait before run_or_serve")?;
342 Ok(#router_fn(__impl))
343 },
344 }
345 }
346 };
347
348 Ok(out)
349}
350
351fn emit_client_method(
355 m: &TraitItemFn,
356 interface: &str,
357 mode: TraitMode,
358) -> syn::Result<TokenStream2> {
359 let sig = &m.sig;
360 let method_str = sig.ident.to_string();
361 let mut arg_serializations: Vec<TokenStream2> = Vec::new();
362
363 for input in &sig.inputs {
364 if let FnArg::Typed(pat_type) = input {
365 let pat = &pat_type.pat;
366 arg_serializations.push(quote! {
367 ::caravan_rpc::__macro_support::serde_json::to_value(&#pat).expect("caravan-rpc: arg serialize")
368 });
369 }
370 }
371
372 let dispatch_call = match mode {
373 TraitMode::Sync => quote! {
374 ::caravan_rpc::dispatch::dispatch_sync(
375 &self.base_url, #interface, #method_str, __args
376 ).expect("caravan-rpc: dispatch_sync")
377 },
378 TraitMode::Async => quote! {
379 ::caravan_rpc::dispatch::dispatch_async(
380 &self.base_url, #interface, #method_str, __args
381 ).await.expect("caravan-rpc: dispatch_async")
382 },
383 };
384
385 let body = quote! {
386 let __args: ::std::vec::Vec<::caravan_rpc::__macro_support::serde_json::Value> = vec![ #(#arg_serializations),* ];
387 let __v = #dispatch_call;
388 ::caravan_rpc::__macro_support::serde_json::from_value(__v).expect("caravan-rpc: deserialize return")
389 };
390
391 let block: syn::Block = syn::parse2(quote! { { #body } })?;
392 let mut m = m.clone();
393 m.default = Some(block);
394 m.semi_token = None;
395 Ok(quote! { #m })
396}
397
398fn emit_server_handler(
402 m: &TraitItemFn,
403 trait_ident: &syn::Ident,
404 mode: TraitMode,
405) -> syn::Result<(TokenStream2, String)> {
406 let sig = &m.sig;
407 let method_ident = &sig.ident;
408 let method_str = method_ident.to_string();
409 let handler_ident = format_ident!("__caravan_handler_{}", method_str);
410
411 let mut decode_blocks: Vec<TokenStream2> = Vec::new();
416 let mut call_args: Vec<TokenStream2> = Vec::new();
417 let mut idx: usize = 0;
418 for input in &sig.inputs {
419 if let FnArg::Typed(pat_type) = input {
420 let pat = &pat_type.pat;
421 let pat_tokens = quote! { #pat };
422 let arg_name = pat_tokens.to_string();
423 let lowering =
424 lower_arg_type(&pat_type.ty, &pat_tokens).expect("is_sync_owned_trait gates this");
425 let owned_ty = &lowering.owned_ty;
426 let idx_lit = idx;
427 let extra = lowering.extra_binding.unwrap_or_default();
428 decode_blocks.push(quote! {
429 let #pat: #owned_ty = match __env.args.get(#idx_lit) {
430 ::std::option::Option::Some(__val) => {
431 match ::caravan_rpc::__macro_support::serde_json::from_value(__val.clone()) {
432 ::std::result::Result::Ok(__t) => __t,
433 ::std::result::Result::Err(__e) => {
434 return ::caravan_rpc::codec::Response::err(
435 format!("BadArg({})", #arg_name),
436 __e.to_string(),
437 );
438 }
439 }
440 }
441 ::std::option::Option::None => {
442 return ::caravan_rpc::codec::Response::err(
443 format!("MissingArg({})", #arg_name),
444 format!("expected args[{}]", #idx_lit),
445 );
446 }
447 };
448 #extra
449 });
450 call_args.push(lowering.call_expr);
451 idx += 1;
452 }
453 }
454
455 let impl_call = match mode {
456 TraitMode::Sync => quote! {
457 <dyn #trait_ident>::#method_ident(&*__impl_arc #(, #call_args)*)
458 },
459 TraitMode::Async => quote! {
460 <dyn #trait_ident>::#method_ident(&*__impl_arc #(, #call_args)*).await
461 },
462 };
463
464 let body = quote! {
465 let #handler_ident: ::caravan_rpc::server::MethodHandler = {
466 let __impl_arc = impl_arc.clone();
467 ::std::sync::Arc::new(move |__body: ::caravan_rpc::__macro_support::axum::body::Bytes| {
468 let __impl_arc = __impl_arc.clone();
469 ::std::boxed::Box::pin(async move {
470 let __env: ::caravan_rpc::codec::Request = match ::caravan_rpc::__macro_support::serde_json::from_slice(&__body) {
471 ::std::result::Result::Ok(__e) => __e,
472 ::std::result::Result::Err(__e) => {
473 return ::caravan_rpc::codec::Response::err(
474 "BadJSON",
475 __e.to_string(),
476 );
477 }
478 };
479 #(#decode_blocks)*
480 let __result = #impl_call;
481 match ::caravan_rpc::__macro_support::serde_json::to_value(&__result) {
482 ::std::result::Result::Ok(__v) => ::caravan_rpc::codec::Response::ok(__v),
483 ::std::result::Result::Err(__e) => ::caravan_rpc::codec::Response::err(
484 "EncodeError",
485 __e.to_string(),
486 ),
487 }
488 })
489 })
490 };
491 };
492
493 Ok((body, method_str))
494}
495
496fn to_snake_case(s: &str) -> String {
500 let mut out = String::with_capacity(s.len() + 4);
501 for (i, ch) in s.chars().enumerate() {
502 if ch.is_uppercase() {
503 if i > 0 {
504 out.push('_');
505 }
506 for low in ch.to_lowercase() {
507 out.push(low);
508 }
509 } else {
510 out.push(ch);
511 }
512 }
513 out
514}