1use crate::attributes::{DefaultAttribute, FromPyWithAttribute, RenamingRule};
2use crate::derive_attributes::{ContainerAttributes, FieldAttributes, FieldGetter};
3#[cfg(feature = "experimental-inspect")]
4use crate::introspection::{elide_lifetimes, ConcatenationBuilder};
5use crate::utils::{self, Ctx};
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use syn::{
9 ext::IdentExt, parse_quote, punctuated::Punctuated, spanned::Spanned, DataEnum, DeriveInput,
10 Fields, Ident, Result, Token,
11};
12
13struct Enum<'a> {
15 enum_ident: &'a Ident,
16 variants: Vec<Container<'a>>,
17}
18
19impl<'a> Enum<'a> {
20 fn new(
25 data_enum: &'a DataEnum,
26 ident: &'a Ident,
27 options: ContainerAttributes,
28 ) -> Result<Self> {
29 ensure_spanned!(
30 !data_enum.variants.is_empty(),
31 ident.span() => "cannot derive FromPyObject for empty enum"
32 );
33 let variants = data_enum
34 .variants
35 .iter()
36 .map(|variant| {
37 let mut variant_options = ContainerAttributes::from_attrs(&variant.attrs)?;
38 if let Some(rename_all) = &options.rename_all {
39 ensure_spanned!(
40 variant_options.rename_all.is_none(),
41 variant_options.rename_all.span() => "Useless variant `rename_all` - enum is already annotated with `rename_all"
42 );
43 variant_options.rename_all = Some(rename_all.clone());
44
45 }
46 let var_ident = &variant.ident;
47 Container::new(
48 &variant.fields,
49 parse_quote!(#ident::#var_ident),
50 variant_options,
51 )
52 })
53 .collect::<Result<Vec<_>>>()?;
54
55 Ok(Enum {
56 enum_ident: ident,
57 variants,
58 })
59 }
60
61 fn build(&self, ctx: &Ctx) -> TokenStream {
63 let Ctx { pyo3_path, .. } = ctx;
64 let mut var_extracts = Vec::new();
65 let mut variant_names = Vec::new();
66 let mut error_names = Vec::new();
67
68 for var in &self.variants {
69 let struct_derive = var.build(ctx);
70 let ext = quote!({
71 let maybe_ret = || -> #pyo3_path::PyResult<Self> {
72 #struct_derive
73 }();
74
75 match maybe_ret {
76 ok @ ::std::result::Result::Ok(_) => return ok,
77 ::std::result::Result::Err(err) => err
78 }
79 });
80
81 var_extracts.push(ext);
82 variant_names.push(var.path.segments.last().unwrap().ident.to_string());
83 error_names.push(&var.err_name);
84 }
85 let ty_name = self.enum_ident.to_string();
86 quote!(
87 let errors = [
88 #(#var_extracts),*
89 ];
90 ::std::result::Result::Err(
91 #pyo3_path::impl_::frompyobject::failed_to_extract_enum(
92 obj.py(),
93 #ty_name,
94 &[#(#variant_names),*],
95 &[#(#error_names),*],
96 &errors
97 )
98 )
99 )
100 }
101
102 #[cfg(feature = "experimental-inspect")]
103 fn write_input_type(&self, builder: &mut ConcatenationBuilder, ctx: &Ctx) {
104 for (i, var) in self.variants.iter().enumerate() {
105 if i > 0 {
106 builder.push_str(" | ");
107 }
108 var.write_input_type(builder, ctx);
109 }
110 }
111}
112
113struct NamedStructField<'a> {
114 ident: &'a syn::Ident,
115 getter: Option<FieldGetter>,
116 from_py_with: Option<FromPyWithAttribute>,
117 default: Option<DefaultAttribute>,
118 ty: &'a syn::Type,
119}
120
121struct TupleStructField {
122 from_py_with: Option<FromPyWithAttribute>,
123 ty: syn::Type,
124}
125
126enum ContainerType<'a> {
130 Struct(Vec<NamedStructField<'a>>),
134 #[cfg_attr(not(feature = "experimental-inspect"), allow(unused))]
138 StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>, &'a syn::Type),
139 Tuple(Vec<TupleStructField>),
144 #[cfg_attr(not(feature = "experimental-inspect"), allow(unused))]
148 TupleNewtype(Option<FromPyWithAttribute>, Box<syn::Type>),
149}
150
151struct Container<'a> {
155 path: syn::Path,
156 ty: ContainerType<'a>,
157 err_name: String,
158 rename_rule: Option<RenamingRule>,
159}
160
161impl<'a> Container<'a> {
162 fn new(fields: &'a Fields, path: syn::Path, options: ContainerAttributes) -> Result<Self> {
166 let style = match fields {
167 Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
168 ensure_spanned!(
169 options.rename_all.is_none(),
170 options.rename_all.span() => "`rename_all` is useless on tuple structs and variants."
171 );
172 let mut tuple_fields = unnamed
173 .unnamed
174 .iter()
175 .map(|field| {
176 let attrs = FieldAttributes::from_attrs(&field.attrs)?;
177 ensure_spanned!(
178 attrs.getter.is_none(),
179 field.span() => "`getter` is not permitted on tuple struct elements."
180 );
181 ensure_spanned!(
182 attrs.default.is_none(),
183 field.span() => "`default` is not permitted on tuple struct elements."
184 );
185 Ok(TupleStructField {
186 from_py_with: attrs.from_py_with,
187 ty: field.ty.clone(),
188 })
189 })
190 .collect::<Result<Vec<_>>>()?;
191
192 if tuple_fields.len() == 1 {
193 let field = tuple_fields.pop().unwrap();
196 ContainerType::TupleNewtype(field.from_py_with, Box::new(field.ty))
197 } else if options.transparent.is_some() {
198 bail_spanned!(
199 fields.span() => "transparent structs and variants can only have 1 field"
200 );
201 } else {
202 ContainerType::Tuple(tuple_fields)
203 }
204 }
205 Fields::Named(named) if !named.named.is_empty() => {
206 let mut struct_fields = named
207 .named
208 .iter()
209 .map(|field| {
210 let ident = field
211 .ident
212 .as_ref()
213 .expect("Named fields should have identifiers");
214 let mut attrs = FieldAttributes::from_attrs(&field.attrs)?;
215
216 if let Some(ref from_item_all) = options.from_item_all {
217 if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(parse_quote!(item), None))
218 {
219 match replaced {
220 FieldGetter::GetItem(item, Some(item_name)) => {
221 attrs.getter = Some(FieldGetter::GetItem(item, Some(item_name)));
222 }
223 FieldGetter::GetItem(_, None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`"),
224 FieldGetter::GetAttr(_, _) => bail_spanned!(
225 from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed"
226 ),
227 }
228 }
229 }
230
231 Ok(NamedStructField {
232 ident,
233 getter: attrs.getter,
234 from_py_with: attrs.from_py_with,
235 default: attrs.default,
236 ty: &field.ty,
237 })
238 })
239 .collect::<Result<Vec<_>>>()?;
240 if struct_fields.iter().all(|field| field.default.is_some()) {
241 bail_spanned!(
242 fields.span() => "cannot derive FromPyObject for structs and variants with only default values"
243 )
244 } else if options.transparent.is_some() {
245 ensure_spanned!(
246 struct_fields.len() == 1,
247 fields.span() => "transparent structs and variants can only have 1 field"
248 );
249 ensure_spanned!(
250 options.rename_all.is_none(),
251 options.rename_all.span() => "`rename_all` is not permitted on `transparent` structs and variants"
252 );
253 let field = struct_fields.pop().unwrap();
254 ensure_spanned!(
255 field.getter.is_none(),
256 field.ident.span() => "`transparent` structs may not have a `getter` for the inner field"
257 );
258 ContainerType::StructNewtype(field.ident, field.from_py_with, field.ty)
259 } else {
260 ContainerType::Struct(struct_fields)
261 }
262 }
263 _ => bail_spanned!(
264 fields.span() => "cannot derive FromPyObject for empty structs and variants"
265 ),
266 };
267 let err_name = options.annotation.map_or_else(
268 || path.segments.last().unwrap().ident.to_string(),
269 |lit_str| lit_str.value(),
270 );
271
272 let v = Container {
273 path,
274 ty: style,
275 err_name,
276 rename_rule: options.rename_all.map(|v| v.value.rule),
277 };
278 Ok(v)
279 }
280
281 fn name(&self) -> String {
282 let mut value = String::new();
283 for segment in &self.path.segments {
284 if !value.is_empty() {
285 value.push_str("::");
286 }
287 value.push_str(&segment.ident.to_string());
288 }
289 value
290 }
291
292 fn build(&self, ctx: &Ctx) -> TokenStream {
294 match &self.ty {
295 ContainerType::StructNewtype(ident, from_py_with, _) => {
296 self.build_newtype_struct(Some(ident), from_py_with, ctx)
297 }
298 ContainerType::TupleNewtype(from_py_with, _) => {
299 self.build_newtype_struct(None, from_py_with, ctx)
300 }
301 ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx),
302 ContainerType::Struct(tups) => self.build_struct(tups, ctx),
303 }
304 }
305
306 fn build_newtype_struct(
307 &self,
308 field_ident: Option<&Ident>,
309 from_py_with: &Option<FromPyWithAttribute>,
310 ctx: &Ctx,
311 ) -> TokenStream {
312 let Ctx { pyo3_path, .. } = ctx;
313 let self_ty = &self.path;
314 let struct_name = self.name();
315 if let Some(ident) = field_ident {
316 let field_name = ident.to_string();
317 if let Some(FromPyWithAttribute {
318 kw,
319 value: expr_path,
320 }) = from_py_with
321 {
322 let extractor = quote_spanned! { kw.span =>
323 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
324 };
325 quote! {
326 Ok(#self_ty {
327 #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, obj, #struct_name, #field_name)?
328 })
329 }
330 } else {
331 quote! {
332 Ok(#self_ty {
333 #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
334 })
335 }
336 }
337 } else if let Some(FromPyWithAttribute {
338 kw,
339 value: expr_path,
340 }) = from_py_with
341 {
342 let extractor = quote_spanned! { kw.span =>
343 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
344 };
345 quote! {
346 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, obj, #struct_name, 0).map(#self_ty)
347 }
348 } else {
349 quote! {
350 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
351 }
352 }
353 }
354
355 fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream {
356 let Ctx { pyo3_path, .. } = ctx;
357 let self_ty = &self.path;
358 let struct_name = &self.name();
359 let field_idents: Vec<_> = (0..struct_fields.len())
360 .map(|i| format_ident!("arg{}", i))
361 .collect();
362 let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
363 if let Some(FromPyWithAttribute {
364 kw,
365 value: expr_path, ..
366 }) = &field.from_py_with {
367 let extractor = quote_spanned! { kw.span =>
368 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
369 };
370 quote! {
371 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, &#ident, #struct_name, #index)?
372 }
373 } else {
374 quote!{
375 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
376 }}
377 });
378
379 quote!(
380 match #pyo3_path::types::PyAnyMethods::extract(obj) {
381 ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
382 ::std::result::Result::Err(err) => ::std::result::Result::Err(err),
383 }
384 )
385 }
386
387 fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream {
388 let Ctx { pyo3_path, .. } = ctx;
389 let self_ty = &self.path;
390 let struct_name = self.name();
391 let mut fields: Punctuated<TokenStream, Token![,]> = Punctuated::new();
392 for field in struct_fields {
393 let ident = field.ident;
394 let field_name = ident.unraw().to_string();
395 let getter = match field
396 .getter
397 .as_ref()
398 .unwrap_or(&FieldGetter::GetAttr(parse_quote!(attribute), None))
399 {
400 FieldGetter::GetAttr(_, Some(name)) => {
401 quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
402 }
403 FieldGetter::GetAttr(_, None) => {
404 let name = self
405 .rename_rule
406 .map(|rule| utils::apply_renaming_rule(rule, &field_name));
407 let name = name.as_deref().unwrap_or(&field_name);
408 quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
409 }
410 FieldGetter::GetItem(_, Some(syn::Lit::Str(key))) => {
411 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #key)))
412 }
413 FieldGetter::GetItem(_, Some(key)) => {
414 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #key))
415 }
416 FieldGetter::GetItem(_, None) => {
417 let name = self
418 .rename_rule
419 .map(|rule| utils::apply_renaming_rule(rule, &field_name));
420 let name = name.as_deref().unwrap_or(&field_name);
421 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #name)))
422 }
423 };
424 let extractor = if let Some(FromPyWithAttribute {
425 kw,
426 value: expr_path,
427 }) = &field.from_py_with
428 {
429 let extractor = quote_spanned! { kw.span =>
430 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
431 };
432 quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, &#getter?, #struct_name, #field_name)?)
433 } else {
434 quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
435 };
436 let extracted = if let Some(default) = &field.default {
437 let default_expr = if let Some(default_expr) = &default.value {
438 default_expr.to_token_stream()
439 } else {
440 quote!(::std::default::Default::default())
441 };
442 quote!(if let ::std::result::Result::Ok(value) = #getter {
443 #extractor
444 } else {
445 #default_expr
446 })
447 } else {
448 quote!({
449 let value = #getter?;
450 #extractor
451 })
452 };
453
454 fields.push(quote!(#ident: #extracted));
455 }
456
457 quote!(::std::result::Result::Ok(#self_ty{#fields}))
458 }
459
460 #[cfg(feature = "experimental-inspect")]
461 fn write_input_type(&self, builder: &mut ConcatenationBuilder, ctx: &Ctx) {
462 match &self.ty {
463 ContainerType::StructNewtype(_, from_py_with, ty) => {
464 Self::write_field_input_type(from_py_with, ty, builder, ctx);
465 }
466 ContainerType::TupleNewtype(from_py_with, ty) => {
467 Self::write_field_input_type(from_py_with, ty, builder, ctx);
468 }
469 ContainerType::Tuple(tups) => {
470 builder.push_str("tuple[");
471 for (i, TupleStructField { from_py_with, ty }) in tups.iter().enumerate() {
472 if i > 0 {
473 builder.push_str(", ");
474 }
475 Self::write_field_input_type(from_py_with, ty, builder, ctx);
476 }
477 builder.push_str("]");
478 }
479 ContainerType::Struct(_) => {
480 builder.push_str("_typeshed.Incomplete")
482 }
483 }
484 }
485
486 #[cfg(feature = "experimental-inspect")]
487 fn write_field_input_type(
488 from_py_with: &Option<FromPyWithAttribute>,
489 ty: &syn::Type,
490 builder: &mut ConcatenationBuilder,
491 ctx: &Ctx,
492 ) {
493 if from_py_with.is_some() {
494 builder.push_str("_typeshed.Incomplete")
496 } else {
497 let mut ty = ty.clone();
498 elide_lifetimes(&mut ty);
499 let pyo3_crate_path = &ctx.pyo3_path;
500 builder.push_tokens(
501 quote! { <#ty as #pyo3_crate_path::FromPyObject<'_, '_>>::INPUT_TYPE.as_bytes() },
502 )
503 }
504 }
505}
506
507fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> {
508 let mut lifetimes = generics.lifetimes();
509 let lifetime = lifetimes.next();
510 ensure_spanned!(
511 lifetimes.next().is_none(),
512 generics.span() => "FromPyObject can be derived with at most one lifetime parameter"
513 );
514 Ok(lifetime)
515}
516
517pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
526 let options = ContainerAttributes::from_attrs(&tokens.attrs)?;
527 let ctx = &Ctx::new(&options.krate, None);
528 let Ctx { pyo3_path, .. } = &ctx;
529
530 let (_, ty_generics, _) = tokens.generics.split_for_impl();
531 let mut trait_generics = tokens.generics.clone();
532 let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? {
533 lt.clone()
534 } else {
535 trait_generics.params.push(parse_quote!('py));
536 parse_quote!('py)
537 };
538 let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
539
540 let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
541 for param in trait_generics.type_params() {
542 let gen_ident = ¶m.ident;
543 where_clause
544 .predicates
545 .push(parse_quote!(#gen_ident: #pyo3_path::conversion::FromPyObjectOwned<#lt_param>))
546 }
547
548 let derives = match &tokens.data {
549 syn::Data::Enum(en) => {
550 if options.transparent.is_some() || options.annotation.is_some() {
551 bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
552 at top level for enums");
553 }
554 let en = Enum::new(en, &tokens.ident, options.clone())?;
555 en.build(ctx)
556 }
557 syn::Data::Struct(st) => {
558 if let Some(lit_str) = &options.annotation {
559 bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
560 }
561 let ident = &tokens.ident;
562 let st = Container::new(&st.fields, parse_quote!(#ident), options.clone())?;
563 st.build(ctx)
564 }
565 syn::Data::Union(_) => bail_spanned!(
566 tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
567 ),
568 };
569
570 #[cfg(feature = "experimental-inspect")]
571 let input_type = {
572 let mut builder = ConcatenationBuilder::default();
573 if tokens
574 .generics
575 .params
576 .iter()
577 .all(|p| matches!(p, syn::GenericParam::Lifetime(_)))
578 {
579 match &tokens.data {
580 syn::Data::Enum(en) => {
581 Enum::new(en, &tokens.ident, options)?.write_input_type(&mut builder, ctx)
582 }
583 syn::Data::Struct(st) => {
584 let ident = &tokens.ident;
585 Container::new(&st.fields, parse_quote!(#ident), options.clone())?
586 .write_input_type(&mut builder, ctx)
587 }
588 syn::Data::Union(_) => {
589 builder.push_str("_typeshed.Incomplete")
591 }
592 }
593 } else {
594 builder.push_str("_typeshed.Incomplete")
597 };
598 let input_type = builder.into_token_stream(&ctx.pyo3_path);
599 quote! { const INPUT_TYPE: &'static str = unsafe { ::std::str::from_utf8_unchecked(#input_type) }; }
600 };
601 #[cfg(not(feature = "experimental-inspect"))]
602 let input_type = quote! {};
603
604 let ident = &tokens.ident;
605 Ok(quote!(
606 #[automatically_derived]
607 impl #impl_generics #pyo3_path::FromPyObject<'_, #lt_param> for #ident #ty_generics #where_clause {
608 type Error = #pyo3_path::PyErr;
609 fn extract(obj: #pyo3_path::Borrowed<'_, #lt_param, #pyo3_path::PyAny>) -> ::std::result::Result<Self, Self::Error> {
610 let obj: &#pyo3_path::Bound<'_, _> = &*obj;
611 #derives
612 }
613 #input_type
614 }
615 ))
616}