pyo3_derive_backend/
pyproto.rs1use crate::defs;
4use crate::method::{FnSpec, FnType};
5use crate::proto_method::impl_method_proto;
6use crate::pymethod;
7use proc_macro2::{Span, TokenStream};
8use quote::quote;
9use quote::ToTokens;
10use std::collections::HashSet;
11
12pub fn build_py_proto(ast: &mut syn::ItemImpl) -> syn::Result<TokenStream> {
13 if let Some((_, ref mut path, _)) = ast.trait_ {
14 let proto = if let Some(ref mut segment) = path.segments.last() {
15 match segment.ident.to_string().as_str() {
16 "PyObjectProtocol" => &defs::OBJECT,
17 "PyAsyncProtocol" => &defs::ASYNC,
18 "PyMappingProtocol" => &defs::MAPPING,
19 "PyIterProtocol" => &defs::ITER,
20 "PyContextProtocol" => &defs::CONTEXT,
21 "PySequenceProtocol" => &defs::SEQ,
22 "PyNumberProtocol" => &defs::NUM,
23 "PyDescrProtocol" => &defs::DESCR,
24 "PyBufferProtocol" => &defs::BUFFER,
25 "PyGCProtocol" => &defs::GC,
26 _ => {
27 return Err(syn::Error::new_spanned(
28 path,
29 "#[pyproto] can not be used with this block",
30 ));
31 }
32 }
33 } else {
34 return Err(syn::Error::new_spanned(
35 path,
36 "#[pyproto] can only be used with protocol trait implementations",
37 ));
38 };
39
40 let tokens = impl_proto_impl(&ast.self_ty, &mut ast.items, proto)?;
41
42 let mut seg = path.segments.pop().unwrap().into_value();
44 seg.arguments = syn::PathArguments::AngleBracketed(syn::parse_quote! {<'p>});
45 path.segments.push(seg);
46 ast.generics.params = syn::parse_quote! {'p};
47
48 Ok(tokens)
49 } else {
50 Err(syn::Error::new_spanned(
51 ast,
52 "#[pyproto] can only be used with protocol trait implementations",
53 ))
54 }
55}
56
57fn impl_proto_impl(
58 ty: &syn::Type,
59 impls: &mut Vec<syn::ImplItem>,
60 proto: &defs::Proto,
61) -> syn::Result<TokenStream> {
62 let mut trait_impls = TokenStream::new();
63 let mut py_methods = Vec::new();
64 let mut method_names = HashSet::new();
65
66 for iimpl in impls.iter_mut() {
67 if let syn::ImplItem::Method(ref mut met) = iimpl {
68 if let Some(m) = proto.get_proto(&met.sig.ident) {
70 impl_method_proto(ty, &mut met.sig, m)?.to_tokens(&mut trait_impls);
71 method_names.insert(met.sig.ident.to_string());
73 }
74 if let Some(m) = proto.get_method(&met.sig.ident) {
76 let name = &met.sig.ident;
77 let fn_spec = FnSpec::parse(&met.sig, &mut met.attrs, false)?;
78
79 let method = if let FnType::Fn(self_ty) = &fn_spec.tp {
80 pymethod::impl_proto_wrap(ty, &fn_spec, &self_ty)
81 } else {
82 return Err(syn::Error::new_spanned(
83 &met.sig,
84 "Expected method with receiver for #[pyproto] method",
85 ));
86 };
87
88 let coexist = if m.can_coexist {
89 quote!(pyo3::ffi::METH_COEXIST)
91 } else {
92 quote!(0)
93 };
94 py_methods.push(quote! {
96 pyo3::class::PyMethodDefType::Method({
97 #method
98 pyo3::class::PyMethodDef::cfunction_with_keywords(
99 concat!(stringify!(#name), "\0"),
100 __wrap,
101 #coexist,
102 "\0"
103 )
104 })
105 });
106 }
107 }
108 }
109 let inventory_submission = inventory_submission(py_methods, ty);
110 let slot_initialization = slot_initialization(method_names, ty, proto)?;
111 Ok(quote! {
112 #trait_impls
113 #inventory_submission
114 #slot_initialization
115 })
116}
117
118fn inventory_submission(py_methods: Vec<TokenStream>, ty: &syn::Type) -> TokenStream {
119 if py_methods.is_empty() {
120 return quote! {};
121 }
122 quote! {
123 pyo3::inventory::submit! {
124 #![crate = pyo3] {
125 type Inventory = <#ty as pyo3::class::methods::HasMethodsInventory>::Methods;
126 <Inventory as pyo3::class::methods::PyMethodsInventory>::new(vec![#(#py_methods),*])
127 }
128 }
129 }
130}
131
132fn slot_initialization(
133 method_names: HashSet<String>,
134 ty: &syn::Type,
135 proto: &defs::Proto,
136) -> syn::Result<TokenStream> {
137 let mut initializers: Vec<TokenStream> = vec![];
139 for setter in proto.setters(method_names) {
140 let set = syn::Ident::new(setter, Span::call_site());
142 initializers.push(quote! { table.#set::<#ty>(); });
143 }
144 if initializers.is_empty() {
145 return Ok(quote! {});
146 }
147 let table: syn::Path = syn::parse_str(proto.slot_table)?;
148 let set = syn::Ident::new(proto.set_slot_table, Span::call_site());
149 let ty_hash = typename_hash(ty);
150 let init = syn::Ident::new(
151 &format!("__init_{}_{}", proto.name, ty_hash),
152 Span::call_site(),
153 );
154 Ok(quote! {
155 #[allow(non_snake_case)]
156 #[pyo3::ctor::ctor]
157 fn #init() {
158 let mut table = #table::default();
159 #(#initializers)*
160 <#ty as pyo3::class::proto_methods::HasProtoRegistry>::registry().#set(table);
161 }
162 })
163}
164
165fn typename_hash(ty: &syn::Type) -> u64 {
166 use std::hash::{Hash, Hasher};
167 let mut hasher = std::collections::hash_map::DefaultHasher::new();
168 ty.hash(&mut hasher);
169 hasher.finish()
170}