pyo3_derive_backend/
from_pyobject.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{quote, ToTokens};
3use syn::punctuated::Punctuated;
4use syn::spanned::Spanned;
5use syn::{parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Meta, MetaList, Result};
6
7/// Describes derivation input of an enum.
8#[derive(Debug)]
9struct Enum<'a> {
10    enum_ident: &'a Ident,
11    variants: Vec<Container<'a>>,
12}
13
14impl<'a> Enum<'a> {
15    /// Construct a new enum representation.
16    ///
17    /// `data_enum` is the `syn` representation of the input enum, `ident` is the
18    /// `Identifier` of the enum.
19    fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
20        if data_enum.variants.is_empty() {
21            return Err(spanned_err(
22                &ident,
23                "Cannot derive FromPyObject for empty enum.",
24            ));
25        }
26        let vars = data_enum
27            .variants
28            .iter()
29            .map(|variant| {
30                let attrs = ContainerAttribute::parse_attrs(&variant.attrs)?;
31                let var_ident = &variant.ident;
32                Container::new(
33                    &variant.fields,
34                    parse_quote!(#ident::#var_ident),
35                    attrs,
36                    true,
37                )
38            })
39            .collect::<Result<Vec<_>>>()?;
40
41        Ok(Enum {
42            enum_ident: ident,
43            variants: vars,
44        })
45    }
46
47    /// Build derivation body for enums.
48    fn build(&self) -> TokenStream {
49        let mut var_extracts = Vec::new();
50        let mut error_names = String::new();
51        for (i, var) in self.variants.iter().enumerate() {
52            let struct_derive = var.build();
53            let ext = quote!(
54                let maybe_ret = || -> ::pyo3::PyResult<Self> {
55                    #struct_derive
56                }();
57                if maybe_ret.is_ok() {
58                    return maybe_ret
59                }
60            );
61
62            var_extracts.push(ext);
63            error_names.push_str(&var.err_name);
64            if i < self.variants.len() - 1 {
65                error_names.push_str(", ");
66            }
67        }
68        let error_names = if self.variants.len() > 1 {
69            format!("Union[{}]", error_names)
70        } else {
71            error_names
72        };
73        quote!(
74            #(#var_extracts)*
75            let type_name = obj.get_type().name();
76            let from = obj
77                .repr()
78                .map(|s| format!("{} ({})", s.to_string_lossy(), type_name))
79                .unwrap_or_else(|_| type_name.to_string());
80            let err_msg = format!("Can't convert {} to {}", from, #error_names);
81            Err(::pyo3::exceptions::PyTypeError::new_err(err_msg))
82        )
83    }
84}
85
86/// Container Style
87///
88/// Covers Structs, Tuplestructs and corresponding Newtypes.
89#[derive(Debug)]
90enum ContainerType<'a> {
91    /// Struct Container, e.g. `struct Foo { a: String }`
92    ///
93    /// Variant contains the list of field identifiers and the corresponding extraction call.
94    Struct(Vec<(&'a Ident, FieldAttribute)>),
95    /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
96    ///
97    /// The field specified by the identifier is extracted directly from the object.
98    StructNewtype(&'a Ident),
99    /// Tuple struct, e.g. `struct Foo(String)`.
100    ///
101    /// Fields are extracted from a tuple.
102    Tuple(usize),
103    /// Tuple newtype, e.g. `#[transparent] struct Foo(String)`
104    ///
105    /// The wrapped field is directly extracted from the object.
106    TupleNewtype,
107}
108
109/// Data container
110///
111/// Either describes a struct or an enum variant.
112#[derive(Debug)]
113struct Container<'a> {
114    path: syn::Path,
115    ty: ContainerType<'a>,
116    err_name: String,
117    is_enum_variant: bool,
118}
119
120impl<'a> Container<'a> {
121    /// Construct a container based on fields, identifier and attributes.
122    ///
123    /// Fails if the variant has no fields or incompatible attributes.
124    fn new(
125        fields: &'a Fields,
126        path: syn::Path,
127        attrs: Vec<ContainerAttribute>,
128        is_enum_variant: bool,
129    ) -> Result<Self> {
130        if fields.is_empty() {
131            return Err(spanned_err(
132                fields,
133                "Cannot derive FromPyObject for empty structs and variants.",
134            ));
135        }
136        let transparent = attrs
137            .iter()
138            .any(|attr| *attr == ContainerAttribute::Transparent);
139        if transparent {
140            Self::check_transparent_len(fields)?;
141        }
142        let style = match (fields, transparent) {
143            (Fields::Unnamed(_), true) => ContainerType::TupleNewtype,
144            (Fields::Unnamed(unnamed), false) => {
145                if unnamed.unnamed.len() == 1 {
146                    ContainerType::TupleNewtype
147                } else {
148                    ContainerType::Tuple(unnamed.unnamed.len())
149                }
150            }
151            (Fields::Named(named), true) => {
152                let field = named
153                    .named
154                    .iter()
155                    .next()
156                    .expect("Check for len 1 is done above");
157                let ident = field
158                    .ident
159                    .as_ref()
160                    .expect("Named fields should have identifiers");
161                ContainerType::StructNewtype(ident)
162            }
163            (Fields::Named(named), false) => {
164                let mut fields = Vec::new();
165                for field in named.named.iter() {
166                    let ident = field
167                        .ident
168                        .as_ref()
169                        .expect("Named fields should have identifiers");
170                    let attr = FieldAttribute::parse_attrs(&field.attrs)?
171                        .unwrap_or_else(|| FieldAttribute::GetAttr(None));
172                    fields.push((ident, attr))
173                }
174                ContainerType::Struct(fields)
175            }
176            (Fields::Unit, _) => {
177                // covered by length check above
178                return Err(spanned_err(
179                    &fields,
180                    "Cannot derive FromPyObject for Unit structs and variants",
181                ));
182            }
183        };
184        let err_name = attrs
185            .iter()
186            .find_map(|a| a.annotation())
187            .unwrap_or_else(|| path.segments.last().unwrap().ident.to_string());
188
189        let v = Container {
190            path,
191            ty: style,
192            err_name,
193            is_enum_variant,
194        };
195        Ok(v)
196    }
197
198    fn verify_struct_container_attrs(
199        attrs: &'a [ContainerAttribute],
200        original: &[Attribute],
201    ) -> Result<()> {
202        for attr in attrs {
203            match attr {
204                ContainerAttribute::Transparent => continue,
205                ContainerAttribute::ErrorAnnotation(_) => {
206                    let span = original
207                        .iter()
208                        .map(|a| a.span())
209                        .fold(None, |mut acc: Option<Span>, span| {
210                            if let Some(all) = acc.as_mut() {
211                                all.join(span)
212                            } else {
213                                Some(span)
214                            }
215                        })
216                        .unwrap_or_else(Span::call_site);
217                    return Err(syn::Error::new(
218                        span,
219                        "Annotating error messages for structs is \
220                                               not supported. Remove the annotation attribute.",
221                    ));
222                }
223            }
224        }
225        Ok(())
226    }
227
228    /// Build derivation body for a struct.
229    fn build(&self) -> TokenStream {
230        match &self.ty {
231            ContainerType::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)),
232            ContainerType::TupleNewtype => self.build_newtype_struct(None),
233            ContainerType::Tuple(len) => self.build_tuple_struct(*len),
234            ContainerType::Struct(tups) => self.build_struct(tups),
235        }
236    }
237
238    fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream {
239        let self_ty = &self.path;
240        if let Some(ident) = field_ident {
241            quote!(
242                Ok(#self_ty{#ident: obj.extract()?})
243            )
244        } else {
245            quote!(Ok(#self_ty(obj.extract()?)))
246        }
247    }
248
249    fn build_tuple_struct(&self, len: usize) -> TokenStream {
250        let self_ty = &self.path;
251        let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
252        for i in 0..len {
253            fields.push(quote!(slice[#i].extract()?));
254        }
255        let msg = if self.is_enum_variant {
256            quote!(format!(
257                "Expected tuple of length {}, but got length {}.",
258                #len,
259                s.len()
260            ))
261        } else {
262            quote!("")
263        };
264        quote!(
265            let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?;
266            if s.len() != #len {
267                return Err(::pyo3::exceptions::PyValueError::new_err(#msg))
268            }
269            let slice = s.as_slice();
270            Ok(#self_ty(#fields))
271        )
272    }
273
274    fn build_struct(&self, tups: &[(&Ident, FieldAttribute)]) -> TokenStream {
275        let self_ty = &self.path;
276        let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
277        for (ident, attr) in tups {
278            let ext_fn = match attr {
279                FieldAttribute::GetAttr(Some(name)) => quote!(getattr(#name)),
280                FieldAttribute::GetAttr(None) => quote!(getattr(stringify!(#ident))),
281                FieldAttribute::GetItem(Some(key)) => quote!(get_item(#key)),
282                FieldAttribute::GetItem(None) => quote!(get_item(stringify!(#ident))),
283            };
284            fields.push(quote!(#ident: obj.#ext_fn?.extract()?));
285        }
286        quote!(Ok(#self_ty{#fields}))
287    }
288
289    fn check_transparent_len(fields: &Fields) -> Result<()> {
290        if fields.len() != 1 {
291            return Err(spanned_err(
292                fields,
293                "Transparent structs and variants can only have 1 field",
294            ));
295        }
296        Ok(())
297    }
298}
299
300/// Attributes for deriving FromPyObject scoped on containers.
301#[derive(Clone, Debug, PartialEq)]
302enum ContainerAttribute {
303    /// Treat the Container as a Wrapper, directly extract its fields from the input object.
304    Transparent,
305    /// Change the name of an enum variant in the generated error message.
306    ErrorAnnotation(String),
307}
308
309impl ContainerAttribute {
310    /// Convenience method to access `ErrorAnnotation`.
311    fn annotation(&self) -> Option<String> {
312        match self {
313            ContainerAttribute::ErrorAnnotation(s) => Some(s.to_string()),
314            _ => None,
315        }
316    }
317
318    /// Parse valid container arguments
319    ///
320    /// Fails if any are invalid.
321    fn parse_attrs(value: &[Attribute]) -> Result<Vec<Self>> {
322        let mut attrs = Vec::new();
323        let list = get_pyo3_meta_list(value)?;
324        for meta in list.nested {
325            if let syn::NestedMeta::Meta(metaitem) = &meta {
326                match metaitem {
327                    Meta::Path(p) if p.is_ident("transparent") => {
328                        attrs.push(ContainerAttribute::Transparent);
329                        continue;
330                    }
331                    Meta::NameValue(nv) if nv.path.is_ident("annotation") => {
332                        if let syn::Lit::Str(s) = &nv.lit {
333                            attrs.push(ContainerAttribute::ErrorAnnotation(s.value()))
334                        } else {
335                            return Err(spanned_err(&nv.lit, "Expected string literal."));
336                        }
337                        continue;
338                    }
339                    _ => {} // return Err below
340                }
341            }
342
343            return Err(spanned_err(meta, "Unrecognized `pyo3` container attribute"));
344        }
345        Ok(attrs)
346    }
347}
348
349/// Attributes for deriving FromPyObject scoped on fields.
350#[derive(Clone, Debug)]
351enum FieldAttribute {
352    GetItem(Option<syn::Lit>),
353    GetAttr(Option<syn::LitStr>),
354}
355
356impl FieldAttribute {
357    /// Extract the field attribute.
358    ///
359    /// Currently fails if more than 1 attribute is passed in `pyo3`
360    fn parse_attrs(attrs: &[Attribute]) -> Result<Option<Self>> {
361        let list = get_pyo3_meta_list(attrs)?;
362        let metaitem = match list.nested.len() {
363            0 => return Ok(None),
364            1 => list.nested.into_iter().next().unwrap(),
365            _ => {
366                return Err(spanned_err(
367                    list.nested,
368                    "Only one of `item`, `attribute` can be provided, possibly with an \
369                     additional argument: `item(\"key\")` or `attribute(\"name\").",
370                ))
371            }
372        };
373        let meta = match metaitem {
374            syn::NestedMeta::Meta(meta) => meta,
375            syn::NestedMeta::Lit(lit) => {
376                return Err(spanned_err(
377                    lit,
378                    "Expected `attribute` or `item`, not a literal.",
379                ))
380            }
381        };
382        let path = meta.path();
383        if path.is_ident("attribute") {
384            Ok(Some(FieldAttribute::GetAttr(Self::attribute_arg(meta)?)))
385        } else if path.is_ident("item") {
386            Ok(Some(FieldAttribute::GetItem(Self::item_arg(meta)?)))
387        } else {
388            Err(spanned_err(meta, "Expected `attribute` or `item`."))
389        }
390    }
391
392    fn attribute_arg(meta: Meta) -> syn::Result<Option<syn::LitStr>> {
393        let arg_list = match meta {
394            Meta::List(list) => list,
395            Meta::Path(_) => return Ok(None),
396            Meta::NameValue(nv) => {
397                let err_msg = "Expected a string literal or no argument: `pyo3(attribute(\"name\") or `pyo3(attribute)`";
398                return Err(spanned_err(nv, err_msg));
399            }
400        };
401        let arg_msg = "Expected a single string literal argument.";
402        let first = match arg_list.nested.len() {
403            1 => arg_list.nested.first().unwrap(),
404            _ => return Err(spanned_err(arg_list, arg_msg)),
405        };
406        if let syn::NestedMeta::Lit(syn::Lit::Str(litstr)) = first {
407            if litstr.value().is_empty() {
408                return Err(spanned_err(litstr, "Attribute name cannot be empty."));
409            }
410            return Ok(Some(parse_quote!(#litstr)));
411        }
412        Err(spanned_err(first, arg_msg))
413    }
414
415    fn item_arg(meta: Meta) -> syn::Result<Option<syn::Lit>> {
416        let arg_list = match meta {
417            Meta::List(list) => list,
418            Meta::Path(_) => return Ok(None),
419            Meta::NameValue(nv) => {
420                return Err(spanned_err(
421                    nv,
422                    "Expected a literal or no argument: `pyo3(item(\"key\") or `pyo3(item)`",
423                ))
424            }
425        };
426        let arg_msg = "Expected a single literal argument.";
427        if arg_list.nested.is_empty() {
428            return Err(spanned_err(arg_list, arg_msg));
429        } else if arg_list.nested.len() > 1 {
430            return Err(spanned_err(arg_list.nested, arg_msg));
431        }
432        let first = arg_list.nested.first().unwrap();
433        if let syn::NestedMeta::Lit(lit) = first {
434            return Ok(Some(parse_quote!(#lit)));
435        }
436        Err(spanned_err(first, arg_msg))
437    }
438}
439
440fn spanned_err<T: ToTokens>(tokens: T, msg: &str) -> syn::Error {
441    syn::Error::new_spanned(tokens, msg)
442}
443
444/// Extract pyo3 metalist, flattens multiple lists into a single one.
445fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result<MetaList> {
446    let mut list: Punctuated<syn::NestedMeta, syn::Token![,]> = Punctuated::new();
447    for value in attrs {
448        match value.parse_meta()? {
449            Meta::List(ml) if value.path.is_ident("pyo3") => {
450                for meta in ml.nested {
451                    list.push(meta);
452                }
453            }
454            _ => continue,
455        }
456    }
457    Ok(MetaList {
458        path: parse_quote!(pyo3),
459        paren_token: syn::token::Paren::default(),
460        nested: list,
461    })
462}
463
464fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeDef>> {
465    let lifetimes = generics.lifetimes().collect::<Vec<_>>();
466    if lifetimes.len() > 1 {
467        return Err(spanned_err(
468            &generics,
469            "FromPyObject can be derived with at most one lifetime parameter.",
470        ));
471    }
472    Ok(lifetimes.into_iter().next())
473}
474
475/// Derive FromPyObject for enums and structs.
476///
477///   * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
478///   * At least one field, in case of `#[transparent]`, exactly one field
479///   * At least one variant for enums.
480///   * Fields of input structs and enums must implement `FromPyObject`
481///   * Derivation for structs with generic fields like `struct<T> Foo(T)`
482///     adds `T: FromPyObject` on the derived implementation.
483pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
484    let mut trait_generics = tokens.generics.clone();
485    let generics = &tokens.generics;
486    let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? {
487        lt.clone()
488    } else {
489        trait_generics.params.push(parse_quote!('source));
490        parse_quote!('source)
491    };
492    let mut where_clause: syn::WhereClause = parse_quote!(where);
493    for param in generics.type_params() {
494        let gen_ident = &param.ident;
495        where_clause
496            .predicates
497            .push(parse_quote!(#gen_ident: FromPyObject<#lt_param>))
498    }
499    let derives = match &tokens.data {
500        syn::Data::Enum(en) => {
501            let en = Enum::new(en, &tokens.ident)?;
502            en.build()
503        }
504        syn::Data::Struct(st) => {
505            let attrs = ContainerAttribute::parse_attrs(&tokens.attrs)?;
506            Container::verify_struct_container_attrs(&attrs, &tokens.attrs)?;
507            let ident = &tokens.ident;
508            let st = Container::new(&st.fields, parse_quote!(#ident), attrs, false)?;
509            st.build()
510        }
511        syn::Data::Union(_) => {
512            return Err(spanned_err(
513                tokens,
514                "#[derive(FromPyObject)] is not supported for unions.",
515            ))
516        }
517    };
518
519    let ident = &tokens.ident;
520    Ok(quote!(
521        #[automatically_derived]
522        impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause {
523            fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult<Self>  {
524                #derives
525            }
526        }
527    ))
528}