1use crate::attributes::{DefaultAttribute, FromPyWithAttribute, RenamingRule};
2use crate::derive_attributes::{ContainerAttributes, FieldAttributes, FieldGetter};
3use crate::utils::{self, Ctx};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote, quote_spanned, ToTokens};
6use syn::{
7 ext::IdentExt, parse_quote, punctuated::Punctuated, spanned::Spanned, DataEnum, DeriveInput,
8 Fields, Ident, Result, Token,
9};
10
11struct Enum<'a> {
13 enum_ident: &'a Ident,
14 variants: Vec<Container<'a>>,
15}
16
17impl<'a> Enum<'a> {
18 fn new(
23 data_enum: &'a DataEnum,
24 ident: &'a Ident,
25 options: ContainerAttributes,
26 ) -> Result<Self> {
27 ensure_spanned!(
28 !data_enum.variants.is_empty(),
29 ident.span() => "cannot derive FromPyObject for empty enum"
30 );
31 let variants = data_enum
32 .variants
33 .iter()
34 .map(|variant| {
35 let mut variant_options = ContainerAttributes::from_attrs(&variant.attrs)?;
36 if let Some(rename_all) = &options.rename_all {
37 ensure_spanned!(
38 variant_options.rename_all.is_none(),
39 variant_options.rename_all.span() => "Useless variant `rename_all` - enum is already annotated with `rename_all"
40 );
41 variant_options.rename_all = Some(rename_all.clone());
42
43 }
44 let var_ident = &variant.ident;
45 Container::new(
46 &variant.fields,
47 parse_quote!(#ident::#var_ident),
48 variant_options,
49 )
50 })
51 .collect::<Result<Vec<_>>>()?;
52
53 Ok(Enum {
54 enum_ident: ident,
55 variants,
56 })
57 }
58
59 fn build(&self, ctx: &Ctx) -> TokenStream {
61 let Ctx { pyo3_path, .. } = ctx;
62 let mut var_extracts = Vec::new();
63 let mut variant_names = Vec::new();
64 let mut error_names = Vec::new();
65
66 for var in &self.variants {
67 let struct_derive = var.build(ctx);
68 let ext = quote!({
69 let maybe_ret = || -> #pyo3_path::PyResult<Self> {
70 #struct_derive
71 }();
72
73 match maybe_ret {
74 ok @ ::std::result::Result::Ok(_) => return ok,
75 ::std::result::Result::Err(err) => err
76 }
77 });
78
79 var_extracts.push(ext);
80 variant_names.push(var.path.segments.last().unwrap().ident.to_string());
81 error_names.push(&var.err_name);
82 }
83 let ty_name = self.enum_ident.to_string();
84 quote!(
85 let errors = [
86 #(#var_extracts),*
87 ];
88 ::std::result::Result::Err(
89 #pyo3_path::impl_::frompyobject::failed_to_extract_enum(
90 obj.py(),
91 #ty_name,
92 &[#(#variant_names),*],
93 &[#(#error_names),*],
94 &errors
95 )
96 )
97 )
98 }
99}
100
101struct NamedStructField<'a> {
102 ident: &'a syn::Ident,
103 getter: Option<FieldGetter>,
104 from_py_with: Option<FromPyWithAttribute>,
105 default: Option<DefaultAttribute>,
106}
107
108struct TupleStructField {
109 from_py_with: Option<FromPyWithAttribute>,
110}
111
112enum ContainerType<'a> {
116 Struct(Vec<NamedStructField<'a>>),
120 StructNewtype(&'a syn::Ident, Option<FromPyWithAttribute>),
124 Tuple(Vec<TupleStructField>),
129 TupleNewtype(Option<FromPyWithAttribute>),
133}
134
135struct Container<'a> {
139 path: syn::Path,
140 ty: ContainerType<'a>,
141 err_name: String,
142 rename_rule: Option<RenamingRule>,
143}
144
145impl<'a> Container<'a> {
146 fn new(fields: &'a Fields, path: syn::Path, options: ContainerAttributes) -> Result<Self> {
150 let style = match fields {
151 Fields::Unnamed(unnamed) if !unnamed.unnamed.is_empty() => {
152 ensure_spanned!(
153 options.rename_all.is_none(),
154 options.rename_all.span() => "`rename_all` is useless on tuple structs and variants."
155 );
156 let mut tuple_fields = unnamed
157 .unnamed
158 .iter()
159 .map(|field| {
160 let attrs = FieldAttributes::from_attrs(&field.attrs)?;
161 ensure_spanned!(
162 attrs.getter.is_none(),
163 field.span() => "`getter` is not permitted on tuple struct elements."
164 );
165 ensure_spanned!(
166 attrs.default.is_none(),
167 field.span() => "`default` is not permitted on tuple struct elements."
168 );
169 Ok(TupleStructField {
170 from_py_with: attrs.from_py_with,
171 })
172 })
173 .collect::<Result<Vec<_>>>()?;
174
175 if tuple_fields.len() == 1 {
176 let field = tuple_fields.pop().unwrap();
179 ContainerType::TupleNewtype(field.from_py_with)
180 } else if options.transparent.is_some() {
181 bail_spanned!(
182 fields.span() => "transparent structs and variants can only have 1 field"
183 );
184 } else {
185 ContainerType::Tuple(tuple_fields)
186 }
187 }
188 Fields::Named(named) if !named.named.is_empty() => {
189 let mut struct_fields = named
190 .named
191 .iter()
192 .map(|field| {
193 let ident = field
194 .ident
195 .as_ref()
196 .expect("Named fields should have identifiers");
197 let mut attrs = FieldAttributes::from_attrs(&field.attrs)?;
198
199 if let Some(ref from_item_all) = options.from_item_all {
200 if let Some(replaced) = attrs.getter.replace(FieldGetter::GetItem(parse_quote!(item), None))
201 {
202 match replaced {
203 FieldGetter::GetItem(item, Some(item_name)) => {
204 attrs.getter = Some(FieldGetter::GetItem(item, Some(item_name)));
205 }
206 FieldGetter::GetItem(_, None) => bail_spanned!(from_item_all.span() => "Useless `item` - the struct is already annotated with `from_item_all`"),
207 FieldGetter::GetAttr(_, _) => bail_spanned!(
208 from_item_all.span() => "The struct is already annotated with `from_item_all`, `attribute` is not allowed"
209 ),
210 }
211 }
212 }
213
214 Ok(NamedStructField {
215 ident,
216 getter: attrs.getter,
217 from_py_with: attrs.from_py_with,
218 default: attrs.default,
219 })
220 })
221 .collect::<Result<Vec<_>>>()?;
222 if struct_fields.iter().all(|field| field.default.is_some()) {
223 bail_spanned!(
224 fields.span() => "cannot derive FromPyObject for structs and variants with only default values"
225 )
226 } else if options.transparent.is_some() {
227 ensure_spanned!(
228 struct_fields.len() == 1,
229 fields.span() => "transparent structs and variants can only have 1 field"
230 );
231 ensure_spanned!(
232 options.rename_all.is_none(),
233 options.rename_all.span() => "`rename_all` is not permitted on `transparent` structs and variants"
234 );
235 let field = struct_fields.pop().unwrap();
236 ensure_spanned!(
237 field.getter.is_none(),
238 field.ident.span() => "`transparent` structs may not have a `getter` for the inner field"
239 );
240 ContainerType::StructNewtype(field.ident, field.from_py_with)
241 } else {
242 ContainerType::Struct(struct_fields)
243 }
244 }
245 _ => bail_spanned!(
246 fields.span() => "cannot derive FromPyObject for empty structs and variants"
247 ),
248 };
249 let err_name = options.annotation.map_or_else(
250 || path.segments.last().unwrap().ident.to_string(),
251 |lit_str| lit_str.value(),
252 );
253
254 let v = Container {
255 path,
256 ty: style,
257 err_name,
258 rename_rule: options.rename_all.map(|v| v.value.rule),
259 };
260 Ok(v)
261 }
262
263 fn name(&self) -> String {
264 let mut value = String::new();
265 for segment in &self.path.segments {
266 if !value.is_empty() {
267 value.push_str("::");
268 }
269 value.push_str(&segment.ident.to_string());
270 }
271 value
272 }
273
274 fn build(&self, ctx: &Ctx) -> TokenStream {
276 match &self.ty {
277 ContainerType::StructNewtype(ident, from_py_with) => {
278 self.build_newtype_struct(Some(ident), from_py_with, ctx)
279 }
280 ContainerType::TupleNewtype(from_py_with) => {
281 self.build_newtype_struct(None, from_py_with, ctx)
282 }
283 ContainerType::Tuple(tups) => self.build_tuple_struct(tups, ctx),
284 ContainerType::Struct(tups) => self.build_struct(tups, ctx),
285 }
286 }
287
288 fn build_newtype_struct(
289 &self,
290 field_ident: Option<&Ident>,
291 from_py_with: &Option<FromPyWithAttribute>,
292 ctx: &Ctx,
293 ) -> TokenStream {
294 let Ctx { pyo3_path, .. } = ctx;
295 let self_ty = &self.path;
296 let struct_name = self.name();
297 if let Some(ident) = field_ident {
298 let field_name = ident.to_string();
299 if let Some(FromPyWithAttribute {
300 kw,
301 value: expr_path,
302 }) = from_py_with
303 {
304 let extractor = quote_spanned! { kw.span =>
305 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
306 };
307 quote! {
308 Ok(#self_ty {
309 #ident: #pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, obj, #struct_name, #field_name)?
310 })
311 }
312 } else {
313 quote! {
314 Ok(#self_ty {
315 #ident: #pyo3_path::impl_::frompyobject::extract_struct_field(obj, #struct_name, #field_name)?
316 })
317 }
318 }
319 } else if let Some(FromPyWithAttribute {
320 kw,
321 value: expr_path,
322 }) = from_py_with
323 {
324 let extractor = quote_spanned! { kw.span =>
325 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
326 };
327 quote! {
328 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, obj, #struct_name, 0).map(#self_ty)
329 }
330 } else {
331 quote! {
332 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(obj, #struct_name, 0).map(#self_ty)
333 }
334 }
335 }
336
337 fn build_tuple_struct(&self, struct_fields: &[TupleStructField], ctx: &Ctx) -> TokenStream {
338 let Ctx { pyo3_path, .. } = ctx;
339 let self_ty = &self.path;
340 let struct_name = &self.name();
341 let field_idents: Vec<_> = (0..struct_fields.len())
342 .map(|i| format_ident!("arg{}", i))
343 .collect();
344 let fields = struct_fields.iter().zip(&field_idents).enumerate().map(|(index, (field, ident))| {
345 if let Some(FromPyWithAttribute {
346 kw,
347 value: expr_path, ..
348 }) = &field.from_py_with {
349 let extractor = quote_spanned! { kw.span =>
350 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
351 };
352 quote! {
353 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field_with(#extractor, &#ident, #struct_name, #index)?
354 }
355 } else {
356 quote!{
357 #pyo3_path::impl_::frompyobject::extract_tuple_struct_field(&#ident, #struct_name, #index)?
358 }}
359 });
360
361 quote!(
362 match #pyo3_path::types::PyAnyMethods::extract(obj) {
363 ::std::result::Result::Ok((#(#field_idents),*)) => ::std::result::Result::Ok(#self_ty(#(#fields),*)),
364 ::std::result::Result::Err(err) => ::std::result::Result::Err(err),
365 }
366 )
367 }
368
369 fn build_struct(&self, struct_fields: &[NamedStructField<'_>], ctx: &Ctx) -> TokenStream {
370 let Ctx { pyo3_path, .. } = ctx;
371 let self_ty = &self.path;
372 let struct_name = self.name();
373 let mut fields: Punctuated<TokenStream, Token![,]> = Punctuated::new();
374 for field in struct_fields {
375 let ident = field.ident;
376 let field_name = ident.unraw().to_string();
377 let getter = match field
378 .getter
379 .as_ref()
380 .unwrap_or(&FieldGetter::GetAttr(parse_quote!(attribute), None))
381 {
382 FieldGetter::GetAttr(_, Some(name)) => {
383 quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
384 }
385 FieldGetter::GetAttr(_, None) => {
386 let name = self
387 .rename_rule
388 .map(|rule| utils::apply_renaming_rule(rule, &field_name));
389 let name = name.as_deref().unwrap_or(&field_name);
390 quote!(#pyo3_path::types::PyAnyMethods::getattr(obj, #pyo3_path::intern!(obj.py(), #name)))
391 }
392 FieldGetter::GetItem(_, Some(syn::Lit::Str(key))) => {
393 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #key)))
394 }
395 FieldGetter::GetItem(_, Some(key)) => {
396 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #key))
397 }
398 FieldGetter::GetItem(_, None) => {
399 let name = self
400 .rename_rule
401 .map(|rule| utils::apply_renaming_rule(rule, &field_name));
402 let name = name.as_deref().unwrap_or(&field_name);
403 quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #name)))
404 }
405 };
406 let extractor = if let Some(FromPyWithAttribute {
407 kw,
408 value: expr_path,
409 }) = &field.from_py_with
410 {
411 let extractor = quote_spanned! { kw.span =>
412 { let from_py_with: fn(_) -> _ = #expr_path; from_py_with }
413 };
414 quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#extractor, &#getter?, #struct_name, #field_name)?)
415 } else {
416 quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
417 };
418 let extracted = if let Some(default) = &field.default {
419 let default_expr = if let Some(default_expr) = &default.value {
420 default_expr.to_token_stream()
421 } else {
422 quote!(::std::default::Default::default())
423 };
424 quote!(if let ::std::result::Result::Ok(value) = #getter {
425 #extractor
426 } else {
427 #default_expr
428 })
429 } else {
430 quote!({
431 let value = #getter?;
432 #extractor
433 })
434 };
435
436 fields.push(quote!(#ident: #extracted));
437 }
438
439 quote!(::std::result::Result::Ok(#self_ty{#fields}))
440 }
441}
442
443fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::LifetimeParam>> {
444 let mut lifetimes = generics.lifetimes();
445 let lifetime = lifetimes.next();
446 ensure_spanned!(
447 lifetimes.next().is_none(),
448 generics.span() => "FromPyObject can be derived with at most one lifetime parameter"
449 );
450 Ok(lifetime)
451}
452
453pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {
462 let options = ContainerAttributes::from_attrs(&tokens.attrs)?;
463 let ctx = &Ctx::new(&options.krate, None);
464 let Ctx { pyo3_path, .. } = &ctx;
465
466 let (_, ty_generics, _) = tokens.generics.split_for_impl();
467 let mut trait_generics = tokens.generics.clone();
468 let lt_param = if let Some(lt) = verify_and_get_lifetime(&trait_generics)? {
469 lt.clone()
470 } else {
471 trait_generics.params.push(parse_quote!('py));
472 parse_quote!('py)
473 };
474 let (impl_generics, _, where_clause) = trait_generics.split_for_impl();
475
476 let mut where_clause = where_clause.cloned().unwrap_or_else(|| parse_quote!(where));
477 for param in trait_generics.type_params() {
478 let gen_ident = ¶m.ident;
479 where_clause
480 .predicates
481 .push(parse_quote!(#gen_ident: #pyo3_path::FromPyObject<'py>))
482 }
483
484 let derives = match &tokens.data {
485 syn::Data::Enum(en) => {
486 if options.transparent.is_some() || options.annotation.is_some() {
487 bail_spanned!(tokens.span() => "`transparent` or `annotation` is not supported \
488 at top level for enums");
489 }
490 let en = Enum::new(en, &tokens.ident, options)?;
491 en.build(ctx)
492 }
493 syn::Data::Struct(st) => {
494 if let Some(lit_str) = &options.annotation {
495 bail_spanned!(lit_str.span() => "`annotation` is unsupported for structs");
496 }
497 let ident = &tokens.ident;
498 let st = Container::new(&st.fields, parse_quote!(#ident), options)?;
499 st.build(ctx)
500 }
501 syn::Data::Union(_) => bail_spanned!(
502 tokens.span() => "#[derive(FromPyObject)] is not supported for unions"
503 ),
504 };
505
506 let ident = &tokens.ident;
507 Ok(quote!(
508 #[automatically_derived]
509 impl #impl_generics #pyo3_path::FromPyObject<#lt_param> for #ident #ty_generics #where_clause {
510 fn extract_bound(obj: &#pyo3_path::Bound<#lt_param, #pyo3_path::PyAny>) -> #pyo3_path::PyResult<Self> {
511 #derives
512 }
513 }
514 ))
515}