use serde::de::{self, Deserializer, MapAccess, Visitor};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::fmt;
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum SchemaRef {
Inline(JsonValue),
File(String),
}
impl<'de> Deserialize<'de> for SchemaRef {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct SchemaRefVisitor;
impl<'de> Visitor<'de> for SchemaRefVisitor {
type Value = SchemaRef;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a JSON Schema object or a file path string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(SchemaRef::File(v.to_string()))
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(SchemaRef::File(v))
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let value = JsonValue::deserialize(de::value::MapAccessDeserializer::new(map))?;
Ok(SchemaRef::Inline(value))
}
}
deserializer.deserialize_any(SchemaRefVisitor)
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct OutputPolicy {
#[serde(default)]
pub format: OutputFormat,
#[serde(default)]
pub schema: Option<SchemaRef>,
#[serde(default)]
pub max_retries: Option<u8>,
#[serde(skip)]
pub source_structured_spec: Option<super::structured::StructuredOutputSpec>,
}
impl OutputPolicy {
pub fn is_structured(&self) -> bool {
self.format == OutputFormat::Json && self.schema.is_some()
}
pub fn to_structured_spec(&self) -> Option<super::structured::StructuredOutputSpec> {
if !self.is_structured() {
return None;
}
if let Some(ref spec) = self.source_structured_spec {
return Some(spec.clone());
}
let schema = self.schema.clone().unwrap();
Some(super::structured::StructuredOutputSpec {
schema,
enable_extractor: None,
enable_tool_injection: None,
enable_retry: Some(true),
enable_repair: Some(true),
max_retries: self.max_retries,
repair_model: None,
})
}
}
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum OutputFormat {
#[default]
Text,
Json,
Yaml,
Markdown,
Binary,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::serde_yaml;
#[test]
fn parse_text_format() {
let yaml = "format: text";
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
assert_eq!(policy.format, OutputFormat::Text);
assert!(policy.schema.is_none());
}
#[test]
fn parse_json_with_schema_file() {
let yaml = r#"
format: json
schema: .nika/schemas/result.json
"#;
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
assert_eq!(policy.format, OutputFormat::Json);
assert!(
matches!(policy.schema, Some(SchemaRef::File(ref p)) if p == ".nika/schemas/result.json")
);
}
#[test]
fn parse_json_with_inline_schema() {
let yaml = r#"
format: json
schema:
type: object
properties:
name:
type: string
required:
- name
"#;
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
assert_eq!(policy.format, OutputFormat::Json);
assert!(matches!(policy.schema, Some(SchemaRef::Inline(_))));
if let Some(SchemaRef::Inline(schema)) = &policy.schema {
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["name"].is_object());
}
}
#[test]
fn parse_max_retries() {
let yaml = r#"
format: json
max_retries: 3
"#;
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
assert_eq!(policy.max_retries, Some(3));
}
#[test]
fn default_is_text() {
let policy = OutputPolicy::default();
assert_eq!(policy.format, OutputFormat::Text);
assert!(policy.schema.is_none());
assert!(policy.max_retries.is_none());
}
#[test]
fn is_structured_true_when_json_with_schema() {
let yaml = r#"
format: json
schema:
type: object
"#;
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
assert!(policy.is_structured());
}
#[test]
fn is_structured_false_when_text() {
let policy = OutputPolicy::default();
assert!(!policy.is_structured());
}
#[test]
fn is_structured_false_when_json_without_schema() {
let yaml = "format: json";
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
assert!(!policy.is_structured());
}
#[test]
fn is_structured_false_when_text_with_schema() {
let yaml = r#"
format: text
schema:
type: object
"#;
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
assert!(!policy.is_structured());
}
#[test]
fn to_structured_spec_returns_spec_when_structured() {
let yaml = r#"
format: json
schema:
type: object
properties:
name:
type: string
max_retries: 5
"#;
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
let spec = policy.to_structured_spec();
assert!(spec.is_some());
let spec = spec.unwrap();
assert!(matches!(spec.schema, SchemaRef::Inline(_)));
assert_eq!(spec.max_retries, Some(5));
assert_eq!(spec.enable_retry, Some(true));
assert_eq!(spec.enable_repair, Some(true));
}
#[test]
fn to_structured_spec_returns_none_when_not_structured() {
let policy = OutputPolicy::default();
assert!(policy.to_structured_spec().is_none());
}
#[test]
fn to_structured_spec_with_file_schema() {
let yaml = r#"
format: json
schema: ./schemas/user.json
"#;
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
let spec = policy.to_structured_spec();
assert!(spec.is_some());
let spec = spec.unwrap();
assert!(matches!(spec.schema, SchemaRef::File(ref p) if p == "./schemas/user.json"));
}
}