1use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::parse::{Parse, ParseStream};
9use syn::punctuated::Punctuated;
10use syn::{
11 parse_macro_input, FnArg, GenericArgument, Ident, ItemFn, LitBool, LitInt, LitStr, Pat,
12 PathArguments, ReturnType, Token, Type,
13};
14
15#[proc_macro_attribute]
77pub fn task(attr: TokenStream, item: TokenStream) -> TokenStream {
78 let attrs = parse_macro_input!(attr as TaskAttrs);
79 let func = parse_macro_input!(item as ItemFn);
80
81 match expand_task(attrs, func) {
82 Ok(tokens) => tokens.into(),
83 Err(err) => err.to_compile_error().into(),
84 }
85}
86
87struct TaskAttrs {
92 max_retries: Option<u32>,
93 module: Option<String>,
94 concurrency: Option<String>,
95 registration_concurrency: Option<String>,
96 key_arguments: Option<Vec<String>>,
97 cache_results: Option<bool>,
98 disable_cache_args: Option<Vec<String>>,
99 retry_for_errors: Option<Vec<String>>,
100 on_diff_non_key_args_raise: Option<bool>,
101 parallel_batch_size: Option<usize>,
102 force_new_workflow: Option<bool>,
103 reroute_on_cc: Option<bool>,
104 blocking: Option<bool>,
105}
106
107impl Parse for TaskAttrs {
108 fn parse(input: ParseStream) -> syn::Result<Self> {
109 let mut max_retries = None;
110 let mut module = None;
111 let mut concurrency = None;
112 let mut registration_concurrency = None;
113 let mut key_arguments = None;
114 let mut cache_results = None;
115 let mut disable_cache_args = None;
116 let mut retry_for_errors = None;
117 let mut on_diff_non_key_args_raise = None;
118 let mut parallel_batch_size = None;
119 let mut force_new_workflow = None;
120 let mut reroute_on_cc = None;
121 let mut blocking = None;
122
123 while !input.is_empty() {
124 let key: Ident = input.parse()?;
125 input.parse::<Token![=]>()?;
126 let key_str = key.to_string();
127
128 macro_rules! check_dup {
130 ($opt:expr) => {
131 if $opt.is_some() {
132 return Err(syn::Error::new(
133 key.span(),
134 format!("duplicate task attribute: `{}`", key_str),
135 ));
136 }
137 };
138 }
139
140 match key_str.as_str() {
141 "max_retries" => {
142 check_dup!(max_retries);
143 let lit: LitInt = input.parse()?;
144 max_retries = Some(lit.base10_parse()?);
145 }
146 "module" => {
147 check_dup!(module);
148 let lit: LitStr = input.parse()?;
149 module = Some(lit.value());
150 }
151 "concurrency" => {
152 check_dup!(concurrency);
153 let lit: LitStr = input.parse()?;
154 validate_concurrency_str(&lit)?;
155 concurrency = Some(lit.value());
156 }
157 "registration_concurrency" => {
158 check_dup!(registration_concurrency);
159 let lit: LitStr = input.parse()?;
160 validate_concurrency_str(&lit)?;
161 registration_concurrency = Some(lit.value());
162 }
163 "key_arguments" => {
164 check_dup!(key_arguments);
165 let content;
166 syn::bracketed!(content in input);
167 let items: Punctuated<LitStr, Token![,]> =
168 Punctuated::parse_terminated(&content)?;
169 key_arguments = Some(items.iter().map(LitStr::value).collect());
170 }
171 "cache_results" => {
172 check_dup!(cache_results);
173 let lit: LitBool = input.parse()?;
174 cache_results = Some(lit.value());
175 }
176 "force_new_workflow" => {
177 check_dup!(force_new_workflow);
178 let lit: LitBool = input.parse()?;
179 force_new_workflow = Some(lit.value());
180 }
181 "reroute_on_cc" => {
182 check_dup!(reroute_on_cc);
183 let lit: LitBool = input.parse()?;
184 reroute_on_cc = Some(lit.value());
185 }
186 "blocking" => {
187 check_dup!(blocking);
188 let lit: LitBool = input.parse()?;
189 blocking = Some(lit.value());
190 }
191 "disable_cache_args" => {
192 check_dup!(disable_cache_args);
193 let content;
194 syn::bracketed!(content in input);
195 let items: Punctuated<LitStr, Token![,]> =
196 Punctuated::parse_terminated(&content)?;
197 disable_cache_args = Some(items.iter().map(LitStr::value).collect());
198 }
199 "retry_for_errors" => {
200 check_dup!(retry_for_errors);
201 let content;
202 syn::bracketed!(content in input);
203 let items: Punctuated<LitStr, Token![,]> =
204 Punctuated::parse_terminated(&content)?;
205 retry_for_errors = Some(items.iter().map(LitStr::value).collect());
206 }
207 "on_diff_non_key_args_raise" => {
208 check_dup!(on_diff_non_key_args_raise);
209 let lit: LitBool = input.parse()?;
210 on_diff_non_key_args_raise = Some(lit.value());
211 }
212 "parallel_batch_size" => {
213 check_dup!(parallel_batch_size);
214 let lit: LitInt = input.parse()?;
215 parallel_batch_size = Some(lit.base10_parse()?);
216 }
217 other => {
218 let known = [
219 "max_retries",
220 "module",
221 "concurrency",
222 "registration_concurrency",
223 "key_arguments",
224 "cache_results",
225 "disable_cache_args",
226 "retry_for_errors",
227 "on_diff_non_key_args_raise",
228 "parallel_batch_size",
229 "force_new_workflow",
230 "reroute_on_cc",
231 "blocking",
232 ];
233 let suggestion = known
234 .iter()
235 .filter(|k| {
236 k.starts_with(&other[..1.min(other.len())])
239 || other.contains(&k[..3.min(k.len())])
240 || k.contains(&other[..3.min(other.len())])
241 })
242 .copied()
243 .next();
244 let msg = match suggestion {
245 Some(s) => format!(
246 "unknown task attribute: `{other}`. Did you mean `{s}`?\n\
247 Valid attributes: {}",
248 known.join(", ")
249 ),
250 None => format!(
251 "unknown task attribute: `{other}`.\n\
252 Valid attributes: {}",
253 known.join(", ")
254 ),
255 };
256 return Err(syn::Error::new(key.span(), msg));
257 }
258 }
259
260 if !input.is_empty() {
261 input.parse::<Token![,]>()?;
262 }
263 }
264
265 Ok(Self {
266 max_retries,
267 module,
268 concurrency,
269 registration_concurrency,
270 key_arguments,
271 cache_results,
272 disable_cache_args,
273 retry_for_errors,
274 on_diff_non_key_args_raise,
275 parallel_batch_size,
276 force_new_workflow,
277 reroute_on_cc,
278 blocking,
279 })
280 }
281}
282
283fn validate_concurrency_str(lit: &LitStr) -> syn::Result<()> {
284 match lit.value().as_str() {
285 "unlimited" | "task" | "argument" | "none" => Ok(()),
286 other => Err(syn::Error::new(
287 lit.span(),
288 format!(
289 "invalid concurrency value: `{other}`. \
290 Expected one of: \"unlimited\", \"task\", \"argument\", \"none\""
291 ),
292 )),
293 }
294}
295
296fn expand_task(attrs: TaskAttrs, func: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
301 validate_function(&func)?;
302 validate_attrs_against_params(&attrs, &func)?;
303
304 let fn_name = &func.sig.ident;
305 let fn_name_str = fn_name.to_string();
306 let vis = &func.vis;
307
308 let task_struct = format_ident!("{}Task", to_pascal_case(&fn_name_str));
310 let params_struct = format_ident!("{}Params", to_pascal_case(&fn_name_str));
311 let fn_name_register = format_ident!("__rustvello_register_{}", fn_name_str);
312
313 let core_path = quote! { ::rustvello::__private::rustvello_core };
315 let proto_path = quote! { ::rustvello::__private::rustvello_proto };
316 let serde_path = quote! { ::rustvello::__private::serde };
317 let serde_crate_str = "::rustvello::__private::serde";
318
319 let (result_type, wrap_ok) = parse_return_type(&func.sig.output)?;
321
322 let params = extract_params(&func)?;
324 let param_names: Vec<&Ident> = params.iter().map(|(name, _)| name).collect();
325
326 let config_body = build_config(&attrs, &proto_path);
328
329 let module_expr = match &attrs.module {
331 Some(m) => quote! { #m },
332 None => quote! { module_path!() },
333 };
334
335 let fn_call = if params.is_empty() {
337 quote! { #fn_name() }
338 } else {
339 quote! { #fn_name(#(#param_names),*) }
340 };
341
342 let run_body = if wrap_ok {
343 quote! { Ok(#fn_call) }
344 } else {
345 quote! { #fn_call }
346 };
347
348 let generated = if params.is_empty() {
349 generate_no_params(
350 &func,
351 vis,
352 &task_struct,
353 &fn_name_register,
354 &core_path,
355 &proto_path,
356 &module_expr,
357 fn_name_str.as_str(),
358 &result_type,
359 &config_body,
360 &run_body,
361 )
362 } else {
363 generate_with_params(
364 &func,
365 vis,
366 &task_struct,
367 ¶ms_struct,
368 &fn_name_register,
369 &core_path,
370 &proto_path,
371 &serde_path,
372 serde_crate_str,
373 &module_expr,
374 fn_name_str.as_str(),
375 &result_type,
376 &config_body,
377 &run_body,
378 ¶ms,
379 ¶m_names,
380 )
381 };
382
383 Ok(generated)
384}
385
386fn validate_function(func: &ItemFn) -> syn::Result<()> {
388 if func.sig.asyncness.is_some() {
389 return Err(syn::Error::new_spanned(
390 &func.sig,
391 "#[rustvello::task] does not support async functions yet",
392 ));
393 }
394 if func.sig.unsafety.is_some() {
395 return Err(syn::Error::new_spanned(
396 &func.sig,
397 "#[rustvello::task] does not support unsafe functions",
398 ));
399 }
400 if !func.sig.generics.params.is_empty() {
401 return Err(syn::Error::new_spanned(
402 &func.sig.generics,
403 "#[rustvello::task] does not support generic functions",
404 ));
405 }
406 Ok(())
407}
408
409fn validate_attrs_against_params(attrs: &TaskAttrs, func: &ItemFn) -> syn::Result<()> {
411 let param_name_strs: Vec<String> = extract_params(func)?
412 .iter()
413 .map(|(name, _)| name.to_string())
414 .collect();
415
416 if let Some(ref keys) = attrs.key_arguments {
417 for key in keys {
418 if !param_name_strs.contains(key) {
419 return Err(syn::Error::new(
420 func.sig.ident.span(),
421 format!(
422 "key_arguments entry `{key}` does not match any function parameter. \
423 Valid parameters: {}",
424 param_name_strs.join(", ")
425 ),
426 ));
427 }
428 }
429 }
430
431 if let Some(ref args) = attrs.disable_cache_args {
432 for arg in args {
433 if !param_name_strs.contains(arg) {
434 return Err(syn::Error::new(
435 func.sig.ident.span(),
436 format!(
437 "disable_cache_args entry `{arg}` does not match any function parameter. \
438 Valid parameters: {}",
439 param_name_strs.join(", ")
440 ),
441 ));
442 }
443 }
444 }
445 Ok(())
446}
447
448#[allow(clippy::too_many_arguments)]
450fn generate_no_params(
451 func: &ItemFn,
452 vis: &syn::Visibility,
453 task_struct: &Ident,
454 fn_name_register: &Ident,
455 core_path: &proc_macro2::TokenStream,
456 proto_path: &proc_macro2::TokenStream,
457 module_expr: &proc_macro2::TokenStream,
458 fn_name_str: &str,
459 result_type: &proc_macro2::TokenStream,
460 config_body: &proc_macro2::TokenStream,
461 run_body: &proc_macro2::TokenStream,
462) -> proc_macro2::TokenStream {
463 quote! {
464 #func
465
466 #vis struct #task_struct {
467 task_id: #proto_path::identifiers::TaskId,
468 config: #proto_path::config::TaskConfig,
469 }
470
471 impl #task_struct {
472 pub fn new() -> Self {
474 Self {
475 task_id: #proto_path::identifiers::TaskId::new(#module_expr, #fn_name_str),
476 config: #config_body,
477 }
478 }
479 }
480
481 impl Default for #task_struct {
482 fn default() -> Self {
483 Self::new()
484 }
485 }
486
487 impl #core_path::task::Task for #task_struct {
488 type Params = ();
489 type Result = #result_type;
490
491 fn task_id(&self) -> &#proto_path::identifiers::TaskId {
492 &self.task_id
493 }
494
495 fn config(&self) -> &#proto_path::config::TaskConfig {
496 &self.config
497 }
498
499 fn run(
500 &self,
501 _params: (),
502 ) -> #core_path::error::RustvelloResult<#result_type> {
503 #run_body
504 }
505 }
506
507 fn #fn_name_register(
508 registry: &mut #core_path::task::TaskRegistry,
509 ) -> #core_path::error::RustvelloResult<()> {
510 registry.register_typed(#task_struct::new())
511 }
512
513 ::rustvello::__private::inventory::submit! {
514 ::rustvello::__private::TaskEntry {
515 register_fn: #fn_name_register,
516 }
517 }
518 }
519}
520
521#[allow(clippy::too_many_arguments)]
523fn generate_with_params(
524 func: &ItemFn,
525 vis: &syn::Visibility,
526 task_struct: &Ident,
527 params_struct: &Ident,
528 fn_name_register: &Ident,
529 core_path: &proc_macro2::TokenStream,
530 proto_path: &proc_macro2::TokenStream,
531 serde_path: &proc_macro2::TokenStream,
532 serde_crate_str: &str,
533 module_expr: &proc_macro2::TokenStream,
534 fn_name_str: &str,
535 result_type: &proc_macro2::TokenStream,
536 config_body: &proc_macro2::TokenStream,
537 run_body: &proc_macro2::TokenStream,
538 params: &[(Ident, Type)],
539 param_names: &[&Ident],
540) -> proc_macro2::TokenStream {
541 let param_fields: Vec<_> = params
542 .iter()
543 .map(|(name, ty)| quote! { pub #name: #ty })
544 .collect();
545
546 quote! {
547 #func
548
549 #[derive(Debug, Clone, #serde_path::Serialize, #serde_path::Deserialize)]
550 #[serde(crate = #serde_crate_str)]
551 #vis struct #params_struct {
552 #(#param_fields,)*
553 }
554
555 #vis struct #task_struct {
556 task_id: #proto_path::identifiers::TaskId,
557 config: #proto_path::config::TaskConfig,
558 }
559
560 impl #task_struct {
561 pub fn new() -> Self {
563 Self {
564 task_id: #proto_path::identifiers::TaskId::new(#module_expr, #fn_name_str),
565 config: #config_body,
566 }
567 }
568 }
569
570 impl Default for #task_struct {
571 fn default() -> Self {
572 Self::new()
573 }
574 }
575
576 impl #core_path::task::Task for #task_struct {
577 type Params = #params_struct;
578 type Result = #result_type;
579
580 fn task_id(&self) -> &#proto_path::identifiers::TaskId {
581 &self.task_id
582 }
583
584 fn config(&self) -> &#proto_path::config::TaskConfig {
585 &self.config
586 }
587
588 fn run(
589 &self,
590 params: #params_struct,
591 ) -> #core_path::error::RustvelloResult<#result_type> {
592 let #params_struct { #(#param_names),* } = params;
593 #run_body
594 }
595 }
596
597 fn #fn_name_register(
598 registry: &mut #core_path::task::TaskRegistry,
599 ) -> #core_path::error::RustvelloResult<()> {
600 registry.register_typed(#task_struct::new())
601 }
602
603 ::rustvello::__private::inventory::submit! {
604 ::rustvello::__private::TaskEntry {
605 register_fn: #fn_name_register,
606 }
607 }
608 }
609}
610
611fn extract_params(func: &ItemFn) -> syn::Result<Vec<(Ident, Type)>> {
616 func.sig
617 .inputs
618 .iter()
619 .map(|arg| match arg {
620 FnArg::Typed(pat_type) => {
621 let name = match &*pat_type.pat {
622 Pat::Ident(pi) => pi.ident.clone(),
623 _ => {
624 return Err(syn::Error::new_spanned(
625 pat_type,
626 "expected a simple parameter name",
627 ))
628 }
629 };
630 Ok((name, (*pat_type.ty).clone()))
631 }
632 FnArg::Receiver(r) => Err(syn::Error::new_spanned(
633 r,
634 "#[rustvello::task] functions cannot take self",
635 )),
636 })
637 .collect()
638}
639
640#[allow(clippy::unnecessary_wraps)]
647fn parse_return_type(ret: &ReturnType) -> syn::Result<(proc_macro2::TokenStream, bool)> {
648 match ret {
649 ReturnType::Default => Ok((quote! { () }, true)),
650 ReturnType::Type(_, ty) => {
651 if let Some(inner) = unwrap_result_type(ty) {
652 Ok((quote! { #inner }, false))
653 } else {
654 Ok((quote! { #ty }, true))
655 }
656 }
657 }
658}
659
660fn unwrap_result_type(ty: &Type) -> Option<&Type> {
662 let Type::Path(tp) = ty else { return None };
663 let last = tp.path.segments.last()?;
664 let name = last.ident.to_string();
665 if name != "Result" && name != "RustvelloResult" {
666 return None;
667 }
668 let PathArguments::AngleBracketed(args) = &last.arguments else {
669 return None;
670 };
671 match args.args.first()? {
672 GenericArgument::Type(inner) => Some(inner),
673 _ => None,
674 }
675}
676
677fn build_config(
678 attrs: &TaskAttrs,
679 proto_path: &proc_macro2::TokenStream,
680) -> proc_macro2::TokenStream {
681 let base = quote! { let mut config = #proto_path::config::TaskConfig::default(); };
682 let mut setters = Vec::new();
683
684 if let Some(retries) = attrs.max_retries {
685 setters.push(quote! { config.max_retries = #retries; });
686 }
687
688 if let Some(ref cc) = attrs.concurrency {
689 let variant = concurrency_variant(cc);
690 setters.push(quote! {
691 config.concurrency_control = #proto_path::status::ConcurrencyControlType::#variant;
692 });
693 }
694
695 if let Some(ref rc) = attrs.registration_concurrency {
696 let variant = concurrency_variant(rc);
697 setters.push(quote! {
698 config.registration_concurrency = #proto_path::status::ConcurrencyControlType::#variant;
699 });
700 }
701
702 if let Some(ref keys) = attrs.key_arguments {
703 setters.push(quote! {
704 config.key_arguments = vec![#(#keys.to_string()),*];
705 });
706 }
707
708 if let Some(cache) = attrs.cache_results {
709 setters.push(quote! { config.cache_results = #cache; });
710 }
711
712 if let Some(force) = attrs.force_new_workflow {
713 setters.push(quote! { config.force_new_workflow = #force; });
714 }
715
716 if let Some(reroute) = attrs.reroute_on_cc {
717 setters.push(quote! { config.reroute_on_cc = #reroute; });
718 }
719
720 if let Some(blocking) = attrs.blocking {
721 setters.push(quote! { config.blocking = #blocking; });
722 }
723
724 if let Some(ref args) = attrs.disable_cache_args {
725 setters.push(quote! {
726 config.disable_cache_args = vec![#(#args.to_string()),*];
727 });
728 }
729
730 if let Some(ref errors) = attrs.retry_for_errors {
731 setters.push(quote! {
732 config.retry_for_errors = vec![#(#errors.to_string()),*];
733 });
734 }
735
736 if let Some(raise) = attrs.on_diff_non_key_args_raise {
737 setters.push(quote! { config.on_diff_non_key_args_raise = #raise; });
738 }
739
740 if let Some(batch) = attrs.parallel_batch_size {
741 setters.push(quote! { config.parallel_batch_size = #batch; });
742 }
743
744 quote! {
745 {
746 #base
747 #(#setters)*
748 config
749 }
750 }
751}
752
753fn concurrency_variant(s: &str) -> proc_macro2::TokenStream {
754 match s {
755 "unlimited" => quote! { Unlimited },
756 "task" => quote! { Task },
757 "argument" => quote! { Argument },
758 "none" => quote! { None },
759 _ => unreachable!("validated in parse"),
760 }
761}
762
763fn to_pascal_case(s: &str) -> String {
764 s.split('_')
765 .map(|word| {
766 let mut chars = word.chars();
767 match chars.next() {
768 None => String::new(),
769 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
770 }
771 })
772 .collect()
773}