use serde_json::Value;
#[derive(Debug, Clone, PartialEq)]
pub enum AnyOfSemantics {
Nullable,
Union,
FlexibleUnion,
}
#[derive(Debug, Clone)]
pub struct DiscriminatedUnionInfo {
pub field: String,
pub mappings: std::collections::HashMap<String, String>,
}
pub fn analyze_anyof_semantics(schema: &Value) -> Option<AnyOfSemantics> {
let variants = schema.get("anyOf")?.as_array()?;
if variants.len() == 2 {
let has_null = variants.iter().any(is_null_type);
let has_type = variants.iter().any(|v| !is_null_type(v));
if has_null && has_type {
return Some(AnyOfSemantics::Nullable);
}
}
if variants.iter().all(is_complex_type) {
return Some(AnyOfSemantics::Union);
}
Some(AnyOfSemantics::FlexibleUnion)
}
pub fn is_null_type(schema: &Value) -> bool {
schema.get("type").and_then(|t| t.as_str()) == Some("null")
}
pub fn is_complex_type(schema: &Value) -> bool {
schema.get("$ref").is_some()
|| schema.get("type").and_then(|t| t.as_str()) == Some("object")
|| schema.get("properties").is_some()
}
pub fn detect_discriminated_union(schema: &Value) -> Option<DiscriminatedUnionInfo> {
if let Some(discriminator) = schema.get("discriminator") {
return parse_explicit_discriminator(discriminator, schema);
}
if let Some(variants) = schema.get("oneOf").or_else(|| schema.get("anyOf")) {
return detect_implicit_discriminator(variants);
}
None
}
fn parse_explicit_discriminator(
discriminator: &Value,
schema: &Value,
) -> Option<DiscriminatedUnionInfo> {
let field = discriminator.get("propertyName")?.as_str()?.to_string();
let mut mappings = std::collections::HashMap::new();
if let Some(mapping) = discriminator.get("mapping") {
if let Some(mapping_obj) = mapping.as_object() {
for (key, value) in mapping_obj {
if let Some(value_str) = value.as_str() {
mappings.insert(key.clone(), value_str.to_string());
}
}
}
} else {
if let Some(variants) = schema.get("oneOf").or_else(|| schema.get("anyOf")) {
if let Some(variants_array) = variants.as_array() {
mappings = extract_variant_discriminator_mappings(variants_array, &field)?;
}
}
}
Some(DiscriminatedUnionInfo { field, mappings })
}
fn extract_variant_discriminator_mappings(
variants: &[Value],
_discriminator_field: &str,
) -> Option<std::collections::HashMap<String, String>> {
let mut mappings = std::collections::HashMap::new();
for variant in variants {
if let Some(ref_str) = variant.get("$ref").and_then(|r| r.as_str()) {
if let Some(schema_name) = extract_schema_name_from_ref(ref_str) {
mappings.insert(schema_name.clone(), format!("schema:{schema_name}"));
}
}
}
if mappings.is_empty() {
None
} else {
Some(mappings)
}
}
fn extract_schema_name_from_ref(ref_str: &str) -> Option<String> {
ref_str.split('/').next_back().map(|s| s.to_string())
}
fn detect_implicit_discriminator(variants: &Value) -> Option<DiscriminatedUnionInfo> {
let variant_refs = extract_variant_refs(variants)?;
let common_discriminator = find_common_discriminator_field(&variant_refs)?;
Some(DiscriminatedUnionInfo {
field: common_discriminator,
mappings: std::collections::HashMap::new(),
})
}
fn extract_variant_refs(variants: &Value) -> Option<Vec<String>> {
let variants_array = variants.as_array()?;
let refs: Vec<String> = variants_array
.iter()
.filter_map(|v| v.get("$ref").and_then(|r| r.as_str()))
.map(|s| s.to_string())
.collect();
if refs.len() == variants_array.len() && refs.len() > 1 {
Some(refs)
} else {
None
}
}
fn find_common_discriminator_field(_variant_refs: &[String]) -> Option<String> {
Some("type".to_string())
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_nullable_pattern_detection() {
let nullable_schema = json!({
"anyOf": [
{"$ref": "#/components/schemas/Error"},
{"type": "null"}
]
});
assert_eq!(
analyze_anyof_semantics(&nullable_schema),
Some(AnyOfSemantics::Nullable)
);
}
#[test]
fn test_union_pattern_detection() {
let union_schema = json!({
"anyOf": [
{"$ref": "#/components/schemas/TextContent"},
{"$ref": "#/components/schemas/ImageContent"}
]
});
assert_eq!(
analyze_anyof_semantics(&union_schema),
Some(AnyOfSemantics::Union)
);
}
#[test]
fn test_discriminated_union_detection() {
let schema = json!({
"oneOf": [
{"$ref": "#/components/schemas/ResponseCreatedEvent"},
{"$ref": "#/components/schemas/ResponseTextDeltaEvent"}
],
"discriminator": {
"propertyName": "type"
}
});
let result = detect_discriminated_union(&schema);
assert!(result.is_some());
let info = result.unwrap();
assert_eq!(info.field, "type");
}
}