Skip to main content

caravan_rpc_macros/
lib.rs

1//! Procedural macros for the Caravan RPC SDK.
2//!
3//! The single attribute macro `#[wagon]` marks a trait as a seam interface.
4//! Behaviour depends on the trait shape:
5//!
6//! * **Sync trait** (no `#[async_trait]`, no `async fn`) — args may be
7//!   owned (`String`, `Vec<T>`) or borrowed in the narrow set the macro
8//!   knows how to lower (`&str` → `String`, `&[T]` → `Vec<T>`, `&[&str]`
9//!   → `Vec<String>`). Emits `<Trait>HttpClient` + `impl <Trait> for
10//!   <Trait>HttpClient` (calls go over HTTP via `dispatch::dispatch_sync`)
11//!   + `build_<trait_snake>_router(impl_arc)` axum router builder.
12//!
13//! * **Anything else** (async-trait, async fn, exotic arg types) → expand
14//!   to the trait unchanged (identity behaviour, same as the B0p macro).
15//!   Async-trait support lands later in M2 Session 4.
16//!
17//! Code-rag's traits at Session 4-narrow:
18//! * Embedder (sync + `&str` + `&[&str]`) → full codegen, suitable for
19//!   `dev-split-light`'s `Embedder: container` mode flip.
20//! * Reranker (sync + `&str` + `Vec<String>`) → would be full codegen
21//!   except for the third-party `fastembed::RerankResult` return type
22//!   lacking serde; left identity until M5.
23//! * LlmClient + VectorReader (`#[async_trait]`) → identity until Session
24//!   4-async lands.
25
26#![forbid(unsafe_code)]
27
28use proc_macro::TokenStream;
29use proc_macro2::TokenStream as TokenStream2;
30use quote::{format_ident, quote};
31use syn::{FnArg, ItemTrait, ReturnType, TraitItem, TraitItemFn, Type, parse_macro_input};
32
33/// Result of analyzing a method arg's type — what owned form to decode it
34/// as on the server, and how to re-borrow when calling the impl method.
35struct ArgLowering {
36    /// Owned form used as the local variable type in the server handler.
37    owned_ty: TokenStream2,
38    /// Expression used when calling the impl method. Either `name` for
39    /// pass-by-value, `&name` for `&T`, or a custom expression like
40    /// `&borrowed` for the `&[&str]` case.
41    call_expr: TokenStream2,
42    /// Optional extra binding emitted *before* the impl call. Used by
43    /// `&[&str]` to build `Vec<&str>` from the decoded `Vec<String>`.
44    extra_binding: Option<TokenStream2>,
45}
46
47/// Mark a trait as a Caravan RPC seam interface.
48///
49/// Accepts one optional argument:
50/// * `#[wagon]` — default: full HTTP codegen if the trait shape is
51///   supported (sync, supported arg/return types). Otherwise identity.
52/// * `#[wagon(identity)]` — explicit opt-out: emit identity regardless.
53///   Use for traits whose types aren't yet wire-ready (e.g.,
54///   third-party non-serde return types like `fastembed::RerankResult`).
55///   Transitional — the goal is to remove this flag once all wagon
56///   traits in the project are wire-ready.
57#[proc_macro_attribute]
58pub fn wagon(attrs: TokenStream, item: TokenStream) -> TokenStream {
59    let item_clone = item.clone();
60
61    // Parse opt-out attribute: `#[wagon(identity)]`.
62    let attrs2: TokenStream2 = attrs.into();
63    let identity_opt_out = attrs2.to_string().trim() == "identity";
64
65    if identity_opt_out {
66        return item_clone;
67    }
68
69    let parsed = parse_macro_input!(item as ItemTrait);
70
71    let Some(mode) = classify_trait(&parsed) else {
72        // Fallback: identity expansion. The trait is emitted unchanged.
73        return item_clone;
74    };
75
76    match expand_trait(&parsed, mode) {
77        Ok(ts) => ts.into(),
78        Err(e) => e.to_compile_error().into(),
79    }
80}
81
82/// Trait shape recognized by the macro for full HTTP codegen.
83#[derive(Clone, Copy, PartialEq, Eq)]
84enum TraitMode {
85    Sync,
86    Async,
87}
88
89/// Classify a trait. Returns `Some(Sync)` if every method is sync and
90/// types lower correctly; `Some(Async)` if `#[async_trait]` is present
91/// (every method then expected to be `async fn`) and types lower
92/// correctly; `None` otherwise (identity fallback).
93fn classify_trait(item: &ItemTrait) -> Option<TraitMode> {
94    let has_async_trait_attr = item.attrs.iter().any(|a| a.path().is_ident("async_trait"));
95
96    let mut all_methods_async = true;
97    let mut any_method_async = false;
98
99    for trait_item in &item.items {
100        let TraitItem::Fn(m) = trait_item else {
101            continue;
102        };
103
104        if m.sig.asyncness.is_some() {
105            any_method_async = true;
106        } else {
107            all_methods_async = false;
108        }
109
110        // Every arg type must be lowerable.
111        for input in &m.sig.inputs {
112            let FnArg::Typed(pat_type) = input else {
113                continue;
114            };
115            let pat = quote! { __dummy };
116            lower_arg_type(&pat_type.ty, &pat)?;
117        }
118
119        // No borrowed types in return.
120        if let ReturnType::Type(_, ty) = &m.sig.output
121            && contains_reference(ty)
122        {
123            return None;
124        }
125    }
126
127    // Decide sync vs async. The wire dispatcher (`dispatch_sync` vs
128    // `dispatch_async`) is selected per the whole trait — mixing sync
129    // and async methods in one #[wagon] trait isn't supported.
130    if has_async_trait_attr || any_method_async {
131        if !all_methods_async {
132            // Mixed shape — bail to identity.
133            return None;
134        }
135        Some(TraitMode::Async)
136    } else {
137        Some(TraitMode::Sync)
138    }
139}
140
141/// Decide how to decode a method arg from the wire and how to pass it to
142/// the impl method. Returns `None` for shapes the Session-4-narrow macro
143/// doesn't support (e.g., `&CustomStruct`, function pointers).
144///
145/// Supported lowerings:
146/// * `&str` → decode `String`, call `&name`
147/// * `&[&str]` → decode `Vec<String>`, then build `Vec<&str>`, call `&name_ref`
148/// * `&[T]` (T owned) → decode `Vec<T>`, call `&name`
149/// * Otherwise (owned T) → decode `T`, call `name`
150fn lower_arg_type(ty: &Type, name: &TokenStream2) -> Option<ArgLowering> {
151    if let Type::Reference(r) = ty {
152        let inner = &*r.elem;
153        // `&str` case.
154        if is_str_path(inner) {
155            return Some(ArgLowering {
156                owned_ty: quote! { ::std::string::String },
157                call_expr: quote! { &#name },
158                extra_binding: None,
159            });
160        }
161        // `&[T]` case — the referenced type is a slice.
162        if let Type::Slice(slice) = inner {
163            // Special-case `&[&str]` → decode Vec<String>, build Vec<&str>.
164            if let Type::Reference(inner_ref) = &*slice.elem
165                && is_str_path(&inner_ref.elem)
166            {
167                let borrowed_ident =
168                    format_ident!("__caravan_{}_borrowed", name.to_string().replace(' ', ""));
169                return Some(ArgLowering {
170                    owned_ty: quote! { ::std::vec::Vec<::std::string::String> },
171                    call_expr: quote! { &#borrowed_ident },
172                    extra_binding: Some(quote! {
173                        let #borrowed_ident: ::std::vec::Vec<&str> =
174                            #name.iter().map(::std::string::String::as_str).collect();
175                    }),
176                });
177            }
178            // `&[T]` where T isn't itself a reference. The owned form is
179            // `Vec<T>`, and we pass it via `&name` (deref coerces to `&[T]`).
180            let elem_ty = &slice.elem;
181            if !contains_reference(elem_ty) {
182                return Some(ArgLowering {
183                    owned_ty: quote! { ::std::vec::Vec<#elem_ty> },
184                    call_expr: quote! { &#name },
185                    extra_binding: None,
186                });
187            }
188            return None;
189        }
190        return None;
191    }
192    // Owned type — no borrow logic needed. The owned form is the type
193    // itself; the call expression is just the name.
194    if contains_reference(ty) {
195        // e.g., `Vec<&str>` as an owned-looking type that nevertheless
196        // borrows — can't deserialize without lifetimes.
197        return None;
198    }
199    Some(ArgLowering {
200        owned_ty: quote! { #ty },
201        call_expr: quote! { #name },
202        extra_binding: None,
203    })
204}
205
206/// Whether a type is `str` (the unsized variant of `&str`).
207fn is_str_path(ty: &Type) -> bool {
208    if let Type::Path(p) = ty
209        && p.qself.is_none()
210        && let Some(last) = p.path.segments.last()
211    {
212        return last.ident == "str";
213    }
214    false
215}
216
217/// Recursively check whether a type contains a reference (`&T`, `&mut T`).
218/// We only descend into generic arguments via `Type::Path` since that's the
219/// common case (`Result<&str, _>`, `Vec<&[u8]>`, etc.); other oddball types
220/// (function pointers, trait objects with explicit lifetimes) are rare in
221/// seam trait signatures and not worth handling at Session 3.
222fn contains_reference(ty: &Type) -> bool {
223    match ty {
224        Type::Reference(_) => true,
225        Type::Slice(_) => true,
226        Type::Array(arr) => contains_reference(&arr.elem),
227        Type::Tuple(t) => t.elems.iter().any(contains_reference),
228        Type::Path(path) => {
229            for segment in &path.path.segments {
230                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
231                    for arg in &args.args {
232                        if let syn::GenericArgument::Type(inner) = arg
233                            && contains_reference(inner)
234                        {
235                            return true;
236                        }
237                    }
238                }
239            }
240            false
241        }
242        Type::Paren(p) => contains_reference(&p.elem),
243        Type::Group(g) => contains_reference(&g.elem),
244        _ => false,
245    }
246}
247
248/// Expand a wagon trait into trait + HttpClient + router builder.
249/// Behavior varies by `mode`:
250/// * Sync — `impl Trait for <Trait>HttpClient { fn ... }` using `dispatch_sync`.
251/// * Async — `#[async_trait] impl Trait for <Trait>HttpClient { async fn ... }` using `dispatch_async`.
252fn expand_trait(item: &ItemTrait, mode: TraitMode) -> syn::Result<TokenStream2> {
253    let trait_ident = &item.ident;
254    let vis = &item.vis;
255    let interface_str = trait_ident.to_string();
256    let client_struct = format_ident!("{}HttpClient", trait_ident);
257    let router_fn = format_ident!("build_{}_router", to_snake_case(&interface_str));
258
259    let mut client_methods: Vec<TokenStream2> = Vec::new();
260    let mut handler_bindings: Vec<TokenStream2> = Vec::new();
261    let mut router_chain: Vec<TokenStream2> = Vec::new();
262
263    for trait_item in &item.items {
264        let TraitItem::Fn(m) = trait_item else {
265            continue;
266        };
267        client_methods.push(emit_client_method(m, &interface_str, mode)?);
268        let (binding, method_str) = emit_server_handler(m, trait_ident, mode)?;
269        handler_bindings.push(binding);
270        let handler_ident = format_ident!("__caravan_handler_{}", method_str);
271        router_chain.push(quote! { .add_method(#method_str, #handler_ident) });
272    }
273
274    // For async traits the impl needs `#[async_trait::async_trait]` so
275    // each `async fn` becomes a regular `fn -> Pin<Box<...>>`. We pull
276    // the macro from the `__macro_support` re-export so the user
277    // doesn't need an explicit `async-trait` dep.
278    let async_trait_attr = match mode {
279        TraitMode::Sync => quote! {},
280        TraitMode::Async => quote! { #[::caravan_rpc::__macro_support::async_trait::async_trait] },
281    };
282
283    let out = quote! {
284        // Original trait, emitted unchanged.
285        #item
286
287        // HTTP-client adapter: dispatches each method call over the wire.
288        #vis struct #client_struct {
289            base_url: ::std::string::String,
290        }
291
292        impl #client_struct {
293            #vis fn new(base_url: impl ::std::convert::Into<::std::string::String>) -> Self {
294                Self { base_url: base_url.into() }
295            }
296        }
297
298        #async_trait_attr
299        impl #trait_ident for #client_struct {
300            #(#client_methods)*
301        }
302
303        // Builder: wraps a registered impl into an axum Router for the peer
304        // service to serve. Reads CARAVAN_RPC_SHARED_SECRET at call time so
305        // the bearer-auth check matches what the client side sends.
306        #vis fn #router_fn(
307            impl_arc: ::std::sync::Arc<dyn #trait_ident>,
308        ) -> ::caravan_rpc::__macro_support::axum::Router {
309            #(#handler_bindings)*
310            ::caravan_rpc::server::RpcRouter::new(#interface_str)
311                #(#router_chain)*
312                .into_axum_router(::caravan_rpc::peers::shared_secret())
313        }
314
315        // Inventory registration: lets `caravan_rpc::client::<dyn Trait>()`
316        // discover this trait's HttpClient constructor at runtime when the
317        // peer table marks the interface as http-mode.
318        ::caravan_rpc::__macro_support::inventory::submit! {
319            ::caravan_rpc::HttpAdapterFactory {
320                interface_name: #interface_str,
321                type_id_fn: || ::std::any::TypeId::of::<dyn #trait_ident>(),
322                construct: |__url: ::std::string::String|
323                    -> ::std::boxed::Box<dyn ::std::any::Any + ::std::marker::Send + ::std::marker::Sync> {
324                    let __adapter: ::std::sync::Arc<dyn #trait_ident> =
325                        ::std::sync::Arc::new(#client_struct::new(__url));
326                    ::std::boxed::Box::new(__adapter)
327                },
328            }
329        }
330
331        // Server-side inventory registration: lets
332        // `caravan_rpc::run_or_serve` discover this trait's server router
333        // builder at runtime when CARAVAN_RPC_ROLE=peer-<Trait> is set.
334        // The closure does the trait-erased work: registry lookup + router
335        // build with the macro-emitted `build_<trait>_router`.
336        ::caravan_rpc::__macro_support::inventory::submit! {
337            ::caravan_rpc::HttpServerFactory {
338                interface_name: #interface_str,
339                build_router_from_registry: || {
340                    let __impl = ::caravan_rpc::try_client::<dyn #trait_ident>()
341                        .ok_or("no provide() call for this trait before run_or_serve")?;
342                    Ok(#router_fn(__impl))
343                },
344            }
345        }
346    };
347
348    Ok(out)
349}
350
351/// Emit one method body for the HttpClient's `impl Trait for` block.
352/// Body shape depends on `mode`: sync uses blocking `dispatch_sync`,
353/// async uses `dispatch_async(...).await`.
354fn emit_client_method(
355    m: &TraitItemFn,
356    interface: &str,
357    mode: TraitMode,
358) -> syn::Result<TokenStream2> {
359    let sig = &m.sig;
360    let method_str = sig.ident.to_string();
361    let mut arg_serializations: Vec<TokenStream2> = Vec::new();
362
363    for input in &sig.inputs {
364        if let FnArg::Typed(pat_type) = input {
365            let pat = &pat_type.pat;
366            arg_serializations.push(quote! {
367                ::caravan_rpc::__macro_support::serde_json::to_value(&#pat).expect("caravan-rpc: arg serialize")
368            });
369        }
370    }
371
372    let dispatch_call = match mode {
373        TraitMode::Sync => quote! {
374            ::caravan_rpc::dispatch::dispatch_sync(
375                &self.base_url, #interface, #method_str, __args
376            ).expect("caravan-rpc: dispatch_sync")
377        },
378        TraitMode::Async => quote! {
379            ::caravan_rpc::dispatch::dispatch_async(
380                &self.base_url, #interface, #method_str, __args
381            ).await.expect("caravan-rpc: dispatch_async")
382        },
383    };
384
385    let body = quote! {
386        let __args: ::std::vec::Vec<::caravan_rpc::__macro_support::serde_json::Value> = vec![ #(#arg_serializations),* ];
387        let __v = #dispatch_call;
388        ::caravan_rpc::__macro_support::serde_json::from_value(__v).expect("caravan-rpc: deserialize return")
389    };
390
391    let block: syn::Block = syn::parse2(quote! { { #body } })?;
392    let mut m = m.clone();
393    m.default = Some(block);
394    m.semi_token = None;
395    Ok(quote! { #m })
396}
397
398/// Emit one MethodHandler binding for the server-side router builder.
399/// Returns the `let __caravan_handler_<method> = ...;` token stream and
400/// the method name (as the string used in path routing + .add_method).
401fn emit_server_handler(
402    m: &TraitItemFn,
403    trait_ident: &syn::Ident,
404    mode: TraitMode,
405) -> syn::Result<(TokenStream2, String)> {
406    let sig = &m.sig;
407    let method_ident = &sig.ident;
408    let method_str = method_ident.to_string();
409    let handler_ident = format_ident!("__caravan_handler_{}", method_str);
410
411    // For each typed arg, emit a decode block (decoding into the OWNED
412    // form, even if the trait method takes a borrowed type) plus a call
413    // expression that re-borrows where needed. `lower_arg_type` owns this
414    // translation; we just call it here.
415    let mut decode_blocks: Vec<TokenStream2> = Vec::new();
416    let mut call_args: Vec<TokenStream2> = Vec::new();
417    let mut idx: usize = 0;
418    for input in &sig.inputs {
419        if let FnArg::Typed(pat_type) = input {
420            let pat = &pat_type.pat;
421            let pat_tokens = quote! { #pat };
422            let arg_name = pat_tokens.to_string();
423            let lowering =
424                lower_arg_type(&pat_type.ty, &pat_tokens).expect("is_sync_owned_trait gates this");
425            let owned_ty = &lowering.owned_ty;
426            let idx_lit = idx;
427            let extra = lowering.extra_binding.unwrap_or_default();
428            decode_blocks.push(quote! {
429                let #pat: #owned_ty = match __env.args.get(#idx_lit) {
430                    ::std::option::Option::Some(__val) => {
431                        match ::caravan_rpc::__macro_support::serde_json::from_value(__val.clone()) {
432                            ::std::result::Result::Ok(__t) => __t,
433                            ::std::result::Result::Err(__e) => {
434                                return ::caravan_rpc::codec::Response::err(
435                                    format!("BadArg({})", #arg_name),
436                                    __e.to_string(),
437                                );
438                            }
439                        }
440                    }
441                    ::std::option::Option::None => {
442                        return ::caravan_rpc::codec::Response::err(
443                            format!("MissingArg({})", #arg_name),
444                            format!("expected args[{}]", #idx_lit),
445                        );
446                    }
447                };
448                #extra
449            });
450            call_args.push(lowering.call_expr);
451            idx += 1;
452        }
453    }
454
455    let impl_call = match mode {
456        TraitMode::Sync => quote! {
457            <dyn #trait_ident>::#method_ident(&*__impl_arc #(, #call_args)*)
458        },
459        TraitMode::Async => quote! {
460            <dyn #trait_ident>::#method_ident(&*__impl_arc #(, #call_args)*).await
461        },
462    };
463
464    let body = quote! {
465        let #handler_ident: ::caravan_rpc::server::MethodHandler = {
466            let __impl_arc = impl_arc.clone();
467            ::std::sync::Arc::new(move |__body: ::caravan_rpc::__macro_support::axum::body::Bytes| {
468                let __impl_arc = __impl_arc.clone();
469                ::std::boxed::Box::pin(async move {
470                    let __env: ::caravan_rpc::codec::Request = match ::caravan_rpc::__macro_support::serde_json::from_slice(&__body) {
471                        ::std::result::Result::Ok(__e) => __e,
472                        ::std::result::Result::Err(__e) => {
473                            return ::caravan_rpc::codec::Response::err(
474                                "BadJSON",
475                                __e.to_string(),
476                            );
477                        }
478                    };
479                    #(#decode_blocks)*
480                    let __result = #impl_call;
481                    match ::caravan_rpc::__macro_support::serde_json::to_value(&__result) {
482                        ::std::result::Result::Ok(__v) => ::caravan_rpc::codec::Response::ok(__v),
483                        ::std::result::Result::Err(__e) => ::caravan_rpc::codec::Response::err(
484                            "EncodeError",
485                            __e.to_string(),
486                        ),
487                    }
488                })
489            })
490        };
491    };
492
493    Ok((body, method_str))
494}
495
496/// Convert PascalCase / CamelCase to snake_case for the router builder
497/// function name (e.g. `Embedder` → `embedder`, `VectorReader` →
498/// `vector_reader`).
499fn to_snake_case(s: &str) -> String {
500    let mut out = String::with_capacity(s.len() + 4);
501    for (i, ch) in s.chars().enumerate() {
502        if ch.is_uppercase() {
503            if i > 0 {
504                out.push('_');
505            }
506            for low in ch.to_lowercase() {
507                out.push(low);
508            }
509        } else {
510            out.push(ch);
511        }
512    }
513    out
514}