sqlxmq_macros/
lib.rs

1#![deny(missing_docs, unsafe_code)]
2//! # sqlxmq_macros
3//!
4//! Provides procedural macros for the `sqlxmq` crate.
5
6use std::mem;
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{quote, ToTokens, TokenStreamExt};
11use syn::{
12    parse::{Parse, ParseStream},
13    parse_macro_input, parse_quote, AttrStyle, Attribute, AttributeArgs, Error, Lit, Meta,
14    NestedMeta, Path, Result, Signature, Visibility,
15};
16
17#[derive(Default)]
18struct JobOptions {
19    proto: Option<Path>,
20    name: Option<String>,
21    channel_name: Option<String>,
22    retries: Option<u32>,
23    backoff_secs: Option<f64>,
24    ordered: Option<bool>,
25}
26
27enum OptionValue<'a> {
28    None,
29    Lit(&'a Lit),
30    Path(&'a Path),
31}
32
33fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
34    fn error(arg: NestedMeta) -> Result<()> {
35        Err(Error::new_spanned(arg, "Unexpected attribute argument"))
36    }
37    match &arg {
38        NestedMeta::Lit(Lit::Str(s)) if options.name.is_none() => {
39            options.name = Some(s.value());
40        }
41        NestedMeta::Meta(m) => {
42            if let Some(ident) = m.path().get_ident() {
43                let name = ident.to_string();
44                let value = match &m {
45                    Meta::List(l) => {
46                        if let NestedMeta::Meta(Meta::Path(p)) = &l.nested[0] {
47                            OptionValue::Path(p)
48                        } else {
49                            return error(arg);
50                        }
51                    }
52                    Meta::Path(_) => OptionValue::None,
53                    Meta::NameValue(nvp) => OptionValue::Lit(&nvp.lit),
54                };
55                match (name.as_str(), value) {
56                    ("proto", OptionValue::Path(p)) if options.proto.is_none() => {
57                        options.proto = Some(p.clone());
58                    }
59                    ("name", OptionValue::Lit(Lit::Str(s))) if options.name.is_none() => {
60                        options.name = Some(s.value());
61                    }
62                    ("channel_name", OptionValue::Lit(Lit::Str(s)))
63                        if options.channel_name.is_none() =>
64                    {
65                        options.channel_name = Some(s.value());
66                    }
67                    ("retries", OptionValue::Lit(Lit::Int(n))) if options.retries.is_none() => {
68                        options.retries = Some(n.base10_parse()?);
69                    }
70                    ("backoff_secs", OptionValue::Lit(Lit::Float(n)))
71                        if options.backoff_secs.is_none() =>
72                    {
73                        options.backoff_secs = Some(n.base10_parse()?);
74                    }
75                    ("backoff_secs", OptionValue::Lit(Lit::Int(n)))
76                        if options.backoff_secs.is_none() =>
77                    {
78                        options.backoff_secs = Some(n.base10_parse()?);
79                    }
80                    ("ordered", OptionValue::None) if options.ordered.is_none() => {
81                        options.ordered = Some(true);
82                    }
83                    ("ordered", OptionValue::Lit(Lit::Bool(b))) if options.ordered.is_none() => {
84                        options.ordered = Some(b.value);
85                    }
86                    _ => return error(arg),
87                }
88            }
89        }
90        _ => return error(arg),
91    }
92    Ok(())
93}
94
95#[derive(Clone)]
96struct MaybeItemFn {
97    attrs: Vec<Attribute>,
98    vis: Visibility,
99    sig: Signature,
100    block: TokenStream2,
101}
102
103/// This parses a `TokenStream` into a `MaybeItemFn`
104/// (just like `ItemFn`, but skips parsing the body).
105impl Parse for MaybeItemFn {
106    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
107        let attrs = input.call(syn::Attribute::parse_outer)?;
108        let vis: Visibility = input.parse()?;
109        let sig: Signature = input.parse()?;
110        let block: TokenStream2 = input.parse()?;
111        Ok(Self {
112            attrs,
113            vis,
114            sig,
115            block,
116        })
117    }
118}
119
120impl ToTokens for MaybeItemFn {
121    fn to_tokens(&self, tokens: &mut TokenStream2) {
122        tokens.append_all(
123            self.attrs
124                .iter()
125                .filter(|attr| matches!(attr.style, AttrStyle::Outer)),
126        );
127        self.vis.to_tokens(tokens);
128        self.sig.to_tokens(tokens);
129        self.block.to_tokens(tokens);
130    }
131}
132
133/// Marks a function as being a background job.
134///
135/// The first argument to the function must have type `CurrentJob`.
136/// Additional arguments can be used to access context from the job
137/// registry. Context is accessed based on the type of the argument.
138/// Context arguments must be `Send + Sync + Clone + 'static`.
139///
140/// The function should be async or return a future.
141///
142/// The async result must be a `Result<(), E>` type, where `E` is convertible
143/// to a `Box<dyn Error + Send + Sync + 'static>`, which is the case for most
144/// error types.
145///
146/// Several options can be provided to the `#[job]` attribute:
147///
148/// # Name
149///
150/// ```ignore
151/// #[job("example")]
152/// #[job(name="example")]
153/// ```
154///
155/// This overrides the name for this job. If unspecified, the fully-qualified
156/// name of the function is used. If you move a job to a new module or rename
157/// the function, you may which to override the job name to prevent it from
158/// changing.
159///
160/// # Channel name
161///
162/// ```ignore
163/// #[job(channel_name="foo")]
164/// ```
165///
166/// This sets the default channel name on which the job will be spawned.
167///
168/// # Retries
169///
170/// ```ignore
171/// #[job(retries = 3)]
172/// ```
173///
174/// This sets the default number of retries for the job.
175///
176/// # Retry backoff
177///
178/// ```ignore
179/// #[job(backoff_secs=1.5)]
180/// #[job(backoff_secs=2)]
181/// ```
182///
183/// This sets the default initial retry backoff for the job in seconds.
184///
185/// # Ordered
186///
187/// ```ignore
188/// #[job(ordered)]
189/// #[job(ordered=true)]
190/// #[job(ordered=false)]
191/// ```
192///
193/// This sets whether the job will be strictly ordered by default.
194///
195/// # Prototype
196///
197/// ```ignore
198/// fn my_proto<'a, 'b>(
199///     builder: &'a mut JobBuilder<'b>
200/// ) -> &'a mut JobBuilder<'b> {
201///     builder.set_channel_name("bar")
202/// }
203///
204/// #[job(proto(my_proto))]
205/// ```
206///
207/// This allows setting several job options at once using the specified function,
208/// and can be convient if you have several jobs which should have similar
209/// defaults.
210///
211/// # Combinations
212///
213/// Multiple job options can be combined. The order is not important, but the
214/// prototype will always be applied first so that explicit options can override it.
215/// Each option can only be provided once in the attribute.
216///
217/// ```ignore
218/// #[job("my_job", proto(my_proto), retries=0, ordered)]
219/// ```
220///
221#[proc_macro_attribute]
222pub fn job(attr: TokenStream, item: TokenStream) -> TokenStream {
223    let args = parse_macro_input!(attr as AttributeArgs);
224    let mut inner_fn = parse_macro_input!(item as MaybeItemFn);
225
226    let mut options = JobOptions::default();
227    let mut errors = Vec::new();
228    for arg in args {
229        if let Err(e) = interpret_job_arg(&mut options, arg) {
230            errors.push(e.into_compile_error());
231        }
232    }
233
234    let outer_docs = inner_fn
235        .attrs
236        .iter()
237        .filter(|attr| attr.path.is_ident("doc"));
238
239    let vis = mem::replace(&mut inner_fn.vis, Visibility::Inherited);
240    let name = mem::replace(&mut inner_fn.sig.ident, parse_quote! {inner});
241    let fq_name = if let Some(name) = options.name {
242        quote! { #name }
243    } else {
244        let name_str = name.to_string();
245        quote! { concat!(module_path!(), "::", #name_str) }
246    };
247
248    let mut chain = Vec::new();
249    if let Some(proto) = &options.proto {
250        chain.push(quote! {
251            .set_proto(#proto)
252        });
253    }
254    if let Some(channel_name) = &options.channel_name {
255        chain.push(quote! {
256            .set_channel_name(#channel_name)
257        });
258    }
259    if let Some(retries) = &options.retries {
260        chain.push(quote! {
261            .set_retries(#retries)
262        });
263    }
264    if let Some(backoff_secs) = &options.backoff_secs {
265        chain.push(quote! {
266            .set_retry_backoff(::std::time::Duration::from_secs_f64(#backoff_secs))
267        });
268    }
269    if let Some(ordered) = options.ordered {
270        chain.push(quote! {
271            .set_ordered(#ordered)
272        });
273    }
274
275    let extract_ctx: Vec<_> = inner_fn
276        .sig
277        .inputs
278        .iter()
279        .skip(1)
280        .map(|_| {
281            quote! {
282                registry.context()
283            }
284        })
285        .collect();
286
287    let expanded = quote! {
288        #(#errors)*
289        #(#outer_docs)*
290        #[allow(non_upper_case_globals)]
291        #vis static #name: &'static sqlxmq::NamedJob = &{
292            #inner_fn
293            sqlxmq::NamedJob::new_internal(
294                #fq_name,
295                sqlxmq::hidden::BuildFn(|builder| {
296                    builder #(#chain)*
297                }),
298                sqlxmq::hidden::RunFn(|registry, current_job| {
299                    registry.spawn_internal(#fq_name, inner(current_job #(, #extract_ctx)*));
300                }),
301            )
302        };
303    };
304    // Hand the output tokens back to the compiler.
305    TokenStream::from(expanded)
306}