use crate::error::AgentError;
use crate::schema::SchemaWarning;
use crate::types::OutputSchema;
use serde_json::Value;
pub const DEFAULT_EXTRACTION_PROMPT: &str = "Provide the final output as valid JSON matching \
the required schema. Output ONLY the JSON, no additional text or markdown formatting.";
#[derive(Debug, Default)]
pub(crate) struct ExtractionState {
primary_output: Option<String>,
result: Option<Value>,
schema_warnings: Option<Vec<SchemaWarning>>,
}
impl ExtractionState {
pub(super) fn reset(&mut self) {
self.primary_output = None;
self.result = None;
self.schema_warnings = None;
}
pub(super) fn set_primary_output(&mut self, output: String) {
self.primary_output = Some(output);
}
pub(super) fn primary_output(&self) -> Option<&str> {
self.primary_output.as_deref()
}
pub(super) fn set_schema_warnings(&mut self, warnings: Vec<SchemaWarning>) {
self.schema_warnings = if warnings.is_empty() {
None
} else {
Some(warnings)
};
}
pub(super) fn record_success(&mut self, value: Value) {
self.result = Some(value);
}
pub(super) fn take_result(&mut self) -> Option<Value> {
self.result.take()
}
pub(super) fn take_schema_warnings(&mut self) -> Option<Vec<SchemaWarning>> {
self.schema_warnings.take()
}
}
#[derive(Debug, Clone, PartialEq)]
pub(super) enum ExtractionValidation {
Passed(Value),
Failed { error: String, retry_prompt: String },
}
pub(super) fn validate_response_text(
content: &str,
output_schema: &OutputSchema,
compiled_schema: &Value,
) -> Result<ExtractionValidation, AgentError> {
let json_content = strip_code_fences(content.trim());
let parsed = match serde_json::from_str::<Value>(json_content) {
Ok(parsed) => parsed,
Err(error) => {
return Ok(invalid_validation(format!("Invalid JSON: {error}")));
}
};
let normalized = unwrap_named_object_wrapper(parsed, output_schema);
#[cfg(feature = "jsonschema")]
{
let validator = jsonschema::Validator::new(compiled_schema)
.map_err(|error| AgentError::InvalidOutputSchema(error.to_string()))?;
if let Err(error) = validator.validate(&normalized) {
return Ok(invalid_validation(format!(
"Schema validation failed: {error}"
)));
}
}
#[cfg(not(feature = "jsonschema"))]
{
let _ = compiled_schema;
tracing::warn!(
"Structured output schema validation unavailable \
(jsonschema feature disabled). Accepting parsed JSON without schema validation."
);
}
Ok(ExtractionValidation::Passed(normalized))
}
fn invalid_validation(error: String) -> ExtractionValidation {
let retry_prompt = retry_prompt_for_error(&error);
ExtractionValidation::Failed {
error,
retry_prompt,
}
}
pub(super) fn retry_prompt_for_error(error: &str) -> String {
format!(
"The previous output was invalid: {error}. \
Please provide valid JSON matching the schema. \
Output ONLY the JSON, no additional text."
)
}
pub(super) fn strip_code_fences(content: &str) -> &str {
let trimmed = content.trim();
let without_prefix = if let Some(stripped) = trimmed.strip_prefix("```json") {
stripped
} else if let Some(stripped) = trimmed.strip_prefix("```") {
stripped
} else {
return trimmed;
};
let without_suffix = without_prefix.trim();
if let Some(stripped) = without_suffix.strip_suffix("```") {
stripped.trim()
} else {
without_suffix.trim()
}
}
pub(super) fn unwrap_named_object_wrapper(parsed: Value, output_schema: &OutputSchema) -> Value {
let Some(wrapper_key) = output_schema.name.as_deref() else {
return parsed;
};
let Value::Object(outer) = &parsed else {
return parsed;
};
if outer.len() != 1 {
return parsed;
}
let Some(Value::Object(inner)) = outer.get(wrapper_key) else {
return parsed;
};
let schema = output_schema.schema.as_value();
let required = schema
.get("required")
.and_then(Value::as_array)
.map(|arr| {
arr.iter()
.filter_map(Value::as_str)
.collect::<std::collections::HashSet<_>>()
})
.unwrap_or_default();
let properties = schema
.get("properties")
.and_then(Value::as_object)
.map(|obj| {
obj.keys()
.map(std::string::String::as_str)
.collect::<std::collections::HashSet<_>>()
})
.unwrap_or_default();
let wrapper_is_declared = required.contains(wrapper_key) || properties.contains(wrapper_key);
if wrapper_is_declared {
return parsed;
}
let outer_has_all_required = required.iter().all(|key| outer.contains_key(*key));
let inner_has_all_required = required.iter().all(|key| inner.contains_key(*key));
let outer_matches_properties = properties.iter().any(|key| outer.contains_key(*key));
let inner_matches_properties = properties.iter().any(|key| inner.contains_key(*key));
if inner_has_all_required && !outer_has_all_required {
return Value::Object(inner.clone());
}
if inner_matches_properties && !outer_matches_properties {
return Value::Object(inner.clone());
}
parsed
}
#[cfg(test)]
#[allow(clippy::panic)]
mod tests {
use super::*;
use crate::types::OutputSchema;
use serde_json::json;
#[test]
fn test_strip_code_fences_no_fences() {
assert_eq!(
strip_code_fences(r#"{"name": "test"}"#),
r#"{"name": "test"}"#
);
}
#[test]
fn test_strip_code_fences_json_fence() {
let input = r#"```json
{"name": "test"}
```"#;
assert_eq!(strip_code_fences(input), r#"{"name": "test"}"#);
}
#[test]
fn test_strip_code_fences_plain_fence() {
let input = r#"```
{"name": "test"}
```"#;
assert_eq!(strip_code_fences(input), r#"{"name": "test"}"#);
}
#[test]
fn test_strip_code_fences_with_whitespace() {
let input = r#"
```json
{"name": "test"}
```
"#;
assert_eq!(strip_code_fences(input), r#"{"name": "test"}"#);
}
#[test]
fn test_unwrap_named_object_wrapper_when_inner_matches_schema()
-> Result<(), Box<dyn std::error::Error>> {
let schema = OutputSchema::new(json!({
"type": "object",
"properties": { "response": { "type": "string" } },
"required": ["response"]
}))?
.with_name("advisor");
let parsed = json!({
"advisor": {
"response": "hello"
}
});
let normalized = unwrap_named_object_wrapper(parsed, &schema);
assert_eq!(normalized, json!({"response": "hello"}));
Ok(())
}
#[test]
fn test_unwrap_named_object_wrapper_preserves_declared_wrapper_key()
-> Result<(), Box<dyn std::error::Error>> {
let schema = OutputSchema::new(json!({
"type": "object",
"properties": { "advisor": { "type": "object" } },
"required": ["advisor"]
}))?
.with_name("advisor");
let parsed = json!({
"advisor": {
"response": "hello"
}
});
let normalized = unwrap_named_object_wrapper(parsed.clone(), &schema);
assert_eq!(normalized, parsed);
Ok(())
}
#[test]
fn test_validate_response_text_returns_retry_prompt_for_invalid_json()
-> Result<(), Box<dyn std::error::Error>> {
let schema = OutputSchema::new(json!({
"type": "object",
"properties": { "answer": { "type": "string" } },
"required": ["answer"]
}))?;
let result = validate_response_text("not json {{{", &schema, schema.schema.as_value())?;
match result {
ExtractionValidation::Failed {
error,
retry_prompt,
} => {
assert!(error.contains("Invalid JSON"));
assert!(retry_prompt.contains(&error));
assert!(retry_prompt.contains("Output ONLY the JSON"));
}
ExtractionValidation::Passed(value) => {
panic!("expected invalid JSON failure, got {value:?}")
}
}
Ok(())
}
#[test]
fn test_validate_response_text_rejects_schema_mismatch()
-> Result<(), Box<dyn std::error::Error>> {
let schema = OutputSchema::new(json!({
"type": "object",
"properties": { "count": { "type": "integer" } },
"required": ["count"]
}))?;
let result =
validate_response_text(r#"{"count":"wrong"}"#, &schema, schema.schema.as_value())?;
match result {
ExtractionValidation::Failed { error, .. } => {
assert!(error.contains("Schema validation failed"));
}
ExtractionValidation::Passed(value) => {
panic!("expected schema failure, got {value:?}")
}
}
Ok(())
}
#[test]
fn test_validate_response_text_accepts_unwrapped_schema_match()
-> Result<(), Box<dyn std::error::Error>> {
let schema = OutputSchema::new(json!({
"type": "object",
"properties": { "response": { "type": "string" } },
"required": ["response"]
}))?
.with_name("advisor");
let result = validate_response_text(
r#"{"advisor":{"response":"hello"}}"#,
&schema,
schema.schema.as_value(),
)?;
assert_eq!(
result,
ExtractionValidation::Passed(json!({"response": "hello"}))
);
Ok(())
}
}