Skip to main content

pyforge_macros_backend/
pyimpl.rs

1use std::collections::HashSet;
2
3use crate::combine_errors::CombineErrors;
4#[cfg(feature = "experimental-inspect")]
5use crate::get_doc;
6#[cfg(feature = "experimental-inspect")]
7use crate::introspection::{attribute_introspection_code, function_introspection_code};
8#[cfg(feature = "experimental-inspect")]
9use crate::method::{FnSpec, FnType};
10#[cfg(feature = "experimental-inspect")]
11use crate::py_expr::PyExpr;
12use crate::utils::{has_attribute, has_attribute_with_namespace, Ctx, PyForgeCratePath};
13use crate::{
14    attributes::{take_pyo3_options, CrateAttribute},
15    konst::{ConstAttributes, ConstSpec},
16    pyfunction::PyFunctionOptions,
17    pymethod::{
18        self, is_proto_method, GeneratedPyMethod, MethodAndMethodDef, MethodAndSlotDef, PyMethod,
19    },
20};
21use proc_macro2::TokenStream;
22use quote::{format_ident, quote};
23use syn::{
24    parse::{Parse, ParseStream},
25    spanned::Spanned,
26    ImplItemFn, Result,
27};
28#[cfg(feature = "experimental-inspect")]
29use syn::{parse_quote, Ident, ReturnType};
30
31/// The mechanism used to collect `#[pymethods]` into the type object
32#[derive(Copy, Clone)]
33pub enum PyClassMethodsType {
34    Specialization,
35    Inventory,
36}
37
38enum PyImplPyForgeOption {
39    Crate(CrateAttribute),
40}
41
42impl Parse for PyImplPyForgeOption {
43    fn parse(input: ParseStream<'_>) -> Result<Self> {
44        let lookahead = input.lookahead1();
45        if lookahead.peek(syn::Token![crate]) {
46            input.parse().map(PyImplPyForgeOption::Crate)
47        } else {
48            Err(lookahead.error())
49        }
50    }
51}
52
53#[derive(Default)]
54pub struct PyImplOptions {
55    krate: Option<CrateAttribute>,
56}
57
58impl PyImplOptions {
59    pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
60        let mut options: PyImplOptions = Default::default();
61
62        for option in take_pyo3_options(attrs)? {
63            match option {
64                PyImplPyForgeOption::Crate(path) => options.set_crate(path)?,
65            }
66        }
67
68        Ok(options)
69    }
70
71    fn set_crate(&mut self, path: CrateAttribute) -> Result<()> {
72        ensure_spanned!(
73            self.krate.is_none(),
74            path.span() => "`crate` may only be specified once"
75        );
76
77        self.krate = Some(path);
78        Ok(())
79    }
80}
81
82pub fn build_py_methods(
83    ast: &mut syn::ItemImpl,
84    methods_type: PyClassMethodsType,
85) -> syn::Result<TokenStream> {
86    if let Some((_, path, _)) = &ast.trait_ {
87        bail_spanned!(path.span() => "#[pymethods] cannot be used on trait impl blocks");
88    } else if ast.generics != Default::default() {
89        bail_spanned!(
90            ast.generics.span() =>
91            "#[pymethods] cannot be used with lifetime parameters or generics"
92        );
93    } else {
94        let options = PyImplOptions::from_attrs(&mut ast.attrs)?;
95        impl_methods(&ast.self_ty, &mut ast.items, methods_type, options)
96    }
97}
98
99fn check_pyfunction(pyo3_path: &PyForgeCratePath, meth: &mut ImplItemFn) -> syn::Result<()> {
100    let mut error = None;
101
102    meth.attrs.retain(|attr| {
103        let attrs = [attr.clone()];
104
105        if has_attribute(&attrs, "pyfunction")
106            || has_attribute_with_namespace(&attrs, Some(pyo3_path),  &["pyfunction"])
107            || has_attribute_with_namespace(&attrs, Some(pyo3_path),  &["prelude", "pyfunction"]) {
108                error = Some(err_spanned!(meth.sig.span() => "functions inside #[pymethods] do not need to be annotated with #[pyfunction]"));
109                false
110        } else {
111            true
112        }
113    });
114
115    error.map_or(Ok(()), Err)
116}
117
118pub fn impl_methods(
119    ty: &syn::Type,
120    impls: &mut [syn::ImplItem],
121    methods_type: PyClassMethodsType,
122    options: PyImplOptions,
123) -> syn::Result<TokenStream> {
124    let mut extra_fragments = Vec::new();
125    let mut proto_impls = Vec::new();
126    let mut methods = Vec::new();
127    let mut associated_methods = Vec::new();
128
129    let mut implemented_proto_fragments = HashSet::new();
130
131    let _: Vec<()> = impls
132        .iter_mut()
133        .map(|iimpl| {
134            match iimpl {
135                syn::ImplItem::Fn(meth) => {
136                    let ctx = &Ctx::new(&options.krate, Some(&meth.sig));
137                    let mut fun_options = PyFunctionOptions::from_attrs(&mut meth.attrs)?;
138                    fun_options.krate = fun_options.krate.or_else(|| options.krate.clone());
139
140                    check_pyfunction(&ctx.pyo3_path, meth)?;
141                    let method = PyMethod::parse(&mut meth.sig, &mut meth.attrs, fun_options)?;
142                    #[cfg(feature = "experimental-inspect")]
143                    extra_fragments.push(method_introspection_code(
144                        &method.spec,
145                        &meth.attrs,
146                        ty,
147                        method.is_returning_not_implemented_on_extraction_error(),
148                        ctx,
149                    ));
150                    match pymethod::gen_py_method(ty, method, &meth.attrs, ctx)? {
151                        GeneratedPyMethod::Method(MethodAndMethodDef {
152                            associated_method,
153                            method_def,
154                        }) => {
155                            let attrs = get_cfg_attributes(&meth.attrs);
156                            associated_methods.push(quote!(#(#attrs)* #associated_method));
157                            methods.push(quote!(#(#attrs)* #method_def));
158                        }
159                        GeneratedPyMethod::SlotTraitImpl(method_name, token_stream) => {
160                            implemented_proto_fragments.insert(method_name);
161                            let attrs = get_cfg_attributes(&meth.attrs);
162                            extra_fragments.push(quote!(#(#attrs)* #token_stream));
163                        }
164                        GeneratedPyMethod::Proto(MethodAndSlotDef {
165                            associated_method,
166                            slot_def,
167                        }) => {
168                            let attrs = get_cfg_attributes(&meth.attrs);
169                            proto_impls.push(quote!(#(#attrs)* #slot_def));
170                            associated_methods.push(quote!(#(#attrs)* #associated_method));
171                        }
172                    }
173                }
174                syn::ImplItem::Const(konst) => {
175                    let ctx = &Ctx::new(&options.krate, None);
176                    #[cfg(feature = "experimental-inspect")]
177                    let doc = get_doc(&konst.attrs, None);
178                    let attributes = ConstAttributes::from_attrs(&mut konst.attrs)?;
179                    if attributes.is_class_attr {
180                        let spec = ConstSpec {
181                            rust_ident: konst.ident.clone(),
182                            attributes,
183                            #[cfg(feature = "experimental-inspect")]
184                            expr: Some(konst.expr.clone()),
185                            #[cfg(feature = "experimental-inspect")]
186                            ty: konst.ty.clone(),
187                            #[cfg(feature = "experimental-inspect")]
188                            doc,
189                        };
190                        let attrs = get_cfg_attributes(&konst.attrs);
191                        let MethodAndMethodDef {
192                            associated_method,
193                            method_def,
194                        } = gen_py_const(ty, &spec, ctx);
195                        methods.push(quote!(#(#attrs)* #method_def));
196                        associated_methods.push(quote!(#(#attrs)* #associated_method));
197                        if is_proto_method(&spec.python_name().to_string()) {
198                            // If this is a known protocol method e.g. __contains__, then allow this
199                            // symbol even though it's not an uppercase constant.
200                            konst
201                                .attrs
202                                .push(syn::parse_quote!(#[allow(non_upper_case_globals)]));
203                        }
204                    }
205                }
206                syn::ImplItem::Macro(m) => bail_spanned!(
207                    m.span() =>
208                    "macros cannot be used as items in `#[pymethods]` impl blocks\n\
209                    = note: this was previously accepted and ignored"
210                ),
211                _ => {}
212            }
213            Ok(())
214        })
215        .try_combine_syn_errors()?;
216
217    let ctx = &Ctx::new(&options.krate, None);
218
219    add_shared_proto_slots(ty, &mut proto_impls, implemented_proto_fragments, ctx);
220
221    let items = match methods_type {
222        PyClassMethodsType::Specialization => impl_py_methods(ty, methods, proto_impls, ctx),
223        PyClassMethodsType::Inventory => submit_methods_inventory(ty, methods, proto_impls, ctx),
224    };
225
226    Ok(quote! {
227        #(#extra_fragments)*
228
229        #items
230
231        #[doc(hidden)]
232        #[allow(non_snake_case)]
233        impl #ty {
234            #(#associated_methods)*
235        }
236    })
237}
238
239pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec, ctx: &Ctx) -> MethodAndMethodDef {
240    let member = &spec.rust_ident;
241    let wrapper_ident = format_ident!("__pymethod_{}__", member);
242    let python_name = spec.null_terminated_python_name();
243    let Ctx { pyo3_path, .. } = ctx;
244
245    let associated_method = quote! {
246        fn #wrapper_ident(py: #pyo3_path::Python<'_>) -> #pyo3_path::PyResult<#pyo3_path::Py<#pyo3_path::PyAny>> {
247            #pyo3_path::IntoPyObjectExt::into_py_any(#cls::#member, py)
248        }
249    };
250
251    let method_def = quote! {
252        #pyo3_path::impl_::pymethods::PyMethodDefType::ClassAttribute({
253            #pyo3_path::impl_::pymethods::PyClassAttributeDef::new(
254                #python_name,
255                #cls::#wrapper_ident
256            )
257        })
258    };
259
260    #[cfg_attr(not(feature = "experimental-inspect"), allow(unused_mut))]
261    let mut def = MethodAndMethodDef {
262        associated_method,
263        method_def,
264    };
265
266    #[cfg(feature = "experimental-inspect")]
267    def.add_introspection(attribute_introspection_code(
268        &ctx.pyo3_path,
269        Some(cls),
270        spec.python_name().to_string(),
271        spec.expr
272            .as_ref()
273            .map_or_else(PyExpr::ellipsis, PyExpr::constant_from_expression),
274        spec.ty.clone(),
275        spec.doc.as_ref(),
276        true,
277    ));
278
279    def
280}
281
282fn impl_py_methods(
283    ty: &syn::Type,
284    methods: Vec<TokenStream>,
285    proto_impls: Vec<TokenStream>,
286    ctx: &Ctx,
287) -> TokenStream {
288    let Ctx { pyo3_path, .. } = ctx;
289    quote! {
290        #[allow(unknown_lints, non_local_definitions)]
291        impl #pyo3_path::impl_::pyclass::PyMethods<#ty>
292            for #pyo3_path::impl_::pyclass::PyClassImplCollector<#ty>
293        {
294            fn py_methods(self) -> &'static #pyo3_path::impl_::pyclass::PyClassItems {
295                static ITEMS: #pyo3_path::impl_::pyclass::PyClassItems = #pyo3_path::impl_::pyclass::PyClassItems {
296                    methods: &[#(#methods),*],
297                    slots: &[#(#proto_impls),*]
298                };
299                &ITEMS
300            }
301        }
302    }
303}
304
305fn add_shared_proto_slots(
306    ty: &syn::Type,
307    proto_impls: &mut Vec<TokenStream>,
308    mut implemented_proto_fragments: HashSet<String>,
309    ctx: &Ctx,
310) {
311    let Ctx { pyo3_path, .. } = ctx;
312    macro_rules! try_add_shared_slot {
313        ($slot:ident, $($fragments:literal),*) => {{
314            let mut implemented = false;
315            $(implemented |= implemented_proto_fragments.remove($fragments));*;
316            if implemented {
317                proto_impls.push(quote! { #pyo3_path::impl_::pyclass::$slot!(#ty) })
318            }
319        }};
320    }
321
322    try_add_shared_slot!(
323        generate_pyclass_getattro_slot,
324        "__getattribute__",
325        "__getattr__"
326    );
327    try_add_shared_slot!(generate_pyclass_setattr_slot, "__setattr__", "__delattr__");
328    try_add_shared_slot!(generate_pyclass_setdescr_slot, "__set__", "__delete__");
329    try_add_shared_slot!(generate_pyclass_setitem_slot, "__setitem__", "__delitem__");
330    try_add_shared_slot!(generate_pyclass_add_slot, "__add__", "__radd__");
331    try_add_shared_slot!(generate_pyclass_sub_slot, "__sub__", "__rsub__");
332    try_add_shared_slot!(generate_pyclass_mul_slot, "__mul__", "__rmul__");
333    try_add_shared_slot!(generate_pyclass_mod_slot, "__mod__", "__rmod__");
334    try_add_shared_slot!(generate_pyclass_divmod_slot, "__divmod__", "__rdivmod__");
335    try_add_shared_slot!(generate_pyclass_lshift_slot, "__lshift__", "__rlshift__");
336    try_add_shared_slot!(generate_pyclass_rshift_slot, "__rshift__", "__rrshift__");
337    try_add_shared_slot!(generate_pyclass_and_slot, "__and__", "__rand__");
338    try_add_shared_slot!(generate_pyclass_or_slot, "__or__", "__ror__");
339    try_add_shared_slot!(generate_pyclass_xor_slot, "__xor__", "__rxor__");
340    try_add_shared_slot!(generate_pyclass_matmul_slot, "__matmul__", "__rmatmul__");
341    try_add_shared_slot!(generate_pyclass_truediv_slot, "__truediv__", "__rtruediv__");
342    try_add_shared_slot!(
343        generate_pyclass_floordiv_slot,
344        "__floordiv__",
345        "__rfloordiv__"
346    );
347    try_add_shared_slot!(generate_pyclass_pow_slot, "__pow__", "__rpow__");
348    try_add_shared_slot!(
349        generate_pyclass_richcompare_slot,
350        "__lt__",
351        "__le__",
352        "__eq__",
353        "__ne__",
354        "__gt__",
355        "__ge__"
356    );
357
358    // if this assertion trips, a slot fragment has been implemented which has not been added in the
359    // list above
360    assert!(implemented_proto_fragments.is_empty());
361}
362
363fn submit_methods_inventory(
364    ty: &syn::Type,
365    methods: Vec<TokenStream>,
366    proto_impls: Vec<TokenStream>,
367    ctx: &Ctx,
368) -> TokenStream {
369    let Ctx { pyo3_path, .. } = ctx;
370    quote! {
371        #pyo3_path::inventory::submit! {
372            type Inventory = <#ty as #pyo3_path::impl_::pyclass::PyClassImpl>::Inventory;
373            Inventory::new(#pyo3_path::impl_::pyclass::PyClassItems { methods: &[#(#methods),*], slots: &[#(#proto_impls),*] })
374        }
375    }
376}
377
378pub(crate) fn get_cfg_attributes(attrs: &[syn::Attribute]) -> Vec<&syn::Attribute> {
379    attrs
380        .iter()
381        .filter(|attr| attr.path().is_ident("cfg"))
382        .collect()
383}
384
385#[cfg(feature = "experimental-inspect")]
386pub fn method_introspection_code(
387    spec: &FnSpec<'_>,
388    attrs: &[syn::Attribute],
389    parent: &syn::Type,
390    is_returning_not_implemented_on_extraction_error: bool,
391    ctx: &Ctx,
392) -> TokenStream {
393    let Ctx { pyo3_path, .. } = ctx;
394
395    let name = spec.python_name.to_string();
396
397    // __richcmp__ special case
398    if name == "__richcmp__" {
399        // We expend into each individual method
400        return ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"]
401            .into_iter()
402            .map(|method_name| {
403                let mut spec = (*spec).clone();
404                spec.python_name = Ident::new(method_name, spec.python_name.span());
405                // We remove the CompareOp arg, this is safe because the signature is always the same
406                // First the other value to compare with then the CompareOp
407                // We cant to keep the first argument type, hence this hack
408                spec.signature.arguments.pop();
409                spec.signature.python_signature.positional_parameters.pop();
410                method_introspection_code(
411                    &spec,
412                    attrs,
413                    parent,
414                    is_returning_not_implemented_on_extraction_error,
415                    ctx,
416                )
417            })
418            .collect();
419    }
420    // We map or ignore some magic methods
421    // TODO: this might create a naming conflict
422    let name = match name.as_str() {
423        "__concat__" => "__add__".into(),
424        "__repeat__" => "__mul__".into(),
425        "__inplace_concat__" => "__iadd__".into(),
426        "__inplace_repeat__" => "__imul__".into(),
427        "__getbuffer__" | "__releasebuffer__" | "__traverse__" | "__clear__" => return quote! {},
428        _ => name,
429    };
430
431    // We introduce self/cls argument and setup decorators
432    let mut first_argument = None;
433    let mut decorators = Vec::new();
434    match &spec.tp {
435        FnType::Getter(_) => {
436            first_argument = Some("self");
437            decorators.push(PyExpr::builtin("property"));
438        }
439        FnType::Setter(_) => {
440            first_argument = Some("self");
441            decorators.push(PyExpr::attribute(
442                PyExpr::attribute(PyExpr::from_type(parent.clone(), None), name.clone()),
443                "setter",
444            ));
445        }
446        FnType::Deleter(_) => {
447            first_argument = Some("self");
448            decorators.push(PyExpr::attribute(
449                PyExpr::attribute(PyExpr::from_type(parent.clone(), None), name.clone()),
450                "deleter",
451            ));
452        }
453        FnType::Fn(_) => {
454            first_argument = Some("self");
455        }
456        FnType::FnClass(_) => {
457            first_argument = Some("cls");
458            if spec.python_name != "__new__" {
459                // special case __new__ - does not get the decorator
460                decorators.push(PyExpr::builtin("classmethod"));
461            }
462        }
463        FnType::FnStatic => {
464            if spec.python_name != "__new__" {
465                decorators.push(PyExpr::builtin("staticmethod"));
466            } else {
467                // special case __new__ - does not get the decorator and gets first argument
468                first_argument = Some("cls");
469            }
470        }
471        FnType::FnModule(_) => (), // TODO: not sure this can happen
472        FnType::ClassAttribute => {
473            // We return an attribute because there is no decorator for this case
474            return attribute_introspection_code(
475                pyo3_path,
476                Some(parent),
477                name,
478                PyExpr::ellipsis(),
479                if let ReturnType::Type(_, t) = &spec.output {
480                    (**t).clone()
481                } else {
482                    parse_quote!(#pyo3_path::Py<#pyo3_path::types::PyNone>)
483                },
484                get_doc(attrs, None).as_ref(),
485                true,
486            );
487        }
488    }
489    let return_type = if spec.python_name == "__new__" {
490        // Hack to return Self while implementing IntoPyObject
491        parse_quote!(-> #pyo3_path::PyRef<Self>)
492    } else {
493        spec.output.clone()
494    };
495    function_introspection_code(
496        pyo3_path,
497        None,
498        &name,
499        &spec.signature,
500        first_argument,
501        return_type,
502        decorators,
503        spec.asyncness.is_some(),
504        is_returning_not_implemented_on_extraction_error,
505        get_doc(attrs, None).as_ref(),
506        Some(parent),
507    )
508}