aargvark_proc_macros/
lib.rs

1use {
2    convert_case::{
3        Case,
4        Casing,
5    },
6    darling::{
7        FromDeriveInput,
8        FromField,
9        FromVariant,
10    },
11    flowcontrol::shed,
12    proc_macro2::TokenStream,
13    quote::{
14        format_ident,
15        quote,
16        ToTokens,
17    },
18    std::collections::HashSet,
19    syn::{
20        self,
21        parse_macro_input,
22        spanned::Spanned,
23        Attribute,
24        DeriveInput,
25        Expr,
26        Fields,
27        Lit,
28        Type,
29    },
30};
31
32#[derive(Default, Clone, FromDeriveInput)]
33#[darling(attributes(vark))]
34#[darling(default)]
35struct TypeAttr {
36    break_help: bool,
37    placeholder: Option<String>,
38}
39
40#[derive(Default, Clone, FromField)]
41#[darling(attributes(vark))]
42#[darling(default)]
43struct FieldAttr {
44    break_help: bool,
45    #[darling(multiple)]
46    flag: Vec<String>,
47    placeholder: Option<String>,
48}
49
50#[derive(Default, Clone, FromVariant)]
51#[darling(attributes(vark))]
52#[darling(default)]
53struct VariantAttr {
54    break_help: bool,
55    name: Option<String>,
56    placeholder: Option<String>,
57}
58
59fn get_docstr(attrs: &Vec<Attribute>) -> String {
60    let mut out = String::new();
61    for attr in attrs {
62        match &attr.meta {
63            syn::Meta::NameValue(meta) => {
64                if !meta.path.is_ident("doc") {
65                    continue;
66                }
67                let Expr::Lit(syn::ExprLit { lit: Lit::Str(s), .. }) = &meta.value else {
68                    continue;
69                };
70                out.push_str(&s.value());
71            },
72            _ => continue,
73        }
74    }
75    return out.trim().to_string();
76}
77
78struct GenRec {
79    vark: TokenStream,
80    help_pattern: TokenStream,
81}
82
83fn gen_impl_type(ty: &Type, path: &str) -> GenRec {
84    match ty {
85        Type::Path(t) => {
86            return GenRec {
87                vark: quote!{
88                    < #t >:: vark(state)
89                },
90                help_pattern: quote!{
91                    < #t as a:: AargvarkTrait >:: build_help_pattern(state)
92                },
93            };
94        },
95        Type::Tuple(t) => {
96            return gen_impl_unnamed(
97                path,
98                ty.to_token_stream(),
99                quote!(),
100                "TUPLE",
101                "",
102                0,
103                t.elems.iter().map(|e| (FieldAttr::default(), String::new(), e)).collect::<Vec<_>>().as_slice(),
104            );
105        },
106        _ => panic!("Unsupported type {} in {}", ty.to_token_stream(), path),
107    }
108}
109
110fn gen_impl_unnamed(
111    path: &str,
112    parent_ident: TokenStream,
113    ident: TokenStream,
114    help_placeholder: &str,
115    help_docstr: &str,
116    subtype_index: usize,
117    fields: &[(FieldAttr, String, &Type)],
118) -> GenRec {
119    let mut parse_positional = vec![];
120    let mut copy_fields = vec![];
121    let mut help_fields = vec![];
122    let mut help_field_patterns = vec![];
123    let help_unit_transparent = fields.len() == 1 && fields[0].1.is_empty();
124    for (i, (field_vark_attr, field_help_docstr, field_ty)) in fields.iter().enumerate() {
125        let f_ident = format_ident!("v{}", i);
126        let gen = gen_impl_type(field_ty, path);
127        let vark = gen.vark;
128        let placeholder = field_vark_attr.placeholder.clone().unwrap_or_else(|| {
129            let mut placeholder = vec![];
130            let mut placeholder_i = i;
131            loop {
132                placeholder.push((('A' as u8) + (placeholder_i % 27) as u8) as char);
133                if placeholder_i < 27 {
134                    break;
135                }
136                placeholder_i = placeholder_i / 27;
137            }
138            placeholder.reverse();
139            String::from_iter(placeholder)
140        });
141        parse_positional.push(quote!{
142            //. .
143            let r = #vark;
144            //. .
145            let #f_ident = match r {
146                a:: R:: Ok(v) => v,
147                a:: R:: Help(b) => break a:: R:: Help(b),
148                a:: R:: Err => break a:: R:: Err,
149            };
150        });
151        copy_fields.push(f_ident.to_token_stream());
152        let field_help_pattern = gen.help_pattern;
153        help_fields.push(quote!{
154            struct_.fields.push(a:: HelpField {
155                id: #placeholder.to_string(),
156                pattern: #field_help_pattern,
157                description: #field_help_docstr.to_string(),
158            });
159        });
160        help_field_patterns.push(field_help_pattern);
161    }
162    return GenRec {
163        vark: quote!{
164            loop {
165                #(#parse_positional) * 
166                //. .
167                break state.r_ok(#ident(#(#copy_fields), *), None);
168            }
169        },
170        help_pattern: if fields.is_empty() {
171            quote!{
172                a::HelpPattern(vec![])
173            }
174        } else if help_unit_transparent {
175            help_field_patterns.pop().unwrap()
176        } else {
177            quote!{
178                {
179                    let(
180                        key,
181                        struct_
182                    ) = state.add_struct(
183                        std:: any:: TypeId:: of::< #parent_ident >(),
184                        #subtype_index,
185                        #help_placeholder,
186                        #help_docstr
187                    );
188                    let mut struct_ = struct_.as_ref().borrow_mut();
189                    #(#help_fields) * 
190                    //. .
191                    a:: HelpPattern(vec![a::HelpPatternElement::Reference(key)])
192                }
193            }
194        },
195    };
196}
197
198fn get_optional_type(t: &Type) -> Option<&Type> {
199    let Type::Path(t) = &t else {
200        return None;
201    };
202    if t.qself.is_some() {
203        return None;
204    }
205    if t.path.leading_colon.is_some() {
206        return None;
207    }
208    if t.path.segments.len() != 1 {
209        return None;
210    }
211    let s = t.path.segments.first().unwrap();
212    if &s.ident.to_string() != "Option" {
213        return None;
214    }
215    let syn::PathArguments::AngleBracketed(a) = &s.arguments else {
216        return None;
217    };
218    if a.args.len() != 1 {
219        return None;
220    }
221    let syn::GenericArgument::Type(t) = &a.args[0] else {
222        return None;
223    };
224    return Some(t);
225}
226
227fn gen_impl_struct(
228    parent_ident: TokenStream,
229    ident: TokenStream,
230    decl_generics: &TokenStream,
231    forward_generics: &TokenStream,
232    help_placeholder: &str,
233    type_break_help: bool,
234    help_docstr: &str,
235    subtype_index: usize,
236    d: &Fields,
237) -> Result<GenRec, syn::Error> {
238    match d {
239        Fields::Named(d) => {
240            let mut help_fields = vec![];
241            let mut partial_help_fields = vec![];
242            let mut vark_flag_fields = vec![];
243            let mut vark_flag_fields_default = vec![];
244            let mut vark_parse_flag_cases = vec![];
245            let mut vark_parse_positional = vec![];
246            let mut vark_copy_flag_fields = vec![];
247            let mut seen_flags = HashSet::new();
248            let mut required_i = 0usize;
249            let mut init_need_flags = vec![];
250            'next_field: for (i, f) in d.named.iter().enumerate() {
251                let field_vark_attr = FieldAttr::from_field(f)?;
252                let field_help_docstr = get_docstr(&f.attrs);
253                let field_ident = f.ident.as_ref().expect("Named field missing name");
254                let f_local_ident = format_ident!("v{}", i);
255
256                // If a flag (non-positional) field, generate parsers and skip positional parsing
257                shed!{
258                    'no_flags _;
259                    let mut flags = field_vark_attr.flag.clone();
260                    if flags.is_empty() {
261                        flags.push(format!("--{}", field_ident.to_string().to_case(Case::Kebab)));
262                    }
263                    for flag in &flags {
264                        if !seen_flags.insert(flag.clone()) {
265                            return Err(
266                                syn::Error::new(f.span(), format!("Duplicate flag [{}] in [{}]", flag, ident)),
267                            );
268                        }
269                    }
270                    let ty;
271                    let copy;
272                    let optional;
273                    if let Some(ty1) = get_optional_type(&f.ty) {
274                        ty = ty1;
275                        copy = quote!(flag_fields.#field_ident);
276                        optional = true;
277                    }
278                    else if ! field_vark_attr.flag.is_empty() {
279                        ty = &f.ty;
280                        copy = quote!(if let Some(f) = flag_fields.#field_ident {
281                            f
282                        } else {
283                            return state.r_err(
284                                format!("Missing required flags: {:?}", need_flags),
285                                Some(build_completer(need_flags)),
286                            );
287                        });
288                        optional = false;
289                    }
290                    else {
291                        break 'no_flags;
292                    }
293                    vark_flag_fields.push(quote!{
294                        #field_ident: Option < #ty >,
295                    });
296                    vark_flag_fields_default.push(quote!{
297                        #field_ident: None,
298                    });
299                    vark_copy_flag_fields.push(quote!{
300                        #field_ident: #copy
301                    });
302                    let gen = gen_impl_type(ty, &field_ident.to_string());
303                    let vark = gen.vark;
304                    for flag in &flags {
305                        if !optional {
306                            init_need_flags.push(quote!{
307                                #flag,
308                            });
309                        }
310                        vark_parse_flag_cases.push(quote!{
311                            #flag => {
312                                need_flags.remove(#flag);
313                                if flag_fields.#field_ident.is_some() {
314                                    return state.r_err(
315                                        format!("The argument {} was already specified", #flag),
316                                        Some(a::empty_completer()),
317                                    );
318                                }
319                                state.consume();
320                                let #f_local_ident = match #vark {
321                                    a:: R:: Ok(v) => v,
322                                    a:: R:: Help(b) => return a:: R:: Help(b),
323                                    a:: R:: Err => return a:: R:: Err,
324                                };
325                                flag_fields.#field_ident = Some(#f_local_ident);
326                                return a::R::Ok(true);
327                            }
328                        });
329                    }
330                    let field_help_pattern;
331                    if type_break_help || field_vark_attr.break_help {
332                        field_help_pattern = quote!(a::HelpPattern(vec![]));
333                    }
334                    else {
335                        field_help_pattern = gen.help_pattern;
336                    }
337                    let help_field = quote!{
338                        a:: HelpFlagField {
339                            option: #optional,
340                            flags: vec ![#(#flags.to_string()), *],
341                            pattern: #field_help_pattern,
342                            description: #field_help_docstr.to_string(),
343                        }
344                    };
345                    help_fields.push(quote!{
346                        struct_.flag_fields.push(#help_field);
347                    });
348                    partial_help_fields.push(quote!{
349                        if flag_fields.#field_ident.is_none() {
350                            help_flag_fields.push(#help_field);
351                        }
352                    });
353                    continue 'next_field;
354                };
355
356                // Positional/required parsing
357                let field_help_placeholder =
358                    field_vark_attr
359                        .placeholder
360                        .unwrap_or_else(|| field_ident.to_string().to_case(Case::UpperKebab));
361                let gen = gen_impl_type(&f.ty, &ident.to_string());
362                let vark = gen.vark;
363                let field_help_pattern = gen.help_pattern;
364                vark_parse_positional.push(quote!{
365                    let #f_local_ident = loop {
366                        let peek = state.peek();
367                        if match peek {
368                            a:: PeekR:: None => false,
369                            a:: PeekR:: Help => return a:: R:: Help(Box:: new(move | state | {
370                                return a:: HelpPartialProduction {
371                                    description: #help_docstr.to_string(),
372                                    content: build_partial_help(state, #required_i, &flag_fields),
373                                };
374                            })),
375                            a:: PeekR:: Ok(
376                                s
377                            ) => match parse_flags(&mut need_flags, &mut flag_fields, state, s.to_string()) {
378                                a:: R:: Ok(v) => v,
379                                a:: R:: Help(b) => break a:: R:: Help(b),
380                                a:: R:: Err => break a:: R:: Err,
381                            },
382                        }
383                        {
384                            continue;
385                        }
386                        break #vark;
387                    };
388                    let #f_local_ident = match #f_local_ident {
389                        a:: R:: Ok(v) => v,
390                        a:: R:: Help(b) => break a:: R:: Help(b),
391                        a:: R:: Err => break a:: R:: Err,
392                    };
393                });
394                vark_copy_flag_fields.push(quote!{
395                    #field_ident: #f_local_ident
396                });
397                let help_field = quote!{
398                    a:: HelpField {
399                        id: #field_help_placeholder.to_string(),
400                        pattern: #field_help_pattern,
401                        description: #field_help_docstr.to_string(),
402                    }
403                };
404                help_fields.push(quote!{
405                    struct_.fields.push(#help_field);
406                });
407                partial_help_fields.push(quote!{
408                    if required_i <= #required_i {
409                        help_fields.push(#help_field);
410                    }
411                });
412                required_i += 1;
413            }
414
415            // Assemble code
416            let vark = quote!{
417                {
418                    loop {
419                        struct FlagFields #decl_generics {
420                            #(#vark_flag_fields) *
421                        }
422                        let mut flag_fields = FlagFields {
423                            #(#vark_flag_fields_default) *
424                        };
425                        type NeedFlags = std::collections::HashSet<&'static str>;
426                        let mut need_flags =[#(#init_need_flags) *].into_iter().collect::< NeedFlags >();
427                        fn parse_flags #decl_generics(
428                            need_flags: & mut NeedFlags,
429                            flag_fields:& mut FlagFields #forward_generics,
430                            state:& mut a:: VarkState,
431                            s: String
432                        ) -> a:: R < bool > {
433                            match s.as_str() {
434                                #(#vark_parse_flag_cases) * 
435                                //. .
436                                _ => return a:: R:: Ok(false),
437                            };
438                        }
439                        fn build_partial_help #decl_generics(
440                            state:& mut a:: HelpState,
441                            required_i: usize,
442                            flag_fields:& FlagFields #forward_generics,
443                        ) -> a:: HelpPartialContent {
444                            let mut help_fields = vec![];
445                            let mut help_flag_fields = vec![];
446                            #(#partial_help_fields) * 
447                            //. .
448                            return a:: HelpPartialContent:: struct_(help_fields, help_flag_fields);
449                        }
450                        #(#vark_parse_positional) * 
451                        // Parse any remaining optional args
452                        let flag_search_res = loop {
453                            match state.peek() {
454                                a:: PeekR:: None => {
455                                    break state.r_ok((), None);
456                                },
457                                a:: PeekR:: Help => return a:: R:: Help(Box:: new(move | state | {
458                                    return a:: HelpPartialProduction {
459                                        description: #help_docstr.to_string(),
460                                        content: build_partial_help(state, #required_i, &flag_fields),
461                                    };
462                                })),
463                                a:: PeekR:: Ok(
464                                    s
465                                ) => match parse_flags(&mut need_flags, &mut flag_fields, state, s.to_string()) {
466                                    a:: R:: Ok(v) => {
467                                        if !v {
468                                            break state.r_ok((), None);
469                                        }
470                                    },
471                                    a:: R:: Help(b) => break a:: R:: Help(b),
472                                    a:: R:: Err => break a:: R:: Err,
473                                },
474                            };
475                        };
476                        match flag_search_res {
477                            a::R::Ok(()) => { },
478                            a::R::Help(b) => break a::R::Help(b),
479                            a::R::Err => break a::R::Err,
480                        };
481                        fn build_completer(need_flags: NeedFlags) -> a::AargvarkCompleter {
482                            return Box::new(move || {
483                                return need_flags.iter().map(|v| vec![v.to_string()]).collect();
484                            });
485                        }
486                        // Build obj + return
487                        break state.r_ok(#ident {
488                            #(#vark_copy_flag_fields),
489                            *
490                        }, None);
491                    }
492                }
493            };
494            return Ok(GenRec {
495                vark: vark,
496                help_pattern: quote!{
497                    {
498                        let(
499                            key,
500                            struct_
501                        ) = state.add_struct(
502                            std::any::TypeId::of::<Self>(),
503                            #subtype_index,
504                            #help_placeholder,
505                            #help_docstr
506                        );
507                        let mut struct_ = struct_.as_ref().borrow_mut();
508                        #(#help_fields) * 
509                        //. .
510                        a:: HelpPattern(vec![a::HelpPatternElement::Reference(key)])
511                    }
512                },
513            });
514        },
515        Fields::Unnamed(d) => {
516            let mut fields = vec![];
517            for f in &d.unnamed {
518                fields.push((FieldAttr::from_field(f)?, get_docstr(&f.attrs), &f.ty));
519            }
520            return Ok(
521                gen_impl_unnamed(
522                    &ident.to_string(),
523                    parent_ident,
524                    ident.to_token_stream(),
525                    help_placeholder,
526                    help_docstr,
527                    subtype_index,
528                    &fields,
529                ),
530            );
531        },
532        Fields::Unit => {
533            return Ok(GenRec {
534                vark: quote!{
535                    state.r_ok(#ident, None)
536                },
537                help_pattern: quote!{
538                    a::HelpPattern(vec![])
539                },
540            });
541        },
542    };
543}
544
545fn gen_impl(ast: syn::DeriveInput) -> Result<TokenStream, syn::Error> {
546    let ident = &ast.ident;
547    let type_attr = TypeAttr::from_derive_input(&ast)?;
548    let help_docstr = get_docstr(&ast.attrs);
549    let decl_generics = ast.generics.to_token_stream();
550    let forward_generics;
551    {
552        let mut parts = vec![];
553        for p in ast.generics.params {
554            match p {
555                syn::GenericParam::Type(p) => parts.push(p.ident.to_token_stream()),
556                syn::GenericParam::Lifetime(p) => parts.push(p.lifetime.to_token_stream()),
557                syn::GenericParam::Const(p) => parts.push(p.ident.to_token_stream()),
558            }
559        }
560        if parts.is_empty() {
561            forward_generics = quote!();
562        } else {
563            forward_generics = quote!(< #(#parts), *>);
564        }
565    }
566    let help_placeholder =
567        type_attr.placeholder.clone().unwrap_or_else(|| ident.to_string().to_case(Case::UpperKebab));
568    let impl_vark;
569    let impl_help_build;
570    match &ast.data {
571        syn::Data::Struct(d) => {
572            let gen =
573                gen_impl_struct(
574                    ast.ident.to_token_stream(),
575                    ast.ident.to_token_stream(),
576                    &decl_generics,
577                    &forward_generics,
578                    &help_placeholder,
579                    type_attr.break_help,
580                    &help_docstr,
581                    0,
582                    &d.fields,
583                )?;
584            impl_vark = gen.vark;
585            impl_help_build = gen.help_pattern;
586        },
587        syn::Data::Enum(d) => {
588            let mut all_tags = vec![];
589            let mut vark_cases = vec![];
590            let mut help_variants = vec![];
591            for (subtype_index, v) in d.variants.iter().enumerate() {
592                let variant_vark_attr = VariantAttr::from_variant(v)?;
593                let variant_help_docstr = get_docstr(&v.attrs);
594                let variant_ident = &v.ident;
595                let name_str =
596                    variant_vark_attr
597                        .name
598                        .clone()
599                        .unwrap_or_else(|| variant_ident.to_string().to_case(Case::Kebab));
600                let help_placeholder =
601                    variant_vark_attr
602                        .placeholder
603                        .unwrap_or_else(|| variant_ident.to_string().to_case(Case::UpperKebab));
604                let gen =
605                    gen_impl_struct(
606                        ident.to_token_stream(),
607                        quote!(#ident:: #variant_ident),
608                        &decl_generics,
609                        &forward_generics,
610                        &help_placeholder,
611                        variant_vark_attr.break_help,
612                        "",
613                        subtype_index + 1,
614                        &v.fields,
615                    )?;
616                all_tags.push(name_str.clone());
617                let vark = gen.vark;
618                let partial_help_variant_pattern = gen.help_pattern;
619                vark_cases.push(quote!{
620                    #name_str => {
621                        state.consume();
622                        #vark
623                    }
624                });
625                let help_variant_pattern;
626                if type_attr.break_help || variant_vark_attr.break_help {
627                    help_variant_pattern =
628                        quote!(a::HelpPattern(vec![a::HelpPatternElement::PseudoReference("...".to_string())]));
629                } else {
630                    help_variant_pattern = partial_help_variant_pattern;
631                }
632                help_variants.push(quote!{
633                    variants.push(a:: HelpVariant {
634                        literal: #name_str.to_string(),
635                        pattern: #help_variant_pattern,
636                        description: #variant_help_docstr.to_string(),
637                    });
638                });
639            }
640            impl_vark = quote!{
641                {
642                    fn build_completer(arg: & str) -> a:: AargvarkCompleter {
643                        let arg = arg.to_string();
644                        return Box:: new(move || {
645                            let mut out = vec![];
646                            for want_arg in &[#(#all_tags), *] {
647                                if want_arg.starts_with(&arg) {
648                                    out.push(vec![want_arg.to_string()]);
649                                }
650                            }
651                            return out;
652                        });
653                    }
654                    let tag = match state.peek() {
655                        a:: PeekR:: None => {
656                            return state.r_err(
657                                format!("Need variant tag - choices are {:?}", vec![#(#all_tags), *]),
658                                Some(build_completer("")),
659                            );
660                        },
661                        a:: PeekR:: Help => return a:: R:: Help(Box:: new(move | state | {
662                            let mut variants = vec![];
663                            #(#help_variants) * 
664                            //. .
665                            return a:: HelpPartialProduction {
666                                description: #help_docstr.to_string(),
667                                content: a:: HelpPartialContent:: enum_(variants),
668                            };
669                        })),
670                        a:: PeekR:: Ok(s) => s,
671                    };
672                    match tag {
673                        #(#vark_cases) * 
674                        //. .
675                        _ => {
676                            state.r_err(
677                                format!("Unrecognized variant {} - choices are {:?}", tag, vec![#(#all_tags), *]),
678                                Some(build_completer(tag)),
679                            )
680                        }
681                    }
682                }
683            };
684            impl_help_build = quote!{
685                let(
686                    key,
687                    variants
688                ) = state.add_enum(std::any::TypeId::of::<Self>(), 0, #help_placeholder, #help_docstr);
689                let mut variants = variants.as_ref().borrow_mut();
690                #(#help_variants) * 
691                //. .
692                return a:: HelpPattern(vec![a::HelpPatternElement::Reference(key)]);
693            };
694        },
695        syn::Data::Union(_) => panic!("Union not supported"),
696    };
697    return Ok(quote!{
698        impl #decl_generics aargvark:: traits:: AargvarkTrait for #ident #forward_generics {
699            fn vark(state:& mut aargvark:: base:: VarkState) -> aargvark:: base:: R < #ident #forward_generics > {
700                mod a {
701                    pub use aargvark::help::*;
702                    pub use aargvark::base::*;
703                    pub use aargvark::traits::*;
704                }
705                #impl_vark
706            }
707            fn build_help_pattern(state:& mut aargvark:: help:: HelpState) -> aargvark:: help:: HelpPattern {
708                mod a {
709                    pub use aargvark::help::*;
710                    pub use aargvark::traits::*;
711                }
712                #impl_help_build
713            }
714        }
715    });
716}
717
718#[proc_macro_derive(Aargvark, attributes(vark))]
719pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
720    return match gen_impl(parse_macro_input!(input as DeriveInput)) {
721        Ok(x) => x,
722        Err(e) => e.to_compile_error(),
723    }.into();
724}