use serde::de::{self, Deserializer, MapAccess, Visitor};
use serde::Deserialize;
use serde_json::Value as JsonValue;
use std::fmt;
#[derive(Debug, Clone)]
pub enum SchemaRef {
Inline(JsonValue),
File(String),
}
impl<'de> Deserialize<'de> for SchemaRef {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct SchemaRefVisitor;
impl<'de> Visitor<'de> for SchemaRefVisitor {
type Value = SchemaRef;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a JSON Schema object or a file path string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(SchemaRef::File(v.to_string()))
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(SchemaRef::File(v))
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let value = JsonValue::deserialize(de::value::MapAccessDeserializer::new(map))?;
Ok(SchemaRef::Inline(value))
}
}
deserializer.deserialize_any(SchemaRefVisitor)
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct OutputPolicy {
#[serde(default)]
pub format: OutputFormat,
#[serde(default)]
pub schema: Option<SchemaRef>,
#[serde(default)]
pub max_retries: Option<u8>,
}
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum OutputFormat {
#[default]
Text,
Json,
Yaml,
Markdown,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::serde_yaml;
#[test]
fn parse_text_format() {
let yaml = "format: text";
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
assert_eq!(policy.format, OutputFormat::Text);
assert!(policy.schema.is_none());
}
#[test]
fn parse_json_with_schema_file() {
let yaml = r#"
format: json
schema: .nika/schemas/result.json
"#;
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
assert_eq!(policy.format, OutputFormat::Json);
assert!(
matches!(policy.schema, Some(SchemaRef::File(ref p)) if p == ".nika/schemas/result.json")
);
}
#[test]
fn parse_json_with_inline_schema() {
let yaml = r#"
format: json
schema:
type: object
properties:
name:
type: string
required:
- name
"#;
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
assert_eq!(policy.format, OutputFormat::Json);
assert!(matches!(policy.schema, Some(SchemaRef::Inline(_))));
if let Some(SchemaRef::Inline(schema)) = &policy.schema {
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["name"].is_object());
}
}
#[test]
fn parse_max_retries() {
let yaml = r#"
format: json
max_retries: 3
"#;
let policy: OutputPolicy = serde_yaml::from_str(yaml).unwrap();
assert_eq!(policy.max_retries, Some(3));
}
#[test]
fn default_is_text() {
let policy = OutputPolicy::default();
assert_eq!(policy.format, OutputFormat::Text);
assert!(policy.schema.is_none());
assert!(policy.max_retries.is_none());
}
}