Skip to main content

clipanion_derive/
lib.rs

1extern crate proc_macro;
2
3use std::collections::HashMap;
4
5use proc_macro::TokenStream;
6use proc_macro2::Span;
7use quote::quote;
8use syn::{parse::{Parse, ParseStream}, parse_macro_input, punctuated::Punctuated, Attribute, DeriveInput, Expr, ExprLit, Fields, Ident, Lit, LitBool, LitStr, Meta, Path, Token};
9
10macro_rules! expect_lit {
11    ($expression:path) => {
12        |val| match val {
13            Expr::Lit(ExprLit {lit: $expression(value), ..}) => Ok(value),
14            _ => Err(syn::Error::new_spanned(val, "Invalid literal type")),
15        }
16    };
17}
18
19fn to_lit_str<T: AsRef<str>>(str: T) -> LitStr {
20    LitStr::new(str.as_ref(), proc_macro2::Span::call_site())
21}
22
23#[derive(Clone, Default)]
24struct AttributeBag {
25    attributes: HashMap<String, Expr>,
26}
27
28impl AttributeBag {
29    pub fn expect_empty(&self) -> syn::Result<()> {
30        if !self.attributes.is_empty() {
31            return Err(syn::Error::new_spanned(self.attributes.iter().next().unwrap().1, "Unsupported extra attributes"));
32        }
33
34        Ok(())
35    }
36
37    pub fn take(&mut self, key: &str) -> Option<Expr> {
38        self.attributes.remove(key)
39    }
40}
41
42impl Parse for AttributeBag {
43    fn parse(input: ParseStream) -> syn::Result<Self> {
44        // Prepare a vector to hold the named parameters
45        let mut attributes = HashMap::new();
46        
47        // Parse the remaining named parameters
48        while !input.is_empty() {
49            let ident: Ident = input.parse()?;
50            let name = ident.to_string();
51
52            if input.peek(Token![=]) {
53                input.parse::<Token![=]>()?;  // Consume the equals sign
54
55                let value: Expr = input.parse()?;
56                attributes.insert(name, value);
57            } else {
58                attributes.insert(name, Expr::Lit(ExprLit {
59                    attrs: vec![],
60                    lit: Lit::Bool(LitBool {
61                        value: true,
62                        span: proc_macro2::Span::call_site(),
63                    }),
64                }));
65            }
66
67            if !input.is_empty() {
68                input.parse::<Token![,]>()?;
69            }
70        }
71
72        if !input.is_empty() {
73            return Err(input.error("Unexpected token"));
74        }
75
76        Ok(Self {attributes})
77    }
78}
79
80#[derive(Clone, Default)]
81struct OptionBag {
82    path: Vec<String>,
83    attributes: AttributeBag,
84}
85
86impl OptionBag {
87    fn parse_with_path(input: ParseStream) -> syn::Result<Self> {
88        let path: LitStr = input.parse()?;
89
90        let path = path.value()
91            .split(',')
92            .map(|s| s.trim().to_string())
93            .collect::<Vec<_>>();
94
95        let mut attributes = AttributeBag::default();
96        if input.peek(Token![,]) {
97            input.parse::<Token![,]>()?;            
98            attributes = input.parse()?;
99        }
100
101        Ok(Self {
102            path,
103            attributes,
104        })
105    }
106
107    fn parse_without_path(input: ParseStream) -> syn::Result<Self> {
108        Ok(Self {
109            path: vec![],
110            attributes: input.parse()?,
111        })
112    }
113}
114
115impl Parse for OptionBag {
116    fn parse(input: ParseStream) -> syn::Result<Self> {
117        if input.peek(syn::Ident) {
118            Self::parse_without_path(input)
119        } else {
120            Self::parse_with_path(input)
121        }
122    }
123}
124
125#[derive(Clone, Default)]
126struct CliAttributes {
127    attributes: HashMap<String, Vec<Attribute>>,
128}
129
130impl CliAttributes {
131    fn parse_args<T: Default + Parse>(attr: &Attribute) -> syn::Result<T> {
132        match attr.meta {
133            Meta::Path(_) => Ok(T::default()),
134            _ => attr.parse_args::<T>(),
135        }
136    }
137
138    fn extract(attrs: &mut Vec<Attribute>) -> syn::Result<Self> {
139        let mut cli_attributes = CliAttributes::default();
140        let mut remaining_attributes = vec![];
141    
142        for attr in std::mem::take(attrs).into_iter(){
143            let path = attr.path();
144            if path.segments.is_empty() || path.segments[0].ident != "cli" {
145                remaining_attributes.push(attr.clone());
146                continue;
147            }
148    
149            if path.segments.len() != 2 {
150                return Err(syn::Error::new_spanned(attr, "Expected a named attribute"));
151            }
152    
153            let name = path.segments[1].ident.to_string();
154    
155            cli_attributes.attributes.entry(name)
156                .or_insert_with(Vec::new)
157                .push(attr);
158        }
159
160        *attrs = remaining_attributes;
161    
162        Ok(cli_attributes)
163    }
164
165    fn take_unique<T: Default + Parse>(&mut self, key: &str) -> syn::Result<Option<T>> {
166        match self.attributes.remove(key) {
167            Some(values) => {
168                if values.len() > 1 {
169                    return Err(syn::Error::new(proc_macro2::Span::call_site(), "Attribute must be unique"));
170                }
171
172                let attr = &values[0];
173                Self::parse_args(attr).map(Some)
174            },
175
176            None => Ok(None),
177        }
178    }
179
180    fn take_paths(&mut self) -> syn::Result<Vec<Vec<LitStr>>> {
181        let path_attributes = self.attributes.remove("path")
182            .unwrap_or_default();
183
184        let punctuated_paths = path_attributes.into_iter()
185            .map(|attr| attr.parse_args_with(Punctuated::<LitStr, Token![,]>::parse_terminated))
186            .collect::<syn::Result<Vec<_>>>()?;
187
188        let path_lits = punctuated_paths.into_iter()
189            .map(|punctuated| punctuated.into_iter().collect())
190            .collect();
191
192        Ok(path_lits)        
193    }
194}
195
196fn command_impl(args: TokenStream, mut input: DeriveInput) -> Result<TokenStream, syn::Error> {
197    let struct_input = if let syn::Data::Struct(data) = &mut input.data {
198        data
199    } else {
200        panic!("Only structs are supported");
201    };
202
203    let mut builder = vec![];
204
205    let mut option_hydrater = vec![];
206    let mut positional_hydrater = vec![];
207    
208    let mut command_cli_attributes
209        = CliAttributes::extract(&mut input.attrs)?;
210
211    let mut command_attribute_bag
212        = syn::parse::<AttributeBag>(args)?;
213
214    let is_default = command_attribute_bag.take("default")
215        .map(expect_lit!(Lit::Bool))
216        .transpose()?
217        .map(|lit| lit.value)
218        .unwrap_or(false);
219
220    let is_proxy = command_attribute_bag.take("proxy")
221        .map(expect_lit!(Lit::Bool))
222        .transpose()?
223        .map(|lit| lit.value)
224        .unwrap_or(false);
225
226    let explicit_positionals = command_attribute_bag.take("explicit_positionals")
227        .map(expect_lit!(Lit::Bool))
228        .transpose()?
229        .map(|lit| lit.value)
230        .unwrap_or(false);
231
232    let paths_lits
233        = command_cli_attributes.take_paths()?;
234
235    command_attribute_bag.expect_empty()?;
236
237    let mut partial_struct_members
238        = vec![];
239
240    let mut initialization_members
241        = vec![];
242
243    if !is_default && paths_lits.is_empty() {
244        return Err(syn::Error::new_spanned(input.ident, "The command must have a path"));
245    }
246
247    if is_default {
248        builder.push(quote! {
249            builder.make_default();
250        });
251    }
252
253    for path_lits in paths_lits {
254        builder.push(quote! {
255            builder.add_path(vec![#(#path_lits.to_string()),*]);
256        });
257    }
258
259    for field in &mut struct_input.fields {
260        let field_ident = &field.ident;
261        let field_type = &field.ty;
262
263        let mut internal_field_type = &field.ty;
264        let mut is_option_type = false;
265        let mut is_vec_type = false;
266
267        if let syn::Type::Path(type_path) = &internal_field_type {
268            if &type_path.path.segments[0].ident == "Option" {
269                let inner_type = &type_path.path.segments[0].arguments;
270                if let syn::PathArguments::AngleBracketed(args) = inner_type {
271                    if let syn::GenericArgument::Type(ty) = &args.args[0] {
272                        internal_field_type = ty;
273                        is_option_type = true;
274                    }
275                }
276            }
277
278            if &type_path.path.segments[0].ident == "Vec" {
279                let inner_type = &type_path.path.segments[0].arguments;
280                if let syn::PathArguments::AngleBracketed(args) = inner_type {
281                    if let syn::GenericArgument::Type(ty) = &args.args[0] {
282                        internal_field_type = ty;
283                        is_vec_type = true;
284                    }
285                }
286            }
287        }
288
289        let mut cli_attributes
290            = CliAttributes::extract(&mut field.attrs)?;
291
292        if !explicit_positionals && !cli_attributes.attributes.contains_key("option") && !cli_attributes.attributes.contains_key("positional") {
293            cli_attributes.attributes.insert("positional".to_string(), vec![Attribute {
294                pound_token: Default::default(),
295                style: syn::AttrStyle::Outer,
296                bracket_token: Default::default(),
297                meta: Meta::Path(Path::from(Ident::new("positional", Span::call_site()))),
298            }]);
299        }
300
301        if let Some(mut option_bag) = cli_attributes.take_unique::<OptionBag>("option")? {
302            let mut is_bool = false;
303            let mut arity = 1;
304
305            if let syn::Type::Path(type_path) = &internal_field_type {
306                if &type_path.path.segments[0].ident == "bool" {
307                    is_bool = true;
308                    arity = 0;
309                }
310            }
311    
312            if let syn::Type::Tuple(tuple) = internal_field_type {
313                arity = tuple.elems.len();
314            }
315
316            let description = option_bag.attributes.take("help")
317                .map(expect_lit!(Lit::Str))
318                .transpose()?
319                .map(|lit| lit.value())
320                .unwrap_or_default();
321
322            let is_required = option_bag.attributes.take("required")
323                .map(expect_lit!(Lit::Bool))
324                .transpose()?
325                .map(|lit| lit.value)
326                .unwrap_or(false);
327
328            let preferred_name_lit = option_bag.path.iter()
329                .max_by_key(|s| s.len())
330                .map(to_lit_str)
331                .unwrap();
332
333            let name_set_lit = option_bag.path
334                .iter()
335                .map(to_lit_str)
336                .collect::<Vec<_>>();
337
338            let value_type = if arity > 1 {
339                quote! {Array}
340            } else if is_bool {
341                quote! {Bool}
342            } else {
343                quote! {String}
344            };
345
346            let value_converter = if arity > 1 {
347                quote! {value.iter().map(|s| s.parse().unwrap()).collect::<Result<Vec<_>, _>>()}
348            } else if is_bool {
349                quote! {Result::<bool, std::convert::Infallible>::Ok(value)}
350            } else {
351                quote! {value.parse()}
352            };
353
354            let value_converter = quote! {
355                #value_converter.map_err(|err| clipanion::details::HydrationError::new(err))?
356            };
357
358            let default_value
359                = option_bag.attributes.take("default");
360
361            if is_vec_type {
362                partial_struct_members.push(quote! {
363                    #field_ident: Vec<#internal_field_type>,
364                });
365
366                option_hydrater.push(quote! {
367                    if option.0.as_str() == #preferred_name_lit {
368                        if let clipanion::core::OptionValue::#value_type(value) = option.1 {
369                            partial.#field_ident.push(#value_converter);
370                            continue;
371                        }
372                    }
373                });
374
375                initialization_members.push(quote! {
376                    #field_ident: Vec::new(),
377                });
378            } else {
379                partial_struct_members.push(quote! {
380                    #field_ident: Option<#field_type>,
381                });
382
383                if is_option_type {
384                    option_hydrater.push(quote! {
385                        if option.0.as_str() == #preferred_name_lit {
386                            if let clipanion::core::OptionValue::#value_type(value) = option.1 {
387                                partial.#field_ident = Some(Some(#value_converter));
388                                continue;
389                            }
390                        }
391                    });
392                } else {
393                    option_hydrater.push(quote! {
394                        if option.0.as_str() == #preferred_name_lit {
395                            if let clipanion::core::OptionValue::#value_type(value) = option.1 {
396                                partial.#field_ident = Some(#value_converter);
397                                continue;
398                            }
399                        }
400                    });
401                }
402
403                let accessor = match default_value {
404                    Some(expr) => quote! { partial.#field_ident.or_else(|| Some(#expr)) },
405                    None => quote! { partial.#field_ident },
406                };
407
408                if is_option_type {
409                    initialization_members.push(quote! {
410                        #field_ident: #accessor.unwrap_or_default(),
411                    });
412                } else {
413                    initialization_members.push(quote! {
414                        #field_ident: #accessor.unwrap(),
415                    });
416                }
417            }
418
419            builder.push(quote! {
420                builder.add_option(clipanion::core::OptionDefinition {
421                    name_set: vec![#(#name_set_lit.to_string()),*],
422                    description: #description.to_string(),
423                    required: #is_required,
424                    arity: #arity,
425                    ..Default::default()
426                })?;
427            });
428
429            option_bag.attributes.expect_empty()?;
430        } else if let Some(positional_bag) = cli_attributes.take_unique::<AttributeBag>("positional")? {
431            let field_name_upper = field.ident.as_ref().unwrap()
432                .to_string()
433                .to_uppercase();
434
435            if is_vec_type {
436                partial_struct_members.push(quote! {
437                    #field_ident: Vec<#internal_field_type>,
438                });
439    
440                positional_hydrater.push(quote! {
441                    while let Some(clipanion::core::Positional::Rest(value)) = clipanion::details::cautious_take_if(&mut positional_it, |item| matches!(item, clipanion::core::Positional::Rest(_))) {
442                        let value = value.as_str().parse()
443                            .map_err(|err| clipanion::details::HydrationError::new(err))?;
444
445                        partial.#field_ident.push(value);
446                    }
447                });
448
449                initialization_members.push(quote! {
450                    #field_ident: partial.#field_ident,
451                });
452
453                let add_cmd = match is_proxy {
454                    true => quote! {add_proxy},
455                    false => quote! {add_rest},
456                };
457
458                builder.push(quote! {
459                    builder.#add_cmd(#field_name_upper)?;
460                });
461            } else {
462                partial_struct_members.push(quote! {
463                    #field_ident: Option<#field_type>,
464                });
465
466                if is_option_type {
467                    positional_hydrater.push(quote! {
468                        if let Some(clipanion::core::Positional::Optional(value)) = clipanion::details::cautious_take_if(&mut positional_it, |item| matches!(item, clipanion::core::Positional::Required(_))) {
469                            let value = value.as_str().try_into()
470                                .map_err(|err| clipanion::details::HydrationError::new(err))?;
471
472                            partial.#field_ident = Some(Some(value));
473                        }
474                    });
475                } else {
476                    positional_hydrater.push(quote! {
477                        if let Some(clipanion::core::Positional::Required(value)) = positional_it.next() {
478                            let value = value.as_str().try_into()
479                                .map_err(|err| clipanion::details::HydrationError::new(err))?;
480
481                            partial.#field_ident = Some(value);
482                        } else {
483                            panic!("Internal error: Unexpected positional type during the Clipanion hydration");
484                        }
485                    });
486                }
487
488                initialization_members.push(quote! {
489                    #field_ident: partial.#field_ident.unwrap(),
490                });
491
492                builder.push(quote! {
493                    builder.add_positional(!#is_option_type, #field_name_upper)?;
494                });
495            }
496
497            positional_bag.expect_empty()?;
498        }
499    }
500
501    if let Fields::Named(fields) = &mut struct_input.fields {
502        fields.named.push(syn::parse_quote! {cli_path: Vec<String>});
503        fields.named.push(syn::parse_quote! {cli_info: clipanion::advanced::Info});
504    }
505
506    let struct_name
507        = &input.ident;
508
509    let expanded = quote! {
510        #input
511
512        impl clipanion::details::CommandController for #struct_name {
513            fn command_usage(opts: clipanion::core::CommandUsageOptions) -> Result<clipanion::core::CommandUsageResult, clipanion::core::BuildError> {
514                let mut cli_builder = clipanion::core::CliBuilder::new();
515                let mut builder = cli_builder.add_command();
516
517                #(#builder)*
518
519                Ok(builder.usage(opts))
520            }
521
522            fn attach_command_to_cli(builder: &mut clipanion::core::CommandBuilder) -> Result<(), clipanion::core::BuildError> {
523                #(#builder)*
524
525                Ok(())
526            }
527
528            fn hydrate_command_from_state(info: &clipanion::advanced::Info, state: clipanion::core::RunState) -> Result<Self, clipanion::details::HydrationError> {
529                #[derive(Default, Debug)]
530                struct Partial {
531                    #(#partial_struct_members)*
532                }
533
534                let mut partial
535                    = Partial::default();
536
537                for option in state.options {
538                    #(#option_hydrater)*
539                }
540
541                let mut positional_it = state.positionals
542                    .into_iter()
543                    .peekable();
544
545                #(#positional_hydrater)*
546
547                Ok(Self {
548                    cli_path: state.path.clone(),
549                    cli_info: info.clone(),
550                    #(#initialization_members)*
551                })
552            }
553        }
554    };
555
556    Ok(TokenStream::from(expanded))
557}
558
559#[proc_macro_attribute]
560pub fn command(args: TokenStream, input: TokenStream) -> TokenStream {
561    let input = parse_macro_input!(input as DeriveInput);
562
563    match command_impl(args, input) {
564        Ok(token_stream) => token_stream,
565        Err(err) => err.to_compile_error().into(),
566    }
567}