pyro-macro 0.2.1

Derive macros for Pyroduct
Documentation
mod codegen;
mod parse;
mod spec;

use proc_macro2::TokenStream;
use syn::parse2;

pub use codegen::{expand, expand_session};
pub use parse::ModuleAttrs;
pub use spec::generate_module_spec;

/// For generating module code from source content (used by build tools)
pub fn generate_module(content: &str) -> syn::Result<syn::File> {
    let file = syn::parse_file(content)?;

    let mut generated_code = quote::quote! {
        //! Automatically generated by pyroduct. DO NOT EDIT.
        #![allow(unused_imports, dead_code, unused_variables, nonstandard_style)]
    };

    for item in file.items {
        match item {
            syn::Item::Fn(item_fn) => {
                if has_module_attr(&item_fn.attrs) {
                    let attr =
                        extract_module_attr(&item_fn.attrs)?.ok_or(syn::Error::new_spanned(
                            &item_fn,
                            "Module attribute requires arguments: #[module(output = ...)]",
                        ))?;
                    let config: ModuleAttrs = parse2(attr)?;

                    // Clone the function without the #[module] attribute
                    let mut clean_fn = item_fn.clone();
                    clean_fn.attrs.retain(|a| !is_module_attr(a));

                    let expanded = if config.session {
                        expand_session(config, clean_fn)?
                    } else {
                        expand(config, clean_fn)?
                    };
                    generated_code.extend(expanded);
                } else {
                    // Pass through non-module functions
                    generated_code.extend(quote::quote! { #item_fn });
                }
            }
            other => {
                // Pass through other items unchanged
                generated_code.extend(quote::quote! { #other });
            }
        }
    }
    let code: syn::File = syn::parse2(generated_code)?;
    Ok(code)
}

fn has_module_attr(attrs: &[syn::Attribute]) -> bool {
    attrs.iter().any(is_module_attr)
}

fn is_module_attr(attr: &syn::Attribute) -> bool {
    if attr.path().is_ident("module") {
        return true;
    }
    if attr.path().segments.len() == 2
        && attr.path().segments[0].ident == "pyroduct"
        && attr.path().segments[1].ident == "module"
    {
        return true;
    }
    false
}

fn extract_module_attr(attrs: &[syn::Attribute]) -> syn::Result<Option<TokenStream>> {
    for attr in attrs {
        if is_module_attr(attr) {
            match &attr.meta {
                syn::Meta::List(list) => {
                    return Ok(Some(list.tokens.clone()));
                }
                syn::Meta::Path(_) => {
                    return Err(syn::Error::new_spanned(
                        attr,
                        "Module attribute requires arguments: #[module(output = ...)]",
                    ));
                }
                syn::Meta::NameValue(_) => {
                    return Err(syn::Error::new_spanned(
                        attr,
                        "Invalid module attribute format",
                    ));
                }
            }
        }
    }
    Ok(None)
}