use std::{
collections::{BTreeMap, HashMap, HashSet},
str::FromStr,
};
use crate::{
attribute::{self, Attribute},
datamodel::DataModel,
json::schema::{AnyOfItemType, DataType, DataTypeItemType, Item, ReferenceItemType},
markdown::frontmatter::FrontMatter,
object::{Enumeration, Object},
option::AttrOption,
validation::BASIC_TYPES,
};
use super::schema::{self, PrimitiveType};
const SCHEMA: &str = "https://json-schema.org/draft/2020-12/schema";
pub fn to_json_schema(
model: &DataModel,
root: &str,
openai: bool,
) -> Result<schema::SchemaObject, String> {
let root_object = retrieve_object(model, root)?;
let mut schema_object = schema::SchemaObject::try_from(root_object)?;
let mut used_types = HashSet::new();
let mut used_enums = HashSet::new();
collect_definitions(root_object, model, &mut used_types, &mut used_enums)?;
let definitions = collect_definitions_from_model(model, &used_types, &used_enums)?;
schema_object.schema = Some(SCHEMA.to_string());
schema_object.definitions = definitions;
if let Some(config) = model.config.clone() {
post_process_schema(&mut schema_object, &config, openai, &used_enums)?;
}
Ok(schema_object)
}
fn retrieve_object<'a>(model: &'a DataModel, name: &'a str) -> Result<&'a Object, String> {
model
.objects
.iter()
.find(|obj| obj.name == name)
.ok_or(format!("Object {name} not found"))
}
fn retrieve_enum<'a>(model: &'a DataModel, name: &'a str) -> Result<&'a Enumeration, String> {
model
.enums
.iter()
.find(|e| e.name == name)
.ok_or(format!("Enum {name} not found"))
}
fn collect_definitions_from_model(
model: &DataModel,
used_types: &HashSet<String>,
used_enums: &HashSet<String>,
) -> Result<BTreeMap<String, schema::SchemaType>, String> {
let mut definitions = BTreeMap::new();
for obj_name in used_types {
let obj = retrieve_object(model, obj_name)?;
definitions.insert(obj_name.clone(), schema::SchemaType::try_from(obj)?);
}
for enum_name in used_enums {
let enum_object = retrieve_enum(model, enum_name)?;
definitions.insert(
enum_name.clone(),
schema::SchemaType::try_from(enum_object)?,
);
}
Ok(definitions)
}
fn collect_definitions(
object: &Object,
model: &DataModel,
used_types: &mut HashSet<String>,
used_enums: &mut HashSet<String>,
) -> Result<(), String> {
for attr in object.attributes.iter() {
for dtype in attr.dtypes.iter() {
if BASIC_TYPES.contains(&dtype.as_str()) || used_types.contains(dtype) {
continue;
}
let object = model.objects.iter().find(|obj| obj.name == *dtype);
let enumeration = model.enums.iter().find(|e| e.name == *dtype);
if let Some(object) = object {
used_types.insert(dtype.clone());
collect_definitions(object, model, used_types, used_enums)?;
} else if let Some(enumeration) = enumeration {
used_enums.insert(enumeration.name.clone());
} else {
return Err(format!("Object or enumeration {dtype} not found"));
}
}
}
Ok(())
}
fn resolve_prefixes(schema: &mut schema::SchemaObject, prefixes: &HashMap<String, String>) {
for (_, property) in schema.properties.iter_mut() {
if let Some(reference) = property.term.clone() {
let (prefix, term) = reference.split_once(":").unwrap_or(("", ""));
if let Some(prefix) = prefixes.get(prefix) {
property.term = Some(format!("{prefix}{term}"));
}
}
}
}
fn post_process_schema(
schema_object: &mut schema::SchemaObject,
config: &FrontMatter,
openai: bool,
used_enums: &HashSet<String>,
) -> Result<(), String> {
schema_object.id = Some(config.repo.clone());
post_process_object(schema_object, config, openai, used_enums)?;
for (_, definition) in schema_object.definitions.iter_mut() {
if let schema::SchemaType::Object(definition) = definition {
post_process_object(definition, config, openai, used_enums)?;
}
}
Ok(())
}
fn post_process_object(
object: &mut schema::SchemaObject,
config: &FrontMatter,
openai: bool,
used_enums: &HashSet<String>,
) -> Result<(), String> {
if let Some(prefixes) = &config.prefixes {
resolve_prefixes(object, prefixes);
}
if openai {
object.schema = None;
object.id = None;
remove_options(object);
set_required_and_nullable(object);
}
for (_, property) in object.properties.iter_mut() {
if let Some(reference) = &property.reference {
if used_enums.contains(
reference
.split("/")
.last()
.ok_or(format!("Failed to split reference: {reference}"))?,
) {
if openai {
property.dtype = None;
} else {
property.dtype = Some(schema::DataType::String);
}
}
}
}
Ok(())
}
fn remove_options(schema: &mut schema::SchemaObject) {
for (_, property) in schema.properties.iter_mut() {
property.options = HashMap::new();
}
}
fn set_required_and_nullable(schema: &mut schema::SchemaObject) {
let mut new_required = Vec::new();
for (name, property) in &mut schema.properties {
clean_reference_property(property);
convert_one_of_to_any_of(property);
if !schema.required.contains(name) {
new_required.push(name.clone());
make_property_nullable(property);
}
}
finalize_schema_requirements(schema, new_required);
}
fn clean_reference_property(property: &mut schema::Property) {
if property.reference.is_some() {
property.description = None;
property.title = None;
property.dtype = None;
}
}
fn convert_one_of_to_any_of(property: &mut schema::Property) {
if let Some(Item::OneOfItem(one_of)) = &mut property.items {
property.items = Some(Item::AnyOfItem(AnyOfItemType {
any_of: one_of.one_of.clone(),
}));
}
}
fn make_property_nullable(property: &mut schema::Property) {
let mut any_of = vec![Item::DataTypeItem(DataTypeItemType {
dtype: DataType::Null,
})];
handle_property_data_type(property, &mut any_of);
handle_property_reference(property, &mut any_of);
handle_property_one_of(property, &mut any_of);
if !matches!(property.dtype, Some(DataType::Array)) {
property.any_of = Some(any_of);
}
}
fn handle_property_data_type(property: &mut schema::Property, any_of: &mut Vec<Item>) {
if let Some(dtype) = &property.dtype {
let is_array = matches!(dtype, DataType::Array);
match dtype {
DataType::Array => {
any_of.push(Item::DataTypeItem(DataTypeItemType {
dtype: DataType::Null,
}));
}
DataType::Object => {
property.dtype = None;
}
DataType::Multiple(data_types) => {
add_multiple_data_types(any_of, data_types);
}
_ => {
any_of.push(Item::DataTypeItem(DataTypeItemType {
dtype: dtype.clone(),
}));
}
}
if !is_array {
property.dtype = None;
}
}
}
fn add_multiple_data_types(any_of: &mut Vec<Item>, data_types: &[DataType]) {
for dtype in data_types.iter() {
if dtype.is_not_object() || dtype.is_array() {
any_of.push(Item::DataTypeItem(DataTypeItemType {
dtype: dtype.clone(),
}));
}
}
}
fn handle_property_reference(property: &mut schema::Property, any_of: &mut Vec<Item>) {
if let Some(reference) = &property.reference {
any_of.push(Item::ReferenceItem(ReferenceItemType {
reference: reference.clone(),
}));
property.reference = None;
property.dtype = None;
property.title = None;
property.description = None;
}
}
fn handle_property_one_of(property: &mut schema::Property, any_of: &mut Vec<Item>) {
if let Some(one_of) = &property.one_of {
any_of.extend(one_of.clone());
property.one_of = None;
}
}
fn finalize_schema_requirements(schema: &mut schema::SchemaObject, new_required: Vec<String>) {
schema.additional_properties = false;
schema.required.extend(new_required);
schema.required.sort();
}
impl TryFrom<&Enumeration> for schema::SchemaType {
type Error = String;
fn try_from(enumeration: &Enumeration) -> Result<Self, Self::Error> {
Ok(schema::SchemaType::Enum(schema::EnumObject::try_from(
enumeration,
)?))
}
}
impl TryFrom<&Object> for schema::SchemaType {
type Error = String;
fn try_from(obj: &Object) -> Result<Self, Self::Error> {
Ok(schema::SchemaType::Object(schema::SchemaObject::try_from(
obj,
)?))
}
}
impl TryFrom<&Object> for schema::SchemaObject {
type Error = String;
fn try_from(obj: &Object) -> Result<Self, Self::Error> {
let properties: Result<BTreeMap<String, schema::Property>, String> = obj
.attributes
.iter()
.map(|attr| -> Result<(String, schema::Property), String> {
Ok((attr.name.clone(), schema::Property::try_from(attr)?))
})
.collect();
let required: Vec<String> = obj
.attributes
.iter()
.filter(|attr| attr.required)
.map(|attr| attr.name.clone())
.collect();
Ok(schema::SchemaObject {
title: obj.name.clone(),
dtype: Some(schema::DataType::Object),
description: Some(obj.docstring.clone()),
properties: properties?,
definitions: BTreeMap::new(),
required,
schema: None,
id: None,
additional_properties: false,
})
}
}
impl TryFrom<&Enumeration> for schema::EnumObject {
type Error = String;
fn try_from(enumeration: &Enumeration) -> Result<Self, Self::Error> {
let values = enumeration
.mappings
.values()
.cloned()
.collect::<Vec<String>>();
Ok(schema::EnumObject {
title: enumeration.name.clone(),
dtype: Some(schema::DataType::String),
description: Some(enumeration.docstring.clone()),
enum_values: values,
})
}
}
impl TryFrom<&Attribute> for schema::Property {
type Error = String;
fn try_from(attr: &Attribute) -> Result<Self, Self::Error> {
let mut dtype = (!attr.is_enum)
.then(|| schema::DataType::try_from(attr))
.transpose()?;
let options: HashMap<String, PrimitiveType> = attr
.options
.iter()
.map(|o| -> Result<(String, PrimitiveType), String> {
Ok((o.key().to_string(), o.try_into()?))
})
.collect::<Result<HashMap<String, PrimitiveType>, String>>()?;
let reference: Option<String> = if (attr.is_enum
|| matches!(dtype, Some(schema::DataType::Object)))
&& attr.dtypes.len() == 1
{
Some(format!("#/$defs/{}", attr.dtypes[0]))
} else {
None
};
let items: Option<schema::Item> = attr.into();
let one_of = (!attr.is_array).then(|| attr.into());
let description = (!attr.docstring.is_empty()).then(|| attr.docstring.clone());
let enum_values = if attr.is_enum { Some(Vec::new()) } else { None };
if attr.dtypes.len() > 1 && !attr.is_array {
dtype = None;
}
let default: Option<PrimitiveType> = if let Some(default) = attr.default.clone() {
process_default(default, &dtype)
} else {
None
};
Ok(schema::Property {
title: Some(attr.name.clone()),
dtype,
default,
description,
term: attr.term.clone(),
reference,
options,
one_of,
items,
enum_values,
any_of: None,
all_of: None,
examples: Vec::new(),
})
}
}
fn process_default(
default: attribute::DataType,
dtype: &Option<schema::DataType>,
) -> Option<PrimitiveType> {
if matches!(dtype, Some(schema::DataType::String)) {
default
.as_string()
.map(|d| PrimitiveType::String(d.trim_matches('"').to_string()))
} else {
Some(default.into())
}
}
impl TryFrom<&Attribute> for schema::DataType {
type Error = String;
fn try_from(attr: &Attribute) -> Result<Self, Self::Error> {
if attr.is_array {
return Ok(schema::DataType::Array);
}
schema::DataType::try_from(
attr.dtypes
.first()
.ok_or(format!("No data types found for attribute: {}", attr.name))?,
)
}
}
impl From<&Attribute> for Option<schema::Item> {
fn from(attr: &Attribute) -> Self {
if !attr.is_array {
return None;
}
let one_of: Vec<schema::Item> = attr.into();
if one_of.is_empty() {
Some(process_dtype(&attr.dtypes[0]))
} else {
Some(schema::Item::OneOfItem(schema::OneOfItemType { one_of }))
}
}
}
impl From<&Attribute> for Vec<schema::Item> {
fn from(attr: &Attribute) -> Self {
if attr.dtypes.len() == 1 {
return Vec::new();
}
let mut items = Vec::new();
for dtype in attr.dtypes.iter() {
items.push(process_dtype(dtype));
}
items
}
}
fn process_dtype(dtype: &str) -> schema::Item {
match schema::DataType::from_str(dtype) {
Ok(basic_type) => {
schema::Item::DataTypeItem(schema::DataTypeItemType { dtype: basic_type })
}
Err(_) => schema::Item::ReferenceItem(schema::ReferenceItemType {
reference: format!("#/$defs/{dtype}"),
}),
}
}
impl TryFrom<&AttrOption> for PrimitiveType {
type Error = String;
fn try_from(option: &AttrOption) -> Result<Self, Self::Error> {
let value = option.value();
if let Ok(float_val) = value.parse::<f64>() {
return Ok(PrimitiveType::Number(float_val));
}
if let Ok(bool_val) = value.parse::<bool>() {
return Ok(PrimitiveType::Boolean(bool_val));
}
if let Ok(int_val) = value.parse::<i64>() {
return Ok(PrimitiveType::Integer(int_val));
}
Ok(PrimitiveType::String(value))
}
}
#[cfg(test)]
mod tests {
use serde_json::{json, Value};
use super::*;
use crate::attribute::Attribute;
#[test]
fn test_attribute_with_multiple_types() {
let attr = Attribute {
name: "test_attribute".to_string(),
is_array: false,
is_id: false,
dtypes: vec!["string".to_string(), "RefType".to_string()],
docstring: "".to_string(),
options: vec![],
term: None,
required: false,
default: None,
xml: None,
is_enum: false,
position: None,
import_prefix: None,
};
let property: schema::Property =
schema::Property::try_from(&attr).expect("Failed to convert Attribute to Property");
let serialized_property =
serde_json::to_value(&property).expect("Failed to serialize Property to JSON");
let expected_json = json!({
"title": "test_attribute",
"oneOf": [
{"type": "string"},
{"$ref": "#/$defs/RefType"},
]
});
assert_eq!(serialized_property, expected_json);
}
#[test]
fn test_array_attribute() {
let attr = Attribute {
name: "test_attribute".to_string(),
is_array: true,
is_id: false,
dtypes: vec!["string".to_string(), "RefType".to_string()],
docstring: "".to_string(),
options: vec![],
term: None,
required: false,
default: None,
xml: None,
is_enum: false,
position: None,
import_prefix: None,
};
let property: schema::Property =
schema::Property::try_from(&attr).expect("Failed to convert Attribute to Property");
let serialized_property: Value =
serde_json::to_value(&property).expect("Failed to serialize Property to JSON");
let expected_json = json!({
"title": "test_attribute",
"type": "array",
"items": {
"oneOf": [
{"type": "string"},
{"$ref": "#/$defs/RefType"}
]
}
});
assert_eq!(serialized_property, expected_json);
}
}