pyo3_derive_backend/
pyclass.rs

1// Copyright (c) 2017-present PyO3 Project and Contributors
2
3use crate::method::{FnType, SelfType};
4use crate::pymethod::{
5    impl_py_getter_def, impl_py_setter_def, impl_wrap_getter, impl_wrap_setter, PropertyType,
6};
7use crate::utils;
8use proc_macro2::{Span, TokenStream};
9use quote::quote;
10use syn::ext::IdentExt;
11use syn::parse::{Parse, ParseStream};
12use syn::punctuated::Punctuated;
13use syn::{parse_quote, Expr, Token};
14
15/// The parsed arguments of the pyclass macro
16pub struct PyClassArgs {
17    pub freelist: Option<syn::Expr>,
18    pub name: Option<syn::Expr>,
19    pub flags: Vec<syn::Expr>,
20    pub base: syn::TypePath,
21    pub has_extends: bool,
22    pub has_unsendable: bool,
23    pub module: Option<syn::LitStr>,
24}
25
26impl Parse for PyClassArgs {
27    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
28        let mut slf = PyClassArgs::default();
29
30        let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
31        for expr in vars {
32            slf.add_expr(&expr)?;
33        }
34        Ok(slf)
35    }
36}
37
38impl Default for PyClassArgs {
39    fn default() -> Self {
40        PyClassArgs {
41            freelist: None,
42            name: None,
43            module: None,
44            // We need the 0 as value for the constant we're later building using quote for when there
45            // are no other flags
46            flags: vec![parse_quote! { 0 }],
47            base: parse_quote! { pyo3::PyAny },
48            has_extends: false,
49            has_unsendable: false,
50        }
51    }
52}
53
54impl PyClassArgs {
55    /// Adda single expression from the comma separated list in the attribute, which is
56    /// either a single word or an assignment expression
57    fn add_expr(&mut self, expr: &Expr) -> syn::parse::Result<()> {
58        match expr {
59            syn::Expr::Path(ref exp) if exp.path.segments.len() == 1 => self.add_path(exp),
60            syn::Expr::Assign(ref assign) => self.add_assign(assign),
61            _ => Err(syn::Error::new_spanned(expr, "Failed to parse arguments")),
62        }
63    }
64
65    /// Match a key/value flag
66    fn add_assign(&mut self, assign: &syn::ExprAssign) -> syn::Result<()> {
67        let syn::ExprAssign { left, right, .. } = assign;
68        let key = match &**left {
69            syn::Expr::Path(exp) if exp.path.segments.len() == 1 => {
70                exp.path.segments.first().unwrap().ident.to_string()
71            }
72            _ => {
73                return Err(syn::Error::new_spanned(assign, "Failed to parse arguments"));
74            }
75        };
76
77        macro_rules! expected {
78            ($expected: literal) => {
79                expected!($expected, right)
80            };
81            ($expected: literal, $span: ident) => {
82                return Err(syn::Error::new_spanned(
83                    $span,
84                    concat!("Expected ", $expected),
85                ));
86            };
87        }
88
89        match key.as_str() {
90            "freelist" => {
91                // We allow arbitrary expressions here so you can e.g. use `8*64`
92                self.freelist = Some(syn::Expr::clone(right));
93            }
94            "name" => match &**right {
95                syn::Expr::Path(exp) if exp.path.segments.len() == 1 => {
96                    self.name = Some(exp.clone().into());
97                }
98                _ => expected!("type name (e.g., Name)"),
99            },
100            "extends" => match &**right {
101                syn::Expr::Path(exp) => {
102                    self.base = syn::TypePath {
103                        path: exp.path.clone(),
104                        qself: None,
105                    };
106                    self.has_extends = true;
107                }
108                _ => expected!("type path (e.g., my_mod::BaseClass)"),
109            },
110            "module" => match &**right {
111                syn::Expr::Lit(syn::ExprLit {
112                    lit: syn::Lit::Str(lit),
113                    ..
114                }) => {
115                    self.module = Some(lit.clone());
116                }
117                _ => expected!(r#"string literal (e.g., "my_mod")"#),
118            },
119            _ => expected!("one of freelist/name/extends/module", left),
120        };
121
122        Ok(())
123    }
124
125    /// Match a single flag
126    fn add_path(&mut self, exp: &syn::ExprPath) -> syn::Result<()> {
127        let flag = exp.path.segments.first().unwrap().ident.to_string();
128        let mut push_flag = |flag| {
129            self.flags.push(syn::Expr::Path(flag));
130        };
131        match flag.as_str() {
132            "gc" => push_flag(parse_quote! {pyo3::type_flags::GC}),
133            "weakref" => push_flag(parse_quote! {pyo3::type_flags::WEAKREF}),
134            "subclass" => push_flag(parse_quote! {pyo3::type_flags::BASETYPE}),
135            "dict" => push_flag(parse_quote! {pyo3::type_flags::DICT}),
136            "unsendable" => {
137                self.has_unsendable = true;
138            }
139            _ => {
140                return Err(syn::Error::new_spanned(
141                    &exp.path,
142                    "Expected one of gc/weakref/subclass/dict/unsendable",
143                ))
144            }
145        };
146        Ok(())
147    }
148}
149
150pub fn build_py_class(class: &mut syn::ItemStruct, attr: &PyClassArgs) -> syn::Result<TokenStream> {
151    let text_signature = utils::parse_text_signature_attrs(
152        &mut class.attrs,
153        &get_class_python_name(&class.ident, attr),
154    )?;
155    let doc = utils::get_doc(&class.attrs, text_signature, true)?;
156    let mut descriptors = Vec::new();
157
158    check_generics(class)?;
159    if let syn::Fields::Named(ref mut fields) = class.fields {
160        for field in fields.named.iter_mut() {
161            let field_descs = parse_descriptors(field)?;
162            if !field_descs.is_empty() {
163                descriptors.push((field.clone(), field_descs));
164            }
165        }
166    } else {
167        return Err(syn::Error::new_spanned(
168            &class.fields,
169            "#[pyclass] can only be used with C-style structs",
170        ));
171    }
172
173    impl_class(&class.ident, &attr, doc, descriptors)
174}
175
176/// Parses `#[pyo3(get, set)]`
177fn parse_descriptors(item: &mut syn::Field) -> syn::Result<Vec<FnType>> {
178    let mut descs = Vec::new();
179    let mut new_attrs = Vec::new();
180    for attr in item.attrs.iter() {
181        if let Ok(syn::Meta::List(ref list)) = attr.parse_meta() {
182            if list.path.is_ident("pyo3") {
183                for meta in list.nested.iter() {
184                    if let syn::NestedMeta::Meta(ref metaitem) = meta {
185                        if metaitem.path().is_ident("get") {
186                            descs.push(FnType::Getter(SelfType::Receiver { mutable: false }));
187                        } else if metaitem.path().is_ident("set") {
188                            descs.push(FnType::Setter(SelfType::Receiver { mutable: true }));
189                        } else {
190                            return Err(syn::Error::new_spanned(
191                                metaitem,
192                                "Only get and set are supported",
193                            ));
194                        }
195                    }
196                }
197            } else {
198                new_attrs.push(attr.clone())
199            }
200        } else {
201            new_attrs.push(attr.clone());
202        }
203    }
204    item.attrs.clear();
205    item.attrs.extend(new_attrs);
206    Ok(descs)
207}
208
209/// To allow multiple #[pymethods]/#[pyproto] block, we define inventory types.
210fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream {
211    // Try to build a unique type for better error messages
212    let name = format!("Pyo3MethodsInventoryFor{}", cls);
213    let inventory_cls = syn::Ident::new(&name, Span::call_site());
214
215    quote! {
216        #[doc(hidden)]
217        pub struct #inventory_cls {
218            methods: Vec<pyo3::class::PyMethodDefType>,
219        }
220        impl pyo3::class::methods::PyMethodsInventory for #inventory_cls {
221            fn new(methods: Vec<pyo3::class::PyMethodDefType>) -> Self {
222                Self { methods }
223            }
224            fn get(&'static self) -> &'static [pyo3::class::PyMethodDefType] {
225                &self.methods
226            }
227        }
228
229        impl pyo3::class::methods::HasMethodsInventory for #cls {
230            type Methods = #inventory_cls;
231        }
232
233        pyo3::inventory::collect!(#inventory_cls);
234    }
235}
236
237/// Implement `HasProtoRegistry` for the class for lazy protocol initialization.
238fn impl_proto_registry(cls: &syn::Ident) -> TokenStream {
239    quote! {
240        impl pyo3::class::proto_methods::HasProtoRegistry for #cls {
241            fn registry() -> &'static pyo3::class::proto_methods::PyProtoRegistry {
242                static REGISTRY: pyo3::class::proto_methods::PyProtoRegistry
243                    = pyo3::class::proto_methods::PyProtoRegistry::new();
244                &REGISTRY
245            }
246        }
247    }
248}
249
250fn get_class_python_name(cls: &syn::Ident, attr: &PyClassArgs) -> TokenStream {
251    match &attr.name {
252        Some(name) => quote! { #name },
253        None => quote! { #cls },
254    }
255}
256
257fn impl_class(
258    cls: &syn::Ident,
259    attr: &PyClassArgs,
260    doc: syn::LitStr,
261    descriptors: Vec<(syn::Field, Vec<FnType>)>,
262) -> syn::Result<TokenStream> {
263    let cls_name = get_class_python_name(cls, attr).to_string();
264
265    let extra = {
266        if let Some(freelist) = &attr.freelist {
267            quote! {
268                impl pyo3::freelist::PyClassWithFreeList for #cls {
269                    #[inline]
270                    fn get_free_list() -> &'static mut pyo3::freelist::FreeList<*mut pyo3::ffi::PyObject> {
271                        static mut FREELIST: *mut pyo3::freelist::FreeList<*mut pyo3::ffi::PyObject> = 0 as *mut _;
272                        unsafe {
273                            if FREELIST.is_null() {
274                                FREELIST = Box::into_raw(Box::new(
275                                    pyo3::freelist::FreeList::with_capacity(#freelist)));
276                            }
277                            &mut *FREELIST
278                        }
279                    }
280                }
281            }
282        } else {
283            quote! {
284                impl pyo3::pyclass::PyClassAlloc for #cls {}
285            }
286        }
287    };
288
289    let extra = if !descriptors.is_empty() {
290        let path = syn::Path::from(syn::PathSegment::from(cls.clone()));
291        let ty = syn::Type::from(syn::TypePath { path, qself: None });
292        let desc_impls = impl_descriptors(&ty, descriptors)?;
293        quote! {
294            #desc_impls
295            #extra
296        }
297    } else {
298        extra
299    };
300
301    // insert space for weak ref
302    let mut has_weakref = false;
303    let mut has_dict = false;
304    let mut has_gc = false;
305    for f in attr.flags.iter() {
306        if let syn::Expr::Path(ref epath) = f {
307            if epath.path == parse_quote! { pyo3::type_flags::WEAKREF } {
308                has_weakref = true;
309            } else if epath.path == parse_quote! { pyo3::type_flags::DICT } {
310                has_dict = true;
311            } else if epath.path == parse_quote! { pyo3::type_flags::GC } {
312                has_gc = true;
313            }
314        }
315    }
316
317    let weakref = if has_weakref {
318        quote! { pyo3::pyclass_slots::PyClassWeakRefSlot }
319    } else if attr.has_extends {
320        quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::WeakRef }
321    } else {
322        quote! { pyo3::pyclass_slots::PyClassDummySlot }
323    };
324    let dict = if has_dict {
325        quote! { pyo3::pyclass_slots::PyClassDictSlot }
326    } else if attr.has_extends {
327        quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::Dict }
328    } else {
329        quote! { pyo3::pyclass_slots::PyClassDummySlot }
330    };
331    let module = if let Some(m) = &attr.module {
332        quote! { Some(#m) }
333    } else {
334        quote! { None }
335    };
336
337    // Enforce at compile time that PyGCProtocol is implemented
338    let gc_impl = if has_gc {
339        let closure_name = format!("__assertion_closure_{}", cls);
340        let closure_token = syn::Ident::new(&closure_name, Span::call_site());
341        quote! {
342            fn #closure_token() {
343                use pyo3::class;
344
345                fn _assert_implements_protocol<'p, T: pyo3::class::PyGCProtocol<'p>>() {}
346                _assert_implements_protocol::<#cls>();
347            }
348        }
349    } else {
350        quote! {}
351    };
352
353    let impl_inventory = impl_methods_inventory(&cls);
354    let impl_proto_registry = impl_proto_registry(&cls);
355
356    let base = &attr.base;
357    let flags = &attr.flags;
358    let extended = if attr.has_extends {
359        quote! { pyo3::type_flags::EXTENDED }
360    } else {
361        quote! { 0 }
362    };
363    let base_layout = if attr.has_extends {
364        quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::LayoutAsBase }
365    } else {
366        quote! { pyo3::pycell::PyCellBase<pyo3::PyAny> }
367    };
368    let base_nativetype = if attr.has_extends {
369        quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::BaseNativeType }
370    } else {
371        quote! { pyo3::PyAny }
372    };
373
374    // If #cls is not extended type, we allow Self->PyObject conversion
375    let into_pyobject = if !attr.has_extends {
376        quote! {
377            impl pyo3::IntoPy<pyo3::PyObject> for #cls {
378                fn into_py(self, py: pyo3::Python) -> pyo3::PyObject {
379                    pyo3::IntoPy::into_py(pyo3::Py::new(py, self).unwrap(), py)
380                }
381            }
382        }
383    } else {
384        quote! {}
385    };
386
387    let thread_checker = if attr.has_unsendable {
388        quote! { pyo3::pyclass::ThreadCheckerImpl<#cls> }
389    } else if attr.has_extends {
390        quote! {
391            pyo3::pyclass::ThreadCheckerInherited<#cls, <#cls as pyo3::type_object::PyTypeInfo>::BaseType>
392        }
393    } else {
394        quote! { pyo3::pyclass::ThreadCheckerStub<#cls> }
395    };
396
397    Ok(quote! {
398        unsafe impl pyo3::type_object::PyTypeInfo for #cls {
399            type Type = #cls;
400            type BaseType = #base;
401            type Layout = pyo3::PyCell<Self>;
402            type BaseLayout = #base_layout;
403            type Initializer = pyo3::pyclass_init::PyClassInitializer<Self>;
404            type AsRefTarget = pyo3::PyCell<Self>;
405
406            const NAME: &'static str = #cls_name;
407            const MODULE: Option<&'static str> = #module;
408            const DESCRIPTION: &'static str = #doc;
409            const FLAGS: usize = #(#flags)|* | #extended;
410
411            #[inline]
412            fn type_object_raw(py: pyo3::Python) -> *mut pyo3::ffi::PyTypeObject {
413                use pyo3::type_object::LazyStaticType;
414                static TYPE_OBJECT: LazyStaticType = LazyStaticType::new();
415                TYPE_OBJECT.get_or_init::<Self>(py)
416            }
417        }
418
419        impl pyo3::PyClass for #cls {
420            type Dict = #dict;
421            type WeakRef = #weakref;
422            type BaseNativeType = #base_nativetype;
423        }
424
425        impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a #cls
426        {
427            type Target = pyo3::PyRef<'a, #cls>;
428        }
429
430        impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls
431        {
432            type Target = pyo3::PyRefMut<'a, #cls>;
433        }
434
435        impl pyo3::pyclass::PyClassSend for #cls {
436            type ThreadChecker = #thread_checker;
437        }
438
439        #into_pyobject
440
441        #impl_inventory
442
443        #impl_proto_registry
444
445        #extra
446
447        #gc_impl
448    })
449}
450
451fn impl_descriptors(
452    cls: &syn::Type,
453    descriptors: Vec<(syn::Field, Vec<FnType>)>,
454) -> syn::Result<TokenStream> {
455    let py_methods: Vec<TokenStream> = descriptors
456        .iter()
457        .flat_map(|&(ref field, ref fns)| {
458            fns.iter()
459                .map(|desc| {
460                    let name = field.ident.as_ref().unwrap().unraw();
461                    let doc = utils::get_doc(&field.attrs, None, true)
462                        .unwrap_or_else(|_| syn::LitStr::new(&name.to_string(), name.span()));
463
464                    match desc {
465                        FnType::Getter(self_ty) => Ok(impl_py_getter_def(
466                            &name,
467                            &doc,
468                            &impl_wrap_getter(&cls, PropertyType::Descriptor(&field), &self_ty)?,
469                        )),
470                        FnType::Setter(self_ty) => Ok(impl_py_setter_def(
471                            &name,
472                            &doc,
473                            &impl_wrap_setter(&cls, PropertyType::Descriptor(&field), &self_ty)?,
474                        )),
475                        _ => unreachable!(),
476                    }
477                })
478                .collect::<Vec<syn::Result<TokenStream>>>()
479        })
480        .collect::<syn::Result<_>>()?;
481
482    Ok(quote! {
483        pyo3::inventory::submit! {
484            #![crate = pyo3] {
485                type Inventory = <#cls as pyo3::class::methods::HasMethodsInventory>::Methods;
486                <Inventory as pyo3::class::methods::PyMethodsInventory>::new(vec![#(#py_methods),*])
487            }
488        }
489    })
490}
491
492fn check_generics(class: &mut syn::ItemStruct) -> syn::Result<()> {
493    if class.generics.params.is_empty() {
494        Ok(())
495    } else {
496        Err(syn::Error::new_spanned(
497            &class.generics,
498            "#[pyclass] cannot have generic parameters",
499        ))
500    }
501}