Skip to main content

genja_core_derive/
lib.rs

1//! Procedural macros used by `genja-core`.
2//!
3//! `DerefMacro` and `DerefMutMacro` generate `Deref` and `DerefMut`
4//! implementations for tuple-wrapper types.
5//!
6//! `genja_task` is the public task-authoring macro. It generates both
7//! `TaskInfo` and `Task` implementations from an inherent `impl` block and
8//! infers execution mode from `fn start(...)` versus `async fn start_async(...)`.
9//!
10//! # Task Authoring Example
11//! ```ignore
12//! use genja_core::genja_task;
13//! use genja_core::inventory::Host;
14//! use genja_core::task::{HostTaskResult, TaskRuntimeContext, TaskSuccess};
15//!
16//! struct CollectFacts;
17//!
18//! #[genja_task(
19//!     name = "collect_facts",
20//!     connection_plugin_name = "ssh",
21//!     processors = ["audit"],
22//! )]
23//! impl CollectFacts {
24//!     async fn start_async(
25//!         &self,
26//!         host: &Host,
27//!         _context: &TaskRuntimeContext,
28//!     ) -> Result<HostTaskResult, genja_core::task::TaskError> {
29//!         Ok(HostTaskResult::passed(
30//!             TaskSuccess::new().with_summary(format!(
31//!                 "collected facts for {:?}",
32//!                 host.hostname()
33//!             )),
34//!         ))
35//!     }
36//! }
37//! ```
38//!
39//! # Deref Example
40//! ```
41//! use genja_core_derive::{DerefMacro, DerefMutMacro};
42//!
43//! pub trait DerefTarget {
44//!     type Target;
45//! }
46//!
47//! pub type DefaultListTarget = Vec<String>;
48//!
49//! impl DerefTarget for DefaultsList {
50//!     type Target = DefaultListTarget;
51//! }
52//!
53//! #[derive(DerefMacro, DerefMutMacro, PartialEq)]
54//! pub struct DefaultsList(DefaultListTarget);
55//!
56//! let mut defaults_list = DefaultsList(DefaultListTarget::new());
57//!
58//! defaults_list.push("default1".to_string());
59//!
60//! assert_eq!(defaults_list.as_ref(), vec!["default1".to_string()]);
61//! ```
62
63use proc_macro::TokenStream;
64use quote::quote;
65use syn::{
66    DeriveInput, Expr, ExprArray, ExprLit, FnArg, GenericArgument, ImplItem, ItemImpl, Lit, LitStr,
67    PathArguments, ReturnType, Token, Type, TypePath,
68    parse::{Parse, ParseStream},
69    parse_macro_input,
70};
71
72/// Generates an implementation of the `Deref` trait for the given type.
73///
74/// This function is used as a procedural macro to automatically derive the `Deref` trait
75/// for a struct. It creates an implementation that dereferences to the first field of the struct.
76///
77/// # Parameters
78///
79/// * `input`: A `TokenStream` representing the input tokens of the derive macro.
80///
81/// # Returns
82///
83/// A `TokenStream` containing the generated implementation of the `Deref` trait.
84#[proc_macro_derive(DerefMacro)]
85pub fn derive_deref(input: TokenStream) -> TokenStream {
86    let input = parse_macro_input!(input as DeriveInput);
87    let name = &input.ident;
88    if let Err(error) = reject_generics(&input, "DerefMacro") {
89        return error.to_compile_error().into();
90    }
91    if let Err(error) = require_tuple_wrapper(&input, "DerefMacro") {
92        return error.to_compile_error().into();
93    }
94
95    let expanded = quote! {
96        impl std::ops::Deref for #name {
97            /*
98            * Define the Target type. To ensure the correct implementation is
99            * to specify `<#name as .. >` which results to the name of the
100            * struct. Otherwise it will result in an **ambiguous error**
101            * if only `DerefTarget::Target` is used.
102            */
103            type Target = <#name as DerefTarget>::Target; //
104
105            fn deref(&self) -> &Self::Target {
106                &self.0
107            }
108        }
109    };
110    TokenStream::from(expanded)
111}
112
113/// Generates an implementation of the `DerefMut` trait for the given type.
114///
115/// This function is used as a procedural macro to automatically derive the `DerefMut` trait
116/// for a struct. It creates an implementation that allows mutable dereferencing to the first
117/// field of the struct.
118///
119/// # Parameters
120///
121/// * `input`: A `TokenStream` representing the input tokens of the derive macro.
122///
123/// # Returns
124///
125/// A `TokenStream` containing the generated implementation of the `DerefMut` trait.
126#[proc_macro_derive(DerefMutMacro)]
127pub fn derive_deref_mut(input: TokenStream) -> TokenStream {
128    let input = parse_macro_input!(input as DeriveInput);
129    let name = &input.ident;
130    if let Err(error) = reject_generics(&input, "DerefMutMacro") {
131        return error.to_compile_error().into();
132    }
133    if let Err(error) = require_tuple_wrapper(&input, "DerefMutMacro") {
134        return error.to_compile_error().into();
135    }
136
137    let expanded = quote! {
138        impl std::ops::DerefMut for #name {
139            fn deref_mut(&mut self) -> &mut Self::Target {
140                &mut self.0
141            }
142        }
143    };
144
145    TokenStream::from(expanded)
146}
147
148#[proc_macro_attribute]
149pub fn genja_task(args: TokenStream, input: TokenStream) -> TokenStream {
150    let args = parse_macro_input!(args as GenjaTaskArgs);
151    let item_impl = parse_macro_input!(input as ItemImpl);
152
153    match expand_genja_task(args, item_impl) {
154        Ok(tokens) => tokens.into(),
155        Err(error) => error.to_compile_error().into(),
156    }
157}
158
159#[derive(Default)]
160struct GenjaTaskArgs {
161    name: Option<LitStr>,
162    connection_plugin_name: Option<LitStr>,
163    processors: Vec<LitStr>,
164}
165
166impl Parse for GenjaTaskArgs {
167    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
168        let mut args = Self::default();
169
170        while !input.is_empty() {
171            let key: syn::Ident = input.parse()?;
172            input.parse::<Token![=]>()?;
173
174            match key.to_string().as_str() {
175                "name" => {
176                    if args.name.is_some() {
177                        return Err(syn::Error::new_spanned(key, "duplicate `name`"));
178                    }
179                    args.name = Some(input.parse()?);
180                }
181                "connection_plugin_name" => {
182                    if args.connection_plugin_name.is_some() {
183                        return Err(syn::Error::new_spanned(
184                            key,
185                            "duplicate `connection_plugin_name`",
186                        ));
187                    }
188                    args.connection_plugin_name = Some(input.parse()?);
189                }
190                "processors" => {
191                    if !args.processors.is_empty() {
192                        return Err(syn::Error::new_spanned(key, "duplicate `processors`"));
193                    }
194                    let array: ExprArray = input.parse()?;
195                    args.processors = parse_processor_exprs(&array)?;
196                }
197                _ => {
198                    return Err(syn::Error::new_spanned(
199                        key,
200                        "unsupported key; expected `name`, `connection_plugin_name`, or `processors`",
201                    ));
202                }
203            }
204
205            if input.is_empty() {
206                break;
207            }
208
209            input.parse::<Token![,]>()?;
210        }
211
212        if args.name.is_none() {
213            return Err(syn::Error::new(
214                proc_macro2::Span::call_site(),
215                "`name = \"...\"` is required",
216            ));
217        }
218
219        Ok(args)
220    }
221}
222
223fn expand_genja_task(
224    args: GenjaTaskArgs,
225    item_impl: ItemImpl,
226) -> syn::Result<proc_macro2::TokenStream> {
227    if item_impl.trait_.is_some() {
228        return Err(syn::Error::new_spanned(
229            &item_impl.self_ty,
230            "`#[genja_task(...)]` can only be applied to inherent impl blocks",
231        ));
232    }
233
234    if !item_impl.generics.params.is_empty() || item_impl.generics.where_clause.is_some() {
235        return Err(syn::Error::new_spanned(
236            &item_impl.generics,
237            "`genja_task` does not support generic parameters or where clauses",
238        ));
239    }
240
241    let self_ty = &item_impl.self_ty;
242    let mut has_start = false;
243    let mut has_start_async = false;
244    let mut has_options = false;
245    let mut has_sub_tasks = false;
246
247    for item in &item_impl.items {
248        let ImplItem::Fn(method) = item else {
249            continue;
250        };
251
252        match method.sig.ident.to_string().as_str() {
253            "start" => {
254                validate_start_method(method, false)?;
255                has_start = true;
256            }
257            "start_async" => {
258                validate_start_method(method, true)?;
259                has_start_async = true;
260            }
261            "options" => {
262                validate_options_method(method)?;
263                has_options = true;
264            }
265            "sub_tasks" => {
266                validate_sub_tasks_method(method)?;
267                has_sub_tasks = true;
268            }
269            _ => {}
270        }
271    }
272
273    if has_start == has_start_async {
274        return Err(syn::Error::new_spanned(
275            &item_impl.self_ty,
276            if has_start {
277                "define exactly one of `fn start(...)` or `async fn start_async(...)`"
278            } else {
279                "define one of `fn start(...)` or `async fn start_async(...)`"
280            },
281        ));
282    }
283
284    let name = args.name.expect("validated above");
285    let connection_plugin_name = args.connection_plugin_name;
286    let processors = args.processors;
287
288    let connection_impl = match connection_plugin_name {
289        Some(plugin_name) => quote! { Some(#plugin_name) },
290        None => quote! { None },
291    };
292
293    let options_impl = if has_options {
294        quote! {
295            fn options(&self) -> Option<&serde_json::Value> {
296                #self_ty::options(self)
297            }
298        }
299    } else {
300        quote! {}
301    };
302
303    let sub_tasks_impl = if has_sub_tasks {
304        quote! {
305            fn sub_tasks(&self) -> Vec<std::sync::Arc<dyn genja_core::task::Task>> {
306                #self_ty::sub_tasks(self)
307            }
308        }
309    } else {
310        quote! {}
311    };
312
313    let processor_names_impl = if processors.is_empty() {
314        quote! {}
315    } else {
316        quote! {
317            fn processor_names(&self) -> Vec<&str> {
318                vec![#(#processors),*]
319            }
320        }
321    };
322
323    let task_impl = if has_start {
324        quote! {
325            #[genja_core::async_trait]
326            impl genja_core::task::Task for #self_ty {
327                fn start(
328                    &self,
329                    host: &genja_core::inventory::Host,
330                    context: &genja_core::task::BlockingTaskRuntimeContext,
331                ) -> Result<genja_core::task::HostTaskResult, genja_core::task::TaskError> {
332                    #self_ty::start(self, host, context)
333                }
334
335                #sub_tasks_impl
336
337                fn execution_mode(&self) -> genja_core::task::TaskExecutionMode {
338                    genja_core::task::TaskExecutionMode::Blocking
339                }
340            }
341        }
342    } else {
343        quote! {
344            #[genja_core::async_trait]
345            impl genja_core::task::Task for #self_ty {
346                async fn start_async(
347                    &self,
348                    host: &genja_core::inventory::Host,
349                    context: &genja_core::task::TaskRuntimeContext,
350                ) -> Result<genja_core::task::HostTaskResult, genja_core::task::TaskError> {
351                    #self_ty::start_async(self, host, context).await
352                }
353
354                #sub_tasks_impl
355
356                fn execution_mode(&self) -> genja_core::task::TaskExecutionMode {
357                    genja_core::task::TaskExecutionMode::Async
358                }
359            }
360        }
361    };
362
363    Ok(quote! {
364        #item_impl
365
366        impl genja_core::task::TaskInfo for #self_ty {
367            fn name(&self) -> &str {
368                #name
369            }
370
371            fn connection_plugin_name(&self) -> Option<&str> {
372                #connection_impl
373            }
374
375            #options_impl
376
377            #processor_names_impl
378        }
379
380        #task_impl
381    })
382}
383
384fn reject_generics(input: &DeriveInput, macro_name: &str) -> syn::Result<()> {
385    if input.generics.params.is_empty() && input.generics.where_clause.is_none() {
386        return Ok(());
387    }
388
389    Err(syn::Error::new_spanned(
390        &input.generics,
391        format!("`{macro_name}` does not support generic parameters or where clauses"),
392    ))
393}
394
395fn parse_processor_exprs(array: &ExprArray) -> syn::Result<Vec<LitStr>> {
396    array
397        .elems
398        .iter()
399        .map(|expr| match expr {
400            Expr::Lit(ExprLit {
401                lit: Lit::Str(value),
402                ..
403            }) => Ok(value.clone()),
404            _ => Err(syn::Error::new_spanned(
405                expr,
406                "`processors` must be an array of string literals",
407            )),
408        })
409        .collect()
410}
411
412fn validate_start_method(method: &syn::ImplItemFn, is_async: bool) -> syn::Result<()> {
413    if method.sig.asyncness.is_some() != is_async {
414        let expected = if is_async {
415            "`start_async` must be declared as `async fn`"
416        } else {
417            "`start` must be declared as `fn`, not `async fn`"
418        };
419        return Err(syn::Error::new_spanned(&method.sig.ident, expected));
420    }
421
422    validate_shared_method_shape(method)?;
423
424    if method.sig.inputs.len() != 3 {
425        return Err(syn::Error::new_spanned(
426            &method.sig.inputs,
427            "task start methods must take `&self`, `host`, and `context`",
428        ));
429    }
430
431    let mut inputs = method.sig.inputs.iter();
432    validate_receiver(inputs.next().unwrap())?;
433    validate_typed_arg(
434        inputs.next().unwrap(),
435        is_host_ref,
436        "`host` must be `&Host`",
437    )?;
438    validate_typed_arg(
439        inputs.next().unwrap(),
440        if is_async {
441            is_async_context_ref
442        } else {
443            is_blocking_context_ref
444        },
445        if is_async {
446            "`context` must be `&TaskRuntimeContext`"
447        } else {
448            "`context` must be `&BlockingTaskRuntimeContext`"
449        },
450    )?;
451
452    validate_return_type(
453        &method.sig.output,
454        is_result_host_task_error,
455        if is_async {
456            "`start_async` must return `Result<HostTaskResult, TaskError>`"
457        } else {
458            "`start` must return `Result<HostTaskResult, TaskError>`"
459        },
460    )
461}
462
463fn validate_options_method(method: &syn::ImplItemFn) -> syn::Result<()> {
464    if method.sig.asyncness.is_some() {
465        return Err(syn::Error::new_spanned(
466            &method.sig.ident,
467            "`options` must not be async",
468        ));
469    }
470
471    validate_shared_method_shape(method)?;
472
473    if method.sig.inputs.len() != 1 {
474        return Err(syn::Error::new_spanned(
475            &method.sig.inputs,
476            "`options` must take only `&self`",
477        ));
478    }
479
480    validate_receiver(method.sig.inputs.first().unwrap())?;
481    validate_return_type(
482        &method.sig.output,
483        is_option_value_ref,
484        "`options` must return `Option<&serde_json::Value>`",
485    )
486}
487
488fn validate_sub_tasks_method(method: &syn::ImplItemFn) -> syn::Result<()> {
489    if method.sig.asyncness.is_some() {
490        return Err(syn::Error::new_spanned(
491            &method.sig.ident,
492            "`sub_tasks` must not be async",
493        ));
494    }
495
496    validate_shared_method_shape(method)?;
497
498    if method.sig.inputs.len() != 1 {
499        return Err(syn::Error::new_spanned(
500            &method.sig.inputs,
501            "`sub_tasks` must take only `&self`",
502        ));
503    }
504
505    validate_receiver(method.sig.inputs.first().unwrap())?;
506    validate_return_type(
507        &method.sig.output,
508        is_vec_of_task_arcs,
509        "`sub_tasks` must return `Vec<Arc<dyn Task>>`",
510    )
511}
512
513fn validate_shared_method_shape(method: &syn::ImplItemFn) -> syn::Result<()> {
514    if method.sig.constness.is_some()
515        || method.sig.unsafety.is_some()
516        || method.sig.abi.is_some()
517        || method.sig.variadic.is_some()
518        || !method.sig.generics.params.is_empty()
519        || method.sig.generics.where_clause.is_some()
520    {
521        return Err(syn::Error::new_spanned(
522            &method.sig,
523            "Genja task hook methods cannot be const, unsafe, generic, extern, or variadic",
524        ));
525    }
526
527    Ok(())
528}
529
530fn validate_receiver(arg: &FnArg) -> syn::Result<()> {
531    match arg {
532        FnArg::Receiver(receiver)
533            if receiver.reference.is_some() && receiver.mutability.is_none() =>
534        {
535            Ok(())
536        }
537        _ => Err(syn::Error::new_spanned(
538            arg,
539            "first argument must be `&self`",
540        )),
541    }
542}
543
544fn validate_typed_arg(arg: &FnArg, predicate: fn(&Type) -> bool, message: &str) -> syn::Result<()> {
545    match arg {
546        FnArg::Typed(typed) if predicate(&typed.ty) => Ok(()),
547        FnArg::Typed(typed) => Err(syn::Error::new_spanned(&typed.ty, message)),
548        FnArg::Receiver(_) => Err(syn::Error::new_spanned(arg, message)),
549    }
550}
551
552fn validate_return_type(
553    output: &ReturnType,
554    predicate: fn(&Type) -> bool,
555    message: &str,
556) -> syn::Result<()> {
557    match output {
558        ReturnType::Type(_, ty) if predicate(ty) => Ok(()),
559        ReturnType::Type(_, ty) => Err(syn::Error::new_spanned(ty, message)),
560        ReturnType::Default => Err(syn::Error::new(proc_macro2::Span::call_site(), message)),
561    }
562}
563
564fn is_result_host_task_error(ty: &Type) -> bool {
565    let Type::Path(TypePath { path, .. }) = ty else {
566        return false;
567    };
568    let Some(seg) = path.segments.last() else {
569        return false;
570    };
571    if seg.ident != "Result" {
572        return false;
573    }
574    let PathArguments::AngleBracketed(args) = &seg.arguments else {
575        return false;
576    };
577    if args.args.len() != 2 {
578        return false;
579    }
580
581    let mut args_iter = args.args.iter();
582    let ok = match args_iter.next() {
583        Some(GenericArgument::Type(ty)) => type_ends_with(ty, "HostTaskResult"),
584        _ => false,
585    };
586    let err = match args_iter.next() {
587        Some(GenericArgument::Type(ty)) => type_ends_with(ty, "TaskError"),
588        _ => false,
589    };
590    ok && err
591}
592
593fn is_option_value_ref(ty: &Type) -> bool {
594    let Type::Path(TypePath { path, .. }) = ty else {
595        return false;
596    };
597    let Some(seg) = path.segments.last() else {
598        return false;
599    };
600    if seg.ident != "Option" {
601        return false;
602    }
603    let PathArguments::AngleBracketed(args) = &seg.arguments else {
604        return false;
605    };
606    if args.args.len() != 1 {
607        return false;
608    }
609    match args.args.first() {
610        Some(GenericArgument::Type(Type::Reference(reference))) => {
611            type_ends_with(&reference.elem, "Value")
612        }
613        _ => false,
614    }
615}
616
617fn is_vec_of_task_arcs(ty: &Type) -> bool {
618    let Type::Path(TypePath { path, .. }) = ty else {
619        return false;
620    };
621    let Some(seg) = path.segments.last() else {
622        return false;
623    };
624    if seg.ident != "Vec" {
625        return false;
626    }
627    let PathArguments::AngleBracketed(args) = &seg.arguments else {
628        return false;
629    };
630    if args.args.len() != 1 {
631        return false;
632    }
633    match args.args.first() {
634        Some(GenericArgument::Type(inner)) => is_arc_task(inner),
635        _ => false,
636    }
637}
638
639fn is_arc_task(ty: &Type) -> bool {
640    match ty {
641        Type::Path(TypePath { path, .. }) => {
642            let Some(seg) = path.segments.last() else {
643                return false;
644            };
645            if seg.ident != "Arc" {
646                return false;
647            }
648            match &seg.arguments {
649                PathArguments::AngleBracketed(args) => args
650                    .args
651                    .iter()
652                    .filter_map(|arg| match arg {
653                        GenericArgument::Type(ty) => Some(ty),
654                        _ => None,
655                    })
656                    .any(is_task_trait_object),
657                _ => false,
658            }
659        }
660        _ => false,
661    }
662}
663
664fn is_task_trait_object(ty: &Type) -> bool {
665    match ty {
666        Type::TraitObject(obj) => obj.bounds.iter().any(|bound| match bound {
667            syn::TypeParamBound::Trait(trait_bound) => trait_bound
668                .path
669                .segments
670                .last()
671                .map(|seg| seg.ident == "Task")
672                .unwrap_or(false),
673            _ => false,
674        }),
675        _ => false,
676    }
677}
678
679fn is_host_ref(ty: &Type) -> bool {
680    matches!(ty, Type::Reference(reference) if type_ends_with(&reference.elem, "Host"))
681}
682
683fn is_async_context_ref(ty: &Type) -> bool {
684    matches!(ty, Type::Reference(reference) if type_ends_with(&reference.elem, "TaskRuntimeContext"))
685}
686
687fn is_blocking_context_ref(ty: &Type) -> bool {
688    matches!(ty, Type::Reference(reference) if type_ends_with(&reference.elem, "BlockingTaskRuntimeContext"))
689}
690
691fn type_ends_with(ty: &Type, ident: &str) -> bool {
692    match ty {
693        Type::Path(TypePath { path, .. }) => path
694            .segments
695            .last()
696            .map(|segment| segment.ident == ident)
697            .unwrap_or(false),
698        _ => false,
699    }
700}
701
702fn require_tuple_wrapper(input: &DeriveInput, macro_name: &str) -> syn::Result<()> {
703    match &input.data {
704        syn::Data::Struct(data) => match &data.fields {
705            syn::Fields::Unnamed(fields) if !fields.unnamed.is_empty() => Ok(()),
706            _ => Err(syn::Error::new_spanned(
707                &input.ident,
708                format!("`{macro_name}` requires a tuple struct with the wrapped value in field 0"),
709            )),
710        },
711        _ => Err(syn::Error::new_spanned(
712            &input.ident,
713            format!("`{macro_name}` can only be derived for tuple structs"),
714        )),
715    }
716}