pyo3_macros_backend/
frompyobject.rs

1use crate::attributes::{DefaultAttribute, FromPyWithAttribute, RenamingRule};
2use crate::derive_attributes::{ContainerAttributes, FieldAttributes, FieldGetter};
3#[cfg(feature = "experimental-inspect")]
4use crate::introspection::{elide_lifetimes, ConcatenationBuilder};
5use crate::utils::{self, Ctx};
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use syn::{
9    ext::IdentExt, parse_quote, punctuated::Punctuated, spanned::Spanned, DataEnum, DeriveInput,
10    Fields, Ident, Result, Token,
11};
12
13/// Describes derivation input of an enum.
14struct Enum<'a> {
15    enum_ident: &'a Ident,
16    variants: Vec<Container<'a>>,
17}
18
19impl<'a> Enum<'a> {
20    /// Construct a new enum representation.
21    ///
22    /// `data_enum` is the `syn` representation of the input enum, `ident` is the
23    /// `Identifier` of the enum.
24    fn new(
25        data_enum: &'a DataEnum,
26        ident: &'a Ident,
27        options: ContainerAttributes,
28    ) -> Result<Self> {
29        ensure_spanned!(
30            !data_enum.variants.is_empty(),
31            ident.span() => "cannot derive FromPyObject for empty enum"
32        );
33        let variants = data_enum
34            .variants
35            .iter()
36            .map(|variant| {
37                let mut variant_options = ContainerAttributes::from_attrs(&variant.attrs)?;
38                if let Some(rename_all) = &options.rename_all {
39                    ensure_spanned!(
40                        variant_options.rename_all.is_none(),
41                        variant_options.rename_all.span() => "Useless variant `rename_all` - enum is already annotated with `rename_all"
42                    );
43                    variant_options.rename_all = Some(rename_all.clone());
44
45                }
46                let var_ident = &variant.ident;
47                Container::new(
48                    &variant.fields,
49                    parse_quote!(#ident::#var_ident),
50                    variant_options,
51                )
52            })
53            .collect::<Result<Vec<_>>>()?;
54
55        Ok(Enum {
56            enum_ident: ident,
57            variants,
58        })
59    }
60
61    /// Build derivation body for enums.
62    fn build(&self, ctx: &Ctx) -> TokenStream {
63        let Ctx { pyo3_path, .. } = ctx;
64        let mut var_extracts = Vec::new();
65        let mut variant_names = Vec::new();
66        let mut error_names = Vec::new();
67
68        for var in &self.variants {
69            let struct_derive = var.build(ctx);
70            let ext = quote!({
71                let maybe_ret = || -> #pyo3_path::PyResult<Self> {
72                    #struct_derive
73                }();
74
75                match maybe_ret {
76                    ok @ ::std::result::Result::Ok(_) => return ok,
77                    ::std::result::Result::Err(err) => err
78                }
79            });
80
81            var_extracts.push(ext);
82            variant_names.push(var.path.segments.last().unwrap().ident.to_string());
83            error_names.push(&var.err_name);
84        }
85        let ty_name = self.enum_ident.to_string();
86        quote!(
87            let errors = [
88                #(#var_extracts),*
89            ];
90            ::std::result::Result::Err(
91                #pyo3_path::impl_::frompyobject::failed_to_extract_enum(
92                    obj.py(),
93                    #ty_name,
94                    &[#(#variant_names),*],
95                    &[#(#error_names),*],
96                    &errors
97                )
98            )
99        )
100    }
101
102    #[cfg(feature = "experimental-inspect")]
103    fn write_input_type(&self, builder: &mut ConcatenationBuilder, ctx: &Ctx) {
104        for (i, var) in self.variants.iter().enumerate() {
105            if i > 0 {
106                builder.push_str(" | ");
107            }
108            var.write_input_type(builder, ctx);
109        }
110    }
111}
112
113struct NamedStructField<'a> {
114    ident: &'a syn::Ident,
115    getter: Option<FieldGetter>,
116    from_py_with: Option<FromPyWithAttribute>,
117    default: Option<DefaultAttribute>,
118    ty: &'a syn::Type,
119}
120
121struct TupleStructField {
122    from_py_with: Option<FromPyWithAttribute>,
123    ty: syn::Type,
124}
125
126/// Container Style
127///
128/// Covers Structs, Tuplestructs and corresponding Newtypes.
129enum ContainerType<'a> {
130    /// Struct Container, e.g. `struct Foo { a: String }`
131    ///
132    /// Variant contains the list of field identifiers and the corresponding extraction call.
133    Struct(Vec<NamedStructField<'a>>),
134    /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
135    ///
136    /// The field specified by the identifier is extracted directly from the object.
137    #[cfg_attr(not(feature = "experimental-inspect"), allow(unused))]
138    StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>, &'a syn::Type),
139    /// Tuple struct, e.g. `struct Foo(String)`.
140    ///
141    /// Variant contains a list of conversion methods for each of the fields that are directly
142    ///  extracted from the tuple.
143    Tuple(Vec<TupleStructField>),
144    /// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
145    ///
146    /// The wrapped field is directly extracted from the object.
147    #[cfg_attr(not(feature = "experimental-inspect"), allow(unused))]
148    TupleNewtype(Option<FromPyWithAttribute>, Box<syn::Type>),
149}
150
151/// Data container
152///
153/// Either describes a struct or an enum variant.
154struct Container<'a> {
155    path: syn::Path,
156    ty: ContainerType<'a>,
157    err_name: String,
158    rename_rule: Option<RenamingRule>,
159}
160
161impl<'a> Container<'a> {
162    /// Construct a container based on fields, identifier and attributes.
163    ///
164    /// Fails if the variant has no fields or incompatible attributes.
165    fn new(fields: &'a Fields, path: syn::Path, options: ContainerAttributes) -> Result<Self> {
166        let style = match fields {
167            Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
168                ensure_spanned!(
169                    options.rename_all.is_none(),
170                    options.rename_all.span() => "`rename_all` is useless on tuple structs and variants."
171                );
172                let mut tuple_fields = unnamed
173                    .unnamed
174                    .iter()
175                    .map(|field| {
176                        let attrs = FieldAttributes::from_attrs(&field.attrs)?;
177                        ensure_spanned!(
178                            attrs.getter.is_none(),
179                            field.span() => "`getter` is not permitted on tuple struct elements."
180                        );
181                        ensure_spanned!(
182                            attrs.default.is_none(),
183                            field.span() => "`default` is not permitted on tuple struct elements."
184                        );
185                        Ok(TupleStructField {
186                            from_py_with: attrs.from_py_with,
187                            ty: field.ty.clone(),
188                        })
189                    })
190                    .collect::<Result<Vec<_>>>()?;
191
192                if tuple_fields.len() == 1 {
193                    // Always treat a 1-length tuple struct as "transparent", even without the
194                    // explicit annotation.
195                    let field = tuple_fields.pop().unwrap();
196                    ContainerType::TupleNewtype(field.from_py_with, Box::new(field.ty))
197                } else if options.transparent.is_some() {
198                    bail_spanned!(
199                        fields.span() => "transparent structs and variants can only have 1 field"
200                    );
201                } else {
202                    ContainerType::Tuple(tuple_fields)
203                }
204            }
205            Fields::Named(named) if !named.named.is_empty() => {
206                let mut struct_fields = named
207                    .named
208                    .iter()
209                    .map(|field| {
210                        let ident = field
211                            .ident
212                            .as_ref()
213                            .expect("Named fields should have identifiers");
214                        let mut attrs = FieldAttributes::from_attrs(&field.attrs)?;
215
216                        if let Some(ref from_item_all) = options.from_item_all {
217                            if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(parse_quote!(item), None))
218                            {
219                                match replaced {
220                                    FieldGetter::GetItem(item, Some(item_name)) => {
221                                        attrs.getter = Some(FieldGetter::GetItem(item, Some(item_name)));
222                                    }
223                                    FieldGetter::GetItem(_, None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`"),
224                                    FieldGetter::GetAttr(_, _) => bail_spanned!(
225                                        from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed"
226                                    ),
227                                }
228                            }
229                        }
230
231                        Ok(NamedStructField {
232                            ident,
233                            getter: attrs.getter,
234                            from_py_with: attrs.from_py_with,
235                            default: attrs.default,
236                            ty: &field.ty,
237                        })
238                    })
239                    .collect::<Result<Vec<_>>>()?;
240                if struct_fields.iter().all(|field| field.default.is_some()) {
241                    bail_spanned!(
242                        fields.span() => "cannot derive FromPyObject for structs and variants with only default values"
243                    )
244                } else if options.transparent.is_some() {
245                    ensure_spanned!(
246                        struct_fields.len() == 1,
247                        fields.span() => "transparent structs and variants can only have 1 field"
248                    );
249                    ensure_spanned!(
250                        options.rename_all.is_none(),
251                        options.rename_all.span() => "`rename_all` is not permitted on `transparent` structs and variants"
252                    );
253                    let field = struct_fields.pop().unwrap();
254                    ensure_spanned!(
255                        field.getter.is_none(),
256                        field.ident.span() => "`transparent` structs may not have a `getter` for the inner field"
257                    );
258                    ContainerType::StructNewtype(field.ident, field.from_py_with, field.ty)
259                } else {
260                    ContainerType::Struct(struct_fields)
261                }
262            }
263            _ => bail_spanned!(
264                fields.span() => "cannot derive FromPyObject for empty structs and variants"
265            ),
266        };
267        let err_name = options.annotation.map_or_else(
268            || path.segments.last().unwrap().ident.to_string(),
269            |lit_str| lit_str.value(),
270        );
271
272        let v = Container {
273            path,
274            ty: style,
275            err_name,
276            rename_rule: options.rename_all.map(|v| v.value.rule),
277        };
278        Ok(v)
279    }
280
281    fn name(&self) -> String {
282        let mut value = String::new();
283        for segment in &self.path.segments {
284            if !value.is_empty() {
285                value.push_str("::");
286            }
287            value.push_str(&segment.ident.to_string());
288        }
289        value
290    }
291
292    /// Build derivation body for a struct.
293    fn build(&self, ctx: &Ctx) -> TokenStream {
294        match &self.ty {
295            ContainerType::StructNewtype(ident, from_py_with, _) => {
296                self.build_newtype_struct(Some(ident), from_py_with, ctx)
297            }
298            ContainerType::TupleNewtype(from_py_with, _) => {
299                self.build_newtype_struct(None, from_py_with, ctx)
300            }
301            ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx),
302            ContainerType::Struct(tups) => self.build_struct(tups, ctx),
303        }
304    }
305
306    fn build_newtype_struct(
307        &self,
308        field_ident: Option<&Ident>,
309        from_py_with: &Option<FromPyWithAttribute>,
310        ctx: &Ctx,
311    ) -> TokenStream {
312        let Ctx { pyo3_path, .. } = ctx;
313        let self_ty = &self.path;
314        let struct_name = self.name();
315        if let Some(ident) = field_ident {
316            let field_name = ident.to_string();
317            if let Some(FromPyWithAttribute {
318                kw,
319                value: expr_path,
320            }) = from_py_with
321            {
322                let extractor = quote_spanned! { kw.span =>
323                    { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
324                };
325                quote! {
326                    Ok(#self_ty {
327                        #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, obj, #struct_name, #field_name)?
328                    })
329                }
330            } else {
331                quote! {
332                    Ok(#self_ty {
333                        #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
334                    })
335                }
336            }
337        } else if let Some(FromPyWithAttribute {
338            kw,
339            value: expr_path,
340        }) = from_py_with
341        {
342            let extractor = quote_spanned! { kw.span =>
343                { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
344            };
345            quote! {
346                #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, obj, #struct_name, 0).map(#self_ty)
347            }
348        } else {
349            quote! {
350                #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
351            }
352        }
353    }
354
355    fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream {
356        let Ctx { pyo3_path, .. } = ctx;
357        let self_ty = &self.path;
358        let struct_name = &self.name();
359        let field_idents: Vec<_> = (0..struct_fields.len())
360            .map(|i| format_ident!("arg{}", i))
361            .collect();
362        let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
363            if let Some(FromPyWithAttribute {
364                kw,
365                value: expr_path, ..
366            }) = &field.from_py_with {
367                let extractor = quote_spanned! { kw.span =>
368                    { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
369                };
370               quote! {
371                    #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, &#ident, #struct_name, #index)?
372               }
373            } else {
374                quote!{
375                    #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
376            }}
377        });
378
379        quote!(
380            match #pyo3_path::types::PyAnyMethods::extract(obj) {
381                ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
382                ::std::result::Result::Err(err) => ::std::result::Result::Err(err),
383            }
384        )
385    }
386
387    fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream {
388        let Ctx { pyo3_path, .. } = ctx;
389        let self_ty = &self.path;
390        let struct_name = self.name();
391        let mut fields: Punctuated<TokenStream, Token![,]> = Punctuated::new();
392        for field in struct_fields {
393            let ident = field.ident;
394            let field_name = ident.unraw().to_string();
395            let getter = match field
396                .getter
397                .as_ref()
398                .unwrap_or(&FieldGetter::GetAttr(parse_quote!(attribute), None))
399            {
400                FieldGetter::GetAttr(_, Some(name)) => {
401                    quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
402                }
403                FieldGetter::GetAttr(_, None) => {
404                    let name = self
405                        .rename_rule
406                        .map(|rule| utils::apply_renaming_rule(rule, &field_name));
407                    let name = name.as_deref().unwrap_or(&field_name);
408                    quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
409                }
410                FieldGetter::GetItem(_, Some(syn::Lit::Str(key))) => {
411                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #key)))
412                }
413                FieldGetter::GetItem(_, Some(key)) => {
414                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #key))
415                }
416                FieldGetter::GetItem(_, None) => {
417                    let name = self
418                        .rename_rule
419                        .map(|rule| utils::apply_renaming_rule(rule, &field_name));
420                    let name = name.as_deref().unwrap_or(&field_name);
421                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #name)))
422                }
423            };
424            let extractor = if let Some(FromPyWithAttribute {
425                kw,
426                value: expr_path,
427            }) = &field.from_py_with
428            {
429                let extractor = quote_spanned! { kw.span =>
430                    { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
431                };
432                quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, &#getter?, #struct_name, #field_name)?)
433            } else {
434                quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
435            };
436            let extracted = if let Some(default) = &field.default {
437                let default_expr = if let Some(default_expr) = &default.value {
438                    default_expr.to_token_stream()
439                } else {
440                    quote!(::std::default::Default::default())
441                };
442                quote!(if let ::std::result::Result::Ok(value) = #getter {
443                    #extractor
444                } else {
445                    #default_expr
446                })
447            } else {
448                quote!({
449                    let value = #getter?;
450                    #extractor
451                })
452            };
453
454            fields.push(quote!(#ident: #extracted));
455        }
456
457        quote!(::std::result::Result::Ok(#self_ty{#fields}))
458    }
459
460    #[cfg(feature = "experimental-inspect")]
461    fn write_input_type(&self, builder: &mut ConcatenationBuilder, ctx: &Ctx) {
462        match &self.ty {
463            ContainerType::StructNewtype(_, from_py_with, ty) => {
464                Self::write_field_input_type(from_py_with, ty, builder, ctx);
465            }
466            ContainerType::TupleNewtype(from_py_with, ty) => {
467                Self::write_field_input_type(from_py_with, ty, builder, ctx);
468            }
469            ContainerType::Tuple(tups) => {
470                builder.push_str("tuple[");
471                for (i, TupleStructField { from_py_with, ty }) in tups.iter().enumerate() {
472                    if i > 0 {
473                        builder.push_str(", ");
474                    }
475                    Self::write_field_input_type(from_py_with, ty, builder, ctx);
476                }
477                builder.push_str("]");
478            }
479            ContainerType::Struct(_) => {
480                // TODO: implement using a Protocol?
481                builder.push_str("_typeshed.Incomplete")
482            }
483        }
484    }
485
486    #[cfg(feature = "experimental-inspect")]
487    fn write_field_input_type(
488        from_py_with: &Option<FromPyWithAttribute>,
489        ty: &syn::Type,
490        builder: &mut ConcatenationBuilder,
491        ctx: &Ctx,
492    ) {
493        if from_py_with.is_some() {
494            // We don't know what from_py_with is doing
495            builder.push_str("_typeshed.Incomplete")
496        } else {
497            let mut ty = ty.clone();
498            elide_lifetimes(&mut ty);
499            let pyo3_crate_path = &ctx.pyo3_path;
500            builder.push_tokens(
501                quote! { <#ty as #pyo3_crate_path::FromPyObject<'_, '_>>::INPUT_TYPE.as_bytes() },
502            )
503        }
504    }
505}
506
507fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> {
508    let mut lifetimes = generics.lifetimes();
509    let lifetime = lifetimes.next();
510    ensure_spanned!(
511        lifetimes.next().is_none(),
512        generics.span() => "FromPyObject can be derived with at most one lifetime parameter"
513    );
514    Ok(lifetime)
515}
516
517/// Derive FromPyObject for enums and structs.
518///
519///   * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
520///   * At least one field, in case of `#[transparent]`, exactly one field
521///   * At least one variant for enums.
522///   * Fields of input structs and enums must implement `FromPyObject` or be annotated with `from_py_with`
523///   * Derivation for structs with generic fields like `struct<T> Foo(T)`
524///     adds `T: FromPyObject` on the derived implementation.
525pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
526    let options = ContainerAttributes::from_attrs(&tokens.attrs)?;
527    let ctx = &Ctx::new(&options.krate, None);
528    let Ctx { pyo3_path, .. } = &ctx;
529
530    let (_, ty_generics, _) = tokens.generics.split_for_impl();
531    let mut trait_generics = tokens.generics.clone();
532    let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? {
533        lt.clone()
534    } else {
535        trait_generics.params.push(parse_quote!('py));
536        parse_quote!('py)
537    };
538    let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
539
540    let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
541    for param in trait_generics.type_params() {
542        let gen_ident = &param.ident;
543        where_clause
544            .predicates
545            .push(parse_quote!(#gen_ident: #pyo3_path::conversion::FromPyObjectOwned<#lt_param>))
546    }
547
548    let derives = match &tokens.data {
549        syn::Data::Enum(en) => {
550            if options.transparent.is_some() || options.annotation.is_some() {
551                bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
552                                                at top level for enums");
553            }
554            let en = Enum::new(en, &tokens.ident, options.clone())?;
555            en.build(ctx)
556        }
557        syn::Data::Struct(st) => {
558            if let Some(lit_str) = &options.annotation {
559                bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
560            }
561            let ident = &tokens.ident;
562            let st = Container::new(&st.fields, parse_quote!(#ident), options.clone())?;
563            st.build(ctx)
564        }
565        syn::Data::Union(_) => bail_spanned!(
566            tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
567        ),
568    };
569
570    #[cfg(feature = "experimental-inspect")]
571    let input_type = {
572        let mut builder = ConcatenationBuilder::default();
573        if tokens
574            .generics
575            .params
576            .iter()
577            .all(|p| matches!(p, syn::GenericParam::Lifetime(_)))
578        {
579            match &tokens.data {
580                syn::Data::Enum(en) => {
581                    Enum::new(en, &tokens.ident, options)?.write_input_type(&mut builder, ctx)
582                }
583                syn::Data::Struct(st) => {
584                    let ident = &tokens.ident;
585                    Container::new(&st.fields, parse_quote!(#ident), options.clone())?
586                        .write_input_type(&mut builder, ctx)
587                }
588                syn::Data::Union(_) => {
589                    // Not supported at this point
590                    builder.push_str("_typeshed.Incomplete")
591                }
592            }
593        } else {
594            // We don't know how to deal with generic parameters
595            // Blocked by https://github.com/rust-lang/rust/issues/76560
596            builder.push_str("_typeshed.Incomplete")
597        };
598        let input_type = builder.into_token_stream(&ctx.pyo3_path);
599        quote! { const INPUT_TYPE: &'static str = unsafe { ::std::str::from_utf8_unchecked(#input_type) }; }
600    };
601    #[cfg(not(feature = "experimental-inspect"))]
602    let input_type = quote! {};
603
604    let ident = &tokens.ident;
605    Ok(quote!(
606        #[automatically_derived]
607        impl #impl_generics #pyo3_path::FromPyObject<'_, #lt_param> for #ident #ty_generics #where_clause {
608            type Error = #pyo3_path::PyErr;
609            fn extract(obj: #pyo3_path::Borrowed<'_, #lt_param, #pyo3_path::PyAny>) -> ::std::result::Result<Self, Self::Error> {
610                let obj: &#pyo3_path::Bound<'_, _> = &*obj;
611                #derives
612            }
613            #input_type
614        }
615    ))
616}