Skip to main content

pyro_macro/module/
mod.rs

1mod codegen;
2mod parse;
3mod spec;
4
5use proc_macro2::TokenStream;
6use syn::parse2;
7
8pub use codegen::expand;
9pub use parse::ModuleAttrs;
10pub use spec::generate_module_spec;
11
12/// For generating module code from source content (used by build tools)
13pub fn generate_module(content: &str) -> syn::Result<syn::File> {
14    let file = syn::parse_file(content)?;
15
16    let mut generated_code = quote::quote! {
17        //! Automatically generated by pyroduct. DO NOT EDIT.
18        #![allow(unused_imports, dead_code, unused_variables, nonstandard_style)]
19    };
20
21    for item in file.items {
22        match item {
23            syn::Item::Fn(item_fn) => {
24                if has_module_attr(&item_fn.attrs) {
25                    let attr =
26                        extract_module_attr(&item_fn.attrs)?.ok_or(syn::Error::new_spanned(
27                            &item_fn,
28                            "Module attribute requires arguments: #[module(output = ...)]",
29                        ))?;
30                    let config: ModuleAttrs = parse2(attr)?;
31
32                    // Clone the function without the #[module] attribute
33                    let mut clean_fn = item_fn.clone();
34                    clean_fn.attrs.retain(|a| !is_module_attr(a));
35
36                    let expanded = expand(config, clean_fn)?;
37                    generated_code.extend(expanded);
38                } else {
39                    // Pass through non-module functions
40                    generated_code.extend(quote::quote! { #item_fn });
41                }
42            }
43            other => {
44                // Pass through other items unchanged
45                generated_code.extend(quote::quote! { #other });
46            }
47        }
48    }
49    let code: syn::File = syn::parse2(generated_code)?;
50    Ok(code)
51}
52
53fn has_module_attr(attrs: &[syn::Attribute]) -> bool {
54    attrs.iter().any(|a| is_module_attr(a))
55}
56
57fn is_module_attr(attr: &syn::Attribute) -> bool {
58    if attr.path().is_ident("module") {
59        return true;
60    }
61    if attr.path().segments.len() == 2
62        && attr.path().segments[0].ident == "pyroduct"
63        && attr.path().segments[1].ident == "module"
64    {
65        return true;
66    }
67    false
68}
69
70fn extract_module_attr(attrs: &[syn::Attribute]) -> syn::Result<Option<TokenStream>> {
71    for attr in attrs {
72        if is_module_attr(attr) {
73            match &attr.meta {
74                syn::Meta::List(list) => {
75                    return Ok(Some(list.tokens.clone()));
76                }
77                syn::Meta::Path(_) => {
78                    return Err(syn::Error::new_spanned(
79                        attr,
80                        "Module attribute requires arguments: #[module(output = ...)]",
81                    ));
82                }
83                syn::Meta::NameValue(_) => {
84                    return Err(syn::Error::new_spanned(
85                        attr,
86                        "Invalid module attribute format",
87                    ));
88                }
89            }
90        }
91    }
92    Ok(None)
93}