1use proc_macro::TokenStream;
11use proc_macro2::TokenStream as TokenStream2;
12use quote::quote;
13use syn::spanned::Spanned;
14use syn::visit::Visit;
15use syn::{
16 parse_macro_input, parse_quote, Data, DeriveInput, Expr, ExprMethodCall, Field, Fields,
17 FieldsNamed, FieldsUnnamed, Ident, ItemFn, ReturnType, Type,
18};
19
20fn marker_ident(fn_name: &str) -> syn::Ident {
21 syn::parse_str(&format!("__Jig_{fn_name}")).unwrap()
22}
23
24fn marker_path_for(name: &str) -> TokenStream2 {
25 let segs: Vec<&str> = name.split("::").collect();
26 let last_idx = segs.len() - 1;
27 let path_segs: Vec<TokenStream2> = segs
28 .iter()
29 .enumerate()
30 .map(|(i, s)| {
31 if i == last_idx {
32 let mi = marker_ident(s);
33 quote!(#mi)
34 } else if *s == "crate" {
35 quote!(crate)
36 } else if *s == "super" {
37 quote!(super)
38 } else if *s == "self" {
39 quote!(self)
40 } else {
41 let id: syn::Ident = syn::parse_str(s).unwrap();
42 quote!(#id)
43 }
44 })
45 .collect();
46 quote!(#(#path_segs)::*)
47}
48
49#[proc_macro_attribute]
50pub fn jig(_attr: TokenStream, item: TokenStream) -> TokenStream {
51 let input = parse_macro_input!(item as ItemFn);
52 let vis = &input.vis;
53 let attrs = &input.attrs;
54 let block = &input.block;
55 let name_str = input.sig.ident.to_string();
56 let marker = marker_ident(&name_str);
57 let input_type_str = first_arg_payload(&input.sig);
58 let output_type_str = return_payload(&input.sig.output);
59 let is_async = input.sig.asyncness.is_some();
60
61 let input_ty = first_arg_type(&input.sig);
62 let output_ty = return_type(&input.sig.output);
63 let kind_expr = classify_expr(output_ty.as_ref());
64 let input_expr = classify_expr(input_ty.as_ref());
65
66 let chain = collect_chain(&input.block);
67
68 let chain_tokens: Vec<TokenStream2> = chain
69 .iter()
70 .map(|(name, kind)| {
71 let kind_ident = match kind {
72 ChainKindTok::Then => quote!(::jigs::ChainKind::Then),
73 ChainKindTok::Fork => quote!(::jigs::ChainKind::Fork),
74 };
75 quote! { ::jigs::ChainStep { name: #name, kind: #kind_ident } }
76 })
77 .collect();
78
79 let chain_collect: Vec<TokenStream2> = chain
80 .iter()
81 .map(|(name, _kind)| {
82 let path = marker_path_for(name);
83 quote! { <#path as ::jigs::JigDef>::collect(out); }
84 })
85 .collect();
86
87 let marker_def = quote! {
88 #[allow(non_camel_case_types)]
89 #[doc(hidden)]
90 pub struct #marker;
91
92 impl ::jigs::JigDef for #marker {
93 const META: ::jigs::JigMeta = ::jigs::JigMeta {
94 name: #name_str,
95 file: file!(),
96 line: line!(),
97 kind: #kind_expr,
98 input: #input_expr,
99 input_type: #input_type_str,
100 output_type: #output_type_str,
101 is_async: #is_async,
102 module: module_path!(),
103 chain: &[#(#chain_tokens),*],
104 };
105
106 fn collect(out: &mut Vec<&'static ::jigs::JigMeta>) {
107 let meta = &<Self as ::jigs::JigDef>::META;
108 if out.iter().any(|m| ::std::ptr::eq(*m, meta)) {
109 return;
110 }
111 out.push(meta);
112 #(#chain_collect)*
113 }
114 }
115 };
116
117 let input_ident = first_arg_ident(&input.sig);
118
119 if input.sig.asyncness.is_some() {
120 let mut sig = input.sig.clone();
121 sig.asyncness = None;
122 let ret_ty = match &input.sig.output {
123 ReturnType::Default => quote!(()),
124 ReturnType::Type(_, ty) => quote!(#ty),
125 };
126 sig.output = parse_quote! {
127 -> ::jigs::Pending<impl ::core::future::Future<Output = #ret_ty>>
128 };
129
130 let body = async_body(block, &name_str, input_ident.as_ref());
131 return quote! { #marker_def #(#attrs)* #vis #sig { #body } }.into();
132 }
133
134 let sig = &input.sig;
135 let body = sync_body(block, &name_str, input_ident.as_ref());
136 quote! { #marker_def #(#attrs)* #vis #sig { #body } }.into()
137}
138
139#[proc_macro_derive(Request, attributes(req))]
140pub fn derive_request(input: TokenStream) -> TokenStream {
141 let parsed = parse_macro_input!(input as DeriveInput);
142 generate_req(&parsed).unwrap_or_else(|e| e.to_compile_error().into())
143}
144
145fn generate_req(input: &DeriveInput) -> Result<TokenStream, syn::Error> {
146 let name = &input.ident;
147 let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
148 let Data::Struct(data) = &input.data else {
149 return Err(syn::Error::new_spanned(
150 input,
151 "Request can only be derived for structs",
152 ));
153 };
154
155 let mut explicit_field: Option<Ident> = None;
156
157 for attr in &input.attrs {
158 if attr.path().is_ident("req") {
159 attr.parse_nested_meta(|meta| {
160 if meta.path.is_ident("field") {
161 let val = meta.value()?;
162 let lit: syn::LitStr = val.parse()?;
163 explicit_field = Some(syn::Ident::new(&lit.value(), lit.span()));
164 return Ok(());
165 }
166 Err(meta.error("unrecognized req attribute"))
167 })?;
168 }
169 }
170
171 let (payload_decl, payload_ref_expr, into_expr, from_expr) =
172 derive_req_field_info(data, explicit_field, input)?;
173
174 let mut merge_generics = input.generics.clone();
175 merge_generics
176 .params
177 .push(syn::GenericParam::Type(syn::TypeParam {
178 attrs: Vec::new(),
179 ident: parse_quote!(__R),
180 colon_token: Some(syn::Token)),
181 bounds: parse_quote!(::jigs::Response),
182 eq_token: None,
183 default: None,
184 }));
185 let (merge_impl_generics, _, merge_where_clause) = merge_generics.split_for_impl();
186
187 Ok(quote! {
188 impl #impl_generics ::jigs::__Classify for #name #type_generics #where_clause {
189 const KIND: &'static str = "Request";
190 }
191 impl #impl_generics ::jigs::Request for #name #type_generics #where_clause {
192 #payload_decl
193 fn payload(&self) -> &Self::Payload {
194 #payload_ref_expr
195 }
196 fn into_payload(self) -> Self::Payload {
197 #into_expr
198 }
199 fn from_payload(payload: Self::Payload) -> Self {
200 #from_expr
201 }
202 }
203 impl #merge_impl_generics ::jigs::Merge<__R> for #name #type_generics #merge_where_clause {
204 type Merged = ::jigs::Branch<#name #type_generics, __R>;
205 fn into_continue(self) -> Self::Merged {
206 ::jigs::Branch::Continue(self)
207 }
208 fn from_done(resp: __R) -> Self::Merged {
209 ::jigs::Branch::Done(resp)
210 }
211 }
212 impl #impl_generics ::jigs::Step for #name #type_generics #where_clause {
213 type Out = #name #type_generics;
214 type Fut = ::core::future::Ready<#name #type_generics>;
215 fn into_step(self) -> Self::Fut {
216 ::core::future::ready(self)
217 }
218 }
219 impl #impl_generics ::jigs::Status for #name #type_generics #where_clause {
220 fn succeeded(&self) -> bool {
221 true
222 }
223 fn error(&self) -> Option<String> {
224 None
225 }
226 }
227 }
228 .into())
229}
230
231fn derive_req_field_info(
232 data: &syn::DataStruct,
233 explicit_field: Option<Ident>,
234 input: &DeriveInput,
235) -> Result<(TokenStream2, TokenStream2, TokenStream2, TokenStream2), syn::Error> {
236 if let Some(field_ident) = explicit_field {
237 let field = find_field(data, &field_ident)?;
238 let payload_ty = &field.ty;
239 let payload_decl = quote! { type Payload = #payload_ty; };
240 let payload_ref = quote! { &self.#field_ident };
241 let into_expr = quote! {
242 let Self { #field_ident, .. } = self;
243 #field_ident
244 };
245 let from_expr = quote! { Self { #field_ident: payload, ..Default::default() } };
246 return Ok((payload_decl, payload_ref, into_expr, from_expr));
247 }
248
249 match &data.fields {
250 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) if unnamed.len() == 1 => {
251 let field = unnamed.first().unwrap();
252 let payload_ty = &field.ty;
253 let payload_decl = quote! { type Payload = #payload_ty; };
254 let payload_ref = quote! { &self.0 };
255 let into_expr = quote! { self.0 };
256 let from_expr = quote! { Self(payload) };
257 Ok((payload_decl, payload_ref, into_expr, from_expr))
258 }
259 Fields::Named(FieldsNamed { named, .. }) if named.len() == 1 => {
260 let field = named.first().unwrap();
261 let field_ident = field.ident.as_ref().unwrap();
262 let payload_ty = &field.ty;
263 let payload_decl = quote! { type Payload = #payload_ty; };
264 let payload_ref = quote! { &self.#field_ident };
265 let into_expr = quote! { self.#field_ident };
266 let from_expr = quote! { Self { #field_ident: payload } };
267 Ok((payload_decl, payload_ref, into_expr, from_expr))
268 }
269 _ => Err(syn::Error::new_spanned(
270 input,
271 "Request derive requires either: one field, or #[req(field = \"name\")]",
272 )),
273 }
274}
275
276fn find_field<'a>(data: &'a syn::DataStruct, ident: &Ident) -> Result<&'a Field, syn::Error> {
277 for f in &data.fields {
278 if f.ident.as_ref() == Some(ident) {
279 return Ok(f);
280 }
281 }
282 Err(syn::Error::new(
283 proc_macro2::Span::call_site(),
284 format!("no field named `{ident}`"),
285 ))
286}
287
288#[proc_macro_derive(Response, attributes(resp))]
289pub fn derive_response(input: TokenStream) -> TokenStream {
290 let parsed = parse_macro_input!(input as DeriveInput);
291 generate_response(&parsed).unwrap_or_else(|e| e.to_compile_error().into())
292}
293
294fn generate_response(input: &DeriveInput) -> Result<TokenStream, syn::Error> {
295 match &input.data {
296 Data::Struct(data) => generate_response_struct(input, data),
297 Data::Enum(data) => generate_response_enum(input, data),
298 Data::Union(_u) => Err(syn::Error::new_spanned(
299 input,
300 "Response cannot be derived for unions",
301 )),
302 }
303}
304
305fn generate_response_struct(
306 input: &DeriveInput,
307 data: &syn::DataStruct,
308) -> Result<TokenStream, syn::Error> {
309 let name = &input.ident;
310 let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
311
312 match &data.fields {
313 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) if unnamed.len() == 1 => {
314 let f = unnamed.first().unwrap();
315 let ok_expr = quote! { Self(Ok(payload)) };
316 let err_expr = quote! { Self(Err(msg.into())) };
317 let is_ok_expr = quote! { self.0.is_ok() };
318 let into_result_expr = quote! { self.0 };
319 let error_msg_expr = quote! { self.0.as_ref().err().cloned() };
320 let payload_ty = extract_result_payload(&f.ty,
321 "Response derive on single-field structs expects `Result<Payload, String>`",
322 )?;
323 Ok(generate_response_impls(ResponseImplParts {
324 name,
325 impl_generics,
326 type_generics,
327 where_clause,
328 payload_ty: &payload_ty,
329 ok_expr,
330 err_expr,
331 is_ok_expr,
332 into_result_expr,
333 error_msg_expr,
334 }))
335 }
336 Fields::Named(FieldsNamed { named, .. }) if named.len() == 1 => {
337 let f = named.first().unwrap();
338 let field_ident = f.ident.as_ref().unwrap();
339 let payload_ty = extract_result_payload(
340 &f.ty,
341 "Response derive on single-field structs expects `Result<Payload, String>`",
342 )?;
343 let ok_expr = quote! { Self { #field_ident: Ok(payload) } };
344 let err_expr = quote! { Self { #field_ident: Err(msg.into()) } };
345 let is_ok_expr = quote! { self.#field_ident.is_ok() };
346 let into_result_expr = quote! { self.#field_ident };
347 let error_msg_expr = quote! { self.#field_ident.as_ref().err().cloned() };
348 Ok(generate_response_impls(ResponseImplParts {
349 name,
350 impl_generics,
351 type_generics,
352 where_clause,
353 payload_ty: &payload_ty,
354 ok_expr,
355 err_expr,
356 is_ok_expr,
357 into_result_expr,
358 error_msg_expr,
359 }))
360 }
361 Fields::Named(FieldsNamed { named, .. }) if named.len() == 2 => {
362 generate_response_two_fields(input, data, named, name, impl_generics, type_generics, where_clause)
363 }
364 _ => Err(syn::Error::new_spanned(
365 input,
366 "Response derive requires either: a single `Result<Payload, String>` field, or two fields",
367 )),
368 }
369}
370
371fn generate_response_two_fields(
372 input: &DeriveInput,
373 _data: &syn::DataStruct,
374 named: &syn::punctuated::Punctuated<Field, syn::token::Comma>,
375 name: &Ident,
376 impl_generics: syn::ImplGenerics,
377 type_generics: syn::TypeGenerics,
378 where_clause: Option<&syn::WhereClause>,
379) -> Result<TokenStream, syn::Error> {
380 let mut ok_field_idx: Option<usize> = None;
381 let mut err_field_idx: Option<usize> = None;
382
383 for (i, f) in named.iter().enumerate() {
384 for attr in &f.attrs {
385 if attr.path().is_ident("resp") {
386 attr.parse_nested_meta(|meta| {
387 if meta.path.is_ident("ok") {
388 ok_field_idx = Some(i);
389 return Ok(());
390 }
391 if meta.path.is_ident("err") {
392 err_field_idx = Some(i);
393 return Ok(());
394 }
395 Err(meta.error("unrecognized resp attribute"))
396 })?;
397 }
398 }
399 }
400
401 let ok_idx = match ok_field_idx {
402 Some(i) => i,
403 None => err_field_idx.map_or(0, |e| 1 - e),
404 };
405 let err_idx = match err_field_idx {
406 Some(i) => i,
407 None => ok_field_idx.map_or(1, |o| 1 - o),
408 };
409
410 if ok_idx == err_idx {
411 return Err(syn::Error::new_spanned(
412 input,
413 "ok and err fields cannot be the same",
414 ));
415 }
416
417 let ok_field = &named[ok_idx];
418 let err_field = &named[err_idx];
419
420 let ok_ident = ok_field.ident.as_ref().unwrap();
421 let err_ident = err_field.ident.as_ref().unwrap();
422
423 let is_err_string = matches!(
424 syn_type_as_string(&err_field.ty).as_deref(),
425 Some(s) if s == "String",
426 );
427
428 if !is_err_string {
429 return Err(syn::Error::new_spanned(
430 input,
431 "Response derive with two fields requires the error field to be `String`",
432 ));
433 }
434
435 let payload_ty = extract_option_inner(
436 &ok_field.ty,
437 "Response derive with two fields expects the ok field to be `Option<Payload>`",
438 )?;
439 let ok_expr = quote! { Self { #ok_ident: Some(payload), #err_ident: "".to_string() } };
440 let err_expr = quote! { Self { #ok_ident: None, #err_ident: msg.into() } };
441 let is_ok_expr = quote! { self.#ok_ident.is_some() };
442 let into_result_expr = quote! {
443 match self.#ok_ident {
444 Some(v) => Ok(v),
445 None => Err(self.#err_ident),
446 }
447 };
448 let error_msg_expr = quote! {
449 if self.#ok_ident.is_some() { None } else { Some(self.#err_ident.clone()) }
450 };
451
452 Ok(generate_response_impls(ResponseImplParts {
453 name,
454 impl_generics,
455 type_generics,
456 where_clause,
457 payload_ty: &payload_ty,
458 ok_expr,
459 err_expr,
460 is_ok_expr,
461 into_result_expr,
462 error_msg_expr,
463 }))
464}
465
466struct ClassifiedVariant<'a> {
467 variant: &'a syn::Variant,
468 ident: syn::Ident,
469 fields: &'a syn::Fields,
470}
471
472fn classify_enum_variants<'a>(
473 data: &'a syn::DataEnum,
474 input: &'a DeriveInput,
475) -> Result<(ClassifiedVariant<'a>, ClassifiedVariant<'a>), syn::Error> {
476 if data.variants.len() != 2 {
477 return Err(syn::Error::new_spanned(
478 input,
479 "Response derive on enums requires exactly 2 variants",
480 ));
481 }
482
483 let mut ok_variant: Option<ClassifiedVariant<'_>> = None;
484 let mut err_variant: Option<ClassifiedVariant<'_>> = None;
485
486 for v in &data.variants {
487 let mut is_ok = false;
488 let mut is_err = false;
489 for attr in &v.attrs {
490 if attr.path().is_ident("resp") {
491 attr.parse_nested_meta(|meta| {
492 if meta.path.is_ident("ok") {
493 is_ok = true;
494 return Ok(());
495 }
496 if meta.path.is_ident("err") {
497 is_err = true;
498 return Ok(());
499 }
500 Err(meta.error("unrecognized resp attribute"))
501 })?;
502 }
503 }
504
505 if is_ok && is_err {
506 return Err(syn::Error::new_spanned(
507 v,
508 "variant cannot be both #[resp(ok)] and #[resp(err)]",
509 ));
510 }
511
512 let cv = ClassifiedVariant {
513 variant: v,
514 ident: v.ident.clone(),
515 fields: &v.fields,
516 };
517
518 if is_ok {
519 if ok_variant.is_some() {
520 return Err(syn::Error::new_spanned(
521 v,
522 "only one variant can be #[resp(ok)]",
523 ));
524 }
525 if v.fields.len() != 1 {
526 return Err(syn::Error::new_spanned(
527 v,
528 "ok variant must have exactly one field (the payload)",
529 ));
530 }
531 ok_variant = Some(cv);
532 } else if is_err {
533 if err_variant.is_some() {
534 return Err(syn::Error::new_spanned(
535 v,
536 "only one variant can be #[resp(err)]",
537 ));
538 }
539 if v.fields.len() > 1 {
540 return Err(syn::Error::new_spanned(
541 v,
542 "err variant must have 0 or 1 fields",
543 ));
544 }
545 err_variant = Some(cv);
546 } else if ok_variant.is_none() {
547 if v.fields.len() != 1 {
548 return Err(syn::Error::new_spanned(
549 v,
550 "ok variant must have exactly one field (the payload)",
551 ));
552 }
553 ok_variant = Some(cv);
554 } else if err_variant.is_none() {
555 if v.fields.len() > 1 {
556 return Err(syn::Error::new_spanned(
557 v,
558 "err variant must have 0 or 1 fields",
559 ));
560 }
561 err_variant = Some(cv);
562 }
563 }
564
565 let ok = ok_variant.ok_or_else(|| {
566 syn::Error::new_spanned(input, "Could not identify ok variant. Use #[resp(ok)]")
567 })?;
568 let err = err_variant.ok_or_else(|| {
569 syn::Error::new_spanned(input, "Could not identify err variant. Use #[resp(err)]")
570 })?;
571 Ok((ok, err))
572}
573
574struct VariantCodegen {
575 constructor: TokenStream2,
576 wild: TokenStream2,
577 pattern: TokenStream2,
578}
579
580fn variant_codegen(
581 name: &syn::Ident,
582 ident: &syn::Ident,
583 fields: &syn::Fields,
584 binding_name: &str,
585) -> VariantCodegen {
586 let b = syn::Ident::new(binding_name, name.span());
587 if fields.is_empty() {
588 let constructor = quote!(#name::#ident);
589 let wild = quote!(#name::#ident);
590 let pattern = quote!(#name::#ident);
591 VariantCodegen {
592 constructor,
593 wild,
594 pattern,
595 }
596 } else {
597 let unnamed = fields.iter().next().unwrap().ident.is_none();
598 let constructor = if unnamed {
599 quote!(#name::#ident(#b))
600 } else {
601 let f = fields.iter().next().unwrap().ident.as_ref().unwrap();
602 quote!(#name::#ident { #f: #b })
603 };
604 let wild = if unnamed {
605 quote! { #name::#ident(..) }
606 } else {
607 quote! { #name::#ident { .. } }
608 };
609 let pattern = if unnamed {
610 let b = syn::Ident::new(binding_name, name.span());
611 quote! { #name::#ident(#b) }
612 } else {
613 let f = fields.iter().next().unwrap().ident.as_ref().unwrap();
614 let b = syn::Ident::new(binding_name, name.span());
615 quote! { #name::#ident { #f: #b } }
616 };
617 VariantCodegen {
618 constructor,
619 wild,
620 pattern,
621 }
622 }
623}
624
625fn generate_response_enum(
626 input: &DeriveInput,
627 data: &syn::DataEnum,
628) -> Result<TokenStream, syn::Error> {
629 let name = &input.ident;
630 let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
631
632 let (ok, err) = classify_enum_variants(data, input)?;
633
634 let ok_ident = &ok.ident;
635 let err_ident = &err.ident;
636 let payload_ty = &ok.variant.fields.iter().next().unwrap().ty;
637
638 let ok_cg = variant_codegen(name, ok_ident, ok.fields, "__p");
639 let err_has_field = err.fields.len() == 1;
640 let err_cg = variant_codegen(name, err_ident, err.fields, "__e");
641 let VariantCodegen {
642 constructor: ok_constr,
643 wild: ok_wild,
644 pattern: ok_pattern,
645 } = ok_cg;
646 let VariantCodegen {
647 constructor: err_constr,
648 wild: err_wild,
649 pattern: err_pattern,
650 } = err_cg;
651
652 let ok_expr = quote! {
653 {
654 let __p = payload;
655 #ok_constr
656 }
657 };
658 let err_expr = if err_has_field {
659 quote! {
660 {
661 let __e = msg.into();
662 #err_constr
663 }
664 }
665 } else {
666 quote! { #name::#err_ident }
667 };
668
669 let is_ok_expr = quote! {
670 match self {
671 #ok_wild => true,
672 #err_wild => false,
673 }
674 };
675 let into_result_expr = if err_has_field {
676 quote! {
677 match self {
678 #ok_pattern => Ok(__p),
679 #err_pattern => Err(__e),
680 }
681 }
682 } else {
683 quote! {
684 match self {
685 #ok_pattern => Ok(__p),
686 #err_wild => Err("unknown error".to_string()),
687 }
688 }
689 };
690 let error_msg_expr = if err_has_field {
691 quote! {
692 match self {
693 #ok_wild => None,
694 #err_pattern => Some(__e.to_string()),
695 }
696 }
697 } else {
698 quote! {
699 match self {
700 #ok_wild => None,
701 #err_wild => Some("unknown error".to_string()),
702 }
703 }
704 };
705
706 Ok(generate_response_impls(ResponseImplParts {
707 name,
708 impl_generics,
709 type_generics,
710 where_clause,
711 payload_ty,
712 ok_expr,
713 err_expr,
714 is_ok_expr,
715 into_result_expr,
716 error_msg_expr,
717 }))
718}
719
720struct ResponseImplParts<'a> {
721 name: &'a syn::Ident,
722 impl_generics: syn::ImplGenerics<'a>,
723 type_generics: syn::TypeGenerics<'a>,
724 where_clause: Option<&'a syn::WhereClause>,
725 payload_ty: &'a Type,
726 ok_expr: TokenStream2,
727 err_expr: TokenStream2,
728 is_ok_expr: TokenStream2,
729 into_result_expr: TokenStream2,
730 error_msg_expr: TokenStream2,
731}
732
733fn generate_response_impls(parts: ResponseImplParts<'_>) -> proc_macro::TokenStream {
734 let ResponseImplParts {
735 name,
736 impl_generics,
737 type_generics,
738 where_clause,
739 payload_ty,
740 ok_expr,
741 err_expr,
742 is_ok_expr,
743 into_result_expr,
744 error_msg_expr,
745 } = parts;
746 quote! {
747 impl #impl_generics ::jigs::__Classify for #name #type_generics #where_clause {
748 const KIND: &'static str = "Response";
749 }
750 impl #impl_generics ::jigs::Response for #name #type_generics #where_clause {
751 type Payload = #payload_ty;
752 fn ok(payload: Self::Payload) -> Self {
753 #ok_expr
754 }
755 fn err(msg: impl Into<String>) -> Self {
756 #err_expr
757 }
758 fn is_ok(&self) -> bool {
759 #is_ok_expr
760 }
761 fn into_result(self) -> Result<Self::Payload, String> {
762 #into_result_expr
763 }
764 fn error_msg(&self) -> Option<String> {
765 #error_msg_expr
766 }
767 }
768 impl #impl_generics ::jigs::Merge<#name #type_generics> for #name #type_generics #where_clause {
769 type Merged = #name #type_generics;
770 fn into_continue(self) -> Self::Merged {
771 self
772 }
773 fn from_done(resp: #name #type_generics) -> Self::Merged {
774 resp
775 }
776 }
777 impl #impl_generics ::jigs::Step for #name #type_generics #where_clause {
778 type Out = #name #type_generics;
779 type Fut = ::core::future::Ready<#name #type_generics>;
780 fn into_step(self) -> Self::Fut {
781 ::core::future::ready(self)
782 }
783 }
784 impl #impl_generics ::jigs::Status for #name #type_generics #where_clause {
785 fn succeeded(&self) -> bool {
786 ::jigs::Response::is_ok(self)
787 }
788 fn error(&self) -> Option<String> {
789 ::jigs::Response::error_msg(self)
790 }
791 }
792 }
793 .into()
794}
795
796fn extract_result_payload(ty: &Type, msg: &str) -> Result<Type, syn::Error> {
797 if let Type::Path(p) = ty {
798 if let Some(seg) = p.path.segments.last() {
799 if seg.ident == "Result" {
800 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
801 if args.args.len() == 2 {
802 if let syn::GenericArgument::Type(t) = &args.args[0] {
803 if let syn::GenericArgument::Type(t2) = &args.args[1] {
804 let s = type_to_string(t2);
805 if s == "String" {
806 return Ok(t.clone());
807 }
808 }
809 }
810 }
811 }
812 }
813 }
814 }
815 Err(syn::Error::new_spanned(ty, msg))
816}
817
818fn extract_option_inner(ty: &Type, msg: &str) -> Result<Type, syn::Error> {
819 if let Type::Path(p) = ty {
820 if let Some(seg) = p.path.segments.last() {
821 if seg.ident == "Option" {
822 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
823 if let Some(syn::GenericArgument::Type(t)) = args.args.first() {
824 return Ok(t.clone());
825 }
826 }
827 }
828 }
829 }
830 Err(syn::Error::new_spanned(ty, msg))
831}
832
833fn syn_type_as_string(ty: &Type) -> Option<String> {
834 if let Type::Path(p) = ty {
835 Some(
836 p.path
837 .segments
838 .iter()
839 .map(|s| s.ident.to_string())
840 .collect::<Vec<_>>()
841 .join("::"),
842 )
843 } else {
844 None
845 }
846}
847
848#[proc_macro]
849pub fn jigs(input: TokenStream) -> TokenStream {
850 let entry: syn::Ident = parse_macro_input!(input);
851 let entry_marker = marker_ident(&entry.to_string());
852 quote! {
853 mod __jigs_registry {
854 pub fn all_jigs() -> impl Iterator<Item = &'static ::jigs::JigMeta> {
855 static CACHE: std::sync::OnceLock<Vec<&'static ::jigs::JigMeta>> = std::sync::OnceLock::new();
856 CACHE.get_or_init(|| {
857 let mut v = Vec::new();
858 <super::#entry_marker as ::jigs::JigDef>::collect(&mut v);
859 v
860 }).iter().copied()
861 }
862
863 pub fn find_jig(name: &str) -> Option<&'static ::jigs::JigMeta> {
864 all_jigs().find(|m| m.name == name)
865 }
866 }
867 pub use __jigs_registry::{all_jigs, find_jig};
868 }
869 .into()
870}
871
872fn first_arg_ident(sig: &syn::Signature) -> Option<syn::Ident> {
873 if let Some(syn::FnArg::Typed(pt)) = sig.inputs.first() {
874 if let syn::Pat::Ident(pi) = &*pt.pat {
875 return Some(pi.ident.clone());
876 }
877 }
878 None
879}
880
881#[cfg(feature = "trace")]
882struct TraceParts {
883 pre: TokenStream2,
884 post: TokenStream2,
885}
886
887#[cfg(feature = "trace")]
888fn trace_instrument(name_str: &str, input_ident: Option<&syn::Ident>) -> TraceParts {
889 let marker = marker_ident(name_str);
890 let snapshot = if let Some(id) = input_ident {
891 quote! { let __jig_input_ok = ::jigs::Status::succeeded(&#id); }
892 } else {
893 quote! { let __jig_input_ok = true; }
894 };
895 let pre = quote! {
896 #snapshot
897 let __jig_idx = ::jigs::trace::enter(&<#marker as ::jigs::JigDef>::META);
898 let __jig_start = ::std::time::Instant::now();
899 };
900 let post = quote! {
901 let mut __jig_ok = ::jigs::Status::succeeded(&__jig_result);
902 let mut __jig_err = ::jigs::Status::error(&__jig_result);
903 if !__jig_input_ok && !__jig_ok {
904 __jig_ok = true;
905 __jig_err = None;
906 }
907 ::jigs::trace::exit(__jig_idx, __jig_start.elapsed(), __jig_ok, __jig_err);
908 __jig_result
909 };
910 TraceParts { pre, post }
911}
912
913#[cfg(feature = "trace")]
914fn sync_body(block: &syn::Block, name_str: &str, input_ident: Option<&syn::Ident>) -> TokenStream2 {
915 let TraceParts { pre, post } = trace_instrument(name_str, input_ident);
916 quote! {
917 #pre
918 let __jig_result = (move || #block)();
919 #post
920 }
921}
922
923#[cfg(not(feature = "trace"))]
924fn sync_body(
925 block: &syn::Block,
926 _name_str: &str,
927 _input_ident: Option<&syn::Ident>,
928) -> TokenStream2 {
929 quote! { #block }
930}
931
932#[cfg(feature = "trace")]
933fn async_body(
934 block: &syn::Block,
935 name_str: &str,
936 input_ident: Option<&syn::Ident>,
937) -> TokenStream2 {
938 let TraceParts { pre, post } = trace_instrument(name_str, input_ident);
939 quote! {
940 ::jigs::Pending(async move {
941 #pre
942 let __jig_result = (async move #block).await;
943 #post
944 })
945 }
946}
947
948#[cfg(not(feature = "trace"))]
949fn async_body(
950 block: &syn::Block,
951 _name_str: &str,
952 _input_ident: Option<&syn::Ident>,
953) -> TokenStream2 {
954 quote! { ::jigs::Pending(async move #block) }
955}
956
957fn first_arg_type(sig: &syn::Signature) -> Option<Type> {
958 match sig.inputs.first() {
959 Some(syn::FnArg::Typed(pt)) => Some((*pt.ty).clone()),
960 _ => None,
961 }
962}
963
964fn return_type(ret: &ReturnType) -> Option<Type> {
965 match ret {
966 ReturnType::Type(_, t) => Some((**t).clone()),
967 _ => None,
968 }
969}
970
971fn classify_expr(ty: Option<&Type>) -> TokenStream2 {
972 match ty {
973 Some(t) => quote!(<#t as ::jigs::__Classify>::KIND),
974 None => quote!("Other"),
975 }
976}
977
978fn first_arg_payload(sig: &syn::Signature) -> String {
979 let ty = match sig.inputs.first() {
980 Some(syn::FnArg::Typed(pt)) => &*pt.ty,
981 _ => return "?".into(),
982 };
983 payload_type(ty)
984}
985
986fn return_payload(ret: &ReturnType) -> String {
987 let ty = match ret {
988 ReturnType::Default => return "?".into(),
989 ReturnType::Type(_, t) => t,
990 };
991 payload_type(ty)
992}
993
994fn payload_type(ty: &Type) -> String {
995 if let Type::Path(p) = ty {
996 if let Some(seg) = p.path.segments.last() {
997 let name = seg.ident.to_string();
998 match name.as_str() {
999 "Request" | "Response" | "Pending" => {
1000 if let syn::PathArguments::AngleBracketed(ref ab) = seg.arguments {
1001 return generic_args_string(ab);
1002 }
1003 }
1004 "Branch" => {
1005 if let syn::PathArguments::AngleBracketed(ref ab) = seg.arguments {
1006 return format!("Branch<{}>", generic_args_string(ab));
1007 }
1008 }
1009 _ => {}
1010 }
1011 }
1012 }
1013 type_to_string(ty)
1014}
1015
1016fn type_to_string(ty: &Type) -> String {
1017 quote::quote!(#ty).to_string().replace(' ', "")
1018}
1019
1020fn generic_args_string(args: &syn::AngleBracketedGenericArguments) -> String {
1021 let mut out = String::new();
1022 for (i, arg) in args.args.iter().enumerate() {
1023 if i > 0 {
1024 out.push(',');
1025 }
1026 match arg {
1027 syn::GenericArgument::Type(t) => out.push_str(&type_to_string(t)),
1028 syn::GenericArgument::Lifetime(l) => out.push_str(&l.ident.to_string()),
1029 other => out.push_str("e::quote!(#other).to_string().replace(' ', "")),
1030 }
1031 }
1032 out
1033}
1034
1035#[derive(Clone, Copy)]
1036enum ChainKindTok {
1037 Then,
1038 Fork,
1039}
1040
1041fn collect_chain(block: &syn::Block) -> Vec<(String, ChainKindTok)> {
1042 struct V(Vec<(String, ChainKindTok)>);
1043 impl V {
1044 fn push_unique(&mut self, name: String, kind: ChainKindTok) {
1045 if !self.0.iter().any(|(n, _)| n == &name) {
1046 self.0.push((name, kind));
1047 }
1048 }
1049 fn push_path(&mut self, p: &syn::Path, kind: ChainKindTok) {
1050 let name = p
1051 .segments
1052 .iter()
1053 .map(|s| s.ident.to_string())
1054 .collect::<Vec<_>>()
1055 .join("::");
1056 self.push_unique(name, kind);
1057 }
1058 }
1059 impl<'ast> Visit<'ast> for V {
1060 fn visit_expr_method_call(&mut self, m: &'ast ExprMethodCall) {
1061 syn::visit::visit_expr(self, &m.receiver);
1062 if m.method == "then" {
1063 if let Some(Expr::Path(p)) = m.args.first() {
1064 self.push_path(&p.path, ChainKindTok::Then);
1065 }
1066 }
1067 for a in &m.args {
1068 syn::visit::visit_expr(self, a);
1069 }
1070 }
1071 fn visit_macro(&mut self, mac: &'ast syn::Macro) {
1072 let last = mac
1073 .path
1074 .segments
1075 .last()
1076 .map(|s| s.ident.to_string())
1077 .unwrap_or_default();
1078 if last == "fork" {
1079 if let Ok(args) = syn::parse2::<ForkArgs>(mac.tokens.clone()) {
1080 for j in &args.arms {
1081 if let syn::Expr::Path(p) = j {
1082 self.push_path(&p.path, ChainKindTok::Fork);
1083 }
1084 }
1085 if let syn::Expr::Path(p) = &args.default {
1086 self.push_path(&p.path, ChainKindTok::Fork);
1087 }
1088 }
1089 }
1090 }
1091 }
1092 let mut v = V(Vec::new());
1093 v.visit_block(block);
1094 v.0
1095}
1096
1097struct ForkArgs {
1098 arms: Vec<syn::Expr>,
1099 default: syn::Expr,
1100}
1101
1102impl syn::parse::Parse for ForkArgs {
1103 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
1104 let _req: syn::Expr = input.parse()?;
1105 input.parse::<syn::Token![,]>()?;
1106 let mut arms = Vec::new();
1107 loop {
1108 if input.peek(syn::Token![_]) {
1109 input.parse::<syn::Token![_]>()?;
1110 input.parse::<syn::Token![=>]>()?;
1111 let default: syn::Expr = input.parse()?;
1112 let _: Option<syn::Token![,]> = input.parse().ok();
1113 return Ok(ForkArgs { arms, default });
1114 }
1115 let _pred: syn::Expr = input.parse()?;
1116 input.parse::<syn::Token![=>]>()?;
1117 let jig: syn::Expr = input.parse()?;
1118 input.parse::<syn::Token![,]>()?;
1119 arms.push(jig);
1120 }
1121 }
1122}