pyo3_derive_backend/
module.rs

1// Copyright (c) 2017-present PyO3 Project and Contributors
2//! Code generation for the function that initializes a python module and adds classes and function.
3
4use crate::method;
5use crate::pyfunction::PyFunctionAttr;
6use crate::pymethod;
7use crate::pymethod::get_arg_names;
8use crate::utils;
9use proc_macro2::{Span, TokenStream};
10use quote::{format_ident, quote};
11use syn::Ident;
12
13/// Generates the function that is called by the python interpreter to initialize the native
14/// module
15pub fn py_init(fnname: &Ident, name: &Ident, doc: syn::LitStr) -> TokenStream {
16    let cb_name = Ident::new(&format!("PyInit_{}", name), Span::call_site());
17
18    quote! {
19        #[no_mangle]
20        #[allow(non_snake_case)]
21        /// This autogenerated function is called by the python interpreter when importing
22        /// the module.
23        pub unsafe extern "C" fn #cb_name() -> *mut pyo3::ffi::PyObject {
24            use pyo3::derive_utils::ModuleDef;
25            const NAME: &'static str = concat!(stringify!(#name), "\0");
26            static MODULE_DEF: ModuleDef = unsafe { ModuleDef::new(NAME) };
27
28            pyo3::callback_body!(_py, { MODULE_DEF.make_module(#doc, #fnname) })
29        }
30    }
31}
32
33/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
34pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
35    let mut stmts: Vec<syn::Stmt> = Vec::new();
36
37    for stmt in func.block.stmts.iter_mut() {
38        if let syn::Stmt::Item(syn::Item::Fn(ref mut func)) = stmt {
39            if let Some((module_name, python_name, pyfn_attrs)) =
40                extract_pyfn_attrs(&mut func.attrs)?
41            {
42                let function_to_python = add_fn_to_module(func, python_name, pyfn_attrs)?;
43                let function_wrapper_ident = function_wrapper_ident(&func.sig.ident);
44                let item: syn::ItemFn = syn::parse_quote! {
45                    fn block_wrapper() {
46                        #function_to_python
47                        #module_name.add_function(#function_wrapper_ident(#module_name)?)?;
48                    }
49                };
50                stmts.extend(item.block.stmts.into_iter());
51            }
52        };
53        stmts.push(stmt.clone());
54    }
55
56    func.block.stmts = stmts;
57    Ok(())
58}
59
60/// Transforms a rust fn arg parsed with syn into a method::FnArg
61fn wrap_fn_argument<'a>(cap: &'a syn::PatType) -> syn::Result<method::FnArg<'a>> {
62    let (mutability, by_ref, ident) = match *cap.pat {
63        syn::Pat::Ident(ref patid) => (&patid.mutability, &patid.by_ref, &patid.ident),
64        _ => return Err(syn::Error::new_spanned(&cap.pat, "Unsupported argument")),
65    };
66
67    Ok(method::FnArg {
68        name: ident,
69        mutability,
70        by_ref,
71        ty: &cap.ty,
72        optional: utils::option_type_argument(&cap.ty),
73        py: utils::is_python(&cap.ty),
74    })
75}
76
77/// Extracts the data from the #[pyfn(...)] attribute of a function
78fn extract_pyfn_attrs(
79    attrs: &mut Vec<syn::Attribute>,
80) -> syn::Result<Option<(syn::Path, Ident, PyFunctionAttr)>> {
81    let mut new_attrs = Vec::new();
82    let mut fnname = None;
83    let mut modname = None;
84    let mut fn_attrs = PyFunctionAttr::default();
85
86    for attr in attrs.iter() {
87        match attr.parse_meta() {
88            Ok(syn::Meta::List(ref list)) if list.path.is_ident("pyfn") => {
89                let meta: Vec<_> = list.nested.iter().cloned().collect();
90                if meta.len() >= 2 {
91                    // read module name
92                    match meta[0] {
93                        syn::NestedMeta::Meta(syn::Meta::Path(ref path)) => {
94                            modname = Some(path.clone())
95                        }
96                        _ => {
97                            return Err(syn::Error::new_spanned(
98                                &meta[0],
99                                "The first parameter of pyfn must be a MetaItem",
100                            ))
101                        }
102                    }
103                    // read Python function name
104                    match meta[1] {
105                        syn::NestedMeta::Lit(syn::Lit::Str(ref lits)) => {
106                            fnname = Some(syn::Ident::new(&lits.value(), lits.span()));
107                        }
108                        _ => {
109                            return Err(syn::Error::new_spanned(
110                                &meta[1],
111                                "The second parameter of pyfn must be a Literal",
112                            ))
113                        }
114                    }
115                    // Read additional arguments
116                    if list.nested.len() >= 3 {
117                        fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])?;
118                    }
119                } else {
120                    return Err(syn::Error::new_spanned(
121                        attr,
122                        format!("can not parse 'pyfn' params {:?}", attr),
123                    ));
124                }
125            }
126            _ => new_attrs.push(attr.clone()),
127        }
128    }
129
130    *attrs = new_attrs;
131    match (modname, fnname) {
132        (Some(modname), Some(fnname)) => Ok(Some((modname, fnname, fn_attrs))),
133        _ => Ok(None),
134    }
135}
136
137/// Coordinates the naming of a the add-function-to-python-module function
138fn function_wrapper_ident(name: &Ident) -> Ident {
139    // Make sure this ident matches the one of wrap_pyfunction
140    format_ident!("__pyo3_get_function_{}", name)
141}
142
143/// Generates python wrapper over a function that allows adding it to a python module as a python
144/// function
145pub fn add_fn_to_module(
146    func: &mut syn::ItemFn,
147    python_name: Ident,
148    pyfn_attrs: PyFunctionAttr,
149) -> syn::Result<TokenStream> {
150    let mut arguments = Vec::new();
151
152    for (i, input) in func.sig.inputs.iter().enumerate() {
153        match input {
154            syn::FnArg::Receiver(_) => {
155                return Err(syn::Error::new_spanned(
156                    input,
157                    "Unexpected receiver for #[pyfn]",
158                ))
159            }
160            syn::FnArg::Typed(ref cap) => {
161                if pyfn_attrs.pass_module && i == 0 {
162                    if let syn::Type::Reference(tyref) = cap.ty.as_ref() {
163                        if let syn::Type::Path(typath) = tyref.elem.as_ref() {
164                            if typath
165                                .path
166                                .segments
167                                .last()
168                                .map(|seg| seg.ident == "PyModule")
169                                .unwrap_or(false)
170                            {
171                                continue;
172                            }
173                        }
174                    }
175                    return Err(syn::Error::new_spanned(
176                        cap,
177                        "Expected &PyModule as first argument with `pass_module`.",
178                    ));
179                } else {
180                    arguments.push(wrap_fn_argument(cap)?);
181                }
182            }
183        }
184    }
185
186    let ty = method::get_return_info(&func.sig.output);
187
188    let text_signature = utils::parse_text_signature_attrs(&mut func.attrs, &python_name)?;
189    let doc = utils::get_doc(&func.attrs, text_signature, true)?;
190
191    let function_wrapper_ident = function_wrapper_ident(&func.sig.ident);
192
193    let spec = method::FnSpec {
194        tp: method::FnType::FnStatic,
195        name: &function_wrapper_ident,
196        python_name,
197        attrs: pyfn_attrs.arguments,
198        args: arguments,
199        output: ty,
200        doc,
201    };
202
203    let doc = syn::LitByteStr::new(spec.doc.value().as_bytes(), spec.doc.span());
204
205    let python_name = &spec.python_name;
206
207    let name = &func.sig.ident;
208    let wrapper_ident = format_ident!("__pyo3_raw_{}", name);
209    let wrapper = function_c_wrapper(name, &wrapper_ident, &spec, pyfn_attrs.pass_module);
210    Ok(quote! {
211        #wrapper
212        fn #function_wrapper_ident<'a>(
213            args: impl Into<pyo3::derive_utils::PyFunctionArguments<'a>>
214        ) -> pyo3::PyResult<&'a pyo3::types::PyCFunction> {
215            let name = concat!(stringify!(#python_name), "\0");
216            let name = std::ffi::CStr::from_bytes_with_nul(name.as_bytes()).unwrap();
217            let doc = std::ffi::CStr::from_bytes_with_nul(#doc).unwrap();
218            pyo3::types::PyCFunction::internal_new(
219                name,
220                doc,
221                pyo3::class::PyMethodType::PyCFunctionWithKeywords(#wrapper_ident),
222                pyo3::ffi::METH_VARARGS | pyo3::ffi::METH_KEYWORDS,
223                args.into(),
224            )
225        }
226    })
227}
228
229/// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords)
230fn function_c_wrapper(
231    name: &Ident,
232    wrapper_ident: &Ident,
233    spec: &method::FnSpec<'_>,
234    pass_module: bool,
235) -> TokenStream {
236    let names: Vec<Ident> = get_arg_names(&spec);
237    let cb;
238    let slf_module;
239    if pass_module {
240        cb = quote! {
241            #name(_slf, #(#names),*)
242        };
243        slf_module = Some(quote! {
244            let _slf = _py.from_borrowed_ptr::<pyo3::types::PyModule>(_slf);
245        });
246    } else {
247        cb = quote! {
248            #name(#(#names),*)
249        };
250        slf_module = None;
251    };
252    let body = pymethod::impl_arg_params(spec, None, cb);
253    quote! {
254        unsafe extern "C" fn #wrapper_ident(
255            _slf: *mut pyo3::ffi::PyObject,
256            _args: *mut pyo3::ffi::PyObject,
257            _kwargs: *mut pyo3::ffi::PyObject) -> *mut pyo3::ffi::PyObject
258        {
259            const _LOCATION: &'static str = concat!(stringify!(#name), "()");
260            pyo3::callback_body!(_py, {
261                #slf_module
262                let _args = _py.from_borrowed_ptr::<pyo3::types::PyTuple>(_args);
263                let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs);
264
265                #body
266            })
267        }
268    }
269}