use proc_macro2::TokenStream;
use quote::quote_spanned;
use {
super::FragmentDeriveField,
crate::{generics_for_serde, schema::types as schema},
};
pub enum DeserializeImpl<'a> {
Standard(StandardDeserializeImpl<'a>),
Spreading(SpreadingDeserializeImpl<'a>),
}
pub struct StandardDeserializeImpl<'a> {
target_struct: &'a syn::Ident,
fields: Vec<Field>,
generics: &'a syn::Generics,
}
pub struct SpreadingDeserializeImpl<'a> {
target_struct: &'a syn::Ident,
fields: Vec<Field>,
generics: &'a syn::Generics,
}
struct Field {
rust_name: proc_macro2::Ident,
ty: syn::Type,
field_variant_name: proc_macro2::Ident,
serialized_name: Option<String>,
is_spread: bool,
is_flattened: bool,
is_recurse: bool,
is_feature_flagged: bool,
is_skippable: bool,
has_default: bool,
}
impl<'a> DeserializeImpl<'a> {
pub fn new(
fields: &[(FragmentDeriveField, Option<schema::Field<'_>>)],
name: &'a syn::Ident,
generics: &'a syn::Generics,
) -> Self {
let spreading = fields.iter().any(|f| f.0.spread());
let target_struct = name;
let fields = fields
.iter()
.map(|(field, schema_field)| process_field(field, schema_field.as_ref()))
.collect();
match spreading {
true => DeserializeImpl::Spreading(SpreadingDeserializeImpl {
target_struct,
fields,
generics,
}),
false => DeserializeImpl::Standard(StandardDeserializeImpl {
target_struct,
fields,
generics,
}),
}
}
}
impl quote::ToTokens for DeserializeImpl<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
DeserializeImpl::Standard(inner) => inner.to_tokens(tokens),
DeserializeImpl::Spreading(inner) => inner.to_tokens(tokens),
}
}
}
impl quote::ToTokens for StandardDeserializeImpl<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
use quote::{TokenStreamExt, quote};
let target_struct = &self.target_struct;
let serialized_names = self
.fields
.iter()
.map(|f| {
proc_macro2::Literal::string(
f.serialized_name
.as_deref()
.expect("non-spread fields must have a serialized_name"),
)
})
.collect::<Vec<_>>();
let field_variant_names = self
.fields
.iter()
.map(|f| &f.field_variant_name)
.collect::<Vec<_>>();
let field_names = self.fields.iter().map(|f| &f.rust_name).collect::<Vec<_>>();
let field_decodes = self.fields.iter().map(|f| {
let field_name = &f.rust_name;
let ty = &f.ty;
if f.is_flattened {
quote! {
#field_name = Some(__map.next_value::<cynic::__private::Flattened<#ty>>()?.into_inner());
}
} else if f.has_default {
quote! {
#field_name = Some(__map.next_value::<Option<#ty>>()?.unwrap_or_default());
}
} else {
quote! {
#field_name = Some(__map.next_value()?);
}
}
});
let struct_name = self.target_struct.to_string();
let expecting_str = proc_macro2::Literal::string(&format!("struct {}", &struct_name));
let struct_name = proc_macro2::Literal::string(&struct_name);
let (_, ty_generics, _) = self.generics.split_for_impl();
let generics_with_de = generics_for_serde::with_de_and_deserialize_bounds(self.generics);
let (impl_generics, ty_generics_with_de, where_clause) = generics_with_de.split_for_impl();
let field_unwraps = self.fields.iter().zip(&serialized_names).map(|(field, serialized_name)| {
let rust_name = &field.rust_name;
if field.is_recurse || field.is_feature_flagged || field.is_skippable {
let span = rust_name.span();
quote_spanned!{ span =>
let #rust_name = #rust_name.unwrap_or_default();
}
} else {
quote! {
let #rust_name = #rust_name.ok_or_else(|| cynic::serde::de::Error::missing_field(#serialized_name))?;
}
}
}).collect::<Vec<_>>();
tokens.append_all(quote! {
#[automatically_derived]
impl #impl_generics cynic::serde::Deserialize<'de> for #target_struct #ty_generics #where_clause {
fn deserialize<__D>(deserializer: __D) -> Result<Self, __D::Error>
where
__D: cynic::serde::Deserializer<'de>,
{
#[derive(cynic::serde::Deserialize)]
#[serde(field_identifier, crate="cynic::serde")]
#[allow(non_camel_case_types)]
enum __FragmentDeriveField {
#(
#[serde(rename = #serialized_names)]
#field_variant_names,
)*
#[serde(other)]
__Other
}
struct Visitor #generics_with_de #where_clause {
marker: ::core::marker::PhantomData<#target_struct #ty_generics>,
lifetime: ::core::marker::PhantomData<&'de ()>,
}
impl #impl_generics cynic::serde::de::Visitor<'de> for Visitor #ty_generics_with_de #where_clause {
type Value = #target_struct #ty_generics;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str(#expecting_str)
}
fn visit_map<V>(self, mut __map: V) -> Result<Self::Value, V::Error>
where
V: cynic::serde::de::MapAccess<'de>,
{
#(
let mut #field_names = None;
)*
while let Some(__key) = __map.next_key()? {
match __key {
#(
__FragmentDeriveField::#field_variant_names => {
if #field_names.is_some() {
return Err(cynic::serde::de::Error::duplicate_field(#serialized_names));
}
#field_decodes
}
)*
__FragmentDeriveField::__Other => {
__map.next_value::<cynic::serde::de::IgnoredAny>()?;
}
}
}
#(#field_unwraps)*
Ok(#target_struct {
#(#field_names),*
})
}
}
const FIELDS: &'static [&str] = &[#(#serialized_names),*];
deserializer.deserialize_struct(
#struct_name,
FIELDS,
Visitor {
marker: ::core::marker::PhantomData,
lifetime: ::core::marker::PhantomData,
},
)
}
}
});
}
}
impl quote::ToTokens for SpreadingDeserializeImpl<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
use quote::{TokenStreamExt, quote};
let target_struct = &self.target_struct;
let field_inserts = self.fields.iter().map(|f| {
let field_name = &f.rust_name;
let field_ty = &f.ty;
if f.is_spread {
quote! {
#field_name: <#field_ty as cynic::serde::Deserialize<'de>>::deserialize(
spreadable.spread_deserializer()
)?
}
} else if f.is_flattened {
let serialized_name = proc_macro2::Literal::string(
f.serialized_name
.as_deref()
.expect("non spread fields must have a serialized_name"),
);
quote! {
#field_name: spreadable.deserialize_field::<
cynic::__private::Flattened<#field_ty>
>(#serialized_name)?.into_inner()
}
} else {
let serialized_name = proc_macro2::Literal::string(
f.serialized_name
.as_deref()
.expect("non spread fields must have a serialized_name"),
);
quote! {
#field_name: spreadable.deserialize_field(#serialized_name)?
}
}
});
let (_, ty_generics, where_clause) = self.generics.split_for_impl();
let generics_with_de = generics_for_serde::with_de_and_deserialize_bounds(self.generics);
let (impl_generics, _, _) = generics_with_de.split_for_impl();
tokens.append_all(quote! {
#[automatically_derived]
impl #impl_generics cynic::serde::Deserialize<'de> for #target_struct #ty_generics #where_clause {
fn deserialize<__D>(deserializer: __D) -> Result<Self, __D::Error>
where
__D: cynic::serde::Deserializer<'de>,
{
let spreadable = cynic::__private::Spreadable::<__D::Error>::deserialize(deserializer)?;
Ok(#target_struct {
#(#field_inserts,)*
})
}
}
});
}
}
fn process_field(field: &FragmentDeriveField, schema_field: Option<&schema::Field<'_>>) -> Field {
let rust_name = field.ident().unwrap();
let field_variant_name = rust_name.clone();
Field {
field_variant_name,
serialized_name: field
.alias()
.or_else(|| schema_field.map(|f| f.name.as_str().to_string())),
rust_name: rust_name.clone(),
ty: field.raw_field.ty.clone(),
is_spread: field.spread(),
is_flattened: *field.raw_field.flatten,
is_recurse: field.raw_field.recurse.is_some(),
is_feature_flagged: field.raw_field.feature.is_some(),
is_skippable: field.is_skippable(),
has_default: field.has_default(),
}
}