1use proc_macro::TokenStream;
11use proc_macro2::TokenStream as TokenStream2;
12use quote::quote;
13use syn::visit::Visit;
14use syn::{parse_macro_input, parse_quote, Expr, ExprMethodCall, ItemFn, ReturnType, Type};
15
16fn marker_ident(fn_name: &str) -> syn::Ident {
17 syn::parse_str(&format!("__Jig_{fn_name}")).unwrap()
18}
19
20fn marker_path_for(name: &str) -> TokenStream2 {
21 let segs: Vec<&str> = name.split("::").collect();
22 let last_idx = segs.len() - 1;
23 let path_segs: Vec<TokenStream2> = segs
24 .iter()
25 .enumerate()
26 .map(|(i, s)| {
27 if i == last_idx {
28 let mi = marker_ident(s);
29 quote!(#mi)
30 } else if *s == "crate" {
31 quote!(crate)
32 } else if *s == "super" {
33 quote!(super)
34 } else if *s == "self" {
35 quote!(self)
36 } else {
37 let id: syn::Ident = syn::parse_str(s).unwrap();
38 quote!(#id)
39 }
40 })
41 .collect();
42 quote!(#(#path_segs)::*)
43}
44
45#[proc_macro_attribute]
46pub fn jig(_attr: TokenStream, item: TokenStream) -> TokenStream {
47 let input = parse_macro_input!(item as ItemFn);
48 let vis = &input.vis;
49 let block = &input.block;
50 let name_str = input.sig.ident.to_string();
51 let marker = marker_ident(&name_str);
52 let kind_str = return_kind(&input.sig.output);
53 let input_str = input_kind(&input.sig);
54 let input_type_str = first_arg_payload(&input.sig);
55 let output_type_str = return_payload(&input.sig.output);
56 let is_async = input.sig.asyncness.is_some();
57
58 let chain_tokens: Vec<TokenStream2> = collect_chain(&input.block)
59 .into_iter()
60 .map(|(name, kind)| {
61 let kind_ident = match kind {
62 ChainKindTok::Then => quote!(::jigs::ChainKind::Then),
63 ChainKindTok::Fork => quote!(::jigs::ChainKind::Fork),
64 };
65 quote! { ::jigs::ChainStep { name: #name, kind: #kind_ident } }
66 })
67 .collect();
68
69 let chain_collect: Vec<TokenStream2> = collect_chain(&input.block)
70 .into_iter()
71 .map(|(name, _kind)| {
72 let path = marker_path_for(&name);
73 quote! { <#path as ::jigs::JigDef>::collect(out); }
74 })
75 .collect();
76
77 let marker_def = quote! {
78 #[allow(non_camel_case_types)]
79 #[doc(hidden)]
80 pub struct #marker;
81
82 impl ::jigs::JigDef for #marker {
83 const META: ::jigs::JigMeta = ::jigs::JigMeta {
84 name: #name_str,
85 file: file!(),
86 line: line!(),
87 kind: #kind_str,
88 input: #input_str,
89 input_type: #input_type_str,
90 output_type: #output_type_str,
91 is_async: #is_async,
92 module: module_path!(),
93 chain: &[#(#chain_tokens),*],
94 };
95
96 fn collect(out: &mut Vec<&'static ::jigs::JigMeta>) {
97 let name = <Self as ::jigs::JigDef>::META.name;
98 if out.iter().any(|m| m.name == name) {
99 return;
100 }
101 out.push(&<Self as ::jigs::JigDef>::META);
102 #(#chain_collect)*
103 }
104 }
105 };
106
107 let response_input_ident = if input_str == "Response" {
108 first_arg_ident(&input.sig)
109 } else {
110 None
111 };
112
113 if input.sig.asyncness.is_some() {
114 let mut sig = input.sig.clone();
115 sig.asyncness = None;
116 let ret_ty = match &input.sig.output {
117 ReturnType::Default => quote!(()),
118 ReturnType::Type(_, ty) => quote!(#ty),
119 };
120 sig.output = parse_quote! {
121 -> ::jigs::Pending<impl ::core::future::Future<Output = #ret_ty>>
122 };
123
124 let body = async_body(block, &name_str, response_input_ident.as_ref());
125 return quote! { #marker_def #vis #sig { #body } }.into();
126 }
127
128 let sig = &input.sig;
129 let body = sync_body(block, &name_str, response_input_ident.as_ref());
130 quote! { #marker_def #vis #sig { #body } }.into()
131}
132
133#[proc_macro]
134pub fn jigs(input: TokenStream) -> TokenStream {
135 let entry: syn::Ident = parse_macro_input!(input);
136 let entry_marker = marker_ident(&entry.to_string());
137 quote! {
138 mod __jigs_registry {
139 pub fn all_jigs() -> impl Iterator<Item = &'static ::jigs::JigMeta> {
140 static CACHE: std::sync::OnceLock<Vec<&'static ::jigs::JigMeta>> = std::sync::OnceLock::new();
141 CACHE.get_or_init(|| {
142 let mut v = Vec::new();
143 <super::#entry_marker as ::jigs::JigDef>::collect(&mut v);
144 v
145 }).iter().copied()
146 }
147
148 pub fn find_jig(name: &str) -> Option<&'static ::jigs::JigMeta> {
149 all_jigs().find(|m| m.name == name)
150 }
151 }
152 pub use __jigs_registry::{all_jigs, find_jig};
153 }
154 .into()
155}
156
157fn first_arg_ident(sig: &syn::Signature) -> Option<syn::Ident> {
158 if let Some(syn::FnArg::Typed(pt)) = sig.inputs.first() {
159 if let syn::Pat::Ident(pi) = &*pt.pat {
160 return Some(pi.ident.clone());
161 }
162 }
163 None
164}
165
166#[cfg(feature = "trace")]
167fn sync_body(
168 block: &syn::Block,
169 name_str: &str,
170 response_input: Option<&syn::Ident>,
171) -> TokenStream2 {
172 let marker = marker_ident(name_str);
173 let snapshot = match response_input {
174 Some(id) => quote! { let __jig_input_ok = ::jigs::Status::ok(&#id); },
175 None => quote! { let __jig_input_ok = true; },
176 };
177 quote! {
178 #snapshot
179 let __jig_idx = ::jigs::trace::enter(&<#marker as ::jigs::JigDef>::META);
180 let __jig_start = ::std::time::Instant::now();
181 let __jig_result = (move || #block)();
182 let mut __jig_ok = ::jigs::Status::ok(&__jig_result);
183 let mut __jig_err = ::jigs::Status::error(&__jig_result);
184 if !__jig_input_ok && !__jig_ok {
185 __jig_ok = true;
186 __jig_err = None;
187 }
188 ::jigs::trace::exit(__jig_idx, __jig_start.elapsed(), __jig_ok, __jig_err);
189 __jig_result
190 }
191}
192
193#[cfg(not(feature = "trace"))]
194fn sync_body(
195 block: &syn::Block,
196 _name_str: &str,
197 _response_input: Option<&syn::Ident>,
198) -> TokenStream2 {
199 quote! { #block }
200}
201
202#[cfg(feature = "trace")]
203fn async_body(
204 block: &syn::Block,
205 name_str: &str,
206 response_input: Option<&syn::Ident>,
207) -> TokenStream2 {
208 let marker = marker_ident(name_str);
209 let snapshot = match response_input {
210 Some(id) => quote! { let __jig_input_ok = ::jigs::Status::ok(&#id); },
211 None => quote! { let __jig_input_ok = true; },
212 };
213 quote! {
214 ::jigs::Pending(async move {
215 #snapshot
216 let __jig_idx = ::jigs::trace::enter(&<#marker as ::jigs::JigDef>::META);
217 let __jig_start = ::std::time::Instant::now();
218 let __jig_result = (async move #block).await;
219 let mut __jig_ok = ::jigs::Status::ok(&__jig_result);
220 let mut __jig_err = ::jigs::Status::error(&__jig_result);
221 if !__jig_input_ok && !__jig_ok {
222 __jig_ok = true;
223 __jig_err = None;
224 }
225 ::jigs::trace::exit(__jig_idx, __jig_start.elapsed(), __jig_ok, __jig_err);
226 __jig_result
227 })
228 }
229}
230
231#[cfg(not(feature = "trace"))]
232fn async_body(
233 block: &syn::Block,
234 _name_str: &str,
235 _response_input: Option<&syn::Ident>,
236) -> TokenStream2 {
237 quote! { ::jigs::Pending(async move #block) }
238}
239
240fn return_kind(ret: &ReturnType) -> &'static str {
241 let ty = match ret {
242 ReturnType::Default => return "Other",
243 ReturnType::Type(_, t) => t,
244 };
245 match last_type_ident(ty).as_deref() {
246 Some("Request") => "Request",
247 Some("Response") => "Response",
248 Some("Branch") => "Branch",
249 Some("Pending") => "Pending",
250 _ => "Other",
251 }
252}
253
254fn input_kind(sig: &syn::Signature) -> &'static str {
255 let ty = match sig.inputs.first() {
256 Some(syn::FnArg::Typed(pt)) => &*pt.ty,
257 _ => return "Other",
258 };
259 match last_type_ident(ty).as_deref() {
260 Some("Request") => "Request",
261 Some("Response") => "Response",
262 _ => "Other",
263 }
264}
265
266fn first_arg_payload(sig: &syn::Signature) -> String {
267 let ty = match sig.inputs.first() {
268 Some(syn::FnArg::Typed(pt)) => &*pt.ty,
269 _ => return "?".into(),
270 };
271 payload_type(ty)
272}
273
274fn return_payload(ret: &ReturnType) -> String {
275 let ty = match ret {
276 ReturnType::Default => return "?".into(),
277 ReturnType::Type(_, t) => t,
278 };
279 payload_type(ty)
280}
281
282fn payload_type(ty: &Type) -> String {
283 if let Type::Path(p) = ty {
284 if let Some(seg) = p.path.segments.last() {
285 let name = seg.ident.to_string();
286 match name.as_str() {
287 "Request" | "Response" => {
288 if let syn::PathArguments::AngleBracketed(ref ab) = seg.arguments {
289 return generic_args_string(ab);
290 }
291 }
292 "Branch" => {
293 if let syn::PathArguments::AngleBracketed(ref ab) = seg.arguments {
294 return format!("Branch<{}>", generic_args_string(ab));
295 }
296 }
297 "Pending" => {
298 if let syn::PathArguments::AngleBracketed(ref ab) = seg.arguments {
299 return generic_args_string(ab);
300 }
301 }
302 _ => {}
303 }
304 }
305 }
306 type_to_string(ty)
307}
308
309fn type_to_string(ty: &Type) -> String {
310 quote::quote!(#ty).to_string().replace(' ', "")
311}
312
313fn generic_args_string(args: &syn::AngleBracketedGenericArguments) -> String {
314 let mut out = String::new();
315 for (i, arg) in args.args.iter().enumerate() {
316 if i > 0 {
317 out.push(',');
318 }
319 match arg {
320 syn::GenericArgument::Type(t) => out.push_str(&type_to_string(t)),
321 syn::GenericArgument::Lifetime(l) => out.push_str(&l.ident.to_string()),
322 other => out.push_str("e::quote!(#other).to_string().replace(' ', "")),
323 }
324 }
325 out
326}
327
328fn last_type_ident(ty: &Type) -> Option<String> {
329 if let Type::Path(p) = ty {
330 return Some(p.path.segments.last()?.ident.to_string());
331 }
332 None
333}
334
335#[derive(Clone, Copy)]
336enum ChainKindTok {
337 Then,
338 Fork,
339}
340
341fn collect_chain(block: &syn::Block) -> Vec<(String, ChainKindTok)> {
342 struct V(Vec<(String, ChainKindTok)>);
343 impl V {
344 fn push_unique(&mut self, name: String, kind: ChainKindTok) {
345 if !self.0.iter().any(|(n, _)| n == &name) {
346 self.0.push((name, kind));
347 }
348 }
349 fn push_path(&mut self, p: &syn::Path, kind: ChainKindTok) {
350 let name = p
351 .segments
352 .iter()
353 .map(|s| s.ident.to_string())
354 .collect::<Vec<_>>()
355 .join("::");
356 self.push_unique(name, kind);
357 }
358 }
359 impl<'ast> Visit<'ast> for V {
360 fn visit_expr_method_call(&mut self, m: &'ast ExprMethodCall) {
361 syn::visit::visit_expr(self, &m.receiver);
362 if m.method == "then" {
363 if let Some(Expr::Path(p)) = m.args.first() {
364 self.push_path(&p.path, ChainKindTok::Then);
365 }
366 }
367 for a in &m.args {
368 syn::visit::visit_expr(self, a);
369 }
370 }
371 fn visit_macro(&mut self, mac: &'ast syn::Macro) {
372 let last = mac
373 .path
374 .segments
375 .last()
376 .map(|s| s.ident.to_string())
377 .unwrap_or_default();
378 if last == "fork" {
379 if let Ok(args) = syn::parse2::<ForkArgs>(mac.tokens.clone()) {
380 for j in &args.arms {
381 if let syn::Expr::Path(p) = j {
382 self.push_path(&p.path, ChainKindTok::Fork);
383 }
384 }
385 if let syn::Expr::Path(p) = &args.default {
386 self.push_path(&p.path, ChainKindTok::Fork);
387 }
388 }
389 }
390 }
391 }
392 let mut v = V(Vec::new());
393 v.visit_block(block);
394 v.0
395}
396
397struct ForkArgs {
398 arms: Vec<syn::Expr>,
399 default: syn::Expr,
400}
401
402impl syn::parse::Parse for ForkArgs {
403 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
404 let _req: syn::Expr = input.parse()?;
405 input.parse::<syn::Token![,]>()?;
406 let mut arms = Vec::new();
407 loop {
408 if input.peek(syn::Token![_]) {
409 input.parse::<syn::Token![_]>()?;
410 input.parse::<syn::Token![=>]>()?;
411 let default: syn::Expr = input.parse()?;
412 let _: Option<syn::Token![,]> = input.parse().ok();
413 return Ok(ForkArgs { arms, default });
414 }
415 let _pred: syn::Expr = input.parse()?;
416 input.parse::<syn::Token![=>]>()?;
417 let jig: syn::Expr = input.parse()?;
418 input.parse::<syn::Token![,]>()?;
419 arms.push(jig);
420 }
421 }
422}