use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub enum OutputFormat {
#[default]
Text,
Json,
JsonSchema {
name: String,
schema: serde_json::Value,
strict: bool,
},
}
impl OutputFormat {
pub fn json() -> Self {
OutputFormat::Json
}
pub fn json_schema(name: &str, schema: serde_json::Value) -> Self {
OutputFormat::JsonSchema {
name: name.to_string(),
schema,
strict: true,
}
}
pub fn json_schema_lenient(name: &str, schema: serde_json::Value) -> Self {
OutputFormat::JsonSchema {
name: name.to_string(),
schema,
strict: false,
}
}
pub fn is_json(&self) -> bool {
matches!(self, OutputFormat::Json | OutputFormat::JsonSchema { .. })
}
pub fn is_text(&self) -> bool {
matches!(self, OutputFormat::Text)
}
pub fn to_openai_response_format(&self) -> Option<serde_json::Value> {
match self {
OutputFormat::Text => None,
OutputFormat::Json => Some(serde_json::json!({
"type": "json_object"
})),
OutputFormat::JsonSchema {
name,
schema,
strict,
} => Some(serde_json::json!({
"type": "json_schema",
"json_schema": {
"name": name,
"schema": schema,
"strict": strict,
}
})),
}
}
pub fn to_claude_system_suffix(&self) -> Option<String> {
match self {
OutputFormat::Text => None,
OutputFormat::Json => Some(
"\n\nIMPORTANT: You MUST respond with valid JSON only. No markdown, no explanation, just JSON.".to_string()
),
OutputFormat::JsonSchema { schema, .. } => {
let pretty_schema = serde_json::to_string_pretty(schema)
.unwrap_or_else(|_| schema.to_string());
Some(format!(
"\n\nIMPORTANT: You MUST respond with valid JSON matching this schema:\n```json\n{}\n```\nRespond with JSON only, no markdown fences, no explanation.",
pretty_schema
))
}
}
}
}
pub fn validate_json_response(
response: &str,
format: &OutputFormat,
) -> Result<serde_json::Value, String> {
match format {
OutputFormat::Text => Err("Not in JSON mode".to_string()),
OutputFormat::Json => {
serde_json::from_str(response).map_err(|e| format!("Invalid JSON: {}", e))
}
OutputFormat::JsonSchema { schema, .. } => {
let value: serde_json::Value =
serde_json::from_str(response).map_err(|e| format!("Invalid JSON: {}", e))?;
if let Some(required) = schema.get("required") {
if let Some(required_arr) = required.as_array() {
for req in required_arr {
if let Some(key) = req.as_str() {
if value.get(key).is_none() {
return Err(format!("Missing required key: {}", key));
}
}
}
}
}
Ok(value)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_output_format_default_is_text() {
let format = OutputFormat::default();
assert_eq!(format, OutputFormat::Text);
}
#[test]
fn test_output_format_is_json() {
assert!(OutputFormat::Json.is_json());
assert!(OutputFormat::JsonSchema {
name: "test".to_string(),
schema: json!({}),
strict: true,
}
.is_json());
}
#[test]
fn test_output_format_is_text() {
assert!(OutputFormat::Text.is_text());
assert!(!OutputFormat::Json.is_text());
assert!(!OutputFormat::JsonSchema {
name: "test".to_string(),
schema: json!({}),
strict: true,
}
.is_text());
}
#[test]
fn test_json_constructor() {
let format = OutputFormat::json();
assert_eq!(format, OutputFormat::Json);
assert!(format.is_json());
}
#[test]
fn test_json_schema_constructor() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
}
});
let format = OutputFormat::json_schema("person", schema.clone());
match format {
OutputFormat::JsonSchema {
name,
schema: s,
strict,
} => {
assert_eq!(name, "person");
assert_eq!(s, schema);
assert!(strict);
}
_ => panic!("Expected JsonSchema variant"),
}
}
#[test]
fn test_json_schema_lenient_constructor() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
}
});
let format = OutputFormat::json_schema_lenient("result", schema.clone());
match format {
OutputFormat::JsonSchema {
name,
schema: s,
strict,
} => {
assert_eq!(name, "result");
assert_eq!(s, schema);
assert!(!strict);
}
_ => panic!("Expected JsonSchema variant"),
}
}
#[test]
fn test_openai_format_text() {
let format = OutputFormat::Text;
assert!(format.to_openai_response_format().is_none());
}
#[test]
fn test_openai_format_json() {
let format = OutputFormat::Json;
let result = format.to_openai_response_format().unwrap();
assert_eq!(result, json!({"type": "json_object"}));
}
#[test]
fn test_openai_format_json_schema() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"]
});
let format = OutputFormat::json_schema("person", schema.clone());
let result = format.to_openai_response_format().unwrap();
assert_eq!(result["type"], "json_schema");
assert_eq!(result["json_schema"]["name"], "person");
assert_eq!(result["json_schema"]["schema"], schema);
assert_eq!(result["json_schema"]["strict"], true);
}
#[test]
fn test_claude_suffix_text() {
let format = OutputFormat::Text;
assert!(format.to_claude_system_suffix().is_none());
}
#[test]
fn test_claude_suffix_json() {
let format = OutputFormat::Json;
let suffix = format.to_claude_system_suffix().unwrap();
assert!(suffix.contains("valid JSON"));
assert!(suffix.contains("IMPORTANT"));
assert!(suffix.contains("No markdown"));
}
#[test]
fn test_claude_suffix_json_schema() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"]
});
let format = OutputFormat::json_schema("person", schema);
let suffix = format.to_claude_system_suffix().unwrap();
assert!(suffix.contains("valid JSON matching this schema"));
assert!(suffix.contains("\"name\""));
assert!(suffix.contains("IMPORTANT"));
}
#[test]
fn test_validate_json_response_text_mode() {
let format = OutputFormat::Text;
let result = validate_json_response(r#"{"key": "value"}"#, &format);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Not in JSON mode");
}
#[test]
fn test_validate_json_response_valid_json() {
let format = OutputFormat::Json;
let result = validate_json_response(r#"{"key": "value"}"#, &format);
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["key"], "value");
}
#[test]
fn test_validate_json_response_invalid_json() {
let format = OutputFormat::Json;
let result = validate_json_response("not valid json {", &format);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.starts_with("Invalid JSON:"));
}
#[test]
fn test_validate_json_schema_with_required_keys() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" },
"age": { "type": "integer" }
},
"required": ["name", "age"]
});
let format = OutputFormat::json_schema("person", schema);
let result = validate_json_response(r#"{"name": "Alice", "age": 30}"#, &format);
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["name"], "Alice");
assert_eq!(value["age"], 30);
}
#[test]
fn test_validate_json_schema_missing_required_key() {
let schema = json!({
"type": "object",
"properties": {
"name": { "type": "string" },
"age": { "type": "integer" }
},
"required": ["name", "age"]
});
let format = OutputFormat::json_schema("person", schema);
let result = validate_json_response(r#"{"name": "Alice"}"#, &format);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.contains("Missing required key: age"));
}
#[test]
fn test_output_format_serialize_roundtrip() {
let formats = vec![
OutputFormat::Text,
OutputFormat::Json,
OutputFormat::json_schema(
"test",
json!({
"type": "object",
"properties": {
"field": { "type": "string" }
}
}),
),
];
for original in formats {
let serialized = serde_json::to_string(&original).unwrap();
let deserialized: OutputFormat = serde_json::from_str(&serialized).unwrap();
assert_eq!(original, deserialized);
}
}
#[test]
fn test_output_format_clone() {
let original = OutputFormat::json_schema(
"test",
json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"]
}),
);
let cloned = original.clone();
assert_eq!(original, cloned);
}
}