1#![deny(missing_docs, unsafe_code)]
2use std::mem;
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{quote, ToTokens, TokenStreamExt};
11use syn::{
12 parse::{Parse, ParseStream},
13 parse_macro_input, parse_quote, AttrStyle, Attribute, AttributeArgs, Error, Lit, Meta,
14 NestedMeta, Path, Result, Signature, Visibility,
15};
16
17#[derive(Default)]
18struct JobOptions {
19 proto: Option<Path>,
20 name: Option<String>,
21 channel_name: Option<String>,
22 retries: Option<u32>,
23 backoff_secs: Option<f64>,
24 ordered: Option<bool>,
25}
26
27enum OptionValue<'a> {
28 None,
29 Lit(&'a Lit),
30 Path(&'a Path),
31}
32
33fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
34 fn error(arg: NestedMeta) -> Result<()> {
35 Err(Error::new_spanned(arg, "Unexpected attribute argument"))
36 }
37 match &arg {
38 NestedMeta::Lit(Lit::Str(s)) if options.name.is_none() => {
39 options.name = Some(s.value());
40 }
41 NestedMeta::Meta(m) => {
42 if let Some(ident) = m.path().get_ident() {
43 let name = ident.to_string();
44 let value = match &m {
45 Meta::List(l) => {
46 if let NestedMeta::Meta(Meta::Path(p)) = &l.nested[0] {
47 OptionValue::Path(p)
48 } else {
49 return error(arg);
50 }
51 }
52 Meta::Path(_) => OptionValue::None,
53 Meta::NameValue(nvp) => OptionValue::Lit(&nvp.lit),
54 };
55 match (name.as_str(), value) {
56 ("proto", OptionValue::Path(p)) if options.proto.is_none() => {
57 options.proto = Some(p.clone());
58 }
59 ("name", OptionValue::Lit(Lit::Str(s))) if options.name.is_none() => {
60 options.name = Some(s.value());
61 }
62 ("channel_name", OptionValue::Lit(Lit::Str(s)))
63 if options.channel_name.is_none() =>
64 {
65 options.channel_name = Some(s.value());
66 }
67 ("retries", OptionValue::Lit(Lit::Int(n))) if options.retries.is_none() => {
68 options.retries = Some(n.base10_parse()?);
69 }
70 ("backoff_secs", OptionValue::Lit(Lit::Float(n)))
71 if options.backoff_secs.is_none() =>
72 {
73 options.backoff_secs = Some(n.base10_parse()?);
74 }
75 ("backoff_secs", OptionValue::Lit(Lit::Int(n)))
76 if options.backoff_secs.is_none() =>
77 {
78 options.backoff_secs = Some(n.base10_parse()?);
79 }
80 ("ordered", OptionValue::None) if options.ordered.is_none() => {
81 options.ordered = Some(true);
82 }
83 ("ordered", OptionValue::Lit(Lit::Bool(b))) if options.ordered.is_none() => {
84 options.ordered = Some(b.value);
85 }
86 _ => return error(arg),
87 }
88 }
89 }
90 _ => return error(arg),
91 }
92 Ok(())
93}
94
95#[derive(Clone)]
96struct MaybeItemFn {
97 attrs: Vec<Attribute>,
98 vis: Visibility,
99 sig: Signature,
100 block: TokenStream2,
101}
102
103impl Parse for MaybeItemFn {
106 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
107 let attrs = input.call(syn::Attribute::parse_outer)?;
108 let vis: Visibility = input.parse()?;
109 let sig: Signature = input.parse()?;
110 let block: TokenStream2 = input.parse()?;
111 Ok(Self {
112 attrs,
113 vis,
114 sig,
115 block,
116 })
117 }
118}
119
120impl ToTokens for MaybeItemFn {
121 fn to_tokens(&self, tokens: &mut TokenStream2) {
122 tokens.append_all(
123 self.attrs
124 .iter()
125 .filter(|attr| matches!(attr.style, AttrStyle::Outer)),
126 );
127 self.vis.to_tokens(tokens);
128 self.sig.to_tokens(tokens);
129 self.block.to_tokens(tokens);
130 }
131}
132
133#[proc_macro_attribute]
222pub fn job(attr: TokenStream, item: TokenStream) -> TokenStream {
223 let args = parse_macro_input!(attr as AttributeArgs);
224 let mut inner_fn = parse_macro_input!(item as MaybeItemFn);
225
226 let mut options = JobOptions::default();
227 let mut errors = Vec::new();
228 for arg in args {
229 if let Err(e) = interpret_job_arg(&mut options, arg) {
230 errors.push(e.into_compile_error());
231 }
232 }
233
234 let outer_docs = inner_fn
235 .attrs
236 .iter()
237 .filter(|attr| attr.path.is_ident("doc"));
238
239 let vis = mem::replace(&mut inner_fn.vis, Visibility::Inherited);
240 let name = mem::replace(&mut inner_fn.sig.ident, parse_quote! {inner});
241 let fq_name = if let Some(name) = options.name {
242 quote! { #name }
243 } else {
244 let name_str = name.to_string();
245 quote! { concat!(module_path!(), "::", #name_str) }
246 };
247
248 let mut chain = Vec::new();
249 if let Some(proto) = &options.proto {
250 chain.push(quote! {
251 .set_proto(#proto)
252 });
253 }
254 if let Some(channel_name) = &options.channel_name {
255 chain.push(quote! {
256 .set_channel_name(#channel_name)
257 });
258 }
259 if let Some(retries) = &options.retries {
260 chain.push(quote! {
261 .set_retries(#retries)
262 });
263 }
264 if let Some(backoff_secs) = &options.backoff_secs {
265 chain.push(quote! {
266 .set_retry_backoff(::std::time::Duration::from_secs_f64(#backoff_secs))
267 });
268 }
269 if let Some(ordered) = options.ordered {
270 chain.push(quote! {
271 .set_ordered(#ordered)
272 });
273 }
274
275 let extract_ctx: Vec<_> = inner_fn
276 .sig
277 .inputs
278 .iter()
279 .skip(1)
280 .map(|_| {
281 quote! {
282 registry.context()
283 }
284 })
285 .collect();
286
287 let expanded = quote! {
288 #(#errors)*
289 #(#outer_docs)*
290 #[allow(non_upper_case_globals)]
291 #vis static #name: &'static sqlxmq::NamedJob = &{
292 #inner_fn
293 sqlxmq::NamedJob::new_internal(
294 #fq_name,
295 sqlxmq::hidden::BuildFn(|builder| {
296 builder #(#chain)*
297 }),
298 sqlxmq::hidden::RunFn(|registry, current_job| {
299 registry.spawn_internal(#fq_name, inner(current_job #(, #extract_ctx)*));
300 }),
301 )
302 };
303 };
304 TokenStream::from(expanded)
306}