use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use serde_json::Value;
use tracing::debug;
use crate::ast::output::SchemaRef;
use crate::ast::StructuredOutputSpec;
use crate::error::NikaError;
use crate::event::{EventKind, EventLog};
use super::output::{extract_json, format_validation_errors, validate_schema_ref};
pub type InferCallback = Arc<
dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<String, NikaError>> + Send>> + Send + Sync,
>;
const LAYER_2_NAME: &str = "extract_validate";
const LAYER_3_NAME: &str = "retry_with_feedback";
const LAYER_4_NAME: &str = "llm_repair";
#[derive(Debug, Clone)]
pub struct StructuredOutputResult {
pub value: Value,
pub layer: u8,
pub layer_name: String,
pub total_attempts: u32,
}
pub struct StructuredOutputEngine {
spec: StructuredOutputSpec,
log: Arc<EventLog>,
compiled_schema: Option<Arc<Value>>,
infer_fn: Option<InferCallback>,
original_prompt: Option<String>,
}
impl StructuredOutputEngine {
pub fn new(spec: StructuredOutputSpec, log: Arc<EventLog>) -> Self {
Self {
spec,
log,
compiled_schema: None,
infer_fn: None,
original_prompt: None,
}
}
pub fn with_infer_callback(mut self, callback: InferCallback) -> Self {
self.infer_fn = Some(callback);
self
}
pub fn with_original_prompt(mut self, prompt: String) -> Self {
self.original_prompt = Some(prompt);
self
}
pub async fn load_schema(&mut self) -> Result<Arc<Value>, NikaError> {
if self.compiled_schema.is_none() {
let schema = match &self.spec.schema {
SchemaRef::Inline(v) => v.clone(),
SchemaRef::File(path) => {
let content = tokio::fs::read_to_string(path).await.map_err(|e| {
NikaError::SchemaFailed {
details: format!("Failed to read schema '{}': {}", path, e),
}
})?;
serde_json::from_str(&content).map_err(|e| NikaError::SchemaFailed {
details: format!("Invalid JSON in schema '{}': {}", path, e),
})?
}
};
self.compiled_schema = Some(Arc::new(schema));
}
self.compiled_schema
.clone()
.ok_or_else(|| NikaError::SchemaFailed {
details: "Schema compilation produced None (internal error)".to_string(),
})
}
pub fn schema(&self) -> &SchemaRef {
&self.spec.schema
}
pub async fn validate(
&mut self,
task_id: &str,
raw_output: &str,
) -> Result<StructuredOutputResult, NikaError> {
let task_id: Arc<str> = Arc::from(task_id);
let mut total_attempts: u32 = 0;
let schema = self.load_schema().await?;
{
total_attempts += 1;
let layer_result = self
.try_layer_2(&task_id, raw_output, &schema, total_attempts)
.await;
if let Ok(value) = layer_result {
self.emit_success(&task_id, 2, LAYER_2_NAME, total_attempts);
return Ok(StructuredOutputResult {
value,
layer: 2,
layer_name: LAYER_2_NAME.to_string(),
total_attempts,
});
}
}
if self.spec.enable_retry_or_default() {
let max_retries = self.spec.max_retries_or_default();
for retry in 1..=max_retries {
total_attempts += 1;
let layer_result = self
.try_layer_3(&task_id, raw_output, &schema, retry, total_attempts)
.await;
if let Ok(value) = layer_result {
self.emit_success(&task_id, 3, LAYER_3_NAME, total_attempts);
return Ok(StructuredOutputResult {
value,
layer: 3,
layer_name: LAYER_3_NAME.to_string(),
total_attempts,
});
}
}
}
if self.spec.enable_repair_or_default() {
total_attempts += 1;
let layer_result = self
.try_layer_4(&task_id, raw_output, &schema, total_attempts)
.await;
if let Ok(value) = layer_result {
self.emit_success(&task_id, 4, LAYER_4_NAME, total_attempts);
return Ok(StructuredOutputResult {
value,
layer: 4,
layer_name: LAYER_4_NAME.to_string(),
total_attempts,
});
}
}
let errors = self.collect_validation_errors(raw_output, &schema);
Err(NikaError::StructuredOutputAllLayersFailed {
task_id: task_id.to_string(),
attempts: total_attempts,
final_errors: errors,
})
}
async fn try_layer_2(
&self,
task_id: &Arc<str>,
raw_output: &str,
schema: &Value,
attempt: u32,
) -> Result<Value, NikaError> {
let json_value = match extract_json(raw_output) {
Ok(v) => v,
Err(e) => {
self.emit_attempt(task_id, 2, LAYER_2_NAME, attempt, false, Some(e.clone()));
return Err(NikaError::StructuredOutputExtractionFailed {
task_id: task_id.to_string(),
layer: LAYER_2_NAME.to_string(),
reason: e,
});
}
};
match validate_schema_ref(&json_value, &SchemaRef::Inline(schema.clone())).await {
Ok(()) => {
self.emit_attempt(task_id, 2, LAYER_2_NAME, attempt, true, None);
Ok(json_value)
}
Err(e) => {
self.emit_attempt(
task_id,
2,
LAYER_2_NAME,
attempt,
false,
Some(e.to_string()),
);
Err(NikaError::StructuredOutputValidationFailed {
task_id: task_id.to_string(),
layer: LAYER_2_NAME.to_string(),
attempt,
errors: vec![e.to_string()],
})
}
}
}
async fn try_layer_3(
&self,
task_id: &Arc<str>,
raw_output: &str,
schema: &Value,
retry_num: u8,
attempt: u32,
) -> Result<Value, NikaError> {
let infer_fn = match &self.infer_fn {
Some(f) => f,
None => {
debug!(
task_id = %task_id,
retry = retry_num,
"Layer 3 skipped: no infer callback configured"
);
self.emit_attempt(
task_id,
3,
LAYER_3_NAME,
attempt,
false,
Some(format!(
"retry {}: no infer callback - Layer 3 disabled",
retry_num
)),
);
return Err(NikaError::StructuredOutputValidationFailed {
task_id: task_id.to_string(),
layer: LAYER_3_NAME.to_string(),
attempt,
errors: vec!["Layer 3 requires infer callback".to_string()],
});
}
};
let validation_errors = self
.collect_validation_errors(raw_output, schema)
.join("\n");
let original_prompt = self.original_prompt.as_deref().unwrap_or("");
let retry_prompt =
self.generate_retry_prompt(original_prompt, raw_output, &validation_errors);
debug!(
task_id = %task_id,
retry = retry_num,
prompt_len = retry_prompt.len(),
"Layer 3: calling LLM with retry prompt"
);
let new_output = match infer_fn(retry_prompt).await {
Ok(output) => output,
Err(e) => {
self.emit_attempt(
task_id,
3,
LAYER_3_NAME,
attempt,
false,
Some(format!("retry {}: LLM call failed: {}", retry_num, e)),
);
return Err(e);
}
};
debug!(
task_id = %task_id,
retry = retry_num,
output_len = new_output.len(),
"Layer 3: received LLM response"
);
let json_value = match extract_json(&new_output) {
Ok(v) => v,
Err(e) => {
self.emit_attempt(
task_id,
3,
LAYER_3_NAME,
attempt,
false,
Some(format!("retry {}: extraction failed: {}", retry_num, e)),
);
return Err(NikaError::StructuredOutputExtractionFailed {
task_id: task_id.to_string(),
layer: LAYER_3_NAME.to_string(),
reason: e,
});
}
};
match validate_schema_ref(&json_value, &SchemaRef::Inline(schema.clone())).await {
Ok(()) => {
debug!(
task_id = %task_id,
retry = retry_num,
"Layer 3: validation succeeded"
);
self.emit_attempt(task_id, 3, LAYER_3_NAME, attempt, true, None);
Ok(json_value)
}
Err(e) => {
self.emit_attempt(
task_id,
3,
LAYER_3_NAME,
attempt,
false,
Some(format!("retry {}: validation failed: {}", retry_num, e)),
);
Err(NikaError::StructuredOutputValidationFailed {
task_id: task_id.to_string(),
layer: LAYER_3_NAME.to_string(),
attempt,
errors: vec![e.to_string()],
})
}
}
}
async fn try_layer_4(
&self,
task_id: &Arc<str>,
raw_output: &str,
schema: &Value,
attempt: u32,
) -> Result<Value, NikaError> {
let infer_fn = match &self.infer_fn {
Some(f) => f,
None => {
debug!(
task_id = %task_id,
"Layer 4 skipped: no infer callback configured"
);
self.emit_attempt(
task_id,
4,
LAYER_4_NAME,
attempt,
false,
Some("no infer callback - Layer 4 disabled".to_string()),
);
return Err(NikaError::StructuredOutputValidationFailed {
task_id: task_id.to_string(),
layer: LAYER_4_NAME.to_string(),
attempt,
errors: vec!["Layer 4 requires infer callback".to_string()],
});
}
};
let repair_prompt = self.generate_repair_prompt(raw_output, schema);
debug!(
task_id = %task_id,
prompt_len = repair_prompt.len(),
"Layer 4: calling repair LLM"
);
let repaired_output = match infer_fn(repair_prompt).await {
Ok(output) => output,
Err(e) => {
self.emit_attempt(
task_id,
4,
LAYER_4_NAME,
attempt,
false,
Some(format!("repair LLM call failed: {}", e)),
);
return Err(e);
}
};
debug!(
task_id = %task_id,
output_len = repaired_output.len(),
"Layer 4: received repair LLM response"
);
let json_value = match extract_json(&repaired_output) {
Ok(v) => v,
Err(e) => {
self.emit_attempt(
task_id,
4,
LAYER_4_NAME,
attempt,
false,
Some(format!("repair extraction failed: {}", e)),
);
return Err(NikaError::StructuredOutputExtractionFailed {
task_id: task_id.to_string(),
layer: LAYER_4_NAME.to_string(),
reason: e,
});
}
};
match validate_schema_ref(&json_value, &SchemaRef::Inline(schema.clone())).await {
Ok(()) => {
debug!(
task_id = %task_id,
"Layer 4: repair validation succeeded"
);
self.emit_attempt(task_id, 4, LAYER_4_NAME, attempt, true, None);
Ok(json_value)
}
Err(e) => {
self.emit_attempt(
task_id,
4,
LAYER_4_NAME,
attempt,
false,
Some(format!("repair validation failed: {}", e)),
);
Err(NikaError::StructuredOutputValidationFailed {
task_id: task_id.to_string(),
layer: LAYER_4_NAME.to_string(),
attempt,
errors: vec![e.to_string()],
})
}
}
}
fn emit_attempt(
&self,
task_id: &Arc<str>,
layer: u8,
layer_name: &str,
attempt: u32,
success: bool,
error: Option<String>,
) {
self.log.emit(EventKind::StructuredOutputAttempt {
task_id: Arc::clone(task_id),
layer,
layer_name: layer_name.to_string(),
attempt,
success,
error,
});
}
fn emit_success(&self, task_id: &Arc<str>, layer: u8, layer_name: &str, total_attempts: u32) {
self.log.emit(EventKind::StructuredOutputSuccess {
task_id: Arc::clone(task_id),
layer,
layer_name: layer_name.to_string(),
total_attempts,
});
}
fn collect_validation_errors(&self, raw_output: &str, schema: &Value) -> Vec<String> {
match extract_json(raw_output) {
Ok(value) => {
let errors_str = format_validation_errors(&value, schema);
errors_str.lines().map(|s| s.to_string()).collect()
}
Err(e) => vec![format!("JSON extraction failed: {}", e)],
}
}
pub fn generate_retry_prompt(
&self,
original_prompt: &str,
invalid_output: &str,
validation_errors: &str,
) -> String {
format!(
r#"{original_prompt}
Your previous response was invalid:
```
{invalid_output}
```
Validation errors:
{validation_errors}
Please provide a corrected response that matches the required JSON schema."#
)
}
pub fn generate_repair_prompt(&self, invalid_output: &str, schema: &Value) -> String {
let schema_str =
serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string());
format!(
r#"You are a JSON repair assistant. Fix the following invalid JSON to match the schema.
Invalid JSON:
```
{invalid_output}
```
Required schema:
```json
{schema_str}
```
Respond with ONLY the corrected JSON, no explanation."#
)
}
}
pub async fn validate_structured_output(
task_id: &str,
output: &str,
spec: &StructuredOutputSpec,
log: &EventLog,
) -> Result<Value, NikaError> {
let task_id: Arc<str> = Arc::from(task_id);
let json_value = extract_json(output).map_err(|e| {
log.emit(EventKind::StructuredOutputAttempt {
task_id: Arc::clone(&task_id),
layer: 2,
layer_name: LAYER_2_NAME.to_string(),
attempt: 1,
success: false,
error: Some(e.clone()),
});
NikaError::StructuredOutputExtractionFailed {
task_id: task_id.to_string(),
layer: LAYER_2_NAME.to_string(),
reason: e,
}
})?;
validate_schema_ref(&json_value, &spec.schema)
.await
.map_err(|e| {
log.emit(EventKind::StructuredOutputAttempt {
task_id: Arc::clone(&task_id),
layer: 2,
layer_name: LAYER_2_NAME.to_string(),
attempt: 1,
success: false,
error: Some(e.to_string()),
});
NikaError::StructuredOutputValidationFailed {
task_id: task_id.to_string(),
layer: LAYER_2_NAME.to_string(),
attempt: 1,
errors: vec![e.to_string()],
}
})?;
log.emit(EventKind::StructuredOutputSuccess {
task_id: Arc::clone(&task_id),
layer: 2,
layer_name: LAYER_2_NAME.to_string(),
total_attempts: 1,
});
Ok(json_value)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_log() -> Arc<EventLog> {
Arc::new(EventLog::new())
}
fn create_user_schema() -> Value {
serde_json::json!({
"type": "object",
"properties": {
"name": { "type": "string" },
"age": { "type": "integer", "minimum": 0 }
},
"required": ["name", "age"]
})
}
#[tokio::test]
async fn layer2_valid_json_passes() {
let log = create_test_log();
let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
let mut engine = StructuredOutputEngine::new(spec, log.clone());
let result = engine
.validate("test-task", r#"{"name": "Alice", "age": 30}"#)
.await;
assert!(result.is_ok());
let r = result.unwrap();
assert_eq!(r.layer, 2);
assert_eq!(r.layer_name, "extract_validate");
assert_eq!(r.value["name"], "Alice");
}
#[tokio::test]
async fn layer2_markdown_wrapped_json_passes() {
let log = create_test_log();
let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
let mut engine = StructuredOutputEngine::new(spec, log.clone());
let output = r#"Here's the result:
```json
{"name": "Bob", "age": 25}
```
Hope this helps!"#;
let result = engine.validate("test-task", output).await;
assert!(result.is_ok());
let r = result.unwrap();
assert_eq!(r.value["name"], "Bob");
assert_eq!(r.value["age"], 25);
}
#[tokio::test]
async fn layer2_invalid_json_fails() {
let log = create_test_log();
let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
let mut engine = StructuredOutputEngine::new(spec, log.clone());
let result = engine.validate("test-task", r#"{"name": "Charlie"}"#).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(
err,
NikaError::StructuredOutputAllLayersFailed { .. }
));
}
#[tokio::test]
async fn layer2_malformed_json_fails() {
let log = create_test_log();
let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
let mut engine = StructuredOutputEngine::new(spec, log.clone());
let result = engine.validate("test-task", "not json at all").await;
assert!(result.is_err());
}
#[tokio::test]
async fn load_schema_from_file() {
let log = create_test_log();
let mut schema_file = NamedTempFile::new().unwrap();
writeln!(
schema_file,
r#"{{"type": "object", "properties": {{"x": {{"type": "number"}}}}}}"#
)
.unwrap();
let path = schema_file.path().to_string_lossy().to_string();
let spec = StructuredOutputSpec::with_file_schema(&path);
let mut engine = StructuredOutputEngine::new(spec, log);
let schema = engine.load_schema().await.unwrap();
assert_eq!(schema["type"], "object");
}
#[tokio::test]
async fn load_schema_file_not_found() {
let log = create_test_log();
let spec = StructuredOutputSpec::with_file_schema("/nonexistent/schema.json");
let mut engine = StructuredOutputEngine::new(spec, log);
let result = engine.load_schema().await;
assert!(result.is_err());
}
#[tokio::test]
async fn events_emitted_on_success() {
let log = create_test_log();
let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
let mut engine = StructuredOutputEngine::new(spec, log.clone());
let _ = engine
.validate("task-1", r#"{"name": "Test", "age": 20}"#)
.await;
let events = log.events();
assert!(!events.is_empty());
let has_attempt = events.iter().any(|e| {
matches!(
&e.kind,
EventKind::StructuredOutputAttempt { success: true, .. }
)
});
let has_success = events
.iter()
.any(|e| matches!(&e.kind, EventKind::StructuredOutputSuccess { .. }));
assert!(has_attempt);
assert!(has_success);
}
#[tokio::test]
async fn events_emitted_on_failure() {
let log = create_test_log();
let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
let mut engine = StructuredOutputEngine::new(spec, log.clone());
let _ = engine.validate("task-2", "invalid").await;
let events = log.events();
assert!(!events.is_empty());
let has_failed_attempt = events.iter().any(|e| {
matches!(
&e.kind,
EventKind::StructuredOutputAttempt { success: false, .. }
)
});
assert!(has_failed_attempt);
}
#[tokio::test]
async fn layers_can_be_disabled() {
let log = create_test_log();
let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
spec.enable_retry = Some(false);
spec.enable_repair = Some(false);
let mut engine = StructuredOutputEngine::new(spec, log.clone());
let result = engine
.validate("task-3", r#"{"name": "Only name, no age"}"#)
.await;
assert!(result.is_err());
let events = log.events();
let attempt_count = events
.iter()
.filter(|e| matches!(&e.kind, EventKind::StructuredOutputAttempt { .. }))
.count();
assert_eq!(attempt_count, 1, "Only Layer 2 should have attempted");
}
#[test]
fn generate_retry_prompt_includes_context() {
let log = create_test_log();
let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
let engine = StructuredOutputEngine::new(spec, log);
let prompt = engine.generate_retry_prompt(
"Generate a user object",
r#"{"name": "Test"}"#,
"missing required field: age",
);
assert!(prompt.contains("Generate a user object"));
assert!(prompt.contains(r#"{"name": "Test"}"#));
assert!(prompt.contains("missing required field: age"));
}
#[test]
fn generate_repair_prompt_includes_schema() {
let log = create_test_log();
let schema = create_user_schema();
let spec = StructuredOutputSpec::with_inline_schema(schema.clone());
let engine = StructuredOutputEngine::new(spec, log);
let prompt = engine.generate_repair_prompt(r#"{"broken": true}"#, &schema);
assert!(prompt.contains(r#"{"broken": true}"#));
assert!(prompt.contains("name"));
assert!(prompt.contains("age"));
}
#[tokio::test]
async fn standalone_validation_works() {
let log = EventLog::new();
let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
let result = validate_structured_output(
"task-4",
r#"{"name": "Standalone", "age": 42}"#,
&spec,
&log,
)
.await;
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["name"], "Standalone");
}
#[tokio::test]
async fn standalone_validation_fails_on_invalid() {
let log = EventLog::new();
let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
let result =
validate_structured_output("task-5", r#"{"invalid": true}"#, &spec, &log).await;
assert!(result.is_err());
}
#[tokio::test]
async fn handles_unicode_content() {
let log = create_test_log();
let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
let mut engine = StructuredOutputEngine::new(spec, log);
let result = engine
.validate("task-unicode", r#"{"name": "日本語テスト", "age": 25}"#)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().value["name"], "日本語テスト");
}
#[tokio::test]
async fn handles_nested_objects() {
let log = create_test_log();
let schema = serde_json::json!({
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"]
}
},
"required": ["user"]
});
let spec = StructuredOutputSpec::with_inline_schema(schema);
let mut engine = StructuredOutputEngine::new(spec, log);
let result = engine
.validate("task-nested", r#"{"user": {"name": "Nested User"}}"#)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn handles_arrays() {
let log = create_test_log();
let schema = serde_json::json!({
"type": "array",
"items": {
"type": "object",
"properties": {
"id": { "type": "integer" }
},
"required": ["id"]
}
});
let spec = StructuredOutputSpec::with_inline_schema(schema);
let mut engine = StructuredOutputEngine::new(spec, log);
let result = engine
.validate("task-array", r#"[{"id": 1}, {"id": 2}, {"id": 3}]"#)
.await;
assert!(result.is_ok());
let arr = result.unwrap().value;
assert!(arr.is_array());
assert_eq!(arr.as_array().unwrap().len(), 3);
}
use std::sync::atomic::{AtomicU32, Ordering};
#[tokio::test]
async fn layer3_actually_retries_llm() {
let call_count = Arc::new(AtomicU32::new(0));
let call_count_clone = call_count.clone();
let callback: InferCallback = Arc::new(move |_prompt: String| {
let count = call_count_clone.clone();
Box::pin(async move {
let n = count.fetch_add(1, Ordering::SeqCst);
if n == 0 {
Ok(r#"{"name": "Alice", "age": 30}"#.to_string())
} else {
Ok(r#"{"name": "Bob", "age": 25}"#.to_string())
}
})
});
let log = create_test_log();
let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
spec.enable_retry = Some(true);
spec.max_retries = Some(3);
spec.enable_repair = Some(false);
let mut engine = StructuredOutputEngine::new(spec, log.clone())
.with_infer_callback(callback)
.with_original_prompt("Generate a user object".to_string());
let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
assert!(result.is_ok(), "Should succeed after Layer 3 retry");
let r = result.unwrap();
assert_eq!(r.layer, 3, "Should succeed at Layer 3");
assert_eq!(r.layer_name, "retry_with_feedback");
assert_eq!(r.value["name"], "Alice");
assert!(
call_count.load(Ordering::SeqCst) >= 1,
"Should have called LLM at least once"
);
}
#[tokio::test]
async fn layer3_skipped_without_callback() {
let log = create_test_log();
let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
spec.enable_retry = Some(true);
spec.max_retries = Some(3);
spec.enable_repair = Some(false);
let mut engine = StructuredOutputEngine::new(spec, log.clone());
let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
assert!(result.is_err(), "Should fail without callback");
let events = log.events();
let layer3_attempts = events.iter().filter(|e| {
matches!(
&e.kind,
EventKind::StructuredOutputAttempt {
layer: 3,
success: false,
error: Some(err),
..
} if err.contains("no infer callback")
)
});
assert!(
layer3_attempts.count() > 0,
"Should have Layer 3 attempt events showing no callback"
);
}
#[tokio::test]
async fn layer4_actually_repairs_json() {
let call_count = Arc::new(AtomicU32::new(0));
let call_count_clone = call_count.clone();
let callback: InferCallback = Arc::new(move |prompt: String| {
let count = call_count_clone.clone();
Box::pin(async move {
count.fetch_add(1, Ordering::SeqCst);
assert!(
prompt.contains("repair") || prompt.contains("schema"),
"Should receive repair prompt"
);
Ok(r#"{"name": "Repaired", "age": 25}"#.to_string())
})
});
let log = create_test_log();
let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
spec.enable_retry = Some(false); spec.enable_repair = Some(true);
let mut engine =
StructuredOutputEngine::new(spec, log.clone()).with_infer_callback(callback);
let result = engine.validate("test-task", "totally broken json").await;
assert!(result.is_ok(), "Should succeed after Layer 4 repair");
let r = result.unwrap();
assert_eq!(r.layer, 4, "Should succeed at Layer 4");
assert_eq!(r.layer_name, "llm_repair");
assert_eq!(r.value["name"], "Repaired");
assert!(
call_count.load(Ordering::SeqCst) >= 1,
"Should have called repair LLM"
);
}
#[tokio::test]
async fn layer4_skipped_without_callback() {
let log = create_test_log();
let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
spec.enable_retry = Some(false);
spec.enable_repair = Some(true);
let mut engine = StructuredOutputEngine::new(spec, log.clone());
let result = engine.validate("test-task", "broken json").await;
assert!(result.is_err(), "Should fail without callback");
let events = log.events();
let layer4_attempts = events.iter().filter(|e| {
matches!(
&e.kind,
EventKind::StructuredOutputAttempt {
layer: 4,
success: false,
error: Some(err),
..
} if err.contains("no infer callback")
)
});
assert!(
layer4_attempts.count() > 0,
"Should have Layer 4 attempt event showing no callback"
);
}
#[tokio::test]
async fn max_retries_is_respected() {
let call_count = Arc::new(AtomicU32::new(0));
let call_count_clone = call_count.clone();
let callback: InferCallback = Arc::new(move |_prompt: String| {
let count = call_count_clone.clone();
Box::pin(async move {
count.fetch_add(1, Ordering::SeqCst);
Ok(r#"{"still_invalid": true}"#.to_string())
})
});
let log = create_test_log();
let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
spec.max_retries = Some(3);
spec.enable_retry = Some(true);
spec.enable_repair = Some(false);
let mut engine =
StructuredOutputEngine::new(spec, log.clone()).with_infer_callback(callback);
let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
assert!(result.is_err(), "Should fail after max retries");
assert_eq!(
call_count.load(Ordering::SeqCst),
3,
"Should have retried exactly max_retries times"
);
}
#[tokio::test]
async fn layer3_layer4_chain_works() {
let call_count = Arc::new(AtomicU32::new(0));
let call_count_clone = call_count.clone();
let callback: InferCallback = Arc::new(move |prompt: String| {
let count = call_count_clone.clone();
Box::pin(async move {
let n = count.fetch_add(1, Ordering::SeqCst);
if prompt.contains("JSON repair assistant") {
Ok(r#"{"name": "Repaired", "age": 42}"#.to_string())
} else {
Ok(format!(
r#"{{"retry_attempt": {}, "still_invalid": true}}"#,
n
))
}
})
});
let log = create_test_log();
let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
spec.max_retries = Some(2);
spec.enable_retry = Some(true);
spec.enable_repair = Some(true);
let mut engine = StructuredOutputEngine::new(spec, log.clone())
.with_infer_callback(callback)
.with_original_prompt("Generate user".to_string());
let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
assert!(result.is_ok(), "Should succeed after Layer 4 repair");
let r = result.unwrap();
assert_eq!(r.layer, 4, "Should succeed at Layer 4");
assert_eq!(r.value["name"], "Repaired");
assert_eq!(
call_count.load(Ordering::SeqCst),
3,
"Should have made 2 retry calls + 1 repair call"
);
}
#[tokio::test]
async fn original_prompt_included_in_retry() {
let captured_prompt = Arc::new(std::sync::Mutex::new(String::new()));
let captured_prompt_clone = captured_prompt.clone();
let callback: InferCallback = Arc::new(move |prompt: String| {
let captured = captured_prompt_clone.clone();
Box::pin(async move {
*captured.lock().unwrap() = prompt.clone();
Ok(r#"{"name": "Test", "age": 30}"#.to_string())
})
});
let log = create_test_log();
let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
spec.enable_retry = Some(true);
spec.max_retries = Some(1);
spec.enable_repair = Some(false);
let mut engine = StructuredOutputEngine::new(spec, log.clone())
.with_infer_callback(callback)
.with_original_prompt("Generate a user object for testing".to_string());
let _ = engine.validate("test-task", r#"{"invalid": true}"#).await;
let prompt = captured_prompt.lock().unwrap().clone();
assert!(
prompt.contains("Generate a user object for testing"),
"Retry prompt should include original prompt"
);
assert!(
prompt.contains("invalid"),
"Retry prompt should include the invalid output"
);
}
}