Skip to main content

pyforge_macros_backend/
pyfunction.rs

1use crate::attributes::KeywordAttribute;
2use crate::combine_errors::CombineErrors;
3#[cfg(feature = "experimental-inspect")]
4use crate::introspection::{function_introspection_code, introspection_id_const};
5#[cfg(feature = "experimental-inspect")]
6use crate::utils::get_doc;
7use crate::utils::Ctx;
8use crate::{
9    attributes::{
10        self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
11        FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
12    },
13    method::{self, CallingConvention, FnArg},
14    pymethod::check_generic,
15};
16use proc_macro2::{Span, TokenStream};
17use quote::{format_ident, quote, ToTokens};
18use std::cmp::PartialEq;
19use std::ffi::CString;
20#[cfg(feature = "experimental-inspect")]
21use std::iter::empty;
22use syn::parse::{Parse, ParseStream};
23use syn::punctuated::Punctuated;
24use syn::LitCStr;
25use syn::{ext::IdentExt, spanned::Spanned, LitStr, Path, Result, Token};
26
27mod signature;
28
29pub use self::signature::{ConstructorAttribute, FunctionSignature, SignatureAttribute};
30
31#[derive(Clone, Debug)]
32pub struct PyFunctionArgPyForgeAttributes {
33    pub from_py_with: Option<FromPyWithAttribute>,
34    pub cancel_handle: Option<attributes::kw::cancel_handle>,
35}
36
37enum PyFunctionArgPyForgeAttribute {
38    FromPyWith(FromPyWithAttribute),
39    CancelHandle(attributes::kw::cancel_handle),
40}
41
42impl Parse for PyFunctionArgPyForgeAttribute {
43    fn parse(input: ParseStream<'_>) -> Result<Self> {
44        let lookahead = input.lookahead1();
45        if lookahead.peek(attributes::kw::cancel_handle) {
46            input.parse().map(PyFunctionArgPyForgeAttribute::CancelHandle)
47        } else if lookahead.peek(attributes::kw::from_py_with) {
48            input.parse().map(PyFunctionArgPyForgeAttribute::FromPyWith)
49        } else {
50            Err(lookahead.error())
51        }
52    }
53}
54
55impl PyFunctionArgPyForgeAttributes {
56    /// Parses #[pyo3(from_python_with = "func")]
57    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
58        let mut attributes = PyFunctionArgPyForgeAttributes {
59            from_py_with: None,
60            cancel_handle: None,
61        };
62        take_attributes(attrs, |attr| {
63            if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
64                for attr in pyo3_attrs {
65                    match attr {
66                        PyFunctionArgPyForgeAttribute::FromPyWith(from_py_with) => {
67                            ensure_spanned!(
68                                attributes.from_py_with.is_none(),
69                                from_py_with.span() => "`from_py_with` may only be specified once per argument"
70                            );
71                            attributes.from_py_with = Some(from_py_with);
72                        }
73                        PyFunctionArgPyForgeAttribute::CancelHandle(cancel_handle) => {
74                            ensure_spanned!(
75                                attributes.cancel_handle.is_none(),
76                                cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
77                            );
78                            attributes.cancel_handle = Some(cancel_handle);
79                        }
80                    }
81                    ensure_spanned!(
82                        attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
83                        attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
84                    );
85                }
86                Ok(true)
87            } else {
88                Ok(false)
89            }
90        })?;
91        Ok(attributes)
92    }
93}
94
95type PyFunctionWarningMessageAttribute = KeywordAttribute<attributes::kw::message, LitStr>;
96type PyFunctionWarningCategoryAttribute = KeywordAttribute<attributes::kw::category, Path>;
97
98pub struct PyFunctionWarningAttribute {
99    pub message: PyFunctionWarningMessageAttribute,
100    pub category: Option<PyFunctionWarningCategoryAttribute>,
101    pub span: Span,
102}
103
104#[derive(PartialEq, Clone)]
105pub enum PyFunctionWarningCategory {
106    Path(Path),
107    UserWarning,
108    DeprecationWarning, // TODO: unused for now, intended for pyo3(deprecated) special-case
109}
110
111#[derive(Clone)]
112pub struct PyFunctionWarning {
113    pub message: LitStr,
114    pub category: PyFunctionWarningCategory,
115    pub span: Span,
116}
117
118impl From<PyFunctionWarningAttribute> for PyFunctionWarning {
119    fn from(value: PyFunctionWarningAttribute) -> Self {
120        Self {
121            message: value.message.value,
122            category: value
123                .category
124                .map_or(PyFunctionWarningCategory::UserWarning, |cat| {
125                    PyFunctionWarningCategory::Path(cat.value)
126                }),
127            span: value.span,
128        }
129    }
130}
131
132pub trait WarningFactory {
133    fn build_py_warning(&self, ctx: &Ctx) -> TokenStream;
134    fn span(&self) -> Span;
135}
136
137impl WarningFactory for PyFunctionWarning {
138    fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
139        let message = &self.message.value();
140        let c_message = LitCStr::new(
141            &CString::new(message.clone()).unwrap(),
142            Spanned::span(&message),
143        );
144        let pyo3_path = &ctx.pyo3_path;
145        let category = match &self.category {
146            PyFunctionWarningCategory::Path(path) => quote! {#path},
147            PyFunctionWarningCategory::UserWarning => {
148                quote! {#pyo3_path::exceptions::PyUserWarning}
149            }
150            PyFunctionWarningCategory::DeprecationWarning => {
151                quote! {#pyo3_path::exceptions::PyDeprecationWarning}
152            }
153        };
154        quote! {
155            #pyo3_path::PyErr::warn(py, &<#category as #pyo3_path::PyTypeInfo>::type_object(py), #c_message, 1)?;
156        }
157    }
158
159    fn span(&self) -> Span {
160        self.span
161    }
162}
163
164impl<T: WarningFactory> WarningFactory for Vec<T> {
165    fn build_py_warning(&self, ctx: &Ctx) -> TokenStream {
166        let warnings = self.iter().map(|warning| warning.build_py_warning(ctx));
167
168        quote! {
169            #(#warnings)*
170        }
171    }
172
173    fn span(&self) -> Span {
174        self.iter()
175            .map(|val| val.span())
176            .reduce(|acc, span| acc.join(span).unwrap_or(acc))
177            .unwrap()
178    }
179}
180
181impl Parse for PyFunctionWarningAttribute {
182    fn parse(input: ParseStream<'_>) -> Result<Self> {
183        let mut message: Option<PyFunctionWarningMessageAttribute> = None;
184        let mut category: Option<PyFunctionWarningCategoryAttribute> = None;
185
186        let span = input.parse::<attributes::kw::warn>()?.span();
187
188        let content;
189        syn::parenthesized!(content in input);
190
191        while !content.is_empty() {
192            let lookahead = content.lookahead1();
193
194            if lookahead.peek(attributes::kw::message) {
195                message = content
196                    .parse::<PyFunctionWarningMessageAttribute>()
197                    .map(Some)?;
198            } else if lookahead.peek(attributes::kw::category) {
199                category = content
200                    .parse::<PyFunctionWarningCategoryAttribute>()
201                    .map(Some)?;
202            } else {
203                return Err(lookahead.error());
204            }
205
206            if content.peek(Token![,]) {
207                content.parse::<Token![,]>()?;
208            }
209        }
210
211        Ok(PyFunctionWarningAttribute {
212            message: message.ok_or(syn::Error::new(
213                content.span(),
214                "missing `message` in `warn` attribute",
215            ))?,
216            category,
217            span,
218        })
219    }
220}
221
222impl ToTokens for PyFunctionWarningAttribute {
223    fn to_tokens(&self, tokens: &mut TokenStream) {
224        let message_tokens = self.message.to_token_stream();
225        let category_tokens = self
226            .category
227            .as_ref()
228            .map_or(quote! {}, |cat| cat.to_token_stream());
229
230        let token_stream = quote! {
231            warn(#message_tokens, #category_tokens)
232        };
233
234        tokens.extend(token_stream);
235    }
236}
237
238#[derive(Default)]
239pub struct PyFunctionOptions {
240    pub pass_module: Option<attributes::kw::pass_module>,
241    pub name: Option<NameAttribute>,
242    pub signature: Option<SignatureAttribute>,
243    pub text_signature: Option<TextSignatureAttribute>,
244    pub krate: Option<CrateAttribute>,
245    pub warnings: Vec<PyFunctionWarning>,
246}
247
248impl Parse for PyFunctionOptions {
249    fn parse(input: ParseStream<'_>) -> Result<Self> {
250        let mut options = PyFunctionOptions::default();
251
252        let attrs = Punctuated::<PyFunctionOption, syn::Token![,]>::parse_terminated(input)?;
253        options.add_attributes(attrs)?;
254
255        Ok(options)
256    }
257}
258
259pub enum PyFunctionOption {
260    Name(NameAttribute),
261    PassModule(attributes::kw::pass_module),
262    Signature(SignatureAttribute),
263    TextSignature(TextSignatureAttribute),
264    Crate(CrateAttribute),
265    Warning(PyFunctionWarningAttribute),
266}
267
268impl Parse for PyFunctionOption {
269    fn parse(input: ParseStream<'_>) -> Result<Self> {
270        let lookahead = input.lookahead1();
271        if lookahead.peek(attributes::kw::name) {
272            input.parse().map(PyFunctionOption::Name)
273        } else if lookahead.peek(attributes::kw::pass_module) {
274            input.parse().map(PyFunctionOption::PassModule)
275        } else if lookahead.peek(attributes::kw::signature) {
276            input.parse().map(PyFunctionOption::Signature)
277        } else if lookahead.peek(attributes::kw::text_signature) {
278            input.parse().map(PyFunctionOption::TextSignature)
279        } else if lookahead.peek(syn::Token![crate]) {
280            input.parse().map(PyFunctionOption::Crate)
281        } else if lookahead.peek(attributes::kw::warn) {
282            input.parse().map(PyFunctionOption::Warning)
283        } else {
284            Err(lookahead.error())
285        }
286    }
287}
288
289impl PyFunctionOptions {
290    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
291        let mut options = PyFunctionOptions::default();
292        options.add_attributes(take_pyo3_options(attrs)?)?;
293        Ok(options)
294    }
295
296    pub fn add_attributes(
297        &mut self,
298        attrs: impl IntoIterator<Item = PyFunctionOption>,
299    ) -> Result<()> {
300        macro_rules! set_option {
301            ($key:ident) => {
302                {
303                    ensure_spanned!(
304                        self.$key.is_none(),
305                        $key.span() => concat!("`", stringify!($key), "` may only be specified once")
306                    );
307                    self.$key = Some($key);
308                }
309            };
310        }
311        for attr in attrs {
312            match attr {
313                PyFunctionOption::Name(name) => set_option!(name),
314                PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
315                PyFunctionOption::Signature(signature) => set_option!(signature),
316                PyFunctionOption::TextSignature(text_signature) => set_option!(text_signature),
317                PyFunctionOption::Crate(krate) => set_option!(krate),
318                PyFunctionOption::Warning(warning) => {
319                    self.warnings.push(warning.into());
320                }
321            }
322        }
323        Ok(())
324    }
325}
326
327pub fn build_py_function(
328    ast: &mut syn::ItemFn,
329    mut options: PyFunctionOptions,
330) -> syn::Result<TokenStream> {
331    options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
332    impl_wrap_pyfunction(ast, options)
333}
334
335/// Generates python wrapper over a function that allows adding it to a python module as a python
336/// function
337pub fn impl_wrap_pyfunction(
338    func: &mut syn::ItemFn,
339    options: PyFunctionOptions,
340) -> syn::Result<TokenStream> {
341    check_generic(&func.sig)?;
342    let PyFunctionOptions {
343        pass_module,
344        name,
345        signature,
346        text_signature,
347        krate,
348        warnings,
349    } = options;
350
351    let ctx = &Ctx::new(&krate, Some(&func.sig));
352    let Ctx { pyo3_path, .. } = &ctx;
353
354    let python_name = name
355        .as_ref()
356        .map_or_else(|| &func.sig.ident, |name| &name.value.0)
357        .unraw();
358
359    let tp = if pass_module.is_some() {
360        let span = match func.sig.inputs.first() {
361            Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
362            Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
363                func.sig.paren_token.span.join() => "expected `&PyModule` or `Py<PyModule>` as first argument with `pass_module`"
364            ),
365        };
366        method::FnType::FnModule(span)
367    } else {
368        method::FnType::FnStatic
369    };
370
371    let arguments = func
372        .sig
373        .inputs
374        .iter_mut()
375        .skip(if tp.skip_first_rust_argument_in_python_signature() {
376            1
377        } else {
378            0
379        })
380        .map(FnArg::parse)
381        .try_combine_syn_errors()?;
382
383    let signature = if let Some(signature) = signature {
384        FunctionSignature::from_arguments_and_attribute(arguments, signature)?
385    } else {
386        FunctionSignature::from_arguments(arguments)
387    };
388
389    let spec = method::FnSpec {
390        tp,
391        name: &func.sig.ident,
392        python_name,
393        signature,
394        text_signature,
395        asyncness: func.sig.asyncness,
396        unsafety: func.sig.unsafety,
397        warnings,
398        output: func.sig.output.clone(),
399    };
400
401    let vis = &func.vis;
402    let name = &func.sig.ident;
403
404    #[cfg(feature = "experimental-inspect")]
405    let introspection = function_introspection_code(
406        pyo3_path,
407        Some(name),
408        &name.to_string(),
409        &spec.signature,
410        None,
411        func.sig.output.clone(),
412        empty(),
413        func.sig.asyncness.is_some(),
414        false,
415        get_doc(&func.attrs, None).as_ref(),
416        None,
417    );
418    #[cfg(not(feature = "experimental-inspect"))]
419    let introspection = quote! {};
420    #[cfg(feature = "experimental-inspect")]
421    let introspection_id = introspection_id_const();
422    #[cfg(not(feature = "experimental-inspect"))]
423    let introspection_id = quote! {};
424
425    let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
426    // PyForge: async support is always enabled (no feature gate required)
427    let calling_convention = CallingConvention::from_signature(&spec.signature);
428    let wrapper = spec.get_wrapper_function(&wrapper_ident, None, calling_convention, ctx)?;
429    let methoddef = spec.get_methoddef(
430        wrapper_ident,
431        spec.get_doc(&func.attrs).as_ref(),
432        calling_convention,
433        ctx,
434    )?;
435
436    let wrapped_pyfunction = quote! {
437        // Create a module with the same name as the `#[pyfunction]` - this way `use <the function>`
438        // will actually bring both the module and the function into scope.
439        #[doc(hidden)]
440        #vis mod #name {
441            pub(crate) struct MakeDef;
442            pub static _PYO3_DEF: #pyo3_path::impl_::pyfunction::PyFunctionDef = MakeDef::_PYO3_DEF;
443            #introspection_id
444        }
445
446        // Generate the definition in the same scope as the original function -
447        // this avoids complications around the fact that the generated module has a different scope
448        // (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is
449        // inside a function body)
450        #[allow(unknown_lints, non_local_definitions)]
451        impl #name::MakeDef {
452            // We're using this to initialize a static, so it's fine.
453            #[allow(clippy::declare_interior_mutable_const)]
454            const _PYO3_DEF: #pyo3_path::impl_::pyfunction::PyFunctionDef =
455                #pyo3_path::impl_::pyfunction::PyFunctionDef::from_method_def(#methoddef);
456        }
457
458        #[allow(non_snake_case)]
459        #wrapper
460
461        #introspection
462    };
463    Ok(wrapped_pyfunction)
464}