rust_actions_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use serde::Deserialize;
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use syn::parse::{Parse, ParseStream};
7use syn::{parse_macro_input, DeriveInput, ItemFn, FnArg, Type, LitStr, Token};
8
9#[proc_macro_attribute]
10pub fn step(attr: TokenStream, item: TokenStream) -> TokenStream {
11    let step_name = parse_macro_input!(attr as LitStr);
12    let input = parse_macro_input!(item as ItemFn);
13
14    let fn_name = &input.sig.ident;
15
16    let mut params = input.sig.inputs.iter();
17
18    let world_type = match params.next() {
19        Some(FnArg::Typed(pat_type)) => {
20            extract_world_type(&pat_type.ty)
21        }
22        _ => {
23            return syn::Error::new_spanned(
24                &input.sig,
25                "Step function must have a world parameter as first argument"
26            ).to_compile_error().into();
27        }
28    };
29
30    let has_args = params.next().is_some();
31
32    let step_call = if has_args {
33        quote! {
34            let parsed_args = match ::rust_actions::args::FromArgs::from_args(&args) {
35                Ok(a) => a,
36                Err(e) => return Box::pin(async move { Err(e) }),
37            };
38            Box::pin(async move {
39                let result = #fn_name(world, parsed_args).await?;
40                Ok(::rust_actions::outputs::IntoOutputs::into_outputs(result))
41            })
42        }
43    } else {
44        quote! {
45            Box::pin(async move {
46                let result = #fn_name(world).await?;
47                Ok(::rust_actions::outputs::IntoOutputs::into_outputs(result))
48            })
49        }
50    };
51
52    let step_name_str = step_name.value();
53    let erased_fn_name = syn::Ident::new(
54        &format!("__erased_{}", fn_name),
55        fn_name.span()
56    );
57
58    let expanded = quote! {
59        #input
60
61        #[doc(hidden)]
62        #[allow(non_upper_case_globals)]
63        fn #erased_fn_name<'a>(
64            world_any: &'a mut dyn ::std::any::Any,
65            args: ::rust_actions::args::RawArgs,
66        ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ::rust_actions::Result<::rust_actions::outputs::StepOutputs>> + Send + 'a>> {
67            let world = match world_any.downcast_mut::<#world_type>() {
68                Some(w) => w,
69                None => {
70                    let msg = format!(
71                        "World type mismatch: expected {}",
72                        ::std::any::type_name::<#world_type>()
73                    );
74                    return Box::pin(async move {
75                        Err(::rust_actions::Error::Custom(msg))
76                    });
77                }
78            };
79
80            #step_call
81        }
82
83        ::rust_actions::inventory::submit! {
84            ::rust_actions::registry::ErasedStepDef::new(
85                #step_name_str,
86                {
87                    use ::std::any::TypeId;
88                    TypeId::of::<#world_type>()
89                },
90                #erased_fn_name,
91            )
92        }
93    };
94
95    TokenStream::from(expanded)
96}
97
98fn extract_world_type(ty: &Type) -> proc_macro2::TokenStream {
99    match ty {
100        Type::Reference(type_ref) => {
101            if let Type::Path(type_path) = &*type_ref.elem {
102                let path = &type_path.path;
103                quote! { #path }
104            } else {
105                quote! { compile_error!("Expected a type path for world parameter") }
106            }
107        }
108        _ => {
109            quote! { compile_error!("World parameter must be a mutable reference") }
110        }
111    }
112}
113
114#[proc_macro_derive(World, attributes(world))]
115pub fn derive_world(input: TokenStream) -> TokenStream {
116    let input = parse_macro_input!(input as DeriveInput);
117    let name = &input.ident;
118
119    let expanded = quote! {
120        impl ::rust_actions::world::World for #name {
121            fn new() -> impl ::std::future::Future<Output = ::rust_actions::Result<Self>> + Send {
122                Self::setup()
123            }
124        }
125    };
126
127    TokenStream::from(expanded)
128}
129
130#[proc_macro_derive(Args, attributes(arg))]
131pub fn derive_args(input: TokenStream) -> TokenStream {
132    let input = parse_macro_input!(input as DeriveInput);
133    let name = &input.ident;
134
135    let expanded = quote! {
136        impl ::rust_actions::args::FromArgs for #name {
137            fn from_args(args: &::rust_actions::args::RawArgs) -> ::rust_actions::Result<Self> {
138                let value = ::rust_actions::serde_json::Value::Object(
139                    args.iter()
140                        .map(|(k, v)| (k.clone(), v.clone()))
141                        .collect()
142                );
143                ::rust_actions::serde_json::from_value(value)
144                    .map_err(|e| ::rust_actions::Error::Args(e.to_string()))
145            }
146        }
147    };
148
149    TokenStream::from(expanded)
150}
151
152#[proc_macro_derive(Outputs)]
153pub fn derive_outputs(input: TokenStream) -> TokenStream {
154    let input = parse_macro_input!(input as DeriveInput);
155    let name = &input.ident;
156
157    let expanded = quote! {
158        impl ::rust_actions::outputs::IntoOutputs for #name {
159            fn into_outputs(self) -> ::rust_actions::outputs::StepOutputs {
160                ::rust_actions::serde_json::to_value(&self)
161                    .map(|v| ::rust_actions::outputs::StepOutputs::from_value(v))
162                    .unwrap_or_default()
163            }
164        }
165    };
166
167    TokenStream::from(expanded)
168}
169
170#[proc_macro_attribute]
171pub fn before_all(_attr: TokenStream, item: TokenStream) -> TokenStream {
172    let input = parse_macro_input!(item as ItemFn);
173    TokenStream::from(quote! { #input })
174}
175
176#[proc_macro_attribute]
177pub fn after_all(_attr: TokenStream, item: TokenStream) -> TokenStream {
178    let input = parse_macro_input!(item as ItemFn);
179    TokenStream::from(quote! { #input })
180}
181
182#[proc_macro_attribute]
183pub fn before_scenario(_attr: TokenStream, item: TokenStream) -> TokenStream {
184    let input = parse_macro_input!(item as ItemFn);
185    TokenStream::from(quote! { #input })
186}
187
188#[proc_macro_attribute]
189pub fn after_scenario(_attr: TokenStream, item: TokenStream) -> TokenStream {
190    let input = parse_macro_input!(item as ItemFn);
191    TokenStream::from(quote! { #input })
192}
193
194#[proc_macro_attribute]
195pub fn before_step(_attr: TokenStream, item: TokenStream) -> TokenStream {
196    let input = parse_macro_input!(item as ItemFn);
197    TokenStream::from(quote! { #input })
198}
199
200#[proc_macro_attribute]
201pub fn after_step(_attr: TokenStream, item: TokenStream) -> TokenStream {
202    let input = parse_macro_input!(item as ItemFn);
203    TokenStream::from(quote! { #input })
204}
205
206struct GenerateTestsArgs {
207    path: LitStr,
208    world_type: syn::Path,
209}
210
211impl Parse for GenerateTestsArgs {
212    fn parse(input: ParseStream) -> syn::Result<Self> {
213        let path: LitStr = input.parse()?;
214        input.parse::<Token![,]>()?;
215        let world_type: syn::Path = input.parse()?;
216        Ok(GenerateTestsArgs { path, world_type })
217    }
218}
219
220#[derive(Debug, Deserialize)]
221struct WorkflowHeader {
222    #[allow(dead_code)]
223    name: Option<String>,
224    #[serde(default)]
225    on: Option<WorkflowTrigger>,
226}
227
228#[derive(Debug, Deserialize)]
229struct WorkflowTrigger {
230    #[serde(default)]
231    workflow_call: Option<HashMap<String, serde_yaml::Value>>,
232}
233
234fn is_reusable_workflow(path: &Path) -> bool {
235    let content = match std::fs::read_to_string(path) {
236        Ok(c) => c,
237        Err(_) => return false,
238    };
239
240    let header: WorkflowHeader = match serde_yaml::from_str(&content) {
241        Ok(h) => h,
242        Err(_) => return false,
243    };
244
245    header
246        .on
247        .as_ref()
248        .map(|t| t.workflow_call.is_some())
249        .unwrap_or(false)
250}
251
252fn discover_yaml_files(dir: &Path) -> Vec<PathBuf> {
253    walkdir::WalkDir::new(dir)
254        .into_iter()
255        .filter_map(|e| e.ok())
256        .filter(|e| {
257            e.path().is_file()
258                && e.path()
259                    .extension()
260                    .map(|ext| ext == "yaml" || ext == "yml")
261                    .unwrap_or(false)
262        })
263        .map(|e| e.path().to_path_buf())
264        .collect()
265}
266
267fn path_to_test_name(path: &Path, base: &Path) -> proc_macro2::Ident {
268    let rel_path = path.strip_prefix(base).unwrap_or(path);
269
270    let name = rel_path
271        .to_string_lossy()
272        .replace(std::path::MAIN_SEPARATOR, "_")
273        .replace(".yaml", "")
274        .replace(".yml", "")
275        .replace('-', "_")
276        .replace('.', "_");
277
278    let name = format!("test_{}", name);
279    proc_macro2::Ident::new(&name, proc_macro2::Span::call_site())
280}
281
282#[proc_macro]
283pub fn generate_tests(input: TokenStream) -> TokenStream {
284    let args = parse_macro_input!(input as GenerateTestsArgs);
285    let workflows_path = args.path.value();
286    let world_type = &args.world_type;
287
288    let manifest_dir = std::env::var("CARGO_MANIFEST_DIR")
289        .expect("CARGO_MANIFEST_DIR not set");
290    let full_path = Path::new(&manifest_dir).join(&workflows_path);
291
292    if !full_path.exists() {
293        let err = format!("Workflows path does not exist: {}", full_path.display());
294        return syn::Error::new_spanned(&args.path, err)
295            .to_compile_error()
296            .into();
297    }
298
299    let yaml_files = discover_yaml_files(&full_path);
300
301    let tests = yaml_files
302        .iter()
303        .filter(|f| !is_reusable_workflow(f))
304        .map(|file| {
305            let rel_path = file.strip_prefix(&manifest_dir).unwrap_or(file);
306            let test_name = path_to_test_name(file, &full_path);
307            let path_str = rel_path.to_string_lossy();
308
309            quote! {
310                #[::tokio::test(flavor = "current_thread", start_paused = true)]
311                async fn #test_name() {
312                    ::rust_actions::prelude::RustActions::<#world_type>::new()
313                        .workflow(#path_str)
314                        .run()
315                        .await;
316                }
317            }
318        });
319
320    let expanded = quote! {
321        #(#tests)*
322    };
323
324    TokenStream::from(expanded)
325}