use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::{Error, Result};
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct JsonSchemaSpec {
pub name: String,
pub schema: Value,
}
impl JsonSchemaSpec {
pub fn new(name: impl Into<String>, schema: Value) -> Result<Self> {
let name = name.into();
if name.trim().is_empty() {
return Err(Error::config("JsonSchemaSpec: name must be non-empty"));
}
if !schema.is_object() {
return Err(Error::config(
"JsonSchemaSpec: schema must be a JSON object at the top level",
));
}
Ok(Self { name, schema })
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum OutputStrategy {
#[default]
Auto,
Native,
Tool,
Prompted,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct ResponseFormat {
pub json_schema: JsonSchemaSpec,
#[serde(default = "ResponseFormat::default_strict")]
pub strict: bool,
#[serde(default)]
pub strategy: OutputStrategy,
}
impl ResponseFormat {
pub fn strict(schema: JsonSchemaSpec) -> Self {
Self {
json_schema: schema,
strict: true,
strategy: OutputStrategy::Auto,
}
}
pub fn best_effort(schema: JsonSchemaSpec) -> Self {
Self {
json_schema: schema,
strict: false,
strategy: OutputStrategy::Auto,
}
}
#[must_use]
pub const fn with_strategy(mut self, strategy: OutputStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn strict_preflight(&self) -> std::result::Result<(), StrictSchemaError> {
if !self.strict {
return Ok(());
}
check_strict(&self.json_schema.schema, "$")
}
const fn default_strict() -> bool {
true
}
}
#[derive(Debug, Clone, Eq, PartialEq, thiserror::Error)]
#[non_exhaustive]
pub enum StrictSchemaError {
#[error("strict-mode schema requires `additionalProperties: false` at {path}")]
AdditionalPropertiesNotFalse {
path: String,
},
#[error("strict-mode schema at {path} declares properties not in `required`: {}", .missing.join(", "))]
RequiredMissingProperties {
path: String,
missing: Vec<String>,
},
}
fn check_strict(schema: &Value, path: &str) -> std::result::Result<(), StrictSchemaError> {
let Some(obj) = schema.as_object() else {
return Ok(());
};
let kind = obj.get("type").and_then(Value::as_str);
if kind == Some("object") {
match obj.get("additionalProperties") {
Some(Value::Bool(false)) => {}
_ => {
return Err(StrictSchemaError::AdditionalPropertiesNotFalse {
path: path.to_owned(),
});
}
}
if let Some(Value::Object(properties)) = obj.get("properties") {
let required: std::collections::BTreeSet<&str> = obj
.get("required")
.and_then(Value::as_array)
.map(|arr| arr.iter().filter_map(Value::as_str).collect())
.unwrap_or_default();
let missing: Vec<String> = properties
.keys()
.filter(|k| !required.contains(k.as_str()))
.cloned()
.collect();
if !missing.is_empty() {
return Err(StrictSchemaError::RequiredMissingProperties {
path: path.to_owned(),
missing,
});
}
for (name, sub) in properties {
check_strict(sub, &format!("{path}.properties.{name}"))?;
}
}
} else if kind == Some("array")
&& let Some(items) = obj.get("items")
{
check_strict(items, &format!("{path}.items"))?;
}
for keyword in ["anyOf", "allOf", "oneOf"] {
if let Some(Value::Array(arr)) = obj.get(keyword) {
for (i, sub) in arr.iter().enumerate() {
check_strict(sub, &format!("{path}.{keyword}[{i}]"))?;
}
}
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use serde_json::json;
use super::*;
#[test]
fn new_rejects_empty_name() {
let err = JsonSchemaSpec::new("", json!({"type": "object"})).unwrap_err();
assert!(format!("{err}").contains("name must be non-empty"));
}
#[test]
fn new_rejects_whitespace_only_name() {
let err = JsonSchemaSpec::new(" ", json!({"type": "object"})).unwrap_err();
assert!(format!("{err}").contains("name must be non-empty"));
}
#[test]
fn new_rejects_non_object_schema() {
let err = JsonSchemaSpec::new("user", json!("not an object")).unwrap_err();
assert!(format!("{err}").contains("must be a JSON object"));
let err2 = JsonSchemaSpec::new("user", json!([1, 2, 3])).unwrap_err();
assert!(format!("{err2}").contains("must be a JSON object"));
}
#[test]
fn new_accepts_valid_object_schema() {
let spec = JsonSchemaSpec::new(
"user",
json!({
"type": "object",
"properties": {"name": {"type": "string"}},
"required": ["name"],
}),
)
.unwrap();
assert_eq!(spec.name, "user");
assert!(spec.schema.is_object());
}
#[test]
fn strict_constructor_sets_strict_flag() {
let spec = JsonSchemaSpec::new("user", json!({"type": "object"})).unwrap();
let format = ResponseFormat::strict(spec);
assert!(format.strict);
}
#[test]
fn best_effort_constructor_clears_strict_flag() {
let spec = JsonSchemaSpec::new("user", json!({"type": "object"})).unwrap();
let format = ResponseFormat::best_effort(spec);
assert!(!format.strict);
}
#[test]
fn round_trips_via_serde() {
let spec = JsonSchemaSpec::new("user", json!({"type": "object"})).unwrap();
let format = ResponseFormat::strict(spec);
let json = serde_json::to_string(&format).unwrap();
let back: ResponseFormat = serde_json::from_str(&json).unwrap();
assert_eq!(format, back);
}
}