Skip to main content

jigs_macros/
lib.rs

1//! Procedural macros for the `jigs` framework.
2//!
3//! `#[jig]` marks a function as a pipeline step. It emits a zero-sized
4//! marker struct implementing `JigDef` alongside the (possibly
5//! transformed) function body. The marker struct is named
6//! `__Jig_<fn_name>` to avoid namespace collisions with the function
7//! itself. With the `trace` feature it additionally wraps the body in a
8//! thread-local trace recorder.
9
10use proc_macro::TokenStream;
11use proc_macro2::TokenStream as TokenStream2;
12use quote::quote;
13use syn::visit::Visit;
14use syn::{parse_macro_input, parse_quote, Expr, ExprMethodCall, ItemFn, ReturnType, Type};
15
16fn marker_ident(fn_name: &str) -> syn::Ident {
17    syn::parse_str(&format!("__Jig_{fn_name}")).unwrap()
18}
19
20fn marker_path_for(name: &str) -> TokenStream2 {
21    let segs: Vec<&str> = name.split("::").collect();
22    let last_idx = segs.len() - 1;
23    let path_segs: Vec<TokenStream2> = segs
24        .iter()
25        .enumerate()
26        .map(|(i, s)| {
27            if i == last_idx {
28                let mi = marker_ident(s);
29                quote!(#mi)
30            } else if *s == "crate" {
31                quote!(crate)
32            } else if *s == "super" {
33                quote!(super)
34            } else if *s == "self" {
35                quote!(self)
36            } else {
37                let id: syn::Ident = syn::parse_str(s).unwrap();
38                quote!(#id)
39            }
40        })
41        .collect();
42    quote!(#(#path_segs)::*)
43}
44
45#[proc_macro_attribute]
46pub fn jig(_attr: TokenStream, item: TokenStream) -> TokenStream {
47    let input = parse_macro_input!(item as ItemFn);
48    let vis = &input.vis;
49    let block = &input.block;
50    let name_str = input.sig.ident.to_string();
51    let marker = marker_ident(&name_str);
52    let kind_str = return_kind(&input.sig.output);
53    let input_str = input_kind(&input.sig);
54    let input_type_str = first_arg_payload(&input.sig);
55    let output_type_str = return_payload(&input.sig.output);
56    let is_async = input.sig.asyncness.is_some();
57
58    let chain_tokens: Vec<TokenStream2> = collect_chain(&input.block)
59        .into_iter()
60        .map(|(name, kind)| {
61            let kind_ident = match kind {
62                ChainKindTok::Then => quote!(::jigs::ChainKind::Then),
63                ChainKindTok::Fork => quote!(::jigs::ChainKind::Fork),
64            };
65            quote! { ::jigs::ChainStep { name: #name, kind: #kind_ident } }
66        })
67        .collect();
68
69    let chain_collect: Vec<TokenStream2> = collect_chain(&input.block)
70        .into_iter()
71        .map(|(name, _kind)| {
72            let path = marker_path_for(&name);
73            quote! { <#path as ::jigs::JigDef>::collect(out); }
74        })
75        .collect();
76
77    let marker_def = quote! {
78        #[allow(non_camel_case_types)]
79        #[doc(hidden)]
80        pub struct #marker;
81
82        impl ::jigs::JigDef for #marker {
83            const META: ::jigs::JigMeta = ::jigs::JigMeta {
84                name: #name_str,
85                file: file!(),
86                line: line!(),
87                kind: #kind_str,
88                input: #input_str,
89                input_type: #input_type_str,
90                output_type: #output_type_str,
91                is_async: #is_async,
92                module: module_path!(),
93                chain: &[#(#chain_tokens),*],
94            };
95
96            fn collect(out: &mut Vec<&'static ::jigs::JigMeta>) {
97                let name = <Self as ::jigs::JigDef>::META.name;
98                if out.iter().any(|m| m.name == name) {
99                    return;
100                }
101                out.push(&<Self as ::jigs::JigDef>::META);
102                #(#chain_collect)*
103            }
104        }
105    };
106
107    let response_input_ident = if input_str == "Response" {
108        first_arg_ident(&input.sig)
109    } else {
110        None
111    };
112
113    if input.sig.asyncness.is_some() {
114        let mut sig = input.sig.clone();
115        sig.asyncness = None;
116        let ret_ty = match &input.sig.output {
117            ReturnType::Default => quote!(()),
118            ReturnType::Type(_, ty) => quote!(#ty),
119        };
120        sig.output = parse_quote! {
121            -> ::jigs::Pending<impl ::core::future::Future<Output = #ret_ty>>
122        };
123
124        let body = async_body(block, &name_str, response_input_ident.as_ref());
125        return quote! { #marker_def #vis #sig { #body } }.into();
126    }
127
128    let sig = &input.sig;
129    let body = sync_body(block, &name_str, response_input_ident.as_ref());
130    quote! { #marker_def #vis #sig { #body } }.into()
131}
132
133#[proc_macro]
134pub fn jigs(input: TokenStream) -> TokenStream {
135    let entry: syn::Ident = parse_macro_input!(input);
136    let entry_marker = marker_ident(&entry.to_string());
137    quote! {
138        mod __jigs_registry {
139            pub fn all_jigs() -> impl Iterator<Item = &'static ::jigs::JigMeta> {
140                static CACHE: std::sync::OnceLock<Vec<&'static ::jigs::JigMeta>> = std::sync::OnceLock::new();
141                CACHE.get_or_init(|| {
142                    let mut v = Vec::new();
143                    <super::#entry_marker as ::jigs::JigDef>::collect(&mut v);
144                    v
145                }).iter().copied()
146            }
147
148            pub fn find_jig(name: &str) -> Option<&'static ::jigs::JigMeta> {
149                all_jigs().find(|m| m.name == name)
150            }
151        }
152        pub use __jigs_registry::{all_jigs, find_jig};
153    }
154    .into()
155}
156
157fn first_arg_ident(sig: &syn::Signature) -> Option<syn::Ident> {
158    if let Some(syn::FnArg::Typed(pt)) = sig.inputs.first() {
159        if let syn::Pat::Ident(pi) = &*pt.pat {
160            return Some(pi.ident.clone());
161        }
162    }
163    None
164}
165
166#[cfg(feature = "trace")]
167fn sync_body(
168    block: &syn::Block,
169    name_str: &str,
170    response_input: Option<&syn::Ident>,
171) -> TokenStream2 {
172    let marker = marker_ident(name_str);
173    let snapshot = match response_input {
174        Some(id) => quote! { let __jig_input_ok = ::jigs::Status::ok(&#id); },
175        None => quote! { let __jig_input_ok = true; },
176    };
177    quote! {
178        #snapshot
179        let __jig_idx = ::jigs::trace::enter(&<#marker as ::jigs::JigDef>::META);
180        let __jig_start = ::std::time::Instant::now();
181        let __jig_result = (move || #block)();
182        let mut __jig_ok = ::jigs::Status::ok(&__jig_result);
183        let mut __jig_err = ::jigs::Status::error(&__jig_result);
184        if !__jig_input_ok && !__jig_ok {
185            __jig_ok = true;
186            __jig_err = None;
187        }
188        ::jigs::trace::exit(__jig_idx, __jig_start.elapsed(), __jig_ok, __jig_err);
189        __jig_result
190    }
191}
192
193#[cfg(not(feature = "trace"))]
194fn sync_body(
195    block: &syn::Block,
196    _name_str: &str,
197    _response_input: Option<&syn::Ident>,
198) -> TokenStream2 {
199    quote! { #block }
200}
201
202#[cfg(feature = "trace")]
203fn async_body(
204    block: &syn::Block,
205    name_str: &str,
206    response_input: Option<&syn::Ident>,
207) -> TokenStream2 {
208    let marker = marker_ident(name_str);
209    let snapshot = match response_input {
210        Some(id) => quote! { let __jig_input_ok = ::jigs::Status::ok(&#id); },
211        None => quote! { let __jig_input_ok = true; },
212    };
213    quote! {
214        ::jigs::Pending(async move {
215            #snapshot
216            let __jig_idx = ::jigs::trace::enter(&<#marker as ::jigs::JigDef>::META);
217            let __jig_start = ::std::time::Instant::now();
218            let __jig_result = (async move #block).await;
219            let mut __jig_ok = ::jigs::Status::ok(&__jig_result);
220            let mut __jig_err = ::jigs::Status::error(&__jig_result);
221            if !__jig_input_ok && !__jig_ok {
222                __jig_ok = true;
223                __jig_err = None;
224            }
225            ::jigs::trace::exit(__jig_idx, __jig_start.elapsed(), __jig_ok, __jig_err);
226            __jig_result
227        })
228    }
229}
230
231#[cfg(not(feature = "trace"))]
232fn async_body(
233    block: &syn::Block,
234    _name_str: &str,
235    _response_input: Option<&syn::Ident>,
236) -> TokenStream2 {
237    quote! { ::jigs::Pending(async move #block) }
238}
239
240fn return_kind(ret: &ReturnType) -> &'static str {
241    let ty = match ret {
242        ReturnType::Default => return "Other",
243        ReturnType::Type(_, t) => t,
244    };
245    match last_type_ident(ty).as_deref() {
246        Some("Request") => "Request",
247        Some("Response") => "Response",
248        Some("Branch") => "Branch",
249        Some("Pending") => "Pending",
250        _ => "Other",
251    }
252}
253
254fn input_kind(sig: &syn::Signature) -> &'static str {
255    let ty = match sig.inputs.first() {
256        Some(syn::FnArg::Typed(pt)) => &*pt.ty,
257        _ => return "Other",
258    };
259    match last_type_ident(ty).as_deref() {
260        Some("Request") => "Request",
261        Some("Response") => "Response",
262        _ => "Other",
263    }
264}
265
266fn first_arg_payload(sig: &syn::Signature) -> String {
267    let ty = match sig.inputs.first() {
268        Some(syn::FnArg::Typed(pt)) => &*pt.ty,
269        _ => return "?".into(),
270    };
271    payload_type(ty)
272}
273
274fn return_payload(ret: &ReturnType) -> String {
275    let ty = match ret {
276        ReturnType::Default => return "?".into(),
277        ReturnType::Type(_, t) => t,
278    };
279    payload_type(ty)
280}
281
282fn payload_type(ty: &Type) -> String {
283    if let Type::Path(p) = ty {
284        if let Some(seg) = p.path.segments.last() {
285            let name = seg.ident.to_string();
286            match name.as_str() {
287                "Request" | "Response" => {
288                    if let syn::PathArguments::AngleBracketed(ref ab) = seg.arguments {
289                        return generic_args_string(ab);
290                    }
291                }
292                "Branch" => {
293                    if let syn::PathArguments::AngleBracketed(ref ab) = seg.arguments {
294                        return format!("Branch<{}>", generic_args_string(ab));
295                    }
296                }
297                "Pending" => {
298                    if let syn::PathArguments::AngleBracketed(ref ab) = seg.arguments {
299                        return generic_args_string(ab);
300                    }
301                }
302                _ => {}
303            }
304        }
305    }
306    type_to_string(ty)
307}
308
309fn type_to_string(ty: &Type) -> String {
310    quote::quote!(#ty).to_string().replace(' ', "")
311}
312
313fn generic_args_string(args: &syn::AngleBracketedGenericArguments) -> String {
314    let mut out = String::new();
315    for (i, arg) in args.args.iter().enumerate() {
316        if i > 0 {
317            out.push(',');
318        }
319        match arg {
320            syn::GenericArgument::Type(t) => out.push_str(&type_to_string(t)),
321            syn::GenericArgument::Lifetime(l) => out.push_str(&l.ident.to_string()),
322            other => out.push_str(&quote::quote!(#other).to_string().replace(' ', "")),
323        }
324    }
325    out
326}
327
328fn last_type_ident(ty: &Type) -> Option<String> {
329    if let Type::Path(p) = ty {
330        return Some(p.path.segments.last()?.ident.to_string());
331    }
332    None
333}
334
335#[derive(Clone, Copy)]
336enum ChainKindTok {
337    Then,
338    Fork,
339}
340
341fn collect_chain(block: &syn::Block) -> Vec<(String, ChainKindTok)> {
342    struct V(Vec<(String, ChainKindTok)>);
343    impl V {
344        fn push_unique(&mut self, name: String, kind: ChainKindTok) {
345            if !self.0.iter().any(|(n, _)| n == &name) {
346                self.0.push((name, kind));
347            }
348        }
349        fn push_path(&mut self, p: &syn::Path, kind: ChainKindTok) {
350            let name = p
351                .segments
352                .iter()
353                .map(|s| s.ident.to_string())
354                .collect::<Vec<_>>()
355                .join("::");
356            self.push_unique(name, kind);
357        }
358    }
359    impl<'ast> Visit<'ast> for V {
360        fn visit_expr_method_call(&mut self, m: &'ast ExprMethodCall) {
361            syn::visit::visit_expr(self, &m.receiver);
362            if m.method == "then" {
363                if let Some(Expr::Path(p)) = m.args.first() {
364                    self.push_path(&p.path, ChainKindTok::Then);
365                }
366            }
367            for a in &m.args {
368                syn::visit::visit_expr(self, a);
369            }
370        }
371        fn visit_macro(&mut self, mac: &'ast syn::Macro) {
372            let last = mac
373                .path
374                .segments
375                .last()
376                .map(|s| s.ident.to_string())
377                .unwrap_or_default();
378            if last == "fork" {
379                if let Ok(args) = syn::parse2::<ForkArgs>(mac.tokens.clone()) {
380                    for j in &args.arms {
381                        if let syn::Expr::Path(p) = j {
382                            self.push_path(&p.path, ChainKindTok::Fork);
383                        }
384                    }
385                    if let syn::Expr::Path(p) = &args.default {
386                        self.push_path(&p.path, ChainKindTok::Fork);
387                    }
388                }
389            }
390        }
391    }
392    let mut v = V(Vec::new());
393    v.visit_block(block);
394    v.0
395}
396
397struct ForkArgs {
398    arms: Vec<syn::Expr>,
399    default: syn::Expr,
400}
401
402impl syn::parse::Parse for ForkArgs {
403    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
404        let _req: syn::Expr = input.parse()?;
405        input.parse::<syn::Token![,]>()?;
406        let mut arms = Vec::new();
407        loop {
408            if input.peek(syn::Token![_]) {
409                input.parse::<syn::Token![_]>()?;
410                input.parse::<syn::Token![=>]>()?;
411                let default: syn::Expr = input.parse()?;
412                let _: Option<syn::Token![,]> = input.parse().ok();
413                return Ok(ForkArgs { arms, default });
414            }
415            let _pred: syn::Expr = input.parse()?;
416            input.parse::<syn::Token![=>]>()?;
417            let jig: syn::Expr = input.parse()?;
418            input.parse::<syn::Token![,]>()?;
419            arms.push(jig);
420        }
421    }
422}