Skip to main content

jigs_macros/
lib.rs

1//! Procedural macros for the `jigs` framework.
2//!
3//! `#[jig]` marks a function as a pipeline step. It emits a zero-sized
4//! marker struct implementing `JigDef` alongside the (possibly
5//! transformed) function body. The marker struct is named
6//! `__Jig_<fn_name>` to avoid namespace collisions with the function
7//! itself. With the `trace` feature it additionally wraps the body in a
8//! thread-local trace recorder.
9
10use proc_macro::TokenStream;
11use proc_macro2::TokenStream as TokenStream2;
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::visit::Visit;
15use syn::{
16    parse_macro_input, parse_quote, Data, DeriveInput, Expr, ExprMethodCall, Field, Fields,
17    FieldsNamed, FieldsUnnamed, Ident, ItemFn, ReturnType, Type,
18};
19
20fn marker_ident(fn_name: &str) -> syn::Ident {
21    syn::parse_str(&format!("__Jig_{fn_name}")).unwrap()
22}
23
24fn marker_path_for(name: &str) -> TokenStream2 {
25    let segs: Vec<&str> = name.split("::").collect();
26    let last_idx = segs.len() - 1;
27    let path_segs: Vec<TokenStream2> = segs
28        .iter()
29        .enumerate()
30        .map(|(i, s)| {
31            if i == last_idx {
32                let mi = marker_ident(s);
33                quote!(#mi)
34            } else if *s == "crate" {
35                quote!(crate)
36            } else if *s == "super" {
37                quote!(super)
38            } else if *s == "self" {
39                quote!(self)
40            } else {
41                let id: syn::Ident = syn::parse_str(s).unwrap();
42                quote!(#id)
43            }
44        })
45        .collect();
46    quote!(#(#path_segs)::*)
47}
48
49#[proc_macro_attribute]
50pub fn jig(_attr: TokenStream, item: TokenStream) -> TokenStream {
51    let input = parse_macro_input!(item as ItemFn);
52    let vis = &input.vis;
53    let attrs = &input.attrs;
54    let block = &input.block;
55    let name_str = input.sig.ident.to_string();
56    let marker = marker_ident(&name_str);
57    let input_type_str = first_arg_payload(&input.sig);
58    let output_type_str = return_payload(&input.sig.output);
59    let is_async = input.sig.asyncness.is_some();
60
61    let input_ty = first_arg_type(&input.sig);
62    let output_ty = return_type(&input.sig.output);
63    let kind_expr = classify_expr(output_ty.as_ref());
64    let input_expr = classify_expr(input_ty.as_ref());
65
66    let chain = collect_chain(&input.block);
67
68    let chain_tokens: Vec<TokenStream2> = chain
69        .iter()
70        .map(|(name, kind)| {
71            let kind_ident = match kind {
72                ChainKindTok::Then => quote!(::jigs::ChainKind::Then),
73                ChainKindTok::Fork => quote!(::jigs::ChainKind::Fork),
74            };
75            quote! { ::jigs::ChainStep { name: #name, kind: #kind_ident } }
76        })
77        .collect();
78
79    let chain_collect: Vec<TokenStream2> = chain
80        .iter()
81        .map(|(name, _kind)| {
82            let path = marker_path_for(name);
83            quote! { <#path as ::jigs::JigDef>::collect(out); }
84        })
85        .collect();
86
87    let marker_def = quote! {
88        #[allow(non_camel_case_types)]
89        #[doc(hidden)]
90        pub struct #marker;
91
92        impl ::jigs::JigDef for #marker {
93            const META: ::jigs::JigMeta = ::jigs::JigMeta {
94                name: #name_str,
95                file: file!(),
96                line: line!(),
97                kind: #kind_expr,
98                input: #input_expr,
99                input_type: #input_type_str,
100                output_type: #output_type_str,
101                is_async: #is_async,
102                module: module_path!(),
103                chain: &[#(#chain_tokens),*],
104            };
105
106            fn collect(out: &mut Vec<&'static ::jigs::JigMeta>) {
107                let meta = &<Self as ::jigs::JigDef>::META;
108                if out.iter().any(|m| ::std::ptr::eq(*m, meta)) {
109                    return;
110                }
111                out.push(meta);
112                #(#chain_collect)*
113            }
114        }
115    };
116
117    let input_ident = first_arg_ident(&input.sig);
118
119    if input.sig.asyncness.is_some() {
120        let mut sig = input.sig.clone();
121        sig.asyncness = None;
122        let ret_ty = match &input.sig.output {
123            ReturnType::Default => quote!(()),
124            ReturnType::Type(_, ty) => quote!(#ty),
125        };
126        sig.output = parse_quote! {
127            -> ::jigs::Pending<impl ::core::future::Future<Output = #ret_ty>>
128        };
129
130        let body = async_body(block, &name_str, input_ident.as_ref());
131        return quote! { #marker_def #(#attrs)* #vis #sig { #body } }.into();
132    }
133
134    let sig = &input.sig;
135    let body = sync_body(block, &name_str, input_ident.as_ref());
136    quote! { #marker_def #(#attrs)* #vis #sig { #body } }.into()
137}
138
139#[proc_macro_derive(Request, attributes(req))]
140pub fn derive_request(input: TokenStream) -> TokenStream {
141    let parsed = parse_macro_input!(input as DeriveInput);
142    generate_req(&parsed).unwrap_or_else(|e| e.to_compile_error().into())
143}
144
145fn generate_req(input: &DeriveInput) -> Result<TokenStream, syn::Error> {
146    let name = &input.ident;
147    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
148    let Data::Struct(data) = &input.data else {
149        return Err(syn::Error::new_spanned(
150            input,
151            "Request can only be derived for structs",
152        ));
153    };
154
155    let mut explicit_field: Option<Ident> = None;
156
157    for attr in &input.attrs {
158        if attr.path().is_ident("req") {
159            attr.parse_nested_meta(|meta| {
160                if meta.path.is_ident("field") {
161                    let val = meta.value()?;
162                    let lit: syn::LitStr = val.parse()?;
163                    explicit_field = Some(syn::Ident::new(&lit.value(), lit.span()));
164                    return Ok(());
165                }
166                Err(meta.error("unrecognized req attribute"))
167            })?;
168        }
169    }
170
171    let (payload_decl, payload_ref_expr, into_expr, from_expr) =
172        derive_req_field_info(data, explicit_field, input)?;
173
174    let mut merge_generics = input.generics.clone();
175    merge_generics
176        .params
177        .push(syn::GenericParam::Type(syn::TypeParam {
178            attrs: Vec::new(),
179            ident: parse_quote!(__R),
180            colon_token: Some(syn::Token![:](input.generics.span())),
181            bounds: parse_quote!(::jigs::Response),
182            eq_token: None,
183            default: None,
184        }));
185    let (merge_impl_generics, _, merge_where_clause) = merge_generics.split_for_impl();
186
187    Ok(quote! {
188        impl #impl_generics ::jigs::__Classify for #name #type_generics #where_clause {
189            const KIND: &'static str = "Request";
190        }
191        impl #impl_generics ::jigs::Request for #name #type_generics #where_clause {
192            #payload_decl
193            fn payload(&self) -> &Self::Payload {
194                #payload_ref_expr
195            }
196            fn into_payload(self) -> Self::Payload {
197                #into_expr
198            }
199            fn from_payload(payload: Self::Payload) -> Self {
200                #from_expr
201            }
202        }
203        impl #merge_impl_generics ::jigs::Merge<__R> for #name #type_generics #merge_where_clause {
204            type Merged = ::jigs::Branch<#name #type_generics, __R>;
205            fn into_continue(self) -> Self::Merged {
206                ::jigs::Branch::Continue(self)
207            }
208            fn from_done(resp: __R) -> Self::Merged {
209                ::jigs::Branch::Done(resp)
210            }
211        }
212        impl #impl_generics ::jigs::Step for #name #type_generics #where_clause {
213            type Out = #name #type_generics;
214            type Fut = ::core::future::Ready<#name #type_generics>;
215            fn into_step(self) -> Self::Fut {
216                ::core::future::ready(self)
217            }
218        }
219        impl #impl_generics ::jigs::Status for #name #type_generics #where_clause {
220            fn succeeded(&self) -> bool {
221                true
222            }
223            fn error(&self) -> Option<String> {
224                None
225            }
226        }
227    }
228    .into())
229}
230
231fn derive_req_field_info(
232    data: &syn::DataStruct,
233    explicit_field: Option<Ident>,
234    input: &DeriveInput,
235) -> Result<(TokenStream2, TokenStream2, TokenStream2, TokenStream2), syn::Error> {
236    if let Some(field_ident) = explicit_field {
237        let field = find_field(data, &field_ident)?;
238        let payload_ty = &field.ty;
239        let payload_decl = quote! { type Payload = #payload_ty; };
240        let payload_ref = quote! { &self.#field_ident };
241        let into_expr = quote! {
242            let Self { #field_ident, .. } = self;
243            #field_ident
244        };
245        let from_expr = quote! { Self { #field_ident: payload, ..Default::default() } };
246        return Ok((payload_decl, payload_ref, into_expr, from_expr));
247    }
248
249    match &data.fields {
250        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) if unnamed.len() == 1 => {
251            let field = unnamed.first().unwrap();
252            let payload_ty = &field.ty;
253            let payload_decl = quote! { type Payload = #payload_ty; };
254            let payload_ref = quote! { &self.0 };
255            let into_expr = quote! { self.0 };
256            let from_expr = quote! { Self(payload) };
257            Ok((payload_decl, payload_ref, into_expr, from_expr))
258        }
259        Fields::Named(FieldsNamed { named, .. }) if named.len() == 1 => {
260            let field = named.first().unwrap();
261            let field_ident = field.ident.as_ref().unwrap();
262            let payload_ty = &field.ty;
263            let payload_decl = quote! { type Payload = #payload_ty; };
264            let payload_ref = quote! { &self.#field_ident };
265            let into_expr = quote! { self.#field_ident };
266            let from_expr = quote! { Self { #field_ident: payload } };
267            Ok((payload_decl, payload_ref, into_expr, from_expr))
268        }
269        _ => Err(syn::Error::new_spanned(
270            input,
271            "Request derive requires either: one field, or #[req(field = \"name\")]",
272        )),
273    }
274}
275
276fn find_field<'a>(data: &'a syn::DataStruct, ident: &Ident) -> Result<&'a Field, syn::Error> {
277    for f in &data.fields {
278        if f.ident.as_ref() == Some(ident) {
279            return Ok(f);
280        }
281    }
282    Err(syn::Error::new(
283        proc_macro2::Span::call_site(),
284        format!("no field named `{ident}`"),
285    ))
286}
287
288#[proc_macro_derive(Response, attributes(resp))]
289pub fn derive_response(input: TokenStream) -> TokenStream {
290    let parsed = parse_macro_input!(input as DeriveInput);
291    generate_response(&parsed).unwrap_or_else(|e| e.to_compile_error().into())
292}
293
294fn generate_response(input: &DeriveInput) -> Result<TokenStream, syn::Error> {
295    match &input.data {
296        Data::Struct(data) => generate_response_struct(input, data),
297        Data::Enum(data) => generate_response_enum(input, data),
298        Data::Union(_u) => Err(syn::Error::new_spanned(
299            input,
300            "Response cannot be derived for unions",
301        )),
302    }
303}
304
305fn generate_response_struct(
306    input: &DeriveInput,
307    data: &syn::DataStruct,
308) -> Result<TokenStream, syn::Error> {
309    let name = &input.ident;
310    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
311
312    match &data.fields {
313        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) if unnamed.len() == 1 => {
314            let f = unnamed.first().unwrap();
315            let ok_expr = quote! { Self(Ok(payload)) };
316            let err_expr = quote! { Self(Err(msg.into())) };
317            let is_ok_expr = quote! { self.0.is_ok() };
318            let into_result_expr = quote! { self.0 };
319            let error_msg_expr = quote! { self.0.as_ref().err().cloned() };
320            let payload_ty = extract_result_payload(&f.ty,
321                "Response derive on single-field structs expects `Result<Payload, String>`",
322            )?;
323            Ok(generate_response_impls(ResponseImplParts {
324                name,
325                impl_generics,
326                type_generics,
327                where_clause,
328                payload_ty: &payload_ty,
329                ok_expr,
330                err_expr,
331                is_ok_expr,
332                into_result_expr,
333                error_msg_expr,
334            }))
335        }
336        Fields::Named(FieldsNamed { named, .. }) if named.len() == 1 => {
337            let f = named.first().unwrap();
338            let field_ident = f.ident.as_ref().unwrap();
339            let payload_ty = extract_result_payload(
340                &f.ty,
341                "Response derive on single-field structs expects `Result<Payload, String>`",
342            )?;
343            let ok_expr = quote! { Self { #field_ident: Ok(payload) } };
344            let err_expr = quote! { Self { #field_ident: Err(msg.into()) } };
345            let is_ok_expr = quote! { self.#field_ident.is_ok() };
346            let into_result_expr = quote! { self.#field_ident };
347            let error_msg_expr = quote! { self.#field_ident.as_ref().err().cloned() };
348            Ok(generate_response_impls(ResponseImplParts {
349                name,
350                impl_generics,
351                type_generics,
352                where_clause,
353                payload_ty: &payload_ty,
354                ok_expr,
355                err_expr,
356                is_ok_expr,
357                into_result_expr,
358                error_msg_expr,
359            }))
360        }
361        Fields::Named(FieldsNamed { named, .. }) if named.len() == 2 => {
362            generate_response_two_fields(input, data, named, name, impl_generics, type_generics, where_clause)
363        }
364        _ => Err(syn::Error::new_spanned(
365            input,
366            "Response derive requires either: a single `Result<Payload, String>` field, or two fields",
367        )),
368    }
369}
370
371fn generate_response_two_fields(
372    input: &DeriveInput,
373    _data: &syn::DataStruct,
374    named: &syn::punctuated::Punctuated<Field, syn::token::Comma>,
375    name: &Ident,
376    impl_generics: syn::ImplGenerics,
377    type_generics: syn::TypeGenerics,
378    where_clause: Option<&syn::WhereClause>,
379) -> Result<TokenStream, syn::Error> {
380    let mut ok_field_idx: Option<usize> = None;
381    let mut err_field_idx: Option<usize> = None;
382
383    for (i, f) in named.iter().enumerate() {
384        for attr in &f.attrs {
385            if attr.path().is_ident("resp") {
386                attr.parse_nested_meta(|meta| {
387                    if meta.path.is_ident("ok") {
388                        ok_field_idx = Some(i);
389                        return Ok(());
390                    }
391                    if meta.path.is_ident("err") {
392                        err_field_idx = Some(i);
393                        return Ok(());
394                    }
395                    Err(meta.error("unrecognized resp attribute"))
396                })?;
397            }
398        }
399    }
400
401    let ok_idx = match ok_field_idx {
402        Some(i) => i,
403        None => err_field_idx.map_or(0, |e| 1 - e),
404    };
405    let err_idx = match err_field_idx {
406        Some(i) => i,
407        None => ok_field_idx.map_or(1, |o| 1 - o),
408    };
409
410    if ok_idx == err_idx {
411        return Err(syn::Error::new_spanned(
412            input,
413            "ok and err fields cannot be the same",
414        ));
415    }
416
417    let ok_field = &named[ok_idx];
418    let err_field = &named[err_idx];
419
420    let ok_ident = ok_field.ident.as_ref().unwrap();
421    let err_ident = err_field.ident.as_ref().unwrap();
422
423    let is_err_string = matches!(
424        syn_type_as_string(&err_field.ty).as_deref(),
425        Some(s) if s == "String",
426    );
427
428    if !is_err_string {
429        return Err(syn::Error::new_spanned(
430            input,
431            "Response derive with two fields requires the error field to be `String`",
432        ));
433    }
434
435    let payload_ty = extract_option_inner(
436        &ok_field.ty,
437        "Response derive with two fields expects the ok field to be `Option<Payload>`",
438    )?;
439    let ok_expr = quote! { Self { #ok_ident: Some(payload), #err_ident: "".to_string() } };
440    let err_expr = quote! { Self { #ok_ident: None, #err_ident: msg.into() } };
441    let is_ok_expr = quote! { self.#ok_ident.is_some() };
442    let into_result_expr = quote! {
443        match self.#ok_ident {
444            Some(v) => Ok(v),
445            None => Err(self.#err_ident),
446        }
447    };
448    let error_msg_expr = quote! {
449        if self.#ok_ident.is_some() { None } else { Some(self.#err_ident.clone()) }
450    };
451
452    Ok(generate_response_impls(ResponseImplParts {
453        name,
454        impl_generics,
455        type_generics,
456        where_clause,
457        payload_ty: &payload_ty,
458        ok_expr,
459        err_expr,
460        is_ok_expr,
461        into_result_expr,
462        error_msg_expr,
463    }))
464}
465
466struct ClassifiedVariant<'a> {
467    variant: &'a syn::Variant,
468    ident: syn::Ident,
469    fields: &'a syn::Fields,
470}
471
472fn classify_enum_variants<'a>(
473    data: &'a syn::DataEnum,
474    input: &'a DeriveInput,
475) -> Result<(ClassifiedVariant<'a>, ClassifiedVariant<'a>), syn::Error> {
476    if data.variants.len() != 2 {
477        return Err(syn::Error::new_spanned(
478            input,
479            "Response derive on enums requires exactly 2 variants",
480        ));
481    }
482
483    let mut ok_variant: Option<ClassifiedVariant<'_>> = None;
484    let mut err_variant: Option<ClassifiedVariant<'_>> = None;
485
486    for v in &data.variants {
487        let mut is_ok = false;
488        let mut is_err = false;
489        for attr in &v.attrs {
490            if attr.path().is_ident("resp") {
491                attr.parse_nested_meta(|meta| {
492                    if meta.path.is_ident("ok") {
493                        is_ok = true;
494                        return Ok(());
495                    }
496                    if meta.path.is_ident("err") {
497                        is_err = true;
498                        return Ok(());
499                    }
500                    Err(meta.error("unrecognized resp attribute"))
501                })?;
502            }
503        }
504
505        if is_ok && is_err {
506            return Err(syn::Error::new_spanned(
507                v,
508                "variant cannot be both #[resp(ok)] and #[resp(err)]",
509            ));
510        }
511
512        let cv = ClassifiedVariant {
513            variant: v,
514            ident: v.ident.clone(),
515            fields: &v.fields,
516        };
517
518        if is_ok {
519            if ok_variant.is_some() {
520                return Err(syn::Error::new_spanned(
521                    v,
522                    "only one variant can be #[resp(ok)]",
523                ));
524            }
525            if v.fields.len() != 1 {
526                return Err(syn::Error::new_spanned(
527                    v,
528                    "ok variant must have exactly one field (the payload)",
529                ));
530            }
531            ok_variant = Some(cv);
532        } else if is_err {
533            if err_variant.is_some() {
534                return Err(syn::Error::new_spanned(
535                    v,
536                    "only one variant can be #[resp(err)]",
537                ));
538            }
539            if v.fields.len() > 1 {
540                return Err(syn::Error::new_spanned(
541                    v,
542                    "err variant must have 0 or 1 fields",
543                ));
544            }
545            err_variant = Some(cv);
546        } else if ok_variant.is_none() {
547            if v.fields.len() != 1 {
548                return Err(syn::Error::new_spanned(
549                    v,
550                    "ok variant must have exactly one field (the payload)",
551                ));
552            }
553            ok_variant = Some(cv);
554        } else if err_variant.is_none() {
555            if v.fields.len() > 1 {
556                return Err(syn::Error::new_spanned(
557                    v,
558                    "err variant must have 0 or 1 fields",
559                ));
560            }
561            err_variant = Some(cv);
562        }
563    }
564
565    let ok = ok_variant.ok_or_else(|| {
566        syn::Error::new_spanned(input, "Could not identify ok variant. Use #[resp(ok)]")
567    })?;
568    let err = err_variant.ok_or_else(|| {
569        syn::Error::new_spanned(input, "Could not identify err variant. Use #[resp(err)]")
570    })?;
571    Ok((ok, err))
572}
573
574struct VariantCodegen {
575    constructor: TokenStream2,
576    wild: TokenStream2,
577    pattern: TokenStream2,
578}
579
580fn variant_codegen(
581    name: &syn::Ident,
582    ident: &syn::Ident,
583    fields: &syn::Fields,
584    binding_name: &str,
585) -> VariantCodegen {
586    let b = syn::Ident::new(binding_name, name.span());
587    if fields.is_empty() {
588        let constructor = quote!(#name::#ident);
589        let wild = quote!(#name::#ident);
590        let pattern = quote!(#name::#ident);
591        VariantCodegen {
592            constructor,
593            wild,
594            pattern,
595        }
596    } else {
597        let unnamed = fields.iter().next().unwrap().ident.is_none();
598        let constructor = if unnamed {
599            quote!(#name::#ident(#b))
600        } else {
601            let f = fields.iter().next().unwrap().ident.as_ref().unwrap();
602            quote!(#name::#ident { #f: #b })
603        };
604        let wild = if unnamed {
605            quote! { #name::#ident(..) }
606        } else {
607            quote! { #name::#ident { .. } }
608        };
609        let pattern = if unnamed {
610            let b = syn::Ident::new(binding_name, name.span());
611            quote! { #name::#ident(#b) }
612        } else {
613            let f = fields.iter().next().unwrap().ident.as_ref().unwrap();
614            let b = syn::Ident::new(binding_name, name.span());
615            quote! { #name::#ident { #f: #b } }
616        };
617        VariantCodegen {
618            constructor,
619            wild,
620            pattern,
621        }
622    }
623}
624
625fn generate_response_enum(
626    input: &DeriveInput,
627    data: &syn::DataEnum,
628) -> Result<TokenStream, syn::Error> {
629    let name = &input.ident;
630    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
631
632    let (ok, err) = classify_enum_variants(data, input)?;
633
634    let ok_ident = &ok.ident;
635    let err_ident = &err.ident;
636    let payload_ty = &ok.variant.fields.iter().next().unwrap().ty;
637
638    let ok_cg = variant_codegen(name, ok_ident, ok.fields, "__p");
639    let err_has_field = err.fields.len() == 1;
640    let err_cg = variant_codegen(name, err_ident, err.fields, "__e");
641    let VariantCodegen {
642        constructor: ok_constr,
643        wild: ok_wild,
644        pattern: ok_pattern,
645    } = ok_cg;
646    let VariantCodegen {
647        constructor: err_constr,
648        wild: err_wild,
649        pattern: err_pattern,
650    } = err_cg;
651
652    let ok_expr = quote! {
653        {
654            let __p = payload;
655            #ok_constr
656        }
657    };
658    let err_expr = if err_has_field {
659        quote! {
660            {
661                let __e = msg.into();
662                #err_constr
663            }
664        }
665    } else {
666        quote! { #name::#err_ident }
667    };
668
669    let is_ok_expr = quote! {
670        match self {
671            #ok_wild => true,
672            #err_wild => false,
673        }
674    };
675    let into_result_expr = if err_has_field {
676        quote! {
677            match self {
678                #ok_pattern => Ok(__p),
679                #err_pattern => Err(__e),
680            }
681        }
682    } else {
683        quote! {
684            match self {
685                #ok_pattern => Ok(__p),
686                #err_wild => Err("unknown error".to_string()),
687            }
688        }
689    };
690    let error_msg_expr = if err_has_field {
691        quote! {
692            match self {
693                #ok_wild => None,
694                #err_pattern => Some(__e.to_string()),
695            }
696        }
697    } else {
698        quote! {
699            match self {
700                #ok_wild => None,
701                #err_wild => Some("unknown error".to_string()),
702            }
703        }
704    };
705
706    Ok(generate_response_impls(ResponseImplParts {
707        name,
708        impl_generics,
709        type_generics,
710        where_clause,
711        payload_ty,
712        ok_expr,
713        err_expr,
714        is_ok_expr,
715        into_result_expr,
716        error_msg_expr,
717    }))
718}
719
720struct ResponseImplParts<'a> {
721    name: &'a syn::Ident,
722    impl_generics: syn::ImplGenerics<'a>,
723    type_generics: syn::TypeGenerics<'a>,
724    where_clause: Option<&'a syn::WhereClause>,
725    payload_ty: &'a Type,
726    ok_expr: TokenStream2,
727    err_expr: TokenStream2,
728    is_ok_expr: TokenStream2,
729    into_result_expr: TokenStream2,
730    error_msg_expr: TokenStream2,
731}
732
733fn generate_response_impls(parts: ResponseImplParts<'_>) -> proc_macro::TokenStream {
734    let ResponseImplParts {
735        name,
736        impl_generics,
737        type_generics,
738        where_clause,
739        payload_ty,
740        ok_expr,
741        err_expr,
742        is_ok_expr,
743        into_result_expr,
744        error_msg_expr,
745    } = parts;
746    quote! {
747        impl #impl_generics ::jigs::__Classify for #name #type_generics #where_clause {
748            const KIND: &'static str = "Response";
749        }
750        impl #impl_generics ::jigs::Response for #name #type_generics #where_clause {
751            type Payload = #payload_ty;
752            fn ok(payload: Self::Payload) -> Self {
753                #ok_expr
754            }
755            fn err(msg: impl Into<String>) -> Self {
756                #err_expr
757            }
758            fn is_ok(&self) -> bool {
759                #is_ok_expr
760            }
761            fn into_result(self) -> Result<Self::Payload, String> {
762                #into_result_expr
763            }
764            fn error_msg(&self) -> Option<String> {
765                #error_msg_expr
766            }
767        }
768        impl #impl_generics ::jigs::Merge<#name #type_generics> for #name #type_generics #where_clause {
769            type Merged = #name #type_generics;
770            fn into_continue(self) -> Self::Merged {
771                self
772            }
773            fn from_done(resp: #name #type_generics) -> Self::Merged {
774                resp
775            }
776        }
777        impl #impl_generics ::jigs::Step for #name #type_generics #where_clause {
778            type Out = #name #type_generics;
779            type Fut = ::core::future::Ready<#name #type_generics>;
780            fn into_step(self) -> Self::Fut {
781                ::core::future::ready(self)
782            }
783        }
784        impl #impl_generics ::jigs::Status for #name #type_generics #where_clause {
785            fn succeeded(&self) -> bool {
786                ::jigs::Response::is_ok(self)
787            }
788            fn error(&self) -> Option<String> {
789                ::jigs::Response::error_msg(self)
790            }
791        }
792    }
793    .into()
794}
795
796fn extract_result_payload(ty: &Type, msg: &str) -> Result<Type, syn::Error> {
797    if let Type::Path(p) = ty {
798        if let Some(seg) = p.path.segments.last() {
799            if seg.ident == "Result" {
800                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
801                    if args.args.len() == 2 {
802                        if let syn::GenericArgument::Type(t) = &args.args[0] {
803                            if let syn::GenericArgument::Type(t2) = &args.args[1] {
804                                let s = type_to_string(t2);
805                                if s == "String" {
806                                    return Ok(t.clone());
807                                }
808                            }
809                        }
810                    }
811                }
812            }
813        }
814    }
815    Err(syn::Error::new_spanned(ty, msg))
816}
817
818fn extract_option_inner(ty: &Type, msg: &str) -> Result<Type, syn::Error> {
819    if let Type::Path(p) = ty {
820        if let Some(seg) = p.path.segments.last() {
821            if seg.ident == "Option" {
822                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
823                    if let Some(syn::GenericArgument::Type(t)) = args.args.first() {
824                        return Ok(t.clone());
825                    }
826                }
827            }
828        }
829    }
830    Err(syn::Error::new_spanned(ty, msg))
831}
832
833fn syn_type_as_string(ty: &Type) -> Option<String> {
834    if let Type::Path(p) = ty {
835        Some(
836            p.path
837                .segments
838                .iter()
839                .map(|s| s.ident.to_string())
840                .collect::<Vec<_>>()
841                .join("::"),
842        )
843    } else {
844        None
845    }
846}
847
848#[proc_macro]
849pub fn jigs(input: TokenStream) -> TokenStream {
850    let entry: syn::Ident = parse_macro_input!(input);
851    let entry_marker = marker_ident(&entry.to_string());
852    quote! {
853        mod __jigs_registry {
854            pub fn all_jigs() -> impl Iterator<Item = &'static ::jigs::JigMeta> {
855                static CACHE: std::sync::OnceLock<Vec<&'static ::jigs::JigMeta>> = std::sync::OnceLock::new();
856                CACHE.get_or_init(|| {
857                    let mut v = Vec::new();
858                    <super::#entry_marker as ::jigs::JigDef>::collect(&mut v);
859                    v
860                }).iter().copied()
861            }
862
863            pub fn find_jig(name: &str) -> Option<&'static ::jigs::JigMeta> {
864                all_jigs().find(|m| m.name == name)
865            }
866        }
867        pub use __jigs_registry::{all_jigs, find_jig};
868    }
869    .into()
870}
871
872fn first_arg_ident(sig: &syn::Signature) -> Option<syn::Ident> {
873    if let Some(syn::FnArg::Typed(pt)) = sig.inputs.first() {
874        if let syn::Pat::Ident(pi) = &*pt.pat {
875            return Some(pi.ident.clone());
876        }
877    }
878    None
879}
880
881#[cfg(feature = "trace")]
882struct TraceParts {
883    pre: TokenStream2,
884    post: TokenStream2,
885}
886
887#[cfg(feature = "trace")]
888fn trace_instrument(name_str: &str, input_ident: Option<&syn::Ident>) -> TraceParts {
889    let marker = marker_ident(name_str);
890    let snapshot = if let Some(id) = input_ident {
891        quote! { let __jig_input_ok = ::jigs::Status::succeeded(&#id); }
892    } else {
893        quote! { let __jig_input_ok = true; }
894    };
895    let pre = quote! {
896        #snapshot
897        let __jig_idx = ::jigs::trace::enter(&<#marker as ::jigs::JigDef>::META);
898        let __jig_start = ::std::time::Instant::now();
899    };
900    let post = quote! {
901        let mut __jig_ok = ::jigs::Status::succeeded(&__jig_result);
902        let mut __jig_err = ::jigs::Status::error(&__jig_result);
903        if !__jig_input_ok && !__jig_ok {
904            __jig_ok = true;
905            __jig_err = None;
906        }
907        ::jigs::trace::exit(__jig_idx, __jig_start.elapsed(), __jig_ok, __jig_err);
908        __jig_result
909    };
910    TraceParts { pre, post }
911}
912
913#[cfg(feature = "trace")]
914fn sync_body(block: &syn::Block, name_str: &str, input_ident: Option<&syn::Ident>) -> TokenStream2 {
915    let TraceParts { pre, post } = trace_instrument(name_str, input_ident);
916    quote! {
917        #pre
918        let __jig_result = (move || #block)();
919        #post
920    }
921}
922
923#[cfg(not(feature = "trace"))]
924fn sync_body(
925    block: &syn::Block,
926    _name_str: &str,
927    _input_ident: Option<&syn::Ident>,
928) -> TokenStream2 {
929    quote! { #block }
930}
931
932#[cfg(feature = "trace")]
933fn async_body(
934    block: &syn::Block,
935    name_str: &str,
936    input_ident: Option<&syn::Ident>,
937) -> TokenStream2 {
938    let TraceParts { pre, post } = trace_instrument(name_str, input_ident);
939    quote! {
940        ::jigs::Pending(async move {
941            #pre
942            let __jig_result = (async move #block).await;
943            #post
944        })
945    }
946}
947
948#[cfg(not(feature = "trace"))]
949fn async_body(
950    block: &syn::Block,
951    _name_str: &str,
952    _input_ident: Option<&syn::Ident>,
953) -> TokenStream2 {
954    quote! { ::jigs::Pending(async move #block) }
955}
956
957fn first_arg_type(sig: &syn::Signature) -> Option<Type> {
958    match sig.inputs.first() {
959        Some(syn::FnArg::Typed(pt)) => Some((*pt.ty).clone()),
960        _ => None,
961    }
962}
963
964fn return_type(ret: &ReturnType) -> Option<Type> {
965    match ret {
966        ReturnType::Type(_, t) => Some((**t).clone()),
967        _ => None,
968    }
969}
970
971fn classify_expr(ty: Option<&Type>) -> TokenStream2 {
972    match ty {
973        Some(t) => quote!(<#t as ::jigs::__Classify>::KIND),
974        None => quote!("Other"),
975    }
976}
977
978fn first_arg_payload(sig: &syn::Signature) -> String {
979    let ty = match sig.inputs.first() {
980        Some(syn::FnArg::Typed(pt)) => &*pt.ty,
981        _ => return "?".into(),
982    };
983    payload_type(ty)
984}
985
986fn return_payload(ret: &ReturnType) -> String {
987    let ty = match ret {
988        ReturnType::Default => return "?".into(),
989        ReturnType::Type(_, t) => t,
990    };
991    payload_type(ty)
992}
993
994fn payload_type(ty: &Type) -> String {
995    if let Type::Path(p) = ty {
996        if let Some(seg) = p.path.segments.last() {
997            let name = seg.ident.to_string();
998            match name.as_str() {
999                "Request" | "Response" | "Pending" => {
1000                    if let syn::PathArguments::AngleBracketed(ref ab) = seg.arguments {
1001                        return generic_args_string(ab);
1002                    }
1003                }
1004                "Branch" => {
1005                    if let syn::PathArguments::AngleBracketed(ref ab) = seg.arguments {
1006                        return format!("Branch<{}>", generic_args_string(ab));
1007                    }
1008                }
1009                _ => {}
1010            }
1011        }
1012    }
1013    type_to_string(ty)
1014}
1015
1016fn type_to_string(ty: &Type) -> String {
1017    quote::quote!(#ty).to_string().replace(' ', "")
1018}
1019
1020fn generic_args_string(args: &syn::AngleBracketedGenericArguments) -> String {
1021    let mut out = String::new();
1022    for (i, arg) in args.args.iter().enumerate() {
1023        if i > 0 {
1024            out.push(',');
1025        }
1026        match arg {
1027            syn::GenericArgument::Type(t) => out.push_str(&type_to_string(t)),
1028            syn::GenericArgument::Lifetime(l) => out.push_str(&l.ident.to_string()),
1029            other => out.push_str(&quote::quote!(#other).to_string().replace(' ', "")),
1030        }
1031    }
1032    out
1033}
1034
1035#[derive(Clone, Copy)]
1036enum ChainKindTok {
1037    Then,
1038    Fork,
1039}
1040
1041fn collect_chain(block: &syn::Block) -> Vec<(String, ChainKindTok)> {
1042    struct V(Vec<(String, ChainKindTok)>);
1043    impl V {
1044        fn push_unique(&mut self, name: String, kind: ChainKindTok) {
1045            if !self.0.iter().any(|(n, _)| n == &name) {
1046                self.0.push((name, kind));
1047            }
1048        }
1049        fn push_path(&mut self, p: &syn::Path, kind: ChainKindTok) {
1050            let name = p
1051                .segments
1052                .iter()
1053                .map(|s| s.ident.to_string())
1054                .collect::<Vec<_>>()
1055                .join("::");
1056            self.push_unique(name, kind);
1057        }
1058    }
1059    impl<'ast> Visit<'ast> for V {
1060        fn visit_expr_method_call(&mut self, m: &'ast ExprMethodCall) {
1061            syn::visit::visit_expr(self, &m.receiver);
1062            if m.method == "then" {
1063                if let Some(Expr::Path(p)) = m.args.first() {
1064                    self.push_path(&p.path, ChainKindTok::Then);
1065                }
1066            }
1067            for a in &m.args {
1068                syn::visit::visit_expr(self, a);
1069            }
1070        }
1071        fn visit_macro(&mut self, mac: &'ast syn::Macro) {
1072            let last = mac
1073                .path
1074                .segments
1075                .last()
1076                .map(|s| s.ident.to_string())
1077                .unwrap_or_default();
1078            if last == "fork" {
1079                if let Ok(args) = syn::parse2::<ForkArgs>(mac.tokens.clone()) {
1080                    for j in &args.arms {
1081                        if let syn::Expr::Path(p) = j {
1082                            self.push_path(&p.path, ChainKindTok::Fork);
1083                        }
1084                    }
1085                    if let syn::Expr::Path(p) = &args.default {
1086                        self.push_path(&p.path, ChainKindTok::Fork);
1087                    }
1088                }
1089            }
1090        }
1091    }
1092    let mut v = V(Vec::new());
1093    v.visit_block(block);
1094    v.0
1095}
1096
1097struct ForkArgs {
1098    arms: Vec<syn::Expr>,
1099    default: syn::Expr,
1100}
1101
1102impl syn::parse::Parse for ForkArgs {
1103    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1104        let _req: syn::Expr = input.parse()?;
1105        input.parse::<syn::Token![,]>()?;
1106        let mut arms = Vec::new();
1107        loop {
1108            if input.peek(syn::Token![_]) {
1109                input.parse::<syn::Token![_]>()?;
1110                input.parse::<syn::Token![=>]>()?;
1111                let default: syn::Expr = input.parse()?;
1112                let _: Option<syn::Token![,]> = input.parse().ok();
1113                return Ok(ForkArgs { arms, default });
1114            }
1115            let _pred: syn::Expr = input.parse()?;
1116            input.parse::<syn::Token![=>]>()?;
1117            let jig: syn::Expr = input.parse()?;
1118            input.parse::<syn::Token![,]>()?;
1119            arms.push(jig);
1120        }
1121    }
1122}