pyo3_macros_backend/
pyfunction.rs

1use crate::attributes::KeywordAttribute;
2use crate::combine_errors::CombineErrors;
3#[cfg(feature = "experimental-inspect")]
4use crate::introspection::{function_introspection_code, introspection_id_const};
5use crate::utils::{Ctx, LitCStr};
6use crate::{
7    attributes::{
8        self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
9        FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
10    },
11    method::{self, CallingConvention, FnArg},
12    pymethod::check_generic,
13};
14use proc_macro2::{Span, TokenStream};
15use quote::{format_ident, quote, ToTokens};
16use std::cmp::PartialEq;
17use std::ffi::CString;
18use syn::parse::{Parse, ParseStream};
19use syn::punctuated::Punctuated;
20use syn::{ext::IdentExt, spanned::Spanned, LitStr, Path, Result, Token};
21
22mod signature;
23
24pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
25
26#[derive(Clone, Debug)]
27pub struct PyFunctionArgPyO3Attributes {
28    pub from_py_with: Option<FromPyWithAttribute>,
29    pub cancel_handle: Option<attributes::kw::cancel_handle>,
30}
31
32enum PyFunctionArgPyO3Attribute {
33    FromPyWith(FromPyWithAttribute),
34    CancelHandle(attributes::kw::cancel_handle),
35}
36
37impl Parse for PyFunctionArgPyO3Attribute {
38    fn parse(input: ParseStream<'_>) -> Result<Self> {
39        let lookahead = input.lookahead1();
40        if lookahead.peek(attributes::kw::cancel_handle) {
41            input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
42        } else if lookahead.peek(attributes::kw::from_py_with) {
43            input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
44        } else {
45            Err(lookahead.error())
46        }
47    }
48}
49
50impl PyFunctionArgPyO3Attributes {
51    /// Parses #[pyo3(from_python_with = "func")]
52    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
53        let mut attributes = PyFunctionArgPyO3Attributes {
54            from_py_with: None,
55            cancel_handle: None,
56        };
57        take_attributes(attrs, |attr| {
58            if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
59                for attr in pyo3_attrs {
60                    match attr {
61                        PyFunctionArgPyO3Attribute::FromPyWith(from_py_with) => {
62                            ensure_spanned!(
63                                attributes.from_py_with.is_none(),
64                                from_py_with.span() => "`from_py_with` may only be specified once per argument"
65                            );
66                            attributes.from_py_with = Some(from_py_with);
67                        }
68                        PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
69                            ensure_spanned!(
70                                attributes.cancel_handle.is_none(),
71                                cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
72                            );
73                            attributes.cancel_handle = Some(cancel_handle);
74                        }
75                    }
76                    ensure_spanned!(
77                        attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
78                        attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
79                    );
80                }
81                Ok(true)
82            } else {
83                Ok(false)
84            }
85        })?;
86        Ok(attributes)
87    }
88}
89
90type PyFunctionWarningMessageAttribute = KeywordAttribute<attributes::kw::message, LitStr>;
91type PyFunctionWarningCategoryAttribute = KeywordAttribute<attributes::kw::category, Path>;
92
93pub struct PyFunctionWarningAttribute {
94    pub message: PyFunctionWarningMessageAttribute,
95    pub category: Option<PyFunctionWarningCategoryAttribute>,
96    pub span: Span,
97}
98
99#[derive(PartialEq, Clone)]
100pub enum PyFunctionWarningCategory {
101    Path(Path),
102    UserWarning,
103    DeprecationWarning, // TODO: unused for now, intended for pyo3(deprecated) special-case
104}
105
106#[derive(Clone)]
107pub struct PyFunctionWarning {
108    pub message: LitStr,
109    pub category: PyFunctionWarningCategory,
110    pub span: Span,
111}
112
113impl From<PyFunctionWarningAttribute> for PyFunctionWarning {
114    fn from(value: PyFunctionWarningAttribute) -> Self {
115        Self {
116            message: value.message.value,
117            category: value
118                .category
119                .map_or(PyFunctionWarningCategory::UserWarning, |cat| {
120                    PyFunctionWarningCategory::Path(cat.value)
121                }),
122            span: value.span,
123        }
124    }
125}
126
127pub trait WarningFactory {
128    fn build_py_warning(&self, ctx: &Ctx) -> TokenStream;
129    fn span(&self) -> Span;
130}
131
132impl WarningFactory for PyFunctionWarning {
133    fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
134        let message = &self.message.value();
135        let c_message = LitCStr::new(
136            CString::new(message.clone()).unwrap(),
137            Spanned::span(&message),
138            ctx,
139        );
140        let pyo3_path = &ctx.pyo3_path;
141        let category = match &self.category {
142            PyFunctionWarningCategory::Path(path) => quote! {#path},
143            PyFunctionWarningCategory::UserWarning => {
144                quote! {#pyo3_path::exceptions::PyUserWarning}
145            }
146            PyFunctionWarningCategory::DeprecationWarning => {
147                quote! {#pyo3_path::exceptions::PyDeprecationWarning}
148            }
149        };
150        quote! {
151            #pyo3_path::PyErr::warn(py, &<#category as #pyo3_path::PyTypeInfo>::type_object(py), #c_message, 1)?;
152        }
153    }
154
155    fn span(&self) -> Span {
156        self.span
157    }
158}
159
160impl<T: WarningFactory> WarningFactory for Vec<T> {
161    fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
162        let warnings = self.iter().map(|warning| warning.build_py_warning(ctx));
163
164        quote! {
165            #(#warnings)*
166        }
167    }
168
169    fn span(&self) -> Span {
170        self.iter()
171            .map(|val| val.span())
172            .reduce(|acc, span| acc.join(span).unwrap_or(acc))
173            .unwrap()
174    }
175}
176
177impl Parse for PyFunctionWarningAttribute {
178    fn parse(input: ParseStream<'_>) -> Result<Self> {
179        let mut message: Option<PyFunctionWarningMessageAttribute> = None;
180        let mut category: Option<PyFunctionWarningCategoryAttribute> = None;
181
182        let span = input.parse::<attributes::kw::warn>()?.span();
183
184        let content;
185        syn::parenthesized!(content in input);
186
187        while !content.is_empty() {
188            let lookahead = content.lookahead1();
189
190            if lookahead.peek(attributes::kw::message) {
191                message = content
192                    .parse::<PyFunctionWarningMessageAttribute>()
193                    .map(Some)?;
194            } else if lookahead.peek(attributes::kw::category) {
195                category = content
196                    .parse::<PyFunctionWarningCategoryAttribute>()
197                    .map(Some)?;
198            } else {
199                return Err(lookahead.error());
200            }
201
202            if content.peek(Token![,]) {
203                content.parse::<Token![,]>()?;
204            }
205        }
206
207        Ok(PyFunctionWarningAttribute {
208            message: message.ok_or(syn::Error::new(
209                content.span(),
210                "missing `message` in `warn` attribute",
211            ))?,
212            category,
213            span,
214        })
215    }
216}
217
218impl ToTokens for PyFunctionWarningAttribute {
219    fn to_tokens(&self, tokens: &mut TokenStream) {
220        let message_tokens = self.message.to_token_stream();
221        let category_tokens = self
222            .category
223            .as_ref()
224            .map_or(quote! {}, |cat| cat.to_token_stream());
225
226        let token_stream = quote! {
227            warn(#message_tokens, #category_tokens)
228        };
229
230        tokens.extend(token_stream);
231    }
232}
233
234#[derive(Default)]
235pub struct PyFunctionOptions {
236    pub pass_module: Option<attributes::kw::pass_module>,
237    pub name: Option<NameAttribute>,
238    pub signature: Option<SignatureAttribute>,
239    pub text_signature: Option<TextSignatureAttribute>,
240    pub krate: Option<CrateAttribute>,
241    pub warnings: Vec<PyFunctionWarning>,
242}
243
244impl Parse for PyFunctionOptions {
245    fn parse(input: ParseStream<'_>) -> Result<Self> {
246        let mut options = PyFunctionOptions::default();
247
248        let attrs = Punctuated::<PyFunctionOption, syn::Token![,]>::parse_terminated(input)?;
249        options.add_attributes(attrs)?;
250
251        Ok(options)
252    }
253}
254
255pub enum PyFunctionOption {
256    Name(NameAttribute),
257    PassModule(attributes::kw::pass_module),
258    Signature(SignatureAttribute),
259    TextSignature(TextSignatureAttribute),
260    Crate(CrateAttribute),
261    Warning(PyFunctionWarningAttribute),
262}
263
264impl Parse for PyFunctionOption {
265    fn parse(input: ParseStream<'_>) -> Result<Self> {
266        let lookahead = input.lookahead1();
267        if lookahead.peek(attributes::kw::name) {
268            input.parse().map(PyFunctionOption::Name)
269        } else if lookahead.peek(attributes::kw::pass_module) {
270            input.parse().map(PyFunctionOption::PassModule)
271        } else if lookahead.peek(attributes::kw::signature) {
272            input.parse().map(PyFunctionOption::Signature)
273        } else if lookahead.peek(attributes::kw::text_signature) {
274            input.parse().map(PyFunctionOption::TextSignature)
275        } else if lookahead.peek(syn::Token![crate]) {
276            input.parse().map(PyFunctionOption::Crate)
277        } else if lookahead.peek(attributes::kw::warn) {
278            input.parse().map(PyFunctionOption::Warning)
279        } else {
280            Err(lookahead.error())
281        }
282    }
283}
284
285impl PyFunctionOptions {
286    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
287        let mut options = PyFunctionOptions::default();
288        options.add_attributes(take_pyo3_options(attrs)?)?;
289        Ok(options)
290    }
291
292    pub fn add_attributes(
293        &mut self,
294        attrs: impl IntoIterator<Item = PyFunctionOption>,
295    ) -> Result<()> {
296        macro_rules! set_option {
297            ($key:ident) => {
298                {
299                    ensure_spanned!(
300                        self.$key.is_none(),
301                        $key.span() => concat!("`", stringify!($key), "` may only be specified once")
302                    );
303                    self.$key = Some($key);
304                }
305            };
306        }
307        for attr in attrs {
308            match attr {
309                PyFunctionOption::Name(name) => set_option!(name),
310                PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
311                PyFunctionOption::Signature(signature) => set_option!(signature),
312                PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature),
313                PyFunctionOption::Crate(krate) => set_option!(krate),
314                PyFunctionOption::Warning(warning) => {
315                    self.warnings.push(warning.into());
316                }
317            }
318        }
319        Ok(())
320    }
321}
322
323pub fn build_py_function(
324    ast: &mut syn::ItemFn,
325    mut options: PyFunctionOptions,
326) -> syn::Result<TokenStream> {
327    options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
328    impl_wrap_pyfunction(ast, options)
329}
330
331/// Generates python wrapper over a function that allows adding it to a python module as a python
332/// function
333pub fn impl_wrap_pyfunction(
334    func: &mut syn::ItemFn,
335    options: PyFunctionOptions,
336) -> syn::Result<TokenStream> {
337    check_generic(&func.sig)?;
338    let PyFunctionOptions {
339        pass_module,
340        name,
341        signature,
342        text_signature,
343        krate,
344        warnings,
345    } = options;
346
347    let ctx = &Ctx::new(&krate, Some(&func.sig));
348    let Ctx { pyo3_path, .. } = &ctx;
349
350    let python_name = name
351        .as_ref()
352        .map_or_else(|| &func.sig.ident, |name| &name.value.0)
353        .unraw();
354
355    let tp = if pass_module.is_some() {
356        let span = match func.sig.inputs.first() {
357            Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
358            Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
359                func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
360            ),
361        };
362        method::FnType::FnModule(span)
363    } else {
364        method::FnType::FnStatic
365    };
366
367    let arguments = func
368        .sig
369        .inputs
370        .iter_mut()
371        .skip(if tp.skip_first_rust_argument_in_python_signature() {
372            1
373        } else {
374            0
375        })
376        .map(FnArg::parse)
377        .try_combine_syn_errors()?;
378
379    let signature = if let Some(signature) = signature {
380        FunctionSignature::from_arguments_and_attribute(arguments, signature)?
381    } else {
382        FunctionSignature::from_arguments(arguments)
383    };
384
385    let vis = &func.vis;
386    let name = &func.sig.ident;
387
388    #[cfg(feature = "experimental-inspect")]
389    let introspection = function_introspection_code(
390        pyo3_path,
391        Some(name),
392        &name.to_string(),
393        &signature,
394        None,
395        func.sig.output.clone(),
396        [] as [String; 0],
397        None,
398    );
399    #[cfg(not(feature = "experimental-inspect"))]
400    let introspection = quote! {};
401    #[cfg(feature = "experimental-inspect")]
402    let introspection_id = introspection_id_const();
403    #[cfg(not(feature = "experimental-inspect"))]
404    let introspection_id = quote! {};
405
406    let spec = method::FnSpec {
407        tp,
408        name: &func.sig.ident,
409        convention: CallingConvention::from_signature(&signature),
410        python_name,
411        signature,
412        text_signature,
413        asyncness: func.sig.asyncness,
414        unsafety: func.sig.unsafety,
415        warnings,
416        #[cfg(feature = "experimental-inspect")]
417        output: func.sig.output.clone(),
418    };
419
420    let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
421    if spec.asyncness.is_some() {
422        ensure_spanned!(
423            cfg!(feature = "experimental-async"),
424            spec.asyncness.span() => "async functions are only supported with the `experimental-async` feature"
425        );
426    }
427    let wrapper = spec.get_wrapper_function(&wrapper_ident, None, ctx)?;
428    let methoddef = spec.get_methoddef(wrapper_ident, &spec.get_doc(&func.attrs, ctx)?, ctx);
429
430    let wrapped_pyfunction = quote! {
431        // Create a module with the same name as the `#[pyfunction]` - this way `use <the function>`
432        // will actually bring both the module and the function into scope.
433        #[doc(hidden)]
434        #vis mod #name {
435            pub(crate) struct MakeDef;
436            pub const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = MakeDef::_PYO3_DEF;
437            #introspection_id
438        }
439
440        // Generate the definition inside an anonymous function in the same scope as the original function -
441        // this avoids complications around the fact that the generated module has a different scope
442        // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is
443        // inside a function body)
444        #[allow(unknown_lints, non_local_definitions)]
445        impl #name::MakeDef {
446            const _PYO3_DEF: #pyo3_path::impl_::pymethods::PyMethodDef = #methoddef;
447        }
448
449        #[allow(non_snake_case)]
450        #wrapper
451
452        #introspection
453    };
454    Ok(wrapped_pyfunction)
455}