use proc_macro2::{Ident, TokenStream};
use quote::quote;
use std::iter;
use crate::context::Context;
use crate::types::UnionDefinition;
pub fn generate(ctx: &Context, def: &UnionDefinition) -> TokenStream {
let enum_ = generate_enum(ctx, def);
let deserialize = generate_deserialize(ctx, def);
let variant = generate_variant(ctx, def);
let unknown = generate_unknown(ctx, def);
quote! {
use conjure_object::serde::{ser, de};
use conjure_object::serde::ser::SerializeMap as SerializeMap_;
use conjure_object::private::{UnionField_, UnionTypeField_};
use std::fmt;
#enum_
#deserialize
#variant
#unknown
}
}
fn variants(ctx: &Context, def: &UnionDefinition) -> Vec<Ident> {
def.union_()
.iter()
.map(|f| ctx.type_name(f.field_name()))
.collect()
}
fn unknown(ctx: &Context, def: &UnionDefinition) -> TokenStream {
if variants(ctx, def).iter().any(|f| f == "Unknown") {
quote!(Unknown_)
} else {
quote!(Unknown)
}
}
fn generate_enum(ctx: &Context, def: &UnionDefinition) -> TokenStream {
let name = ctx.type_name(def.type_name().name());
let result = ctx.result_ident(def.type_name());
let some = ctx.some_ident(def.type_name());
let mut type_attrs = vec![];
let mut derives = vec!["Debug", "Clone"];
if def.union_().iter().any(|v| ctx.has_double(v.type_())) {
derives.push("conjure_object::private::Educe");
type_attrs.push(quote!(#[educe(PartialEq, Eq, PartialOrd, Ord, Hash)]));
} else {
derives.push("PartialEq");
derives.push("Eq");
derives.push("PartialOrd");
derives.push("Ord");
derives.push("Hash");
}
let derives = derives.iter().map(|s| s.parse::<TokenStream>().unwrap());
type_attrs.insert(0, quote!(#[derive(#(#derives),*)]));
let docs = def.union_().iter().map(|f| ctx.docs(f.docs()));
let deprecated = def.union_().iter().map(|f| ctx.deprecated(f.deprecated()));
let allow_deprecated = def
.union_()
.iter()
.map(|f| ctx.allow_deprecated(f.deprecated()));
let variants = &variants(ctx, def);
let types = &def
.union_()
.iter()
.map(|f| {
let attr = if ctx.is_double(f.type_()) {
quote! {
#[educe(
PartialEq(trait = "conjure_object::private::DoubleOps"),
PartialOrd(trait = "conjure_object::private::DoubleOps"),
Ord(trait = "conjure_object::private::DoubleOps"),
Hash(trait = "conjure_object::private::DoubleOps"),
)]
}
} else {
quote!()
};
let ty = ctx.boxed_rust_type(def.type_name(), f.type_());
quote!(#attr #ty)
})
.collect::<Vec<_>>();
let unknown = unknown(ctx, def);
let unknown_variant = if ctx.exhaustive() {
quote!()
} else {
quote! {
#unknown(#unknown),
}
};
let serialize_unknown = if ctx.exhaustive() {
quote!()
} else {
quote! {
#name::#unknown(value) => {
map.serialize_entry(&"type", &value.type_)?;
map.serialize_entry(&value.type_, &value.value)?;
}
}
};
let variant_strs = &def
.union_()
.iter()
.map(|f| &f.field_name().0)
.collect::<Vec<_>>();
let variant_strs2 = variant_strs;
let name_repeat = iter::repeat(&name);
quote! {
#(#type_attrs)*
pub enum #name {
#(
#docs
#deprecated
#variants(#types),
)*
#unknown_variant
}
impl ser::Serialize for #name {
fn serialize<S>(&self, s: S) -> #result<S::Ok, S::Error>
where
S: ser::Serializer
{
let mut map = s.serialize_map(#some(2))?;
match self {
#(
#allow_deprecated
#name_repeat::#variants(value) => {
map.serialize_entry(&"type", &#variant_strs)?;
map.serialize_entry(&#variant_strs2, value)?;
}
)*
#serialize_unknown
}
map.end()
}
}
}
}
fn generate_deserialize(ctx: &Context, def: &UnionDefinition) -> TokenStream {
let name = ctx.type_name(def.type_name().name());
let result = ctx.result_ident(def.type_name());
let expecting = format!("union {}", name);
let some = ctx.some_ident(def.type_name());
let variants = &variants(ctx, def);
let variants2 = variants;
let variants3 = variants;
let allow_deprecated = &def
.union_()
.iter()
.map(|f| ctx.allow_deprecated(f.deprecated()))
.collect::<Vec<_>>();
let name_repeat = iter::repeat(&name);
let some_repeat = iter::repeat(&some);
let unknown = unknown(ctx, def);
let err = ctx.err_ident(def.type_name());
let unknown_match1 = if ctx.exhaustive() {
quote!()
} else {
quote! {
(Variant_::#unknown(type_), #some(Variant_::#unknown(b))) => {
if type_ == b {
let value = map.next_value()?;
#name::#unknown(#unknown { type_, value })
} else {
return #err(de::Error::invalid_value(de::Unexpected::Str(&type_), &&*b))
}
}
}
};
let none = ctx.none_ident(def.type_name());
let name_repeat2 = iter::repeat(&name);
let unknown_match2 = if ctx.exhaustive() {
quote!()
} else {
quote! {
Variant_::#unknown(type_) => {
let value = map.next_value()?;
#name::#unknown(#unknown { type_: type_.clone(), value })
}
}
};
let ok = ctx.ok_ident(def.type_name());
quote! {
impl<'de> de::Deserialize<'de> for #name {
fn deserialize<D>(d: D) -> #result<#name, D::Error>
where
D: de::Deserializer<'de>
{
d.deserialize_map(Visitor_)
}
}
struct Visitor_;
impl<'de> de::Visitor<'de> for Visitor_ {
type Value = #name;
fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str(#expecting)
}
fn visit_map<A>(self, mut map: A) -> #result<#name, A::Error>
where
A: de::MapAccess<'de>
{
let v = match map.next_key::<UnionField_<Variant_>>()? {
#some(UnionField_::Type) => {
let variant = map.next_value()?;
let key = map.next_key()?;
match (variant, key) {
#(
#allow_deprecated
(Variant_::#variants, #some_repeat(Variant_::#variants2)) => {
let value = map.next_value()?;
#name_repeat::#variants3(value)
}
)*
#unknown_match1
(variant, #some(key)) => {
return #err(
de::Error::invalid_value(de::Unexpected::Str(key.as_str()), &variant.as_str()),
);
}
(variant, #none) => return #err(de::Error::missing_field(variant.as_str())),
}
}
#some(UnionField_::Value(variant)) => {
let value = match &variant {
#(
Variant_::#variants => {
let value = map.next_value()?;
#allow_deprecated
#name_repeat2::#variants2(value)
}
)*
#unknown_match2
};
if map.next_key::<UnionTypeField_>()?.is_none() {
return #err(de::Error::missing_field("type"));
}
let type_variant = map.next_value::<Variant_>()?;
if variant != type_variant {
return #err(
de::Error::invalid_value(de::Unexpected::Str(type_variant.as_str()), &variant.as_str()),
);
}
value
}
#none => return #err(de::Error::missing_field("type")),
};
if map.next_key::<UnionField_<Variant_>>()?.is_some() {
return #err(de::Error::invalid_length(3, &"type and value fields"));
}
#ok(v)
}
}
}
}
fn generate_variant(ctx: &Context, def: &UnionDefinition) -> TokenStream {
let variants = &variants(ctx, def);
let unknown = unknown(ctx, def);
let unknown_variant = if ctx.exhaustive() {
quote!()
} else {
let box_ = ctx.box_ident(def.type_name());
quote!(#unknown(#box_<str>))
};
let variant_strs = &def
.union_()
.iter()
.map(|f| &f.field_name().0)
.collect::<Vec<_>>();
let unknown_as_str = if ctx.exhaustive() {
quote!()
} else {
quote! {
Variant_::#unknown(_) => "unknown variant",
}
};
let result = ctx.result_ident(def.type_name());
let unknown_de_visit_str = if ctx.exhaustive() {
let err = ctx.err_ident(def.type_name());
quote! {
value => return #err(de::Error::unknown_variant(value, &[#(#variant_strs, )*])),
}
} else {
quote! {
value => Variant_::#unknown(value.to_string().into_boxed_str()),
}
};
let ok = ctx.ok_ident(def.type_name());
quote! {
#[derive(PartialEq)]
enum Variant_ {
#(#variants,)*
#unknown_variant
}
impl Variant_ {
fn as_str(&self) -> &'static str {
match self {
#(
Variant_::#variants => #variant_strs,
)*
#unknown_as_str
}
}
}
impl<'de> de::Deserialize<'de> for Variant_ {
fn deserialize<D>(d: D) -> #result<Variant_, D::Error>
where
D: de::Deserializer<'de>
{
d.deserialize_str(VariantVisitor_)
}
}
struct VariantVisitor_;
impl<'de> de::Visitor<'de> for VariantVisitor_ {
type Value = Variant_;
fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str("string")
}
fn visit_str<E>(self, value: &str) -> #result<Variant_, E>
where
E: de::Error,
{
let v = match value {
#(
#variant_strs => Variant_::#variants,
)*
#unknown_de_visit_str
};
#ok(v)
}
}
}
}
fn generate_unknown(ctx: &Context, def: &UnionDefinition) -> TokenStream {
if ctx.exhaustive() {
return quote!();
}
let doc = format!(
"An unknown variant of the {} union.",
ctx.type_name(def.type_name().name())
);
let unknown = unknown(ctx, def);
let box_ = ctx.box_ident(def.type_name());
quote! {
#[doc = #doc]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct #unknown {
type_: #box_<str>,
value: conjure_object::Any,
}
impl #unknown {
#[inline]
pub fn type_(&self) -> &str {
&self.type_
}
}
}
}