use std::collections::HashSet;
use openapiv3::ReferenceOr;
use indexmap::IndexMap;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SchemaConversionError {
#[error("failed to serialize schema '{name}': {source}")]
Serialization {
name: String,
#[source]
source: serde_json::Error,
},
#[error("failed to deserialize schema '{name}': {source}")]
Deserialization {
name: String,
#[source]
source: serde_json::Error,
},
}
#[derive(Default)]
pub struct ComponentRegistry {
components: IndexMap<String, openapiv3::Schema>,
generator: schemars::SchemaGenerator,
security_schemes: HashSet<String>,
operation_scopes: HashSet<String>,
pub (crate) operation_scope_join_all: bool,
operation_security: HashSet<String>,
pub (crate) tags: HashSet<String>,
}
impl std::fmt::Debug for ComponentRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ComponentRegistry")
.field("components", &self.components)
.field("security_schemes", &self.security_schemes)
.field("operation_scopes", &self.operation_scopes)
.field("operation_scope_join_all", &self.operation_scope_join_all)
.field("operation_security", &self.operation_security)
.field("tags", &self.tags)
.finish_non_exhaustive()
}
}
impl ComponentRegistry {
pub fn new() -> Self {
Self {
components: IndexMap::new(),
generator: schemars::SchemaGenerator::default(),
security_schemes: HashSet::new(),
operation_scopes: HashSet::new(),
operation_security: HashSet::new(),
operation_scope_join_all: false,
tags: HashSet::new(),
}
}
pub fn register(&mut self, name: String, schema: openapiv3::Schema) -> String {
let ref_path = format!("#/components/schemas/{}", name);
self.components.insert(name, schema);
ref_path
}
pub fn register_security(&mut self, name: String, scopes: &[String], join_all: bool) {
self.security_schemes.insert(name.clone());
self.operation_security.insert(name);
self.operation_scope_join_all = join_all;
self.operation_scopes.extend(scopes.iter().cloned());
}
pub fn has_security_schemes(&self) -> bool {
!self.security_schemes.is_empty()
}
pub fn get_security_scheme_names(&self) -> Vec<String> {
self.security_schemes.iter().cloned().collect()
}
pub fn drain_operation_scopes(&mut self) -> impl Iterator<Item=String> + '_ {
self.operation_scopes.drain()
}
pub fn drain_operation_security(&mut self) -> impl Iterator<Item=String> + '_ {
self.operation_security.drain()
}
pub fn has_operation_security(&self) -> bool {
!self.operation_security.is_empty()
}
pub fn contains(&self, name: &str) -> bool {
self.components.contains_key(name)
}
pub fn into_components(self) -> IndexMap<String, openapiv3::ReferenceOr<openapiv3::Schema>> {
self.components
.into_iter()
.map(|(k, v)| (k, openapiv3::ReferenceOr::Item(v)))
.collect()
}
pub fn into_components_schemars(mut self) -> Result<IndexMap<String, openapiv3::ReferenceOr<openapiv3::Schema>>, SchemaConversionError> {
let definitions = std::mem::take(self.generator.definitions_mut());
let mut result = IndexMap::with_capacity(definitions.len());
for (name, json_schema) in definitions {
let openapi_schema = convert_json_value_to_openapi(json_schema, &name)?;
result.insert(name, openapi_schema);
}
Ok(result)
}
pub fn generator_mut(&mut self) -> &mut schemars::SchemaGenerator {
&mut self.generator
}
}
fn convert_json_value_to_openapi(
mut json_value: serde_json::Value,
name: &str,
) -> Result<ReferenceOr<openapiv3::Schema>, SchemaConversionError> {
if let Some(ref_str) = json_value.get("$ref").and_then(|v| v.as_str()) {
let openapi_ref = ref_str
.replace("#/$defs/", "#/components/schemas/")
.replace("#/definitions/", "#/components/schemas/");
return Ok(ReferenceOr::Reference { reference: openapi_ref });
}
transform_for_openapi(&mut json_value);
let schema = serde_json::from_value::<openapiv3::Schema>(json_value)
.map_err(|e| SchemaConversionError::Deserialization {
name: name.to_string(),
source: e,
})?;
Ok(ReferenceOr::Item(schema))
}
fn transform_for_openapi(val: &mut serde_json::Value) {
if let serde_json::Value::Object(map) = val {
if let Some(type_val) = map.get("type").and_then(|v| v.as_array()).cloned() {
transform_type_array(map, &type_val);
}
if let Some(serde_json::Value::Object(props)) = map.get_mut("properties") {
for (_prop_name, prop_schema) in props.iter_mut() {
transform_for_openapi(prop_schema);
}
}
for key in ["items", "additionalProperties", "not", "$defs", "definitions"] {
if let Some(nested) = map.get_mut(key) {
transform_for_openapi(nested);
}
}
for key in ["allOf", "anyOf", "oneOf"] {
if let Some(serde_json::Value::Array(schemas)) = map.get_mut(key) {
for schema in schemas {
transform_for_openapi(schema);
}
}
}
} else if let serde_json::Value::Array(arr) = val {
for item in arr {
transform_for_openapi(item);
}
}
}
fn transform_type_array(map: &mut serde_json::Map<String, serde_json::Value>, types: &[serde_json::Value]) {
let (has_null, non_null): (Vec<_>, Vec<_>) = types.iter()
.partition(|v| v.as_str() == Some("null"));
match non_null.len() {
0 => {} 1 => {
map.insert("type".to_string(), non_null[0].clone());
if !has_null.is_empty() {
map.insert("nullable".to_string(), serde_json::Value::Bool(true));
}
}
_ => {
let any_of: Vec<_> = non_null.iter()
.map(|t| serde_json::json!({"type": t}))
.collect();
map.remove("type");
map.insert("anyOf".to_string(), serde_json::Value::Array(any_of));
if !has_null.is_empty() {
map.insert("nullable".to_string(), serde_json::Value::Bool(true));
}
}
}
}