action_derive/
lib.rs

1#![allow(clippy::missing_panics_doc)]
2
3mod ident;
4mod manifest;
5
6use manifest::Manifest;
7use proc_macro2::TokenStream;
8use quote::quote;
9use std::path::{Path, PathBuf};
10
11fn resolve_path(path: impl AsRef<Path>) -> PathBuf {
12    let root = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".into()));
13    if root.join(path.as_ref()).exists() {
14        root.join(path.as_ref())
15    } else {
16        root.join("src/").join(path.as_ref())
17    }
18}
19
20fn quote_option<T: quote::ToTokens>(value: Option<&T>) -> TokenStream {
21    if let Some(v) = value {
22        quote! { Some(#v) }
23    } else {
24        quote! { None }
25    }
26}
27
28fn get_attribute(attr: &syn::Attribute) -> String {
29    match &attr.meta {
30        syn::Meta::NameValue(syn::MetaNameValue { path, value, .. }) => {
31            debug_assert!(path.is_ident("action"));
32            match value {
33                syn::Expr::Lit(syn::ExprLit {
34                    lit: syn::Lit::Str(s),
35                    ..
36                }) => s.value(),
37                _ => panic!("action attribute must be a literal string"),
38            }
39        }
40        _ => panic!(r#"action attribute must be of the form `action = "..."`"#),
41    }
42}
43
44fn parse_derive(ast: &syn::DeriveInput) -> (&syn::Ident, &syn::Generics, PathBuf) {
45    let name = &ast.ident;
46    let generics = &ast.generics;
47
48    let manifests: Vec<_> = ast
49        .attrs
50        .iter()
51        .filter(|attr| attr.path().is_ident("action"))
52        .map(get_attribute)
53        .map(resolve_path)
54        .collect();
55
56    let manifest = manifests.into_iter().next().expect("a path to an action manifest (action.yml) file needs to be provided with the #[action = \"PATH\"] attribute");
57    (name, generics, manifest)
58}
59
60#[proc_macro_derive(Action, attributes(action))]
61pub fn action_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
62    let ast: syn::DeriveInput = syn::parse2(input.into()).unwrap();
63    let (struct_name, generics, manifest_path) = parse_derive(&ast);
64    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
65
66    let manifest = Manifest::from_action_yml(manifest_path);
67    // dbg!(&manifest);
68
69    let input_enum_variants: Vec<_> = manifest
70        .inputs
71        .keys()
72        .map(|name| {
73            let variant = ident::str_to_enum_variant(name);
74            quote! { #variant }
75        })
76        .collect();
77
78    let input_enum_matches: Vec<_> = manifest
79        .inputs
80        .keys()
81        .map(|name| {
82            let variant = ident::str_to_enum_variant(name);
83            quote! { #name => Ok(Self::#variant) }
84        })
85        .collect();
86
87    let input_enum_ident = quote::format_ident!("{}Input", struct_name);
88    let input_enum = quote! {
89        #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
90        pub enum #input_enum_ident {
91            #(#input_enum_variants,)*
92        }
93
94        impl std::str::FromStr for #input_enum_ident {
95            type Err = ();
96            fn from_str(input: &str) -> Result<Self, Self::Err> {
97                match input {
98                    #(#input_enum_matches,)*
99                    _  => Err(()),
100                }
101            }
102        }
103    };
104    // eprintln!("{}", pretty_print(&quote! { #input_enum }));
105
106    let parse_impl = quote! {
107        #[allow(clippy::all)]
108        impl #impl_generics ::action_core::Parse for #struct_name #ty_generics #where_clause {
109            type Input = #input_enum_ident;
110
111            fn parse_from<E: ::action_core::env::Read>(env: &E) -> std::collections::HashMap<Self::Input, Option<String>> {
112                Self::inputs().iter().filter_map(|(name, input)| {
113                    let value = ::action_core::input::ParseInput::parse_input::<String>(env, name);
114                    let default = input.default.map(|s| s.to_string());
115                    match std::str::FromStr::from_str(&name) {
116                        Ok(variant) => Some((variant, value.unwrap().or(default))),
117                        Err(_) => None,
118                    }
119                }).collect()
120            }
121        }
122    };
123
124    let input_impl_methods = input_impl_methods(&manifest);
125    let input_impl = quote! {
126        #[allow(clippy::all)]
127        impl #impl_generics #struct_name #ty_generics #where_clause {
128            #input_impl_methods
129        }
130    };
131
132    let tokens = quote! {
133        #input_enum
134        #input_impl
135        #parse_impl
136    };
137    // eprintln!("{}", pretty_print(&tokens));
138    tokens.into()
139}
140
141fn input_impl_methods(manifest: &Manifest) -> TokenStream {
142    let Manifest {
143        name,
144        description,
145        author,
146        ..
147    } = manifest;
148
149    let derived_methods: TokenStream = manifest
150        .inputs
151        .keys()
152        .map(|name| {
153            let fn_name = ident::parse_str(name);
154            quote! {
155                pub fn #fn_name<T>() -> Result<Option<T>, <T as ::action_core::input::Parse>::Error>
156                where T: ::action_core::input::Parse {
157                    let env = ::action_core::env::OsEnv::default();
158                    ::action_core::input::ParseInput::parse_input::<T>(&env, #name)
159                }
160            }
161        })
162        .collect();
163
164    let inputs: Vec<_> = manifest
165        .inputs
166        .iter()
167        .map(|(name, input)| {
168            let description = quote_option(input.description.as_ref());
169            let deprecation_message = quote_option(input.deprecation_message.as_ref());
170            let r#default = quote_option(input.default.as_ref());
171            let required = quote_option(input.required.as_ref());
172            quote! {
173                (#name, ::action_core::input::Input {
174                    description: #description,
175                    deprecation_message: #deprecation_message,
176                    default: #r#default,
177                    required: #required,
178                })
179            }
180        })
181        .collect();
182    // eprintln!("{}", pretty_print(&quote! { vec![#(#inputs,)*]; }));
183
184    quote! {
185        /// Inputs of this action.
186        pub fn inputs() -> ::std::collections::HashMap<
187            &'static str, ::action_core::input::Input<'static>
188        > {
189            static inputs: &'static [(&'static str, ::action_core::input::Input<'static>)] = &[
190                #(#inputs,)*
191            ];
192            inputs.iter().cloned().collect()
193        }
194
195        /// Description of this action.
196        pub fn description() -> &'static str {
197            #description
198        }
199
200        /// Name of this action.
201        pub fn name() -> &'static str {
202            #name
203        }
204
205        /// Author of this trait.
206        pub fn author() -> &'static str {
207            #author
208        }
209
210        #derived_methods
211    }
212}
213
214#[allow(dead_code)]
215fn pretty_print(tokens: &TokenStream) -> String {
216    let _file = syn::parse_file(&tokens.to_string()).unwrap();
217    // TODO: this will not work until prettyplease updates to syn 2+
218    // prettyplease::unparse(&file);
219    tokens.to_string()
220}