Skip to main content

pyforge_macros_backend/
module.rs

1//! Code generation for the function that initializes a python module and adds classes and function.
2
3#[cfg(feature = "experimental-inspect")]
4use crate::introspection::{
5    attribute_introspection_code, introspection_id_const, module_introspection_code,
6};
7#[cfg(feature = "experimental-inspect")]
8use crate::py_expr::PyExpr;
9use crate::{
10    attributes::{
11        self, kw, take_attributes, take_pyo3_options, CrateAttribute, GILUsedAttribute,
12        ModuleAttribute, NameAttribute, SubmoduleAttribute,
13    },
14    combine_errors::CombineErrors,
15    get_doc,
16    pyclass::PyClassPyForgeOption,
17    pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
18    utils::{has_attribute, has_attribute_with_namespace, Ctx, IdentOrStr, PythonDoc},
19};
20use proc_macro2::{Span, TokenStream};
21use quote::{quote, ToTokens};
22use std::ffi::CString;
23use syn::LitCStr;
24use syn::{
25    ext::IdentExt,
26    parse::{Parse, ParseStream},
27    parse_quote, parse_quote_spanned,
28    punctuated::Punctuated,
29    spanned::Spanned,
30    token::Comma,
31    Item, Meta, Path, Result,
32};
33
34#[derive(Default)]
35pub struct PyModuleOptions {
36    krate: Option<CrateAttribute>,
37    name: Option<NameAttribute>,
38    module: Option<ModuleAttribute>,
39    submodule: Option<kw::submodule>,
40    gil_used: Option<GILUsedAttribute>,
41}
42
43impl Parse for PyModuleOptions {
44    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
45        let mut options: PyModuleOptions = Default::default();
46
47        options.add_attributes(
48            Punctuated::<PyModulePyForgeOption, syn::Token![,]>::parse_terminated(input)?,
49        )?;
50
51        Ok(options)
52    }
53}
54
55impl PyModuleOptions {
56    fn take_pyo3_options(&mut self, attrs: &mut Vec<syn::Attribute>) -> Result<()> {
57        self.add_attributes(take_pyo3_options(attrs)?)
58    }
59
60    fn add_attributes(
61        &mut self,
62        attrs: impl IntoIterator<Item = PyModulePyForgeOption>,
63    ) -> Result<()> {
64        macro_rules! set_option {
65            ($key:ident $(, $extra:literal)?) => {
66                {
67                    ensure_spanned!(
68                        self.$key.is_none(),
69                        $key.span() => concat!("`", stringify!($key), "` may only be specified once" $(, $extra)?)
70                    );
71                    self.$key = Some($key);
72                }
73            };
74        }
75        attrs
76            .into_iter()
77            .map(|attr| {
78                match attr {
79                    PyModulePyForgeOption::Crate(krate) => set_option!(krate),
80                    PyModulePyForgeOption::Name(name) => set_option!(name),
81                    PyModulePyForgeOption::Module(module) => set_option!(module),
82                    PyModulePyForgeOption::Submodule(submodule) => set_option!(
83                        submodule,
84                        " (it is implicitly always specified for nested modules)"
85                    ),
86                    PyModulePyForgeOption::GILUsed(gil_used) => {
87                        set_option!(gil_used)
88                    }
89                }
90
91                Ok(())
92            })
93            .try_combine_syn_errors()?;
94        Ok(())
95    }
96}
97
98pub fn pymodule_module_impl(
99    module: &mut syn::ItemMod,
100    mut options: PyModuleOptions,
101) -> Result<TokenStream> {
102    let syn::ItemMod {
103        attrs,
104        vis,
105        unsafety: _,
106        ident,
107        mod_token,
108        content,
109        semi: _,
110    } = module;
111    let items = if let Some((_, items)) = content {
112        items
113    } else {
114        bail_spanned!(mod_token.span() => "`#[pymodule]` can only be used on inline modules")
115    };
116    options.take_pyo3_options(attrs)?;
117    let ctx = &Ctx::new(&options.krate, None);
118    let Ctx { pyo3_path, .. } = ctx;
119    let doc = get_doc(attrs, None);
120    let name = options
121        .name
122        .map_or_else(|| ident.unraw(), |name| name.value.0);
123    let full_name = if let Some(module) = &options.module {
124        format!("{}.{}", module.value.value(), name)
125    } else {
126        name.to_string()
127    };
128
129    let mut module_items = Vec::new();
130    let mut module_items_cfg_attrs = Vec::new();
131    #[cfg(feature = "experimental-inspect")]
132    let mut introspection_chunks = Vec::new();
133    #[cfg(not(feature = "experimental-inspect"))]
134    let introspection_chunks = Vec::<TokenStream>::new();
135
136    fn extract_use_items(
137        source: &syn::UseTree,
138        cfg_attrs: &[syn::Attribute],
139        target_items: &mut Vec<syn::Ident>,
140        target_cfg_attrs: &mut Vec<Vec<syn::Attribute>>,
141    ) -> Result<()> {
142        match source {
143            syn::UseTree::Name(name) => {
144                target_items.push(name.ident.clone());
145                target_cfg_attrs.push(cfg_attrs.to_vec());
146            }
147            syn::UseTree::Path(path) => {
148                extract_use_items(&path.tree, cfg_attrs, target_items, target_cfg_attrs)?
149            }
150            syn::UseTree::Group(group) => {
151                for tree in &group.items {
152                    extract_use_items(tree, cfg_attrs, target_items, target_cfg_attrs)?
153                }
154            }
155            syn::UseTree::Glob(glob) => {
156                bail_spanned!(glob.span() => "#[pymodule] cannot import glob statements")
157            }
158            syn::UseTree::Rename(rename) => {
159                target_items.push(rename.rename.clone());
160                target_cfg_attrs.push(cfg_attrs.to_vec());
161            }
162        }
163        Ok(())
164    }
165
166    let mut pymodule_init = None;
167    let mut module_consts = Vec::new();
168    let mut module_consts_cfg_attrs = Vec::new();
169
170    let _: Vec<()> = (*items).iter_mut().map(|item|{
171        match item {
172            Item::Use(item_use) => {
173                let is_pymodule_export =
174                    find_and_remove_attribute(&mut item_use.attrs, "pymodule_export");
175                if is_pymodule_export {
176                    let cfg_attrs = get_cfg_attributes(&item_use.attrs);
177                    extract_use_items(
178                        &item_use.tree,
179                        &cfg_attrs,
180                        &mut module_items,
181                        &mut module_items_cfg_attrs,
182                    )?;
183                }
184            }
185            Item::Fn(item_fn) => {
186                ensure_spanned!(
187                    !has_attribute(&item_fn.attrs, "pymodule_export"),
188                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
189                );
190                let is_pymodule_init =
191                    find_and_remove_attribute(&mut item_fn.attrs, "pymodule_init");
192                let ident = &item_fn.sig.ident;
193                if is_pymodule_init {
194                    ensure_spanned!(
195                        !has_attribute(&item_fn.attrs, "pyfunction"),
196                        item_fn.span() => "`#[pyfunction]` cannot be used alongside `#[pymodule_init]`"
197                    );
198                    ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified");
199                    pymodule_init = Some(quote! { #ident(module)?; });
200                } else if has_attribute(&item_fn.attrs, "pyfunction")
201                    || has_attribute_with_namespace(
202                        &item_fn.attrs,
203                        Some(pyo3_path),
204                        &["pyfunction"],
205                    )
206                    || has_attribute_with_namespace(
207                        &item_fn.attrs,
208                        Some(pyo3_path),
209                        &["prelude", "pyfunction"],
210                    )
211                {
212                    module_items.push(ident.clone());
213                    module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
214                }
215            }
216            Item::Struct(item_struct) => {
217                ensure_spanned!(
218                    !has_attribute(&item_struct.attrs, "pymodule_export"),
219                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
220                );
221                if has_attribute(&item_struct.attrs, "pyclass")
222                    || has_attribute_with_namespace(
223                        &item_struct.attrs,
224                        Some(pyo3_path),
225                        &["pyclass"],
226                    )
227                    || has_attribute_with_namespace(
228                        &item_struct.attrs,
229                        Some(pyo3_path),
230                        &["prelude", "pyclass"],
231                    )
232                {
233                    module_items.push(item_struct.ident.clone());
234                    module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
235                    if !has_pyo3_module_declared::<PyClassPyForgeOption>(
236                        &item_struct.attrs,
237                        "pyclass",
238                        |option| matches!(option, PyClassPyForgeOption::Module(_)),
239                    )? {
240                        set_module_attribute(&mut item_struct.attrs, &full_name);
241                    }
242                }
243            }
244            Item::Enum(item_enum) => {
245                ensure_spanned!(
246                    !has_attribute(&item_enum.attrs, "pymodule_export"),
247                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
248                );
249                if has_attribute(&item_enum.attrs, "pyclass")
250                    || has_attribute_with_namespace(&item_enum.attrs, Some(pyo3_path), &["pyclass"])
251                    || has_attribute_with_namespace(
252                        &item_enum.attrs,
253                        Some(pyo3_path),
254                        &["prelude", "pyclass"],
255                    )
256                {
257                    module_items.push(item_enum.ident.clone());
258                    module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
259                    if !has_pyo3_module_declared::<PyClassPyForgeOption>(
260                        &item_enum.attrs,
261                        "pyclass",
262                        |option| matches!(option, PyClassPyForgeOption::Module(_)),
263                    )? {
264                        set_module_attribute(&mut item_enum.attrs, &full_name);
265                    }
266                }
267            }
268            Item::Mod(item_mod) => {
269                ensure_spanned!(
270                    !has_attribute(&item_mod.attrs, "pymodule_export"),
271                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
272                );
273                if has_attribute(&item_mod.attrs, "pymodule")
274                    || has_attribute_with_namespace(&item_mod.attrs, Some(pyo3_path), &["pymodule"])
275                    || has_attribute_with_namespace(
276                        &item_mod.attrs,
277                        Some(pyo3_path),
278                        &["prelude", "pymodule"],
279                    )
280                {
281                    module_items.push(item_mod.ident.clone());
282                    module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
283                    if !has_pyo3_module_declared::<PyModulePyForgeOption>(
284                        &item_mod.attrs,
285                        "pymodule",
286                        |option| matches!(option, PyModulePyForgeOption::Module(_)),
287                    )? {
288                        set_module_attribute(&mut item_mod.attrs, &full_name);
289                    }
290                    item_mod
291                        .attrs
292                        .push(parse_quote_spanned!(item_mod.mod_token.span()=> #[pyo3(submodule)]));
293                }
294            }
295            Item::ForeignMod(item) => {
296                ensure_spanned!(
297                    !has_attribute(&item.attrs, "pymodule_export"),
298                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
299                );
300            }
301            Item::Trait(item) => {
302                ensure_spanned!(
303                    !has_attribute(&item.attrs, "pymodule_export"),
304                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
305                );
306            }
307            Item::Const(item) => {
308                if !find_and_remove_attribute(&mut item.attrs, "pymodule_export") {
309                    return Ok(());
310                }
311                module_consts.push(item.ident.clone());
312                module_consts_cfg_attrs.push(get_cfg_attributes(&item.attrs));
313                #[cfg(feature = "experimental-inspect")]
314                {
315                    let cfg_attrs = get_cfg_attributes(&item.attrs);
316                    let chunk = attribute_introspection_code(
317                        pyo3_path,
318                        None,
319                        item.ident.unraw().to_string(),
320                        PyExpr::constant_from_expression(&item.expr),
321                        (*item.ty).clone(),
322                        get_doc(&item.attrs, None).as_ref(),
323                        true,
324                    );
325                    introspection_chunks.push(quote! {
326                        #(#cfg_attrs)*
327                        #chunk
328                    });
329                }
330            }
331            Item::Static(item) => {
332                ensure_spanned!(
333                    !has_attribute(&item.attrs, "pymodule_export"),
334                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
335                );
336            }
337            Item::Macro(item) => {
338                ensure_spanned!(
339                    !has_attribute(&item.attrs, "pymodule_export"),
340                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
341                );
342            }
343            Item::ExternCrate(item) => {
344                ensure_spanned!(
345                    !has_attribute(&item.attrs, "pymodule_export"),
346                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
347                );
348            }
349            Item::Impl(item) => {
350                ensure_spanned!(
351                    !has_attribute(&item.attrs, "pymodule_export"),
352                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
353                );
354            }
355            Item::TraitAlias(item) => {
356                ensure_spanned!(
357                    !has_attribute(&item.attrs, "pymodule_export"),
358                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
359                );
360            }
361            Item::Type(item) => {
362                ensure_spanned!(
363                    !has_attribute(&item.attrs, "pymodule_export"),
364                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
365                );
366            }
367            Item::Union(item) => {
368                ensure_spanned!(
369                    !has_attribute(&item.attrs, "pymodule_export"),
370                    item.span() => "`#[pymodule_export]` may only be used on `use` or `const` statements"
371                );
372            }
373            _ => (),
374        }
375        Ok(())
376    }).try_combine_syn_errors()?;
377
378    #[cfg(feature = "experimental-inspect")]
379    let introspection = module_introspection_code(
380        pyo3_path,
381        &name.to_string(),
382        &module_items,
383        &module_items_cfg_attrs,
384        doc.as_ref(),
385        pymodule_init.is_some(),
386    );
387    #[cfg(not(feature = "experimental-inspect"))]
388    let introspection = quote! {};
389    #[cfg(feature = "experimental-inspect")]
390    let introspection_id = introspection_id_const();
391    #[cfg(not(feature = "experimental-inspect"))]
392    let introspection_id = quote! {};
393
394    let gil_used = options.gil_used.is_some_and(|op| op.value.value);
395
396    let initialization = module_initialization(
397        &full_name,
398        &name,
399        ctx,
400        quote! { __pyo3_pymodule },
401        options.submodule.is_some(),
402        gil_used,
403        doc.as_ref(),
404    )?;
405
406    let module_consts_names = module_consts.iter().map(|i| i.unraw().to_string());
407
408    Ok(quote!(
409        #(#attrs)*
410        #vis #mod_token #ident {
411            #(#items)*
412
413            #initialization
414            #introspection
415            #introspection_id
416            #(#introspection_chunks)*
417
418            fn __pyo3_pymodule(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
419                use #pyo3_path::impl_::pymodule::PyAddToModule;
420                #(
421                    #(#module_items_cfg_attrs)*
422                    #module_items::_PYO3_DEF.add_to_module(module)?;
423                )*
424
425                #(
426                    #(#module_consts_cfg_attrs)*
427                    #pyo3_path::types::PyModuleMethods::add(module, #module_consts_names, #module_consts)?;
428                )*
429
430                #pymodule_init
431                ::std::result::Result::Ok(())
432            }
433        }
434    ))
435}
436
437/// Generates the function that is called by the python interpreter to initialize the native
438/// module
439pub fn pymodule_function_impl(
440    function: &mut syn::ItemFn,
441    mut options: PyModuleOptions,
442) -> Result<TokenStream> {
443    options.take_pyo3_options(&mut function.attrs)?;
444    process_functions_in_module(&options, function)?;
445    let ctx = &Ctx::new(&options.krate, None);
446    let Ctx { pyo3_path, .. } = ctx;
447    let ident = &function.sig.ident;
448    let name = options
449        .name
450        .map_or_else(|| ident.unraw(), |name| name.value.0);
451    let vis = &function.vis;
452    let doc = get_doc(&function.attrs, None);
453
454    let gil_used = options.gil_used.is_some_and(|op| op.value.value);
455
456    let initialization = module_initialization(
457        &name.to_string(),
458        &name,
459        ctx,
460        quote! { ModuleExec::__pyo3_module_exec },
461        false,
462        gil_used,
463        doc.as_ref(),
464    )?;
465
466    #[cfg(feature = "experimental-inspect")]
467    let introspection = module_introspection_code(
468        pyo3_path,
469        &name.unraw().to_string(),
470        &[],
471        &[],
472        doc.as_ref(),
473        true,
474    );
475    #[cfg(not(feature = "experimental-inspect"))]
476    let introspection = quote! {};
477    #[cfg(feature = "experimental-inspect")]
478    let introspection_id = introspection_id_const();
479    #[cfg(not(feature = "experimental-inspect"))]
480    let introspection_id = quote! {};
481
482    // Module function called with optional Python<'_> marker as first arg, followed by the module.
483    let mut module_args = Vec::new();
484    if function.sig.inputs.len() == 2 {
485        module_args.push(quote!(module.py()));
486    }
487    module_args.push(quote!(::std::convert::Into::into(module)));
488
489    Ok(quote! {
490        #[doc(hidden)]
491        #vis mod #ident {
492            #initialization
493            #introspection
494            #introspection_id
495        }
496
497        // Generate the definition inside an anonymous function in the same scope as the original function -
498        // this avoids complications around the fact that the generated module has a different scope
499        // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pymodule] is
500        // inside a function body)
501        #[allow(unknown_lints, non_local_definitions)]
502        impl #ident::ModuleExec {
503            fn __pyo3_module_exec(module: &#pyo3_path::Bound<'_, #pyo3_path::types::PyModule>) -> #pyo3_path::PyResult<()> {
504                #ident(#(#module_args),*)
505            }
506        }
507    })
508}
509
510fn module_initialization(
511    full_name: &str,
512    name: &syn::Ident,
513    ctx: &Ctx,
514    module_exec: TokenStream,
515    is_submodule: bool,
516    gil_used: bool,
517    doc: Option<&PythonDoc>,
518) -> Result<TokenStream> {
519    let Ctx { pyo3_path, .. } = ctx;
520    let pyinit_symbol = format!("PyInit_{name}");
521    let pymodexport_symbol = format!("PyModExport_{name}");
522    let pyo3_name = LitCStr::new(&CString::new(full_name).unwrap(), Span::call_site());
523    let doc = if let Some(doc) = doc {
524        doc.to_cstr_stream(ctx)?
525    } else {
526        c"".into_token_stream()
527    };
528
529    let mut result = quote! {
530        #[doc(hidden)]
531        pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name;
532
533        // This structure exists for `fn` modules declared within `fn` bodies, where due to the hidden
534        // module (used for importing) the `fn` to initialize the module cannot be seen from the #module_def
535        // declaration just below.
536        #[doc(hidden)]
537        pub(super) struct ModuleExec;
538
539        #[doc(hidden)]
540        pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = {
541            use #pyo3_path::impl_::pymodule as impl_;
542
543            unsafe extern "C" fn __pyo3_module_exec(module: *mut #pyo3_path::ffi::PyObject) -> ::std::os::raw::c_int {
544                #pyo3_path::impl_::trampoline::module_exec(module, #module_exec)
545            }
546
547            // The full slots, used for the PyModExport initializaiton
548            static SLOTS: impl_::PyModuleSlots = impl_::PyModuleSlotsBuilder::new()
549                .with_mod_exec(__pyo3_module_exec)
550                .with_abi_info()
551                .with_gil_used(#gil_used)
552                .with_name(__PYO3_NAME)
553                .with_doc(#doc)
554                .build();
555
556            // Since the macros need to be written agnostic to the Python version
557            // we need to explicitly pass the name and docstring for PyModuleDef
558            // initializaiton.
559            impl_::ModuleDef::new(__PYO3_NAME, #doc, &SLOTS)
560        };
561    };
562    if !is_submodule {
563        result.extend(quote! {
564            /// This autogenerated function is called by the python interpreter when importing
565            /// the module on Python 3.14 and older.
566            #[doc(hidden)]
567            #[export_name = #pyinit_symbol]
568            pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject {
569                _PYO3_DEF.init_multi_phase()
570            }
571
572            /// This autogenerated function is called by the python interpreter when importing
573            /// the module on Python 3.15 and newer.
574            #[doc(hidden)]
575            #[export_name = #pymodexport_symbol]
576            pub unsafe extern "C" fn __pyo3_export() -> *mut #pyo3_path::ffi::PyModuleDef_Slot {
577                _PYO3_DEF.get_slots()
578            }
579        });
580    }
581    Ok(result)
582}
583
584/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
585fn process_functions_in_module(options: &PyModuleOptions, func: &mut syn::ItemFn) -> Result<()> {
586    let ctx = &Ctx::new(&options.krate, None);
587    let Ctx { pyo3_path, .. } = ctx;
588    let mut stmts: Vec<syn::Stmt> = Vec::new();
589
590    for mut stmt in func.block.stmts.drain(..) {
591        if let syn::Stmt::Item(Item::Fn(func)) = &mut stmt {
592            if let Some((pyfn_span, pyfn_args)) = get_pyfn_attr(&mut func.attrs)? {
593                let module_name = pyfn_args.modname;
594                let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?;
595                let name = &func.sig.ident;
596                let statements: Vec<syn::Stmt> = syn::parse_quote_spanned! {
597                    pyfn_span =>
598                    #wrapped_function
599                    {
600                        use #pyo3_path::types::PyModuleMethods;
601                        #module_name.add_function(#pyo3_path::wrap_pyfunction!(#name, #module_name.as_borrowed())?)?;
602                        #[deprecated(note = "`pyfn` will be removed in a future PyForge version, use declarative `#[pymodule]` with `mod` instead")]
603                        #[allow(dead_code)]
604                        const PYFN_ATTRIBUTE: () = ();
605                        const _: () = PYFN_ATTRIBUTE;
606                    }
607                };
608                stmts.extend(statements);
609            }
610        };
611        stmts.push(stmt);
612    }
613
614    func.block.stmts = stmts;
615    Ok(())
616}
617
618pub struct PyFnArgs {
619    modname: Path,
620    options: PyFunctionOptions,
621}
622
623impl Parse for PyFnArgs {
624    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
625        let modname = input.parse().map_err(
626            |e| err_spanned!(e.span() => "expected module as first argument to #[pyfn()]"),
627        )?;
628
629        if input.is_empty() {
630            return Ok(Self {
631                modname,
632                options: Default::default(),
633            });
634        }
635
636        let _: Comma = input.parse()?;
637
638        Ok(Self {
639            modname,
640            options: input.parse()?,
641        })
642    }
643}
644
645/// Extracts the data from the #[pyfn(...)] attribute of a function
646fn get_pyfn_attr(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Option<(Span, PyFnArgs)>> {
647    let mut pyfn_args: Option<(Span, PyFnArgs)> = None;
648
649    take_attributes(attrs, |attr| {
650        if attr.path().is_ident("pyfn") {
651            ensure_spanned!(
652                pyfn_args.is_none(),
653                attr.span() => "`#[pyfn] may only be specified once"
654            );
655            pyfn_args = Some((attr.path().span(), attr.parse_args()?));
656            Ok(true)
657        } else {
658            Ok(false)
659        }
660    })?;
661
662    if let Some((_, pyfn_args)) = &mut pyfn_args {
663        pyfn_args
664            .options
665            .add_attributes(take_pyo3_options(attrs)?)?;
666    }
667
668    Ok(pyfn_args)
669}
670
671fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<syn::Attribute> {
672    attrs
673        .iter()
674        .filter(|attr| attr.path().is_ident("cfg"))
675        .cloned()
676        .collect()
677}
678
679fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bool {
680    let mut found = false;
681    attrs.retain(|attr| {
682        if attr.path().is_ident(ident) {
683            found = true;
684            false
685        } else {
686            true
687        }
688    });
689    found
690}
691
692impl PartialEq<syn::Ident> for IdentOrStr<'_> {
693    fn eq(&self, other: &syn::Ident) -> bool {
694        match self {
695            IdentOrStr::Str(s) => other == s,
696            IdentOrStr::Ident(i) => other == i,
697        }
698    }
699}
700
701fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
702    attrs.push(parse_quote!(#[pyo3(module = #module_name)]));
703}
704
705fn has_pyo3_module_declared<T: Parse>(
706    attrs: &[syn::Attribute],
707    root_attribute_name: &str,
708    is_module_option: impl Fn(&T) -> bool + Copy,
709) -> Result<bool> {
710    for attr in attrs {
711        if (attr.path().is_ident("pyo3") || attr.path().is_ident(root_attribute_name))
712            && matches!(attr.meta, Meta::List(_))
713        {
714            for option in &attr.parse_args_with(Punctuated::<T, Comma>::parse_terminated)? {
715                if is_module_option(option) {
716                    return Ok(true);
717                }
718            }
719        }
720    }
721    Ok(false)
722}
723
724enum PyModulePyForgeOption {
725    Submodule(SubmoduleAttribute),
726    Crate(CrateAttribute),
727    Name(NameAttribute),
728    Module(ModuleAttribute),
729    GILUsed(GILUsedAttribute),
730}
731
732impl Parse for PyModulePyForgeOption {
733    fn parse(input: ParseStream<'_>) -> Result<Self> {
734        let lookahead = input.lookahead1();
735        if lookahead.peek(attributes::kw::name) {
736            input.parse().map(PyModulePyForgeOption::Name)
737        } else if lookahead.peek(syn::Token![crate]) {
738            input.parse().map(PyModulePyForgeOption::Crate)
739        } else if lookahead.peek(attributes::kw::module) {
740            input.parse().map(PyModulePyForgeOption::Module)
741        } else if lookahead.peek(attributes::kw::submodule) {
742            input.parse().map(PyModulePyForgeOption::Submodule)
743        } else if lookahead.peek(attributes::kw::gil_used) {
744            input.parse().map(PyModulePyForgeOption::GILUsed)
745        } else {
746            Err(lookahead.error())
747        }
748    }
749}