pyo3_macros_backend/
frompyobject.rs

1use crate::attributes::{DefaultAttribute, FromPyWithAttribute, RenamingRule};
2use crate::derive_attributes::{ContainerAttributes, FieldAttributes, FieldGetter};
3use crate::utils::{self, Ctx};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote, quote_spanned, ToTokens};
6use syn::{
7    ext::IdentExt, parse_quote, punctuated::Punctuated, spanned::Spanned, DataEnum, DeriveInput,
8    Fields, Ident, Result, Token,
9};
10
11/// Describes derivation input of an enum.
12struct Enum<'a> {
13    enum_ident: &'a Ident,
14    variants: Vec<Container<'a>>,
15}
16
17impl<'a> Enum<'a> {
18    /// Construct a new enum representation.
19    ///
20    /// `data_enum` is the `syn` representation of the input enum, `ident` is the
21    /// `Identifier` of the enum.
22    fn new(
23        data_enum: &'a DataEnum,
24        ident: &'a Ident,
25        options: ContainerAttributes,
26    ) -> Result<Self> {
27        ensure_spanned!(
28            !data_enum.variants.is_empty(),
29            ident.span() => "cannot derive FromPyObject for empty enum"
30        );
31        let variants = data_enum
32            .variants
33            .iter()
34            .map(|variant| {
35                let mut variant_options = ContainerAttributes::from_attrs(&variant.attrs)?;
36                if let Some(rename_all) = &options.rename_all {
37                    ensure_spanned!(
38                        variant_options.rename_all.is_none(),
39                        variant_options.rename_all.span() => "Useless variant `rename_all` - enum is already annotated with `rename_all"
40                    );
41                    variant_options.rename_all = Some(rename_all.clone());
42
43                }
44                let var_ident = &variant.ident;
45                Container::new(
46                    &variant.fields,
47                    parse_quote!(#ident::#var_ident),
48                    variant_options,
49                )
50            })
51            .collect::<Result<Vec<_>>>()?;
52
53        Ok(Enum {
54            enum_ident: ident,
55            variants,
56        })
57    }
58
59    /// Build derivation body for enums.
60    fn build(&self, ctx: &Ctx) -> TokenStream {
61        let Ctx { pyo3_path, .. } = ctx;
62        let mut var_extracts = Vec::new();
63        let mut variant_names = Vec::new();
64        let mut error_names = Vec::new();
65
66        for var in &self.variants {
67            let struct_derive = var.build(ctx);
68            let ext = quote!({
69                let maybe_ret = || -> #pyo3_path::PyResult<Self> {
70                    #struct_derive
71                }();
72
73                match maybe_ret {
74                    ok @ ::std::result::Result::Ok(_) => return ok,
75                    ::std::result::Result::Err(err) => err
76                }
77            });
78
79            var_extracts.push(ext);
80            variant_names.push(var.path.segments.last().unwrap().ident.to_string());
81            error_names.push(&var.err_name);
82        }
83        let ty_name = self.enum_ident.to_string();
84        quote!(
85            let errors = [
86                #(#var_extracts),*
87            ];
88            ::std::result::Result::Err(
89                #pyo3_path::impl_::frompyobject::failed_to_extract_enum(
90                    obj.py(),
91                    #ty_name,
92                    &[#(#variant_names),*],
93                    &[#(#error_names),*],
94                    &errors
95                )
96            )
97        )
98    }
99}
100
101struct NamedStructField<'a> {
102    ident: &'a syn::Ident,
103    getter: Option<FieldGetter>,
104    from_py_with: Option<FromPyWithAttribute>,
105    default: Option<DefaultAttribute>,
106}
107
108struct TupleStructField {
109    from_py_with: Option<FromPyWithAttribute>,
110}
111
112/// Container Style
113///
114/// Covers Structs, Tuplestructs and corresponding Newtypes.
115enum ContainerType<'a> {
116    /// Struct Container, e.g. `struct Foo { a: String }`
117    ///
118    /// Variant contains the list of field identifiers and the corresponding extraction call.
119    Struct(Vec<NamedStructField<'a>>),
120    /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
121    ///
122    /// The field specified by the identifier is extracted directly from the object.
123    StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>),
124    /// Tuple struct, e.g. `struct Foo(String)`.
125    ///
126    /// Variant contains a list of conversion methods for each of the fields that are directly
127    ///  extracted from the tuple.
128    Tuple(Vec<TupleStructField>),
129    /// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
130    ///
131    /// The wrapped field is directly extracted from the object.
132    TupleNewtype(Option<FromPyWithAttribute>),
133}
134
135/// Data container
136///
137/// Either describes a struct or an enum variant.
138struct Container<'a> {
139    path: syn::Path,
140    ty: ContainerType<'a>,
141    err_name: String,
142    rename_rule: Option<RenamingRule>,
143}
144
145impl<'a> Container<'a> {
146    /// Construct a container based on fields, identifier and attributes.
147    ///
148    /// Fails if the variant has no fields or incompatible attributes.
149    fn new(fields: &'a Fields, path: syn::Path, options: ContainerAttributes) -> Result<Self> {
150        let style = match fields {
151            Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
152                ensure_spanned!(
153                    options.rename_all.is_none(),
154                    options.rename_all.span() => "`rename_all` is useless on tuple structs and variants."
155                );
156                let mut tuple_fields = unnamed
157                    .unnamed
158                    .iter()
159                    .map(|field| {
160                        let attrs = FieldAttributes::from_attrs(&field.attrs)?;
161                        ensure_spanned!(
162                            attrs.getter.is_none(),
163                            field.span() => "`getter` is not permitted on tuple struct elements."
164                        );
165                        ensure_spanned!(
166                            attrs.default.is_none(),
167                            field.span() => "`default` is not permitted on tuple struct elements."
168                        );
169                        Ok(TupleStructField {
170                            from_py_with: attrs.from_py_with,
171                        })
172                    })
173                    .collect::<Result<Vec<_>>>()?;
174
175                if tuple_fields.len() == 1 {
176                    // Always treat a 1-length tuple struct as "transparent", even without the
177                    // explicit annotation.
178                    let field = tuple_fields.pop().unwrap();
179                    ContainerType::TupleNewtype(field.from_py_with)
180                } else if options.transparent.is_some() {
181                    bail_spanned!(
182                        fields.span() => "transparent structs and variants can only have 1 field"
183                    );
184                } else {
185                    ContainerType::Tuple(tuple_fields)
186                }
187            }
188            Fields::Named(named) if !named.named.is_empty() => {
189                let mut struct_fields = named
190                    .named
191                    .iter()
192                    .map(|field| {
193                        let ident = field
194                            .ident
195                            .as_ref()
196                            .expect("Named fields should have identifiers");
197                        let mut attrs = FieldAttributes::from_attrs(&field.attrs)?;
198
199                        if let Some(ref from_item_all) = options.from_item_all {
200                            if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(parse_quote!(item), None))
201                            {
202                                match replaced {
203                                    FieldGetter::GetItem(item, Some(item_name)) => {
204                                        attrs.getter = Some(FieldGetter::GetItem(item, Some(item_name)));
205                                    }
206                                    FieldGetter::GetItem(_, None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`"),
207                                    FieldGetter::GetAttr(_, _) => bail_spanned!(
208                                        from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed"
209                                    ),
210                                }
211                            }
212                        }
213
214                        Ok(NamedStructField {
215                            ident,
216                            getter: attrs.getter,
217                            from_py_with: attrs.from_py_with,
218                            default: attrs.default,
219                        })
220                    })
221                    .collect::<Result<Vec<_>>>()?;
222                if struct_fields.iter().all(|field| field.default.is_some()) {
223                    bail_spanned!(
224                        fields.span() => "cannot derive FromPyObject for structs and variants with only default values"
225                    )
226                } else if options.transparent.is_some() {
227                    ensure_spanned!(
228                        struct_fields.len() == 1,
229                        fields.span() => "transparent structs and variants can only have 1 field"
230                    );
231                    ensure_spanned!(
232                        options.rename_all.is_none(),
233                        options.rename_all.span() => "`rename_all` is not permitted on `transparent` structs and variants"
234                    );
235                    let field = struct_fields.pop().unwrap();
236                    ensure_spanned!(
237                        field.getter.is_none(),
238                        field.ident.span() => "`transparent` structs may not have a `getter` for the inner field"
239                    );
240                    ContainerType::StructNewtype(field.ident, field.from_py_with)
241                } else {
242                    ContainerType::Struct(struct_fields)
243                }
244            }
245            _ => bail_spanned!(
246                fields.span() => "cannot derive FromPyObject for empty structs and variants"
247            ),
248        };
249        let err_name = options.annotation.map_or_else(
250            || path.segments.last().unwrap().ident.to_string(),
251            |lit_str| lit_str.value(),
252        );
253
254        let v = Container {
255            path,
256            ty: style,
257            err_name,
258            rename_rule: options.rename_all.map(|v| v.value.rule),
259        };
260        Ok(v)
261    }
262
263    fn name(&self) -> String {
264        let mut value = String::new();
265        for segment in &self.path.segments {
266            if !value.is_empty() {
267                value.push_str("::");
268            }
269            value.push_str(&segment.ident.to_string());
270        }
271        value
272    }
273
274    /// Build derivation body for a struct.
275    fn build(&self, ctx: &Ctx) -> TokenStream {
276        match &self.ty {
277            ContainerType::StructNewtype(ident, from_py_with) => {
278                self.build_newtype_struct(Some(ident), from_py_with, ctx)
279            }
280            ContainerType::TupleNewtype(from_py_with) => {
281                self.build_newtype_struct(None, from_py_with, ctx)
282            }
283            ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx),
284            ContainerType::Struct(tups) => self.build_struct(tups, ctx),
285        }
286    }
287
288    fn build_newtype_struct(
289        &self,
290        field_ident: Option<&Ident>,
291        from_py_with: &Option<FromPyWithAttribute>,
292        ctx: &Ctx,
293    ) -> TokenStream {
294        let Ctx { pyo3_path, .. } = ctx;
295        let self_ty = &self.path;
296        let struct_name = self.name();
297        if let Some(ident) = field_ident {
298            let field_name = ident.to_string();
299            if let Some(FromPyWithAttribute {
300                kw,
301                value: expr_path,
302            }) = from_py_with
303            {
304                let extractor = quote_spanned! { kw.span =>
305                    { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
306                };
307                quote! {
308                    Ok(#self_ty {
309                        #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, obj, #struct_name, #field_name)?
310                    })
311                }
312            } else {
313                quote! {
314                    Ok(#self_ty {
315                        #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
316                    })
317                }
318            }
319        } else if let Some(FromPyWithAttribute {
320            kw,
321            value: expr_path,
322        }) = from_py_with
323        {
324            let extractor = quote_spanned! { kw.span =>
325                { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
326            };
327            quote! {
328                #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, obj, #struct_name, 0).map(#self_ty)
329            }
330        } else {
331            quote! {
332                #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
333            }
334        }
335    }
336
337    fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream {
338        let Ctx { pyo3_path, .. } = ctx;
339        let self_ty = &self.path;
340        let struct_name = &self.name();
341        let field_idents: Vec<_> = (0..struct_fields.len())
342            .map(|i| format_ident!("arg{}", i))
343            .collect();
344        let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
345            if let Some(FromPyWithAttribute {
346                kw,
347                value: expr_path, ..
348            }) = &field.from_py_with {
349                let extractor = quote_spanned! { kw.span =>
350                    { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
351                };
352               quote! {
353                    #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, &#ident, #struct_name, #index)?
354               }
355            } else {
356                quote!{
357                    #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
358            }}
359        });
360
361        quote!(
362            match #pyo3_path::types::PyAnyMethods::extract(obj) {
363                ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
364                ::std::result::Result::Err(err) => ::std::result::Result::Err(err),
365            }
366        )
367    }
368
369    fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream {
370        let Ctx { pyo3_path, .. } = ctx;
371        let self_ty = &self.path;
372        let struct_name = self.name();
373        let mut fields: Punctuated<TokenStream, Token![,]> = Punctuated::new();
374        for field in struct_fields {
375            let ident = field.ident;
376            let field_name = ident.unraw().to_string();
377            let getter = match field
378                .getter
379                .as_ref()
380                .unwrap_or(&FieldGetter::GetAttr(parse_quote!(attribute), None))
381            {
382                FieldGetter::GetAttr(_, Some(name)) => {
383                    quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
384                }
385                FieldGetter::GetAttr(_, None) => {
386                    let name = self
387                        .rename_rule
388                        .map(|rule| utils::apply_renaming_rule(rule, &field_name));
389                    let name = name.as_deref().unwrap_or(&field_name);
390                    quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
391                }
392                FieldGetter::GetItem(_, Some(syn::Lit::Str(key))) => {
393                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #key)))
394                }
395                FieldGetter::GetItem(_, Some(key)) => {
396                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #key))
397                }
398                FieldGetter::GetItem(_, None) => {
399                    let name = self
400                        .rename_rule
401                        .map(|rule| utils::apply_renaming_rule(rule, &field_name));
402                    let name = name.as_deref().unwrap_or(&field_name);
403                    quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #name)))
404                }
405            };
406            let extractor = if let Some(FromPyWithAttribute {
407                kw,
408                value: expr_path,
409            }) = &field.from_py_with
410            {
411                let extractor = quote_spanned! { kw.span =>
412                    { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
413                };
414                quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, &#getter?, #struct_name, #field_name)?)
415            } else {
416                quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
417            };
418            let extracted = if let Some(default) = &field.default {
419                let default_expr = if let Some(default_expr) = &default.value {
420                    default_expr.to_token_stream()
421                } else {
422                    quote!(::std::default::Default::default())
423                };
424                quote!(if let ::std::result::Result::Ok(value) = #getter {
425                    #extractor
426                } else {
427                    #default_expr
428                })
429            } else {
430                quote!({
431                    let value = #getter?;
432                    #extractor
433                })
434            };
435
436            fields.push(quote!(#ident: #extracted));
437        }
438
439        quote!(::std::result::Result::Ok(#self_ty{#fields}))
440    }
441}
442
443fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> {
444    let mut lifetimes = generics.lifetimes();
445    let lifetime = lifetimes.next();
446    ensure_spanned!(
447        lifetimes.next().is_none(),
448        generics.span() => "FromPyObject can be derived with at most one lifetime parameter"
449    );
450    Ok(lifetime)
451}
452
453/// Derive FromPyObject for enums and structs.
454///
455///   * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
456///   * At least one field, in case of `#[transparent]`, exactly one field
457///   * At least one variant for enums.
458///   * Fields of input structs and enums must implement `FromPyObject` or be annotated with `from_py_with`
459///   * Derivation for structs with generic fields like `struct<T> Foo(T)`
460///     adds `T: FromPyObject` on the derived implementation.
461pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
462    let options = ContainerAttributes::from_attrs(&tokens.attrs)?;
463    let ctx = &Ctx::new(&options.krate, None);
464    let Ctx { pyo3_path, .. } = &ctx;
465
466    let (_, ty_generics, _) = tokens.generics.split_for_impl();
467    let mut trait_generics = tokens.generics.clone();
468    let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? {
469        lt.clone()
470    } else {
471        trait_generics.params.push(parse_quote!('py));
472        parse_quote!('py)
473    };
474    let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
475
476    let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
477    for param in trait_generics.type_params() {
478        let gen_ident = &param.ident;
479        where_clause
480            .predicates
481            .push(parse_quote!(#gen_ident: #pyo3_path::FromPyObject<'py>))
482    }
483
484    let derives = match &tokens.data {
485        syn::Data::Enum(en) => {
486            if options.transparent.is_some() || options.annotation.is_some() {
487                bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
488                                                at top level for enums");
489            }
490            let en = Enum::new(en, &tokens.ident, options)?;
491            en.build(ctx)
492        }
493        syn::Data::Struct(st) => {
494            if let Some(lit_str) = &options.annotation {
495                bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
496            }
497            let ident = &tokens.ident;
498            let st = Container::new(&st.fields, parse_quote!(#ident), options)?;
499            st.build(ctx)
500        }
501        syn::Data::Union(_) => bail_spanned!(
502            tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
503        ),
504    };
505
506    let ident = &tokens.ident;
507    Ok(quote!(
508        #[automatically_derived]
509        impl #impl_generics #pyo3_path::FromPyObject<#lt_param> for #ident #ty_generics #where_clause {
510            fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult<Self>  {
511                #derives
512            }
513        }
514    ))
515}