use crate::error::{Error, Result};
use crate::recursive::validate::{Score, Validate};
use serde::de::DeserializeOwned;
use std::borrow::Cow;
use std::marker::PhantomData;
pub trait FormatInstruction: Send + Sync {
fn instruction(&self) -> Cow<'_, str>;
}
#[derive(Debug, Clone, Copy)]
pub struct DefaultFormat;
impl FormatInstruction for DefaultFormat {
fn instruction(&self) -> Cow<'_, str> {
Cow::Borrowed("Respond with valid JSON only. No markdown, no explanation.")
}
}
#[derive(Debug, Clone)]
pub struct SchemaFormat {
schema: Cow<'static, str>,
}
impl SchemaFormat {
pub fn new(schema: &'static str) -> Self {
Self {
schema: Cow::Borrowed(schema),
}
}
pub fn new_owned(schema: String) -> Self {
Self {
schema: Cow::Owned(schema),
}
}
}
impl FormatInstruction for SchemaFormat {
fn instruction(&self) -> Cow<'_, str> {
Cow::Owned(format!(
"Respond with valid JSON matching this schema:\n{}",
self.schema
))
}
}
pub struct TypedValidator<T, F = DefaultFormat> {
format: F,
_phantom: PhantomData<fn() -> T>,
}
unsafe impl<T, F: Send> Send for TypedValidator<T, F> {}
unsafe impl<T, F: Sync> Sync for TypedValidator<T, F> {}
pub fn typed<T: DeserializeOwned>() -> TypedValidator<T, DefaultFormat> {
TypedValidator {
format: DefaultFormat,
_phantom: PhantomData,
}
}
impl<T: DeserializeOwned> TypedValidator<T, DefaultFormat> {
pub fn schema(self, schema: &'static str) -> TypedValidator<T, SchemaFormat> {
TypedValidator {
format: SchemaFormat::new(schema),
_phantom: PhantomData,
}
}
pub fn schema_owned(self, schema: String) -> TypedValidator<T, SchemaFormat> {
TypedValidator {
format: SchemaFormat::new_owned(schema),
_phantom: PhantomData,
}
}
}
impl<T: DeserializeOwned, F: FormatInstruction> TypedValidator<T, F> {
pub fn with_format<F2: FormatInstruction>(self, format: F2) -> TypedValidator<T, F2> {
TypedValidator {
format,
_phantom: PhantomData,
}
}
pub fn instruction(&self) -> Cow<'_, str> {
self.format.instruction()
}
}
impl<T: DeserializeOwned + Send + Sync, F: FormatInstruction> Validate for TypedValidator<T, F> {
fn validate(&self, text: &str) -> Score<'static> {
let json_str = extract_json(text);
match serde_json::from_str::<T>(json_str) {
Ok(_) => Score::pass(),
Err(e) => Score::with_feedback(
0.0,
format!("Invalid JSON: {}. {}", e, self.format.instruction()),
),
}
}
fn name(&self) -> &'static str {
"typed_validator"
}
}
pub fn extract_json(text: &str) -> &str {
let trimmed = text.trim();
if let Some(start) = trimmed.find("```json") {
let after_fence = &trimmed[start + 7..];
let content_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
let content = &after_fence[content_start..];
if let Some(end) = content.find("```") {
return content[..end].trim();
}
}
if let Some(start) = trimmed.find("```") {
let after_fence = &trimmed[start + 3..];
let content_start = after_fence.find('\n').map(|i| i + 1).unwrap_or(0);
let content = &after_fence[content_start..];
if let Some(end) = content.find("```") {
let extracted = content[..end].trim();
if extracted.starts_with('{') || extracted.starts_with('[') {
return extracted;
}
}
}
let first_brace = trimmed.find('{');
let first_bracket = trimmed.find('[');
let json_start = match (first_brace, first_bracket) {
(Some(b), Some(k)) => Some(b.min(k)),
(Some(b), None) => Some(b),
(None, Some(k)) => Some(k),
(None, None) => None,
};
if let Some(start) = json_start {
let open_char = trimmed.as_bytes()[start];
let close_char = if open_char == b'{' { b'}' } else { b']' };
if let Some(end) = trimmed.rfind(close_char as char) {
if end > start {
return &trimmed[start..=end];
}
}
}
trimmed
}
pub fn parse_output<T: DeserializeOwned>(text: &str) -> Result<T> {
let json_str = extract_json(text);
serde_json::from_str(json_str)
.map_err(|e| Error::module(format!("Failed to parse output as JSON: {}", e)))
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq)]
struct SimpleStruct {
name: String,
value: i32,
}
#[derive(Deserialize, Debug)]
#[allow(dead_code)]
struct NestedStruct {
items: Vec<String>,
count: usize,
}
#[test]
fn test_typed_validator_valid_json() {
let v = typed::<SimpleStruct>();
let score = v.validate(r#"{"name": "test", "value": 42}"#);
assert!(score.is_perfect());
}
#[test]
fn test_typed_validator_invalid_json() {
let v = typed::<SimpleStruct>();
let score = v.validate("not json at all");
assert!((score.value - 0.0).abs() < f64::EPSILON);
assert!(score.feedback_str().unwrap().contains("Invalid JSON"));
}
#[test]
fn test_typed_validator_wrong_schema() {
let v = typed::<SimpleStruct>();
let score = v.validate(r#"{"wrong_field": true}"#);
assert!((score.value - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_typed_validator_with_schema() {
let v = typed::<SimpleStruct>().schema(r#"{"name": "string", "value": "number"}"#);
let score = v.validate(r#"{"name": "hello", "value": 1}"#);
assert!(score.is_perfect());
let score = v.validate("bad");
let feedback = score.feedback_str().unwrap();
assert!(feedback.contains("schema"));
}
#[test]
fn test_extract_json_code_fence() {
let text = r#"Here is the JSON:
```json
{"name": "test", "value": 42}
```
That's the answer."#;
let extracted = extract_json(text);
assert_eq!(extracted, r#"{"name": "test", "value": 42}"#);
}
#[test]
fn test_extract_json_generic_fence() {
let text = r#"```
{"items": ["a", "b"], "count": 2}
```"#;
let extracted = extract_json(text);
assert_eq!(extracted, r#"{"items": ["a", "b"], "count": 2}"#);
}
#[test]
fn test_extract_json_raw() {
let text = r#"The answer is {"name": "raw", "value": 99} and that's it."#;
let extracted = extract_json(text);
assert_eq!(extracted, r#"{"name": "raw", "value": 99}"#);
}
#[test]
fn test_extract_json_array() {
let text = r#"[1, 2, 3]"#;
let extracted = extract_json(text);
assert_eq!(extracted, "[1, 2, 3]");
}
#[test]
fn test_extract_json_no_json() {
let text = "just plain text";
let extracted = extract_json(text);
assert_eq!(extracted, "just plain text");
}
#[test]
fn test_parse_output_from_fence() {
let text = r#"```json
{"name": "parsed", "value": 7}
```"#;
let result: SimpleStruct = parse_output(text).unwrap();
assert_eq!(result.name, "parsed");
assert_eq!(result.value, 7);
}
#[test]
fn test_parse_output_raw_json() {
let result: SimpleStruct = parse_output(r#"{"name": "raw", "value": 1}"#).unwrap();
assert_eq!(result.name, "raw");
assert_eq!(result.value, 1);
}
#[test]
fn test_parse_output_invalid() {
let result = parse_output::<SimpleStruct>("not json");
assert!(result.is_err());
}
#[test]
fn test_typed_nested() {
let v = typed::<NestedStruct>();
let score = v.validate(r#"{"items": ["hello", "world"], "count": 2}"#);
assert!(score.is_perfect());
}
#[test]
fn test_format_instruction_default() {
let v = typed::<SimpleStruct>();
let inst = v.instruction();
assert!(inst.contains("JSON"));
}
#[test]
fn test_format_instruction_schema() {
let v = typed::<SimpleStruct>().schema(r#"{"name": "string"}"#);
let inst = v.instruction();
assert!(inst.contains("schema"));
assert!(inst.contains(r#""name""#));
}
#[test]
fn test_typed_with_surrounding_text() {
let v = typed::<SimpleStruct>();
let text = r#"Sure! Here is the answer:
{"name": "extracted", "value": 100}
Hope this helps!"#;
let score = v.validate(text);
assert!(score.is_perfect());
}
#[test]
fn test_schema_owned() {
let schema = format!(r#"{{"name": "{}"}}"#, "string");
let v = typed::<SimpleStruct>().schema_owned(schema);
let inst = v.instruction();
assert!(inst.contains("name"));
}
}