1use proc_macro2::{Span, TokenStream};
2use quote::{quote, ToTokens};
3use syn::punctuated::Punctuated;
4use syn::spanned::Spanned;
5use syn::{parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Meta, MetaList, Result};
6
7#[derive(Debug)]
9struct Enum<'a> {
10 enum_ident: &'a Ident,
11 variants: Vec<Container<'a>>,
12}
13
14impl<'a> Enum<'a> {
15 fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result<Self> {
20 if data_enum.variants.is_empty() {
21 return Err(spanned_err(
22 &ident,
23 "Cannot derive FromPyObject for empty enum.",
24 ));
25 }
26 let vars = data_enum
27 .variants
28 .iter()
29 .map(|variant| {
30 let attrs = ContainerAttribute::parse_attrs(&variant.attrs)?;
31 let var_ident = &variant.ident;
32 Container::new(
33 &variant.fields,
34 parse_quote!(#ident::#var_ident),
35 attrs,
36 true,
37 )
38 })
39 .collect::<Result<Vec<_>>>()?;
40
41 Ok(Enum {
42 enum_ident: ident,
43 variants: vars,
44 })
45 }
46
47 fn build(&self) -> TokenStream {
49 let mut var_extracts = Vec::new();
50 let mut error_names = String::new();
51 for (i, var) in self.variants.iter().enumerate() {
52 let struct_derive = var.build();
53 let ext = quote!(
54 let maybe_ret = || -> ::pyo3::PyResult<Self> {
55 #struct_derive
56 }();
57 if maybe_ret.is_ok() {
58 return maybe_ret
59 }
60 );
61
62 var_extracts.push(ext);
63 error_names.push_str(&var.err_name);
64 if i < self.variants.len() - 1 {
65 error_names.push_str(", ");
66 }
67 }
68 let error_names = if self.variants.len() > 1 {
69 format!("Union[{}]", error_names)
70 } else {
71 error_names
72 };
73 quote!(
74 #(#var_extracts)*
75 let type_name = obj.get_type().name();
76 let from = obj
77 .repr()
78 .map(|s| format!("{} ({})", s.to_string_lossy(), type_name))
79 .unwrap_or_else(|_| type_name.to_string());
80 let err_msg = format!("Can't convert {} to {}", from, #error_names);
81 Err(::pyo3::exceptions::PyTypeError::new_err(err_msg))
82 )
83 }
84}
85
86#[derive(Debug)]
90enum ContainerType<'a> {
91 Struct(Vec<(&'a Ident, FieldAttribute)>),
95 StructNewtype(&'a Ident),
99 Tuple(usize),
103 TupleNewtype,
107}
108
109#[derive(Debug)]
113struct Container<'a> {
114 path: syn::Path,
115 ty: ContainerType<'a>,
116 err_name: String,
117 is_enum_variant: bool,
118}
119
120impl<'a> Container<'a> {
121 fn new(
125 fields: &'a Fields,
126 path: syn::Path,
127 attrs: Vec<ContainerAttribute>,
128 is_enum_variant: bool,
129 ) -> Result<Self> {
130 if fields.is_empty() {
131 return Err(spanned_err(
132 fields,
133 "Cannot derive FromPyObject for empty structs and variants.",
134 ));
135 }
136 let transparent = attrs
137 .iter()
138 .any(|attr| *attr == ContainerAttribute::Transparent);
139 if transparent {
140 Self::check_transparent_len(fields)?;
141 }
142 let style = match (fields, transparent) {
143 (Fields::Unnamed(_), true) => ContainerType::TupleNewtype,
144 (Fields::Unnamed(unnamed), false) => {
145 if unnamed.unnamed.len() == 1 {
146 ContainerType::TupleNewtype
147 } else {
148 ContainerType::Tuple(unnamed.unnamed.len())
149 }
150 }
151 (Fields::Named(named), true) => {
152 let field = named
153 .named
154 .iter()
155 .next()
156 .expect("Check for len 1 is done above");
157 let ident = field
158 .ident
159 .as_ref()
160 .expect("Named fields should have identifiers");
161 ContainerType::StructNewtype(ident)
162 }
163 (Fields::Named(named), false) => {
164 let mut fields = Vec::new();
165 for field in named.named.iter() {
166 let ident = field
167 .ident
168 .as_ref()
169 .expect("Named fields should have identifiers");
170 let attr = FieldAttribute::parse_attrs(&field.attrs)?
171 .unwrap_or_else(|| FieldAttribute::GetAttr(None));
172 fields.push((ident, attr))
173 }
174 ContainerType::Struct(fields)
175 }
176 (Fields::Unit, _) => {
177 return Err(spanned_err(
179 &fields,
180 "Cannot derive FromPyObject for Unit structs and variants",
181 ));
182 }
183 };
184 let err_name = attrs
185 .iter()
186 .find_map(|a| a.annotation())
187 .unwrap_or_else(|| path.segments.last().unwrap().ident.to_string());
188
189 let v = Container {
190 path,
191 ty: style,
192 err_name,
193 is_enum_variant,
194 };
195 Ok(v)
196 }
197
198 fn verify_struct_container_attrs(
199 attrs: &'a [ContainerAttribute],
200 original: &[Attribute],
201 ) -> Result<()> {
202 for attr in attrs {
203 match attr {
204 ContainerAttribute::Transparent => continue,
205 ContainerAttribute::ErrorAnnotation(_) => {
206 let span = original
207 .iter()
208 .map(|a| a.span())
209 .fold(None, |mut acc: Option<Span>, span| {
210 if let Some(all) = acc.as_mut() {
211 all.join(span)
212 } else {
213 Some(span)
214 }
215 })
216 .unwrap_or_else(Span::call_site);
217 return Err(syn::Error::new(
218 span,
219 "Annotating error messages for structs is \
220 not supported. Remove the annotation attribute.",
221 ));
222 }
223 }
224 }
225 Ok(())
226 }
227
228 fn build(&self) -> TokenStream {
230 match &self.ty {
231 ContainerType::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)),
232 ContainerType::TupleNewtype => self.build_newtype_struct(None),
233 ContainerType::Tuple(len) => self.build_tuple_struct(*len),
234 ContainerType::Struct(tups) => self.build_struct(tups),
235 }
236 }
237
238 fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream {
239 let self_ty = &self.path;
240 if let Some(ident) = field_ident {
241 quote!(
242 Ok(#self_ty{#ident: obj.extract()?})
243 )
244 } else {
245 quote!(Ok(#self_ty(obj.extract()?)))
246 }
247 }
248
249 fn build_tuple_struct(&self, len: usize) -> TokenStream {
250 let self_ty = &self.path;
251 let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
252 for i in 0..len {
253 fields.push(quote!(slice[#i].extract()?));
254 }
255 let msg = if self.is_enum_variant {
256 quote!(format!(
257 "Expected tuple of length {}, but got length {}.",
258 #len,
259 s.len()
260 ))
261 } else {
262 quote!("")
263 };
264 quote!(
265 let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?;
266 if s.len() != #len {
267 return Err(::pyo3::exceptions::PyValueError::new_err(#msg))
268 }
269 let slice = s.as_slice();
270 Ok(#self_ty(#fields))
271 )
272 }
273
274 fn build_struct(&self, tups: &[(&Ident, FieldAttribute)]) -> TokenStream {
275 let self_ty = &self.path;
276 let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
277 for (ident, attr) in tups {
278 let ext_fn = match attr {
279 FieldAttribute::GetAttr(Some(name)) => quote!(getattr(#name)),
280 FieldAttribute::GetAttr(None) => quote!(getattr(stringify!(#ident))),
281 FieldAttribute::GetItem(Some(key)) => quote!(get_item(#key)),
282 FieldAttribute::GetItem(None) => quote!(get_item(stringify!(#ident))),
283 };
284 fields.push(quote!(#ident: obj.#ext_fn?.extract()?));
285 }
286 quote!(Ok(#self_ty{#fields}))
287 }
288
289 fn check_transparent_len(fields: &Fields) -> Result<()> {
290 if fields.len() != 1 {
291 return Err(spanned_err(
292 fields,
293 "Transparent structs and variants can only have 1 field",
294 ));
295 }
296 Ok(())
297 }
298}
299
300#[derive(Clone, Debug, PartialEq)]
302enum ContainerAttribute {
303 Transparent,
305 ErrorAnnotation(String),
307}
308
309impl ContainerAttribute {
310 fn annotation(&self) -> Option<String> {
312 match self {
313 ContainerAttribute::ErrorAnnotation(s) => Some(s.to_string()),
314 _ => None,
315 }
316 }
317
318 fn parse_attrs(value: &[Attribute]) -> Result<Vec<Self>> {
322 let mut attrs = Vec::new();
323 let list = get_pyo3_meta_list(value)?;
324 for meta in list.nested {
325 if let syn::NestedMeta::Meta(metaitem) = &meta {
326 match metaitem {
327 Meta::Path(p) if p.is_ident("transparent") => {
328 attrs.push(ContainerAttribute::Transparent);
329 continue;
330 }
331 Meta::NameValue(nv) if nv.path.is_ident("annotation") => {
332 if let syn::Lit::Str(s) = &nv.lit {
333 attrs.push(ContainerAttribute::ErrorAnnotation(s.value()))
334 } else {
335 return Err(spanned_err(&nv.lit, "Expected string literal."));
336 }
337 continue;
338 }
339 _ => {} }
341 }
342
343 return Err(spanned_err(meta, "Unrecognized `pyo3` container attribute"));
344 }
345 Ok(attrs)
346 }
347}
348
349#[derive(Clone, Debug)]
351enum FieldAttribute {
352 GetItem(Option<syn::Lit>),
353 GetAttr(Option<syn::LitStr>),
354}
355
356impl FieldAttribute {
357 fn parse_attrs(attrs: &[Attribute]) -> Result<Option<Self>> {
361 let list = get_pyo3_meta_list(attrs)?;
362 let metaitem = match list.nested.len() {
363 0 => return Ok(None),
364 1 => list.nested.into_iter().next().unwrap(),
365 _ => {
366 return Err(spanned_err(
367 list.nested,
368 "Only one of `item`, `attribute` can be provided, possibly with an \
369 additional argument: `item(\"key\")` or `attribute(\"name\").",
370 ))
371 }
372 };
373 let meta = match metaitem {
374 syn::NestedMeta::Meta(meta) => meta,
375 syn::NestedMeta::Lit(lit) => {
376 return Err(spanned_err(
377 lit,
378 "Expected `attribute` or `item`, not a literal.",
379 ))
380 }
381 };
382 let path = meta.path();
383 if path.is_ident("attribute") {
384 Ok(Some(FieldAttribute::GetAttr(Self::attribute_arg(meta)?)))
385 } else if path.is_ident("item") {
386 Ok(Some(FieldAttribute::GetItem(Self::item_arg(meta)?)))
387 } else {
388 Err(spanned_err(meta, "Expected `attribute` or `item`."))
389 }
390 }
391
392 fn attribute_arg(meta: Meta) -> syn::Result<Option<syn::LitStr>> {
393 let arg_list = match meta {
394 Meta::List(list) => list,
395 Meta::Path(_) => return Ok(None),
396 Meta::NameValue(nv) => {
397 let err_msg = "Expected a string literal or no argument: `pyo3(attribute(\"name\") or `pyo3(attribute)`";
398 return Err(spanned_err(nv, err_msg));
399 }
400 };
401 let arg_msg = "Expected a single string literal argument.";
402 let first = match arg_list.nested.len() {
403 1 => arg_list.nested.first().unwrap(),
404 _ => return Err(spanned_err(arg_list, arg_msg)),
405 };
406 if let syn::NestedMeta::Lit(syn::Lit::Str(litstr)) = first {
407 if litstr.value().is_empty() {
408 return Err(spanned_err(litstr, "Attribute name cannot be empty."));
409 }
410 return Ok(Some(parse_quote!(#litstr)));
411 }
412 Err(spanned_err(first, arg_msg))
413 }
414
415 fn item_arg(meta: Meta) -> syn::Result<Option<syn::Lit>> {
416 let arg_list = match meta {
417 Meta::List(list) => list,
418 Meta::Path(_) => return Ok(None),
419 Meta::NameValue(nv) => {
420 return Err(spanned_err(
421 nv,
422 "Expected a literal or no argument: `pyo3(item(\"key\") or `pyo3(item)`",
423 ))
424 }
425 };
426 let arg_msg = "Expected a single literal argument.";
427 if arg_list.nested.is_empty() {
428 return Err(spanned_err(arg_list, arg_msg));
429 } else if arg_list.nested.len() > 1 {
430 return Err(spanned_err(arg_list.nested, arg_msg));
431 }
432 let first = arg_list.nested.first().unwrap();
433 if let syn::NestedMeta::Lit(lit) = first {
434 return Ok(Some(parse_quote!(#lit)));
435 }
436 Err(spanned_err(first, arg_msg))
437 }
438}
439
440fn spanned_err<T: ToTokens>(tokens: T, msg: &str) -> syn::Error {
441 syn::Error::new_spanned(tokens, msg)
442}
443
444fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result<MetaList> {
446 let mut list: Punctuated<syn::NestedMeta, syn::Token![,]> = Punctuated::new();
447 for value in attrs {
448 match value.parse_meta()? {
449 Meta::List(ml) if value.path.is_ident("pyo3") => {
450 for meta in ml.nested {
451 list.push(meta);
452 }
453 }
454 _ => continue,
455 }
456 }
457 Ok(MetaList {
458 path: parse_quote!(pyo3),
459 paren_token: syn::token::Paren::default(),
460 nested: list,
461 })
462}
463
464fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeDef>> {
465 let lifetimes = generics.lifetimes().collect::<Vec<_>>();
466 if lifetimes.len() > 1 {
467 return Err(spanned_err(
468 &generics,
469 "FromPyObject can be derived with at most one lifetime parameter.",
470 ));
471 }
472 Ok(lifetimes.into_iter().next())
473}
474
475pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
484 let mut trait_generics = tokens.generics.clone();
485 let generics = &tokens.generics;
486 let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? {
487 lt.clone()
488 } else {
489 trait_generics.params.push(parse_quote!('source));
490 parse_quote!('source)
491 };
492 let mut where_clause: syn::WhereClause = parse_quote!(where);
493 for param in generics.type_params() {
494 let gen_ident = ¶m.ident;
495 where_clause
496 .predicates
497 .push(parse_quote!(#gen_ident: FromPyObject<#lt_param>))
498 }
499 let derives = match &tokens.data {
500 syn::Data::Enum(en) => {
501 let en = Enum::new(en, &tokens.ident)?;
502 en.build()
503 }
504 syn::Data::Struct(st) => {
505 let attrs = ContainerAttribute::parse_attrs(&tokens.attrs)?;
506 Container::verify_struct_container_attrs(&attrs, &tokens.attrs)?;
507 let ident = &tokens.ident;
508 let st = Container::new(&st.fields, parse_quote!(#ident), attrs, false)?;
509 st.build()
510 }
511 syn::Data::Union(_) => {
512 return Err(spanned_err(
513 tokens,
514 "#[derive(FromPyObject)] is not supported for unions.",
515 ))
516 }
517 };
518
519 let ident = &tokens.ident;
520 Ok(quote!(
521 #[automatically_derived]
522 impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause {
523 fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult<Self> {
524 #derives
525 }
526 }
527 ))
528}