pyo3_derive_backend/
pyproto.rs

1// Copyright (c) 2017-present PyO3 Project and Contributors
2
3use crate::defs;
4use crate::method::{FnSpec, FnType};
5use crate::proto_method::impl_method_proto;
6use crate::pymethod;
7use proc_macro2::{Span, TokenStream};
8use quote::quote;
9use quote::ToTokens;
10use std::collections::HashSet;
11
12pub fn build_py_proto(ast: &mut syn::ItemImpl) -> syn::Result<TokenStream> {
13    if let Some((_, ref mut path, _)) = ast.trait_ {
14        let proto = if let Some(ref mut segment) = path.segments.last() {
15            match segment.ident.to_string().as_str() {
16                "PyObjectProtocol" => &defs::OBJECT,
17                "PyAsyncProtocol" => &defs::ASYNC,
18                "PyMappingProtocol" => &defs::MAPPING,
19                "PyIterProtocol" => &defs::ITER,
20                "PyContextProtocol" => &defs::CONTEXT,
21                "PySequenceProtocol" => &defs::SEQ,
22                "PyNumberProtocol" => &defs::NUM,
23                "PyDescrProtocol" => &defs::DESCR,
24                "PyBufferProtocol" => &defs::BUFFER,
25                "PyGCProtocol" => &defs::GC,
26                _ => {
27                    return Err(syn::Error::new_spanned(
28                        path,
29                        "#[pyproto] can not be used with this block",
30                    ));
31                }
32            }
33        } else {
34            return Err(syn::Error::new_spanned(
35                path,
36                "#[pyproto] can only be used with protocol trait implementations",
37            ));
38        };
39
40        let tokens = impl_proto_impl(&ast.self_ty, &mut ast.items, proto)?;
41
42        // attach lifetime
43        let mut seg = path.segments.pop().unwrap().into_value();
44        seg.arguments = syn::PathArguments::AngleBracketed(syn::parse_quote! {<'p>});
45        path.segments.push(seg);
46        ast.generics.params = syn::parse_quote! {'p};
47
48        Ok(tokens)
49    } else {
50        Err(syn::Error::new_spanned(
51            ast,
52            "#[pyproto] can only be used with protocol trait implementations",
53        ))
54    }
55}
56
57fn impl_proto_impl(
58    ty: &syn::Type,
59    impls: &mut Vec<syn::ImplItem>,
60    proto: &defs::Proto,
61) -> syn::Result<TokenStream> {
62    let mut trait_impls = TokenStream::new();
63    let mut py_methods = Vec::new();
64    let mut method_names = HashSet::new();
65
66    for iimpl in impls.iter_mut() {
67        if let syn::ImplItem::Method(ref mut met) = iimpl {
68            // impl Py~Protocol<'p> { type = ... }
69            if let Some(m) = proto.get_proto(&met.sig.ident) {
70                impl_method_proto(ty, &mut met.sig, m)?.to_tokens(&mut trait_impls);
71                // Insert the method to the HashSet
72                method_names.insert(met.sig.ident.to_string());
73            }
74            // Add non-slot methods to inventory like `#[pymethods]`
75            if let Some(m) = proto.get_method(&met.sig.ident) {
76                let name = &met.sig.ident;
77                let fn_spec = FnSpec::parse(&met.sig, &mut met.attrs, false)?;
78
79                let method = if let FnType::Fn(self_ty) = &fn_spec.tp {
80                    pymethod::impl_proto_wrap(ty, &fn_spec, &self_ty)
81                } else {
82                    return Err(syn::Error::new_spanned(
83                        &met.sig,
84                        "Expected method with receiver for #[pyproto] method",
85                    ));
86                };
87
88                let coexist = if m.can_coexist {
89                    // We need METH_COEXIST here to prevent __add__  from overriding __radd__
90                    quote!(pyo3::ffi::METH_COEXIST)
91                } else {
92                    quote!(0)
93                };
94                // TODO(kngwyu): Set ml_doc
95                py_methods.push(quote! {
96                    pyo3::class::PyMethodDefType::Method({
97                        #method
98                        pyo3::class::PyMethodDef::cfunction_with_keywords(
99                            concat!(stringify!(#name), "\0"),
100                            __wrap,
101                            #coexist,
102                            "\0"
103                        )
104                    })
105                });
106            }
107        }
108    }
109    let inventory_submission = inventory_submission(py_methods, ty);
110    let slot_initialization = slot_initialization(method_names, ty, proto)?;
111    Ok(quote! {
112        #trait_impls
113        #inventory_submission
114        #slot_initialization
115    })
116}
117
118fn inventory_submission(py_methods: Vec<TokenStream>, ty: &syn::Type) -> TokenStream {
119    if py_methods.is_empty() {
120        return quote! {};
121    }
122    quote! {
123        pyo3::inventory::submit! {
124            #![crate = pyo3] {
125                type Inventory = <#ty as pyo3::class::methods::HasMethodsInventory>::Methods;
126                <Inventory as pyo3::class::methods::PyMethodsInventory>::new(vec![#(#py_methods),*])
127            }
128        }
129    }
130}
131
132fn slot_initialization(
133    method_names: HashSet<String>,
134    ty: &syn::Type,
135    proto: &defs::Proto,
136) -> syn::Result<TokenStream> {
137    // Collect initializers
138    let mut initializers: Vec<TokenStream> = vec![];
139    for setter in proto.setters(method_names) {
140        // Add slot methods to PyProtoRegistry
141        let set = syn::Ident::new(setter, Span::call_site());
142        initializers.push(quote! { table.#set::<#ty>(); });
143    }
144    if initializers.is_empty() {
145        return Ok(quote! {});
146    }
147    let table: syn::Path = syn::parse_str(proto.slot_table)?;
148    let set = syn::Ident::new(proto.set_slot_table, Span::call_site());
149    let ty_hash = typename_hash(ty);
150    let init = syn::Ident::new(
151        &format!("__init_{}_{}", proto.name, ty_hash),
152        Span::call_site(),
153    );
154    Ok(quote! {
155        #[allow(non_snake_case)]
156        #[pyo3::ctor::ctor]
157        fn #init() {
158            let mut table = #table::default();
159            #(#initializers)*
160            <#ty as pyo3::class::proto_methods::HasProtoRegistry>::registry().#set(table);
161        }
162    })
163}
164
165fn typename_hash(ty: &syn::Type) -> u64 {
166    use std::hash::{Hash, Hasher};
167    let mut hasher = std::collections::hash_map::DefaultHasher::new();
168    ty.hash(&mut hasher);
169    hasher.finish()
170}