clap_dispatch/
lib.rs

1//! Ergonomic way to dispatch CLI subcommands.
2//!
3//! Useful when your CLI defines subcommands, all of which should do the same kind of action, just in a different way. \
4//! I.e., when the subcommands are variants of running a certain action.
5//!
6//! It becomes especially useful when you have nested subcommands.
7//! In this case you can dispatch all the way down to the leaves of your command tree!
8//!
9//! # Example
10//!
11//! Suppose you implement a CLI for sorting numbers. \
12//! You have two algorithms, Quick Sort and Merge Sort, and they have their own subcommand, respectively.
13//! ```
14//! #[derive(Parser)]
15//! enum Cli {
16//!     Quick(QuickArgs),
17//!     Merge(MergeArgs),
18//! }
19//! ```
20//! Now the point is, both algorithms will essentially just implement some `sort(...)` function with the same signature.
21//! You could model this as both `QuickArgs` and `MergeArgs` implementing a function like:
22//! ```
23//! fn sort(self, nums: Vec<i32>) -> Vec<i32>
24//! ```
25//! (The `self` is there so that they can make use of the special arguments passed for the respective algorithm.) \
26//! So you could put such a function into a trait, and then implement the trait for both `QuickArgs` and `MergeArgs`.
27//!
28//! The annoying part is, to dispatch the `sort(...)` function, you then have to do a `match` over your `Cli` enum and call `sort(...)` on every variant.
29//! That's boilerplate.
30//!
31//! This crate is doing the boilerplate for you.
32//!
33//! In this case, you would do the following macro invocation:
34//! ```
35//! #[derive(Parser)]
36//! #[clap_dispatch(fn sort(self, nums: Vec<i32>) -> Vec<i32>)] // macro call
37//! enum Cli {
38//!     Quick(QuickArgs),
39//!     Merge(MergeArgs),
40//! }
41//! ```
42//! This defines a trait `Sort` which contains the `sort(...)` function, and it also implements `Sort` for `Cli`, where the implementation just dispatches to the variants.
43//! So what's left to you is only:
44//! - implement `Sort` for `QuickArgs` and `MergeArgs` (i.e. implement the algorithms)
45//! - call `cli.sort(...)`
46//!
47//! # Usage
48//!
49//! A minimal explanation is in the definition of the [macro@clap_dispatch] macro.
50//!
51//! The full code for the above example can be found in the `example/` folder. ([latest GitHub version](https://github.com/jbirnick/clap-dispatch/tree/main/example))
52
53use heck::ToUpperCamelCase;
54use proc_macro::TokenStream;
55use proc_macro2::Span;
56use quote::quote;
57use syn::{Ident, ItemEnum, Signature};
58
59/// The main macro.
60///
61/// Needs to be attached to an `enum` and given a function signature as an attribute.
62/// ```
63/// #[clap_dispatch(fn run(self))]
64/// enum MyCommand {
65///   Foo(FooArgs)
66///   Bar(BarArgs)
67/// }
68/// ```
69///
70/// It does two things:
71///
72/// 1. It creates a new trait, named like the provided function transformed to UpperCamelCase.
73///    The trait will contain only one function, which has exactly the provided signature.
74///
75///    In this case it will generate:
76///    ```
77///    trait Run {
78///        fn run(self);
79///    }
80///    ```
81///
82/// 2. It implements the trait for the enum.
83///    The implementation is just to dispatch onto the different variants.
84///    **This means the fields of all the enum variants need to implement the generated trait from above.**
85///    It's your job to make those implementations by hand.
86///
87///    In this case it will generate:
88///    ```
89///    impl Run for MyCommand {
90///       fn run(self) {
91///           match self {
92///               Self::Foo(args) => args.run(),
93///               Self::Bar(args) => args.run(),
94///           }
95///       }
96///    }
97///    ```
98///
99#[proc_macro_attribute]
100pub fn clap_dispatch(attr: TokenStream, mut item: TokenStream) -> TokenStream {
101    let generated =
102        clap_dispatch_gen(&attr, &item).unwrap_or_else(|error| error.to_compile_error().into());
103    item.extend(generated);
104    item
105}
106
107fn clap_dispatch_gen(attr: &TokenStream, item: &TokenStream) -> Result<TokenStream, syn::Error> {
108    // parse the enum and the attribute
109    let item_enum: ItemEnum = syn::parse(item.clone())?;
110    let signature: Signature = syn::parse(attr.clone())?;
111
112    // generate new things which should be appended after the enum
113    generate(item_enum, signature)
114}
115
116// generates both:
117// 1. the new trait whose only function is given by the provided signature
118// 2. the implementation of this trait for the enum
119fn generate(
120    // the enum on which the attribute macro was placed
121    item_enum: ItemEnum,
122    // the function signature that was provided with the attribute macro
123    signature: Signature,
124) -> Result<TokenStream, syn::Error> {
125    // make sure the user provided everything in the correct form
126    validity_checks(&item_enum, &signature)?;
127
128    // relevant identifiers
129    let enum_ident = item_enum.ident;
130    let signature_ident = &signature.ident;
131    let trait_ident = upper_camel_case(signature_ident);
132
133    // the arguments which need to be passed to the function, except `self`
134    let call_args = signature.inputs.iter().skip(1).map(|fn_arg| {
135        if let syn::FnArg::Typed(pat_type) = fn_arg {
136            &pat_type.pat
137        } else {
138            // all functions arguments except the first one should be FnArg::Typed (not FnArg::Receiver)
139            unreachable!()
140        }
141    });
142
143    // the match arms for the implementation of the trait
144    let match_arms = item_enum.variants.into_iter().map(|variant| {
145        let variant_ident = variant.ident;
146        let call_args = call_args.clone();
147
148        quote! {
149            Self::#variant_ident(args) => self::#trait_ident::#signature_ident(args, #(#call_args),*),
150        }
151    });
152
153    // the final generated code
154    let generated = quote! {
155        trait #trait_ident {
156            #signature;
157        }
158
159        impl #trait_ident for #enum_ident {
160            #signature {
161                match self {
162                    #(#match_arms)*
163                }
164            }
165        }
166    };
167
168    Ok(generated.into())
169}
170
171fn upper_camel_case(ident: &Ident) -> Ident {
172    let new_ident = ident.to_string().to_upper_camel_case();
173    Ident::new(&new_ident, Span::call_site())
174}
175
176fn validity_checks(item_enum: &ItemEnum, signature: &Signature) -> Result<(), syn::Error> {
177    // make sure the enum doesn't use generics
178    if item_enum.generics.lt_token.is_some() {
179        return Err(syn::Error::new_spanned(
180            &item_enum.generics,
181            "generics are not yet supported by clap-dispatch",
182        ));
183    }
184
185    // make sure signature has no generics
186    if signature.generics.lt_token.is_some() {
187        return Err(syn::Error::new_spanned(
188            &signature.generics,
189            "generics are not yet supported by clap-dispatch",
190        ));
191    }
192
193    // make sure signature has no variadic
194    if signature.variadic.is_some() {
195        return Err(syn::Error::new_spanned(
196            &signature.variadic,
197            "variadics are not yet supported by clap-dispatch",
198        ));
199    }
200
201    // make sure first argument of signature is some form of `self`
202    match signature.inputs.first() {
203        Some(fn_arg) => {
204            if !matches!(fn_arg, syn::FnArg::Receiver(_)) {
205                return Err(syn::Error::new_spanned(
206                    fn_arg,
207                    "first argument of function must be `self` or `&self` or `&mut self`",
208                ));
209            }
210        }
211        None => {
212            return Err(syn::Error::new_spanned(
213                &signature.inputs,
214                "function needs at least a `self` argument (or `&self` or `&mut self`)",
215            ))
216        }
217    }
218
219    // make sure the enum variants have exactly one unnamed field
220    for variant in item_enum.variants.iter() {
221        match &variant.fields {
222            syn::Fields::Named(fields_named) => {
223                return Err(syn::Error::new_spanned(
224                    fields_named,
225                    "must have unnamed field, not named",
226                ));
227            }
228            syn::Fields::Unnamed(fields_unnamed) => {
229                if fields_unnamed.unnamed.len() != 1 {
230                    return Err(syn::Error::new_spanned(
231                        fields_unnamed,
232                        "number of unnamed fields must be exactly one",
233                    ));
234                }
235            }
236            syn::Fields::Unit => {
237                return Err(syn::Error::new_spanned(
238                    &variant.ident,
239                    "variant must have an unnamed field",
240                ));
241            }
242        };
243    }
244
245    Ok(())
246}