1use crate::method::{FnType, SelfType};
4use crate::pymethod::{
5 impl_py_getter_def, impl_py_setter_def, impl_wrap_getter, impl_wrap_setter, PropertyType,
6};
7use crate::utils;
8use proc_macro2::{Span, TokenStream};
9use quote::quote;
10use syn::ext::IdentExt;
11use syn::parse::{Parse, ParseStream};
12use syn::punctuated::Punctuated;
13use syn::{parse_quote, Expr, Token};
14
15pub struct PyClassArgs {
17 pub freelist: Option<syn::Expr>,
18 pub name: Option<syn::Expr>,
19 pub flags: Vec<syn::Expr>,
20 pub base: syn::TypePath,
21 pub has_extends: bool,
22 pub has_unsendable: bool,
23 pub module: Option<syn::LitStr>,
24}
25
26impl Parse for PyClassArgs {
27 fn parse(input: ParseStream) -> syn::parse::Result<Self> {
28 let mut slf = PyClassArgs::default();
29
30 let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
31 for expr in vars {
32 slf.add_expr(&expr)?;
33 }
34 Ok(slf)
35 }
36}
37
38impl Default for PyClassArgs {
39 fn default() -> Self {
40 PyClassArgs {
41 freelist: None,
42 name: None,
43 module: None,
44 flags: vec![parse_quote! { 0 }],
47 base: parse_quote! { pyo3::PyAny },
48 has_extends: false,
49 has_unsendable: false,
50 }
51 }
52}
53
54impl PyClassArgs {
55 fn add_expr(&mut self, expr: &Expr) -> syn::parse::Result<()> {
58 match expr {
59 syn::Expr::Path(ref exp) if exp.path.segments.len() == 1 => self.add_path(exp),
60 syn::Expr::Assign(ref assign) => self.add_assign(assign),
61 _ => Err(syn::Error::new_spanned(expr, "Failed to parse arguments")),
62 }
63 }
64
65 fn add_assign(&mut self, assign: &syn::ExprAssign) -> syn::Result<()> {
67 let syn::ExprAssign { left, right, .. } = assign;
68 let key = match &**left {
69 syn::Expr::Path(exp) if exp.path.segments.len() == 1 => {
70 exp.path.segments.first().unwrap().ident.to_string()
71 }
72 _ => {
73 return Err(syn::Error::new_spanned(assign, "Failed to parse arguments"));
74 }
75 };
76
77 macro_rules! expected {
78 ($expected: literal) => {
79 expected!($expected, right)
80 };
81 ($expected: literal, $span: ident) => {
82 return Err(syn::Error::new_spanned(
83 $span,
84 concat!("Expected ", $expected),
85 ));
86 };
87 }
88
89 match key.as_str() {
90 "freelist" => {
91 self.freelist = Some(syn::Expr::clone(right));
93 }
94 "name" => match &**right {
95 syn::Expr::Path(exp) if exp.path.segments.len() == 1 => {
96 self.name = Some(exp.clone().into());
97 }
98 _ => expected!("type name (e.g., Name)"),
99 },
100 "extends" => match &**right {
101 syn::Expr::Path(exp) => {
102 self.base = syn::TypePath {
103 path: exp.path.clone(),
104 qself: None,
105 };
106 self.has_extends = true;
107 }
108 _ => expected!("type path (e.g., my_mod::BaseClass)"),
109 },
110 "module" => match &**right {
111 syn::Expr::Lit(syn::ExprLit {
112 lit: syn::Lit::Str(lit),
113 ..
114 }) => {
115 self.module = Some(lit.clone());
116 }
117 _ => expected!(r#"string literal (e.g., "my_mod")"#),
118 },
119 _ => expected!("one of freelist/name/extends/module", left),
120 };
121
122 Ok(())
123 }
124
125 fn add_path(&mut self, exp: &syn::ExprPath) -> syn::Result<()> {
127 let flag = exp.path.segments.first().unwrap().ident.to_string();
128 let mut push_flag = |flag| {
129 self.flags.push(syn::Expr::Path(flag));
130 };
131 match flag.as_str() {
132 "gc" => push_flag(parse_quote! {pyo3::type_flags::GC}),
133 "weakref" => push_flag(parse_quote! {pyo3::type_flags::WEAKREF}),
134 "subclass" => push_flag(parse_quote! {pyo3::type_flags::BASETYPE}),
135 "dict" => push_flag(parse_quote! {pyo3::type_flags::DICT}),
136 "unsendable" => {
137 self.has_unsendable = true;
138 }
139 _ => {
140 return Err(syn::Error::new_spanned(
141 &exp.path,
142 "Expected one of gc/weakref/subclass/dict/unsendable",
143 ))
144 }
145 };
146 Ok(())
147 }
148}
149
150pub fn build_py_class(class: &mut syn::ItemStruct, attr: &PyClassArgs) -> syn::Result<TokenStream> {
151 let text_signature = utils::parse_text_signature_attrs(
152 &mut class.attrs,
153 &get_class_python_name(&class.ident, attr),
154 )?;
155 let doc = utils::get_doc(&class.attrs, text_signature, true)?;
156 let mut descriptors = Vec::new();
157
158 check_generics(class)?;
159 if let syn::Fields::Named(ref mut fields) = class.fields {
160 for field in fields.named.iter_mut() {
161 let field_descs = parse_descriptors(field)?;
162 if !field_descs.is_empty() {
163 descriptors.push((field.clone(), field_descs));
164 }
165 }
166 } else {
167 return Err(syn::Error::new_spanned(
168 &class.fields,
169 "#[pyclass] can only be used with C-style structs",
170 ));
171 }
172
173 impl_class(&class.ident, &attr, doc, descriptors)
174}
175
176fn parse_descriptors(item: &mut syn::Field) -> syn::Result<Vec<FnType>> {
178 let mut descs = Vec::new();
179 let mut new_attrs = Vec::new();
180 for attr in item.attrs.iter() {
181 if let Ok(syn::Meta::List(ref list)) = attr.parse_meta() {
182 if list.path.is_ident("pyo3") {
183 for meta in list.nested.iter() {
184 if let syn::NestedMeta::Meta(ref metaitem) = meta {
185 if metaitem.path().is_ident("get") {
186 descs.push(FnType::Getter(SelfType::Receiver { mutable: false }));
187 } else if metaitem.path().is_ident("set") {
188 descs.push(FnType::Setter(SelfType::Receiver { mutable: true }));
189 } else {
190 return Err(syn::Error::new_spanned(
191 metaitem,
192 "Only get and set are supported",
193 ));
194 }
195 }
196 }
197 } else {
198 new_attrs.push(attr.clone())
199 }
200 } else {
201 new_attrs.push(attr.clone());
202 }
203 }
204 item.attrs.clear();
205 item.attrs.extend(new_attrs);
206 Ok(descs)
207}
208
209fn impl_methods_inventory(cls: &syn::Ident) -> TokenStream {
211 let name = format!("Pyo3MethodsInventoryFor{}", cls);
213 let inventory_cls = syn::Ident::new(&name, Span::call_site());
214
215 quote! {
216 #[doc(hidden)]
217 pub struct #inventory_cls {
218 methods: Vec<pyo3::class::PyMethodDefType>,
219 }
220 impl pyo3::class::methods::PyMethodsInventory for #inventory_cls {
221 fn new(methods: Vec<pyo3::class::PyMethodDefType>) -> Self {
222 Self { methods }
223 }
224 fn get(&'static self) -> &'static [pyo3::class::PyMethodDefType] {
225 &self.methods
226 }
227 }
228
229 impl pyo3::class::methods::HasMethodsInventory for #cls {
230 type Methods = #inventory_cls;
231 }
232
233 pyo3::inventory::collect!(#inventory_cls);
234 }
235}
236
237fn impl_proto_registry(cls: &syn::Ident) -> TokenStream {
239 quote! {
240 impl pyo3::class::proto_methods::HasProtoRegistry for #cls {
241 fn registry() -> &'static pyo3::class::proto_methods::PyProtoRegistry {
242 static REGISTRY: pyo3::class::proto_methods::PyProtoRegistry
243 = pyo3::class::proto_methods::PyProtoRegistry::new();
244 ®ISTRY
245 }
246 }
247 }
248}
249
250fn get_class_python_name(cls: &syn::Ident, attr: &PyClassArgs) -> TokenStream {
251 match &attr.name {
252 Some(name) => quote! { #name },
253 None => quote! { #cls },
254 }
255}
256
257fn impl_class(
258 cls: &syn::Ident,
259 attr: &PyClassArgs,
260 doc: syn::LitStr,
261 descriptors: Vec<(syn::Field, Vec<FnType>)>,
262) -> syn::Result<TokenStream> {
263 let cls_name = get_class_python_name(cls, attr).to_string();
264
265 let extra = {
266 if let Some(freelist) = &attr.freelist {
267 quote! {
268 impl pyo3::freelist::PyClassWithFreeList for #cls {
269 #[inline]
270 fn get_free_list() -> &'static mut pyo3::freelist::FreeList<*mut pyo3::ffi::PyObject> {
271 static mut FREELIST: *mut pyo3::freelist::FreeList<*mut pyo3::ffi::PyObject> = 0 as *mut _;
272 unsafe {
273 if FREELIST.is_null() {
274 FREELIST = Box::into_raw(Box::new(
275 pyo3::freelist::FreeList::with_capacity(#freelist)));
276 }
277 &mut *FREELIST
278 }
279 }
280 }
281 }
282 } else {
283 quote! {
284 impl pyo3::pyclass::PyClassAlloc for #cls {}
285 }
286 }
287 };
288
289 let extra = if !descriptors.is_empty() {
290 let path = syn::Path::from(syn::PathSegment::from(cls.clone()));
291 let ty = syn::Type::from(syn::TypePath { path, qself: None });
292 let desc_impls = impl_descriptors(&ty, descriptors)?;
293 quote! {
294 #desc_impls
295 #extra
296 }
297 } else {
298 extra
299 };
300
301 let mut has_weakref = false;
303 let mut has_dict = false;
304 let mut has_gc = false;
305 for f in attr.flags.iter() {
306 if let syn::Expr::Path(ref epath) = f {
307 if epath.path == parse_quote! { pyo3::type_flags::WEAKREF } {
308 has_weakref = true;
309 } else if epath.path == parse_quote! { pyo3::type_flags::DICT } {
310 has_dict = true;
311 } else if epath.path == parse_quote! { pyo3::type_flags::GC } {
312 has_gc = true;
313 }
314 }
315 }
316
317 let weakref = if has_weakref {
318 quote! { pyo3::pyclass_slots::PyClassWeakRefSlot }
319 } else if attr.has_extends {
320 quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::WeakRef }
321 } else {
322 quote! { pyo3::pyclass_slots::PyClassDummySlot }
323 };
324 let dict = if has_dict {
325 quote! { pyo3::pyclass_slots::PyClassDictSlot }
326 } else if attr.has_extends {
327 quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::Dict }
328 } else {
329 quote! { pyo3::pyclass_slots::PyClassDummySlot }
330 };
331 let module = if let Some(m) = &attr.module {
332 quote! { Some(#m) }
333 } else {
334 quote! { None }
335 };
336
337 let gc_impl = if has_gc {
339 let closure_name = format!("__assertion_closure_{}", cls);
340 let closure_token = syn::Ident::new(&closure_name, Span::call_site());
341 quote! {
342 fn #closure_token() {
343 use pyo3::class;
344
345 fn _assert_implements_protocol<'p, T: pyo3::class::PyGCProtocol<'p>>() {}
346 _assert_implements_protocol::<#cls>();
347 }
348 }
349 } else {
350 quote! {}
351 };
352
353 let impl_inventory = impl_methods_inventory(&cls);
354 let impl_proto_registry = impl_proto_registry(&cls);
355
356 let base = &attr.base;
357 let flags = &attr.flags;
358 let extended = if attr.has_extends {
359 quote! { pyo3::type_flags::EXTENDED }
360 } else {
361 quote! { 0 }
362 };
363 let base_layout = if attr.has_extends {
364 quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::LayoutAsBase }
365 } else {
366 quote! { pyo3::pycell::PyCellBase<pyo3::PyAny> }
367 };
368 let base_nativetype = if attr.has_extends {
369 quote! { <Self::BaseType as pyo3::derive_utils::PyBaseTypeUtils>::BaseNativeType }
370 } else {
371 quote! { pyo3::PyAny }
372 };
373
374 let into_pyobject = if !attr.has_extends {
376 quote! {
377 impl pyo3::IntoPy<pyo3::PyObject> for #cls {
378 fn into_py(self, py: pyo3::Python) -> pyo3::PyObject {
379 pyo3::IntoPy::into_py(pyo3::Py::new(py, self).unwrap(), py)
380 }
381 }
382 }
383 } else {
384 quote! {}
385 };
386
387 let thread_checker = if attr.has_unsendable {
388 quote! { pyo3::pyclass::ThreadCheckerImpl<#cls> }
389 } else if attr.has_extends {
390 quote! {
391 pyo3::pyclass::ThreadCheckerInherited<#cls, <#cls as pyo3::type_object::PyTypeInfo>::BaseType>
392 }
393 } else {
394 quote! { pyo3::pyclass::ThreadCheckerStub<#cls> }
395 };
396
397 Ok(quote! {
398 unsafe impl pyo3::type_object::PyTypeInfo for #cls {
399 type Type = #cls;
400 type BaseType = #base;
401 type Layout = pyo3::PyCell<Self>;
402 type BaseLayout = #base_layout;
403 type Initializer = pyo3::pyclass_init::PyClassInitializer<Self>;
404 type AsRefTarget = pyo3::PyCell<Self>;
405
406 const NAME: &'static str = #cls_name;
407 const MODULE: Option<&'static str> = #module;
408 const DESCRIPTION: &'static str = #doc;
409 const FLAGS: usize = #(#flags)|* | #extended;
410
411 #[inline]
412 fn type_object_raw(py: pyo3::Python) -> *mut pyo3::ffi::PyTypeObject {
413 use pyo3::type_object::LazyStaticType;
414 static TYPE_OBJECT: LazyStaticType = LazyStaticType::new();
415 TYPE_OBJECT.get_or_init::<Self>(py)
416 }
417 }
418
419 impl pyo3::PyClass for #cls {
420 type Dict = #dict;
421 type WeakRef = #weakref;
422 type BaseNativeType = #base_nativetype;
423 }
424
425 impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a #cls
426 {
427 type Target = pyo3::PyRef<'a, #cls>;
428 }
429
430 impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls
431 {
432 type Target = pyo3::PyRefMut<'a, #cls>;
433 }
434
435 impl pyo3::pyclass::PyClassSend for #cls {
436 type ThreadChecker = #thread_checker;
437 }
438
439 #into_pyobject
440
441 #impl_inventory
442
443 #impl_proto_registry
444
445 #extra
446
447 #gc_impl
448 })
449}
450
451fn impl_descriptors(
452 cls: &syn::Type,
453 descriptors: Vec<(syn::Field, Vec<FnType>)>,
454) -> syn::Result<TokenStream> {
455 let py_methods: Vec<TokenStream> = descriptors
456 .iter()
457 .flat_map(|&(ref field, ref fns)| {
458 fns.iter()
459 .map(|desc| {
460 let name = field.ident.as_ref().unwrap().unraw();
461 let doc = utils::get_doc(&field.attrs, None, true)
462 .unwrap_or_else(|_| syn::LitStr::new(&name.to_string(), name.span()));
463
464 match desc {
465 FnType::Getter(self_ty) => Ok(impl_py_getter_def(
466 &name,
467 &doc,
468 &impl_wrap_getter(&cls, PropertyType::Descriptor(&field), &self_ty)?,
469 )),
470 FnType::Setter(self_ty) => Ok(impl_py_setter_def(
471 &name,
472 &doc,
473 &impl_wrap_setter(&cls, PropertyType::Descriptor(&field), &self_ty)?,
474 )),
475 _ => unreachable!(),
476 }
477 })
478 .collect::<Vec<syn::Result<TokenStream>>>()
479 })
480 .collect::<syn::Result<_>>()?;
481
482 Ok(quote! {
483 pyo3::inventory::submit! {
484 #![crate = pyo3] {
485 type Inventory = <#cls as pyo3::class::methods::HasMethodsInventory>::Methods;
486 <Inventory as pyo3::class::methods::PyMethodsInventory>::new(vec![#(#py_methods),*])
487 }
488 }
489 })
490}
491
492fn check_generics(class: &mut syn::ItemStruct) -> syn::Result<()> {
493 if class.generics.params.is_empty() {
494 Ok(())
495 } else {
496 Err(syn::Error::new_spanned(
497 &class.generics,
498 "#[pyclass] cannot have generic parameters",
499 ))
500 }
501}