use std::{
collections::{BTreeMap, BTreeSet},
ops::{Deref, DerefMut},
str::FromStr,
};
use anyhow::{format_err, Context, Result};
use json_patch::merge as json_merge_patch;
use log::warn;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use topological_sort::TopologicalSort;
use crate::openapi::{
ref_or::RefOr,
schema::{AdditionalProperties, BasicSchema, PrimitiveSchema, Type},
serde_helpers::{default_as_true, deserialize_enum_helper},
};
use super::{
ref_or::{split_interface_ref, InterfaceRef},
schema::{Discriminator, Nullable, OneOf, Schema},
Scope, Transpile,
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum InterfaceVariant {
Get,
Post,
Put,
MergePatch,
}
const INTERFACE_VARIANTS: &[InterfaceVariant] = &[
InterfaceVariant::Get,
InterfaceVariant::Post,
InterfaceVariant::Put,
InterfaceVariant::MergePatch,
];
impl InterfaceVariant {
pub fn to_fragment_str(self) -> &'static str {
match self {
InterfaceVariant::Get => "",
InterfaceVariant::Post => "#Post",
InterfaceVariant::Put => "#Put",
InterfaceVariant::MergePatch => "#MergePatch",
}
}
pub fn to_schema_suffix_str(self) -> &'static str {
let s = self.to_fragment_str();
if s.is_empty() {
s
} else {
assert!(s.starts_with('#'));
&s[1..]
}
}
}
impl FromStr for InterfaceVariant {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"" => Ok(InterfaceVariant::Get),
"#Post" => Ok(InterfaceVariant::Post),
"#Put" => Ok(InterfaceVariant::Put),
"#MergePatch" => Ok(InterfaceVariant::MergePatch),
_ => Err(format_err!("unknown interface variety: {:?}", s)),
}
}
}
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
#[serde(transparent)]
pub struct Interfaces(BTreeMap<String, Interface>);
impl Interfaces {
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
fn expand_includes_interfaces(
&self,
) -> Result<BTreeMap<&str, Box<dyn TranspileInterface>>> {
let mut sort = TopologicalSort::<&str>::new();
for (name, interface) in &self.0 {
if let Interface::Includes(inclusion) = interface {
if !self.0.contains_key(&inclusion.base) {
return Err(format_err!(
"interface {:?} includes {:?}, but that interface isn't defined",
name,
inclusion.base,
));
}
sort.add_dependency(inclusion.base.as_str(), name.as_str());
} else {
sort.insert(name.as_str());
}
}
let mut expanded: BTreeMap<&str, BasicInterface> = BTreeMap::new();
let mut interfaces: BTreeMap<&str, Box<dyn TranspileInterface>> =
BTreeMap::new();
for name in sort {
let interface = self
.0
.get(name)
.expect("interface should always be in hash table");
match interface {
Interface::Includes(inclusion) => {
let mut doc =
serde_json::to_value(expanded.get(inclusion.base.as_str()))?;
let patch = Value::Object(inclusion.merge_patch.clone());
json_merge_patch(&mut doc, &patch);
let mut reparsed = serde_json::from_value::<BasicInterface>(doc)
.with_context(|| {
format!("error parsing merged {:?}", name)
})?;
reparsed.emit = inclusion.emit; expanded.insert(name, reparsed.clone());
interfaces.insert(name, Box::new(reparsed));
}
Interface::Basic(base) => {
expanded.insert(name, base.clone());
interfaces.insert(name, Box::new(base.clone()));
}
Interface::OneOf(one_of) => {
interfaces.insert(name, Box::new(one_of.clone()));
}
}
}
Ok(interfaces)
}
}
impl Deref for Interfaces {
type Target = BTreeMap<String, Interface>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Interfaces {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Transpile for Interfaces {
type Output = BTreeMap<String, Schema>;
fn transpile(&self, scope: &Scope) -> anyhow::Result<Self::Output> {
let interfaces = self.expand_includes_interfaces()?;
let mut interface_discriminators = BTreeMap::default();
for (&name, interface) in &interfaces {
if let Some(discriminator) = interface.discriminator_info()? {
interface_discriminators.insert(name.to_owned(), discriminator);
}
}
let mut schemas = BTreeMap::new();
for (name, interface) in interfaces {
if !interface.should_emit() {
continue;
}
for variant in INTERFACE_VARIANTS.iter().cloned() {
let schema_name = interface.schema_variant_name(name, variant);
let schema = interface.generate_schema_variant(
scope,
&interface_discriminators,
name,
variant,
)?;
if schema.matches_only_empty_object() {
warn!(
"output schema {} would match only empty objects, skipping",
schema_name
);
continue;
}
if schemas.insert(schema_name.clone(), schema).is_some() {
return Err(format_err!(
"generated multiple schemas named {:?}",
&schema_name
));
}
}
}
Ok(schemas)
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
#[serde(untagged)]
#[allow(clippy::large_enum_variant)]
pub enum Interface {
Includes(IncludesInterface),
Basic(BasicInterface),
OneOf(OneOfInterface),
}
impl<'de> Deserialize<'de> for Interface {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde_yaml::{Mapping, Value};
let yaml = Mapping::deserialize(deserializer)?;
let includes_key = Value::String(String::from("$includes"));
let oneof_key = Value::String(String::from("oneOf"));
if yaml.contains_key(&includes_key) {
Ok(Interface::Includes(deserialize_enum_helper::<D, _>(
"`$includes` interface",
Value::Mapping(yaml),
)?))
} else if yaml.contains_key(&oneof_key) {
Ok(Interface::OneOf(deserialize_enum_helper::<D, _>(
"oneOf interface",
Value::Mapping(yaml),
)?))
} else {
Ok(Interface::Basic(deserialize_enum_helper::<D, _>(
"interface",
Value::Mapping(yaml),
)?))
}
}
}
struct DiscriminatorInfo {
member_name: String,
value: String,
}
trait TranspileInterface {
fn should_emit(&self) -> bool {
true
}
fn discriminator_info(&self) -> Result<Option<DiscriminatorInfo>> {
Ok(None)
}
fn schema_variant_name(&self, name: &str, variant: InterfaceVariant) -> String {
format!("{}{}", name, variant.to_schema_suffix_str())
}
fn generate_schema_variant(
&self,
scope: &Scope,
interface_discriminators: &BTreeMap<String, DiscriminatorInfo>,
name: &str,
variant: InterfaceVariant,
) -> Result<Schema>;
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct IncludesInterface {
#[serde(rename = "$includes")]
base: String,
#[serde(default = "default_as_true")]
emit: bool,
#[serde(flatten)]
merge_patch: Map<String, Value>,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
pub struct BasicInterface {
#[serde(default = "default_as_true")]
emit: bool,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
members: BTreeMap<String, Member>,
#[serde(default, skip_serializing_if = "Option::is_none")]
additional_members: Option<Member>,
#[serde(default, skip_serializing_if = "Option::is_none")]
discriminator_member_name: Option<String>,
#[serde(default)]
description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
title: Option<String>,
#[serde(default)]
example: Option<Value>,
}
impl TranspileInterface for BasicInterface {
fn should_emit(&self) -> bool {
self.emit
}
fn discriminator_info(&self) -> Result<Option<DiscriminatorInfo>> {
if let Some(discr) = &self.discriminator_member_name {
if let Some(member) = self.members.get(discr) {
if !member.required || !member.is_initializable() || member.mutable {
return Err(format_err!(
"discriminator member {:?} must be `initializable: true`, `required: true`, `mutable: false`",
discr
));
}
if let RefOr::Value(BasicSchema::Primitive(schema)) = &member.schema {
if let Some(value) = &schema.r#const {
if let Some(value) = value.as_str() {
Ok(Some(DiscriminatorInfo {
member_name: discr.to_owned(),
value: value.to_owned(),
}))
} else {
Err(format_err!("discriminator member {:?} must have a `schema.const` containing a string, not {}", discr, value))
}
} else {
Err(format_err!("discriminator member {:?} must have a `schema.const` value", discr))
}
} else {
Err(format_err!("discriminator member {:?} must have a simple schema with `type`", discr))
}
} else {
Err(format_err!(
"discriminatorMemberName {:?} not present in `members:`",
discr
))
}
} else {
Ok(None)
}
}
fn generate_schema_variant(
&self,
scope: &Scope,
_interface_discriminators: &BTreeMap<String, DiscriminatorInfo>,
name: &str,
variant: InterfaceVariant,
) -> Result<Schema> {
let mut types = BTreeSet::new();
types.insert(Type::Object);
let mut required = vec![];
let mut properties = BTreeMap::new();
for (name, member) in &self.members {
let is_discriminator =
Some(name) == self.discriminator_member_name.as_ref();
if let Some(schema) =
member.schema_for(scope, variant, is_discriminator)?
{
properties.insert(name.to_owned(), schema);
if member.is_required_for(variant, is_discriminator) {
required.push(name.to_owned());
}
}
}
let additional_properties = match &self.additional_members {
Some(additional_members) if additional_members.required => {
return Err(format_err!(
"cannot use `required` with `additional_members` in {}",
name,
));
}
Some(additional_members) => {
if let Some(schema) =
additional_members.schema_for(scope, variant, false)?
{
AdditionalProperties::Schema(schema)
} else {
AdditionalProperties::Bool(false)
}
}
None => AdditionalProperties::Bool(false),
};
let description = self.description.as_ref().map(|desc| match variant {
InterfaceVariant::Get => desc.clone(),
InterfaceVariant::Post => format!(
"(Parameters used to POST a new value of the `{}` type.)\n\n{}",
name, desc
),
InterfaceVariant::Put => format!(
"(Parameters used to PUT a value of the `{}` type.)\n\n{}",
name, desc
),
InterfaceVariant::MergePatch => format!(
"(Parameters used to PATCH the `{}` type.)\n\n{}",
name, desc
),
});
let title = self.title.clone();
let example = if variant == InterfaceVariant::Post {
self.example.clone()
} else {
None
};
let schema = PrimitiveSchema {
types,
required,
properties,
additional_properties,
items: None,
nullable: None,
description,
title,
r#const: None,
example,
unknown_fields: BTreeMap::default(),
};
Ok(RefOr::Value(BasicSchema::Primitive(Box::new(schema))))
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
pub struct Member {
#[serde(default)]
required: bool,
#[serde(default)]
mutable: bool,
#[serde(default)]
initializable: Option<bool>,
schema: Schema,
}
impl Member {
fn is_initializable(&self) -> bool {
self.initializable.unwrap_or(self.mutable)
}
fn is_required_for(
&self,
variant: InterfaceVariant,
is_discriminator: bool,
) -> bool {
match variant {
_ if is_discriminator => true,
InterfaceVariant::Get => self.required,
InterfaceVariant::Post => self.required && self.is_initializable(),
InterfaceVariant::Put => self.required && self.mutable,
InterfaceVariant::MergePatch => false,
}
}
fn schema_for(
&self,
scope: &Scope,
variant: InterfaceVariant,
is_discriminator: bool,
) -> Result<Option<Schema>> {
let scope = scope.with_variant(variant);
Ok(match variant {
_ if is_discriminator => Some(self.schema.transpile(&scope)?),
InterfaceVariant::Get => Some(self.schema.transpile(&scope)?),
InterfaceVariant::Post if self.is_initializable() => {
Some(self.schema.transpile(&scope)?)
}
InterfaceVariant::Post => None,
InterfaceVariant::Put if self.mutable => {
Some(self.schema.transpile(&scope)?)
}
InterfaceVariant::Put => None,
InterfaceVariant::MergePatch if self.mutable => {
let schema = self.schema.transpile(&scope)?;
if self.required {
Some(schema)
} else {
Some(
schema.new_schema_matching_current_or_null_for_merge_patch(
&scope,
),
)
}
}
InterfaceVariant::MergePatch => None,
})
}
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
pub struct OneOfInterface {
#[serde(default)]
description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
title: Option<String>,
one_of: Vec<InterfaceRef>,
}
impl TranspileInterface for OneOfInterface {
fn generate_schema_variant(
&self,
scope: &Scope,
interface_discriminators: &BTreeMap<String, DiscriminatorInfo>,
name: &str,
variant: InterfaceVariant,
) -> Result<Schema> {
let scope = scope.with_variant(variant);
let schemas = self
.one_of
.iter()
.map(|interface_ref| Ok(RefOr::Ref(interface_ref.transpile(&scope)?)))
.collect::<Result<Vec<_>>>()?;
let mut discriminator_member_names = BTreeSet::default();
let mut discriminator_values = BTreeSet::default();
let mut discriminator_value_to_interface_map = BTreeMap::default();
for interface_ref in &self.one_of {
let (base, _fragment) = split_interface_ref(&interface_ref.target);
let discr_info = interface_discriminators.get(base).ok_or_else(|| {
format_err!(
"interface {:?} referred to by {:?} does not exist, or does not have a discriminatorMember",
base,
name
)
})?;
discriminator_member_names.insert(discr_info.member_name.clone());
if !discriminator_values.insert(discr_info.value.clone()) {
return Err(format_err!(
"discriminator value {}.{} = {:?} is already used by another type in {}",
base, discr_info.member_name, discr_info.value, name
));
}
if let Some(existing_type) = discriminator_value_to_interface_map
.insert(discr_info.value.clone(), base.to_owned())
{
return Err(format_err!(
"interface {iface} includes conflicting discriminator values {current}.{member} = {value:?} and {existing}.{member} = {value:?}",
iface = name,
existing = existing_type,
current = base,
member = discr_info.member_name,
value = discr_info.value
));
}
}
if discriminator_member_names.is_empty() {
return Err(format_err!("interface {} includes no types", name));
} else if discriminator_member_names.len() > 1 {
return Err(format_err!(
"interface {} includes interfaces with multiple, conflicting discriminator names: {:?}",
name, discriminator_member_names,
));
}
let property_name = discriminator_member_names
.into_iter()
.next()
.expect("should always have a value");
let mut mapping = BTreeMap::default();
for (value, iface) in discriminator_value_to_interface_map {
mapping.insert(
value.to_owned(),
format!(
"#/components/schemas/{}",
self.schema_variant_name(&iface, variant)
),
);
}
let discriminator = Discriminator {
property_name,
mapping,
unknown_fields: Default::default(),
};
Ok(Schema::Value(BasicSchema::OneOf(OneOf {
r#type: Some(Type::Object),
schemas,
description: self.description.clone(),
title: self.title.clone(),
discriminator: Some(discriminator),
nullable: None,
unknown_fields: Default::default(),
})))
}
}
#[test]
fn parses_one_of_example() {
use crate::openapi::OpenApi;
use pretty_assertions::assert_eq;
use std::path::Path;
let path = Path::new("./examples/oneof_example.yml").to_owned();
let parsed = OpenApi::from_path(&path).unwrap();
let transpiled = parsed.transpile(&Scope::default()).unwrap();
println!("{}", serde_yaml::to_string(&transpiled).unwrap());
let expected =
OpenApi::from_path(Path::new("./examples/oneof_example_output.yml")).unwrap();
assert_eq!(transpiled, expected);
}