use heck::ToPascalCase;
use openapiv3::{Discriminator, ReferenceOr, Schema, SchemaKind, Type};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use super::schemas::{doc_attr, object_field_tokens};
use super::types::{ref_to_ident, schema_to_rust_type_ctx};
#[must_use]
pub fn generate_one_of(
name: &str,
variants: &[ReferenceOr<Schema>],
discriminator: Option<&Discriminator>,
description: Option<&String>,
inline_types: &mut Vec<TokenStream>,
) -> TokenStream {
generate_enum(name, variants, discriminator, description, inline_types)
}
#[must_use]
pub fn generate_any_of(
name: &str,
variants: &[ReferenceOr<Schema>],
description: Option<&String>,
inline_types: &mut Vec<TokenStream>,
) -> TokenStream {
generate_enum(name, variants, None, description, inline_types)
}
#[must_use]
pub fn generate_all_of(
name: &str,
variants: &[ReferenceOr<Schema>],
description: Option<&String>,
inline_types: &mut Vec<TokenStream>,
) -> TokenStream {
let ident = format_ident!("{}", name.to_pascal_case());
let doc = doc_attr(&description.cloned());
let mut fields: Vec<TokenStream> = Vec::new();
let mut ref_field_counter = 0usize;
for branch in variants {
match branch {
ReferenceOr::Reference { reference } => {
ref_field_counter += 1;
let field_ident = format_ident!("inner_{}", ref_field_counter);
let ty = ref_to_ident(reference);
fields.push(quote! {
#[serde(flatten)]
pub #field_ident: #ty,
});
}
ReferenceOr::Item(schema) => {
if let SchemaKind::Type(Type::Object(obj)) = &schema.schema_kind {
for (prop_name, prop_ref) in &obj.properties {
let is_required = obj.required.iter().any(|r| r == prop_name);
fields.push(object_field_tokens(
prop_name,
&prop_ref.clone().unbox(),
is_required,
name,
inline_types,
));
}
} else {
ref_field_counter += 1;
let field_ident = format_ident!("inner_{}", ref_field_counter);
let parent = format!("{name}Inner{ref_field_counter}");
let ty = schema_to_rust_type_ctx(
&ReferenceOr::Item(schema.clone()),
true,
Some(&parent),
inline_types,
);
fields.push(quote! {
#[serde(flatten)]
pub #field_ident: #ty,
});
}
}
}
}
quote! {
#doc
#[derive(
::core::fmt::Debug,
::core::clone::Clone,
::serde::Serialize,
::serde::Deserialize,
)]
pub struct #ident {
#(#fields)*
}
}
}
fn generate_enum(
name: &str,
variants: &[ReferenceOr<Schema>],
discriminator: Option<&Discriminator>,
description: Option<&String>,
inline_types: &mut Vec<TokenStream>,
) -> TokenStream {
let ident = format_ident!("{}", name.to_pascal_case());
let doc = doc_attr(&description.cloned());
let serde_attr = discriminator.map_or_else(
|| quote! { #[serde(untagged)] },
|d| {
let tag = &d.property_name;
quote! { #[serde(tag = #tag)] }
},
);
let variant_tokens: Vec<TokenStream> = variants
.iter()
.enumerate()
.map(|(idx, branch)| build_enum_variant(name, idx, branch, discriminator, inline_types))
.collect();
quote! {
#doc
#[derive(
::core::fmt::Debug,
::core::clone::Clone,
::serde::Serialize,
::serde::Deserialize,
)]
#serde_attr
pub enum #ident {
#(#variant_tokens,)*
}
}
}
fn build_enum_variant(
parent: &str,
idx: usize,
branch: &ReferenceOr<Schema>,
discriminator: Option<&Discriminator>,
inline_types: &mut Vec<TokenStream>,
) -> TokenStream {
match branch {
ReferenceOr::Reference { reference } => {
let target_name = reference.rsplit('/').next().unwrap_or(reference);
let variant_ident = format_ident!("{}", target_name.to_pascal_case());
let ty = ref_to_ident(reference);
let rename_attr = discriminator
.and_then(|d| discriminator_key_for_ref(d, reference))
.map_or_else(|| quote! {}, |k| quote! { #[serde(rename = #k)] });
quote! {
#rename_attr
#variant_ident(#ty)
}
}
ReferenceOr::Item(schema) => {
let variant_ident = format_ident!("Variant{}", idx + 1);
let parent_for_synth = format!("{parent}Variant{}", idx + 1);
let ty = schema_to_rust_type_ctx(
&ReferenceOr::Item(schema.clone()),
true,
Some(&parent_for_synth),
inline_types,
);
quote! {
#variant_ident(#ty)
}
}
}
}
fn discriminator_key_for_ref(d: &Discriminator, reference: &str) -> Option<String> {
let bare = reference.rsplit('/').next().unwrap_or(reference);
d.mapping
.iter()
.find(|(_, v)| *v == reference || *v == bare)
.map(|(k, _)| k.clone())
}