Skip to main content

pyro_macro/module/
mod.rs

1mod codegen;
2mod parse;
3mod spec;
4
5use proc_macro2::TokenStream;
6use syn::parse2;
7
8pub use codegen::{expand, expand_session};
9pub use parse::ModuleAttrs;
10pub use spec::generate_module_spec;
11
12/// For generating module code from source content (used by build tools)
13pub fn generate_module(content: &str) -> syn::Result<syn::File> {
14    let file = syn::parse_file(content)?;
15
16    let mut generated_code = quote::quote! {
17        //! Automatically generated by pyroduct. DO NOT EDIT.
18        #![allow(unused_imports, dead_code, unused_variables, nonstandard_style)]
19    };
20
21    for item in file.items {
22        match item {
23            syn::Item::Fn(item_fn) => {
24                if has_module_attr(&item_fn.attrs) {
25                    let attr =
26                        extract_module_attr(&item_fn.attrs)?.ok_or(syn::Error::new_spanned(
27                            &item_fn,
28                            "Module attribute requires arguments: #[module(output = ...)]",
29                        ))?;
30                    let config: ModuleAttrs = parse2(attr)?;
31
32                    // Clone the function without the #[module] attribute
33                    let mut clean_fn = item_fn.clone();
34                    clean_fn.attrs.retain(|a| !is_module_attr(a));
35
36                    let expanded = if config.session {
37                        expand_session(config, clean_fn)?
38                    } else {
39                        expand(config, clean_fn)?
40                    };
41                    generated_code.extend(expanded);
42                } else {
43                    // Pass through non-module functions
44                    generated_code.extend(quote::quote! { #item_fn });
45                }
46            }
47            other => {
48                // Pass through other items unchanged
49                generated_code.extend(quote::quote! { #other });
50            }
51        }
52    }
53    let code: syn::File = syn::parse2(generated_code)?;
54    Ok(code)
55}
56
57fn has_module_attr(attrs: &[syn::Attribute]) -> bool {
58    attrs.iter().any(is_module_attr)
59}
60
61fn is_module_attr(attr: &syn::Attribute) -> bool {
62    if attr.path().is_ident("module") {
63        return true;
64    }
65    if attr.path().segments.len() == 2
66        && attr.path().segments[0].ident == "pyroduct"
67        && attr.path().segments[1].ident == "module"
68    {
69        return true;
70    }
71    false
72}
73
74fn extract_module_attr(attrs: &[syn::Attribute]) -> syn::Result<Option<TokenStream>> {
75    for attr in attrs {
76        if is_module_attr(attr) {
77            match &attr.meta {
78                syn::Meta::List(list) => {
79                    return Ok(Some(list.tokens.clone()));
80                }
81                syn::Meta::Path(_) => {
82                    return Err(syn::Error::new_spanned(
83                        attr,
84                        "Module attribute requires arguments: #[module(output = ...)]",
85                    ));
86                }
87                syn::Meta::NameValue(_) => {
88                    return Err(syn::Error::new_spanned(
89                        attr,
90                        "Invalid module attribute format",
91                    ));
92                }
93            }
94        }
95    }
96    Ok(None)
97}