1use std::future::Future;
44use std::pin::Pin;
45use std::sync::Arc;
46use std::time::Instant;
47
48use serde_json::Value;
49use tracing::debug;
50
51use crate::ast::output::SchemaRef;
52use crate::ast::StructuredOutputSpec;
53use crate::error::NikaError;
54use crate::event::{EventKind, EventLog};
55
56use super::output::{extract_json, format_validation_errors, validate_schema_ref};
57
58pub type InferCallback = Arc<
66 dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<String, NikaError>> + Send>> + Send + Sync,
67>;
68
69const LAYER_2_NAME: &str = "extract_validate";
71const LAYER_3_NAME: &str = "retry_with_feedback";
72const LAYER_4_NAME: &str = "llm_repair";
73
74fn estimate_tokens(char_len: usize) -> u64 {
76 char_len.div_ceil(4) as u64
77}
78
79#[derive(Debug, Clone)]
81pub struct StructuredOutputResult {
82 pub value: Value,
84 pub layer: u8,
86 pub layer_name: String,
88 pub total_attempts: u32,
90}
91
92pub struct StructuredOutputEngine {
109 spec: StructuredOutputSpec,
111 log: Arc<EventLog>,
113 compiled_schema: Option<Arc<Value>>,
115 cached_example: Option<Value>,
120 infer_fn: Option<InferCallback>,
124 original_prompt: Option<String>,
128 provider_name: Option<String>,
130 model_name: Option<String>,
132}
133
134impl StructuredOutputEngine {
135 pub fn new(spec: StructuredOutputSpec, log: Arc<EventLog>) -> Self {
137 Self {
138 spec,
139 log,
140 compiled_schema: None,
141 cached_example: None,
142 infer_fn: None,
143 original_prompt: None,
144 provider_name: None,
145 model_name: None,
146 }
147 }
148
149 pub fn with_infer_callback(mut self, callback: InferCallback) -> Self {
167 self.infer_fn = Some(callback);
168 self
169 }
170
171 pub fn with_original_prompt(mut self, prompt: String) -> Self {
175 self.original_prompt = Some(prompt);
176 self
177 }
178
179 pub fn with_provider_context(mut self, provider: String, model: String) -> Self {
181 self.provider_name = Some(provider);
182 self.model_name = Some(model);
183 self
184 }
185
186 fn estimate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
188 let provider = self.provider_name.as_deref().unwrap_or("");
189 let model = self.model_name.as_deref().unwrap_or("");
190 crate::provider::cost::ProviderKind::parse(provider)
191 .map(|pk| crate::provider::cost::calculate_cost(pk, model, input_tokens, output_tokens))
192 .unwrap_or(0.0)
193 }
194
195 pub async fn load_schema(&mut self) -> Result<Arc<Value>, NikaError> {
198 if self.compiled_schema.is_none() {
199 let schema = if let Some(ref example_ref) = self.spec.from_example {
200 let example_value = match example_ref {
202 SchemaRef::Inline(v) => v.clone(),
203 SchemaRef::File(path) => {
204 let content = tokio::fs::read_to_string(path).await.map_err(|e| {
205 NikaError::SchemaFailed {
206 details: format!("Failed to read example '{}': {}", path, e),
207 }
208 })?;
209 let parsed: Value = serde_json::from_str(&content).map_err(|e| {
210 NikaError::SchemaFailed {
211 details: format!("Invalid JSON in example '{}': {}", path, e),
212 }
213 })?;
214 self.cached_example = Some(parsed.clone());
216 parsed
217 }
218 };
219 if self.spec.strict == Some(true) {
220 crate::ast::structured::json_to_schema_strict(&example_value)
221 } else {
222 crate::ast::structured::json_to_schema(&example_value)
223 }
224 } else {
225 match self.spec.schema.as_ref() {
227 Some(SchemaRef::Inline(v)) => v.clone(),
228 Some(SchemaRef::File(path)) => {
229 let content = tokio::fs::read_to_string(path).await.map_err(|e| {
230 NikaError::SchemaFailed {
231 details: format!("Failed to read schema '{}': {}", path, e),
232 }
233 })?;
234 serde_json::from_str(&content).map_err(|e| NikaError::SchemaFailed {
235 details: format!("Invalid JSON in schema '{}': {}", path, e),
236 })?
237 }
238 None => {
239 return Err(NikaError::SchemaFailed {
240 details: "No schema or from_example defined".to_string(),
241 });
242 }
243 }
244 };
245 self.compiled_schema = Some(Arc::new(schema));
246 }
247 self.compiled_schema
248 .clone()
249 .ok_or_else(|| NikaError::SchemaFailed {
250 details: "Schema compilation produced None (internal error)".to_string(),
251 })
252 }
253
254 pub fn schema(&self) -> Option<&SchemaRef> {
259 self.spec.schema.as_ref()
260 }
261
262 pub fn cached_example(&self) -> Option<&Value> {
267 self.cached_example.as_ref()
268 }
269
270 pub async fn validate(
274 &mut self,
275 task_id: &str,
276 raw_output: &str,
277 ) -> Result<StructuredOutputResult, NikaError> {
278 let task_id: Arc<str> = Arc::from(task_id);
279 let mut total_attempts: u32 = 0;
280
281 let schema = self.load_schema().await?;
283
284 {
292 total_attempts += 1;
293 let layer_result = self
294 .try_layer_2(&task_id, raw_output, &schema, total_attempts)
295 .await;
296
297 if let Ok(value) = layer_result {
298 self.emit_success(&task_id, 2, LAYER_2_NAME, total_attempts);
299 return Ok(StructuredOutputResult {
300 value,
301 layer: 2,
302 layer_name: LAYER_2_NAME.to_string(),
303 total_attempts,
304 });
305 }
306 }
307
308 let mut current_output = raw_output.to_string();
312 if self.spec.enable_retry_or_default() {
313 let max_retries = self.spec.max_retries_or_default();
314 for retry in 1..=max_retries {
315 total_attempts += 1;
316 let (layer_result, llm_output) = self
317 .try_layer_3(&task_id, ¤t_output, &schema, retry, total_attempts)
318 .await;
319
320 if let Some(output) = llm_output {
322 current_output = output;
323 }
324
325 if let Ok(value) = layer_result {
326 self.emit_success(&task_id, 3, LAYER_3_NAME, total_attempts);
327 return Ok(StructuredOutputResult {
328 value,
329 layer: 3,
330 layer_name: LAYER_3_NAME.to_string(),
331 total_attempts,
332 });
333 }
334 }
335 }
336
337 if self.spec.enable_repair_or_default() {
339 total_attempts += 1;
340 let layer_result = self
341 .try_layer_4(&task_id, ¤t_output, &schema, total_attempts)
342 .await;
343
344 if let Ok(value) = layer_result {
345 self.emit_success(&task_id, 4, LAYER_4_NAME, total_attempts);
346 return Ok(StructuredOutputResult {
347 value,
348 layer: 4,
349 layer_name: LAYER_4_NAME.to_string(),
350 total_attempts,
351 });
352 }
353 }
354
355 let errors = self.collect_validation_errors(¤t_output, &schema);
357 Err(NikaError::StructuredOutputAllLayersFailed {
358 task_id: task_id.to_string(),
359 attempts: total_attempts,
360 final_errors: errors,
361 })
362 }
363
364 async fn try_layer_2(
369 &self,
370 task_id: &Arc<str>,
371 raw_output: &str,
372 schema: &Value,
373 attempt: u32,
374 ) -> Result<Value, NikaError> {
375 let json_value = match extract_json(raw_output) {
377 Ok(v) => v,
378 Err(e) => {
379 self.emit_attempt(task_id, 2, LAYER_2_NAME, attempt, false, Some(e.clone()));
380 return Err(NikaError::StructuredOutputExtractionFailed {
381 task_id: task_id.to_string(),
382 layer: LAYER_2_NAME.to_string(),
383 reason: e,
384 });
385 }
386 };
387
388 match validate_schema_ref(&json_value, &SchemaRef::Inline(schema.clone())).await {
390 Ok(()) => {
391 self.emit_attempt(task_id, 2, LAYER_2_NAME, attempt, true, None);
392 Ok(json_value)
393 }
394 Err(e) => {
395 self.emit_attempt(
396 task_id,
397 2,
398 LAYER_2_NAME,
399 attempt,
400 false,
401 Some(e.to_string()),
402 );
403 Err(NikaError::StructuredOutputValidationFailed {
404 task_id: task_id.to_string(),
405 layer: LAYER_2_NAME.to_string(),
406 attempt,
407 errors: vec![e.to_string()],
408 })
409 }
410 }
411 }
412
413 async fn try_layer_3(
422 &self,
423 task_id: &Arc<str>,
424 raw_output: &str,
425 schema: &Value,
426 retry_num: u8,
427 attempt: u32,
428 ) -> (Result<Value, NikaError>, Option<String>) {
429 let infer_fn = match &self.infer_fn {
431 Some(f) => f,
432 None => {
433 debug!(
435 task_id = %task_id,
436 retry = retry_num,
437 "Layer 3 skipped: no infer callback configured"
438 );
439 self.emit_attempt(
440 task_id,
441 3,
442 LAYER_3_NAME,
443 attempt,
444 false,
445 Some(format!(
446 "retry {}: no infer callback - Layer 3 disabled",
447 retry_num
448 )),
449 );
450 return (
451 Err(NikaError::StructuredOutputValidationFailed {
452 task_id: task_id.to_string(),
453 layer: LAYER_3_NAME.to_string(),
454 attempt,
455 errors: vec!["Layer 3 requires infer callback".to_string()],
456 }),
457 None,
458 );
459 }
460 };
461
462 let validation_errors = self
464 .collect_validation_errors(raw_output, schema)
465 .join("\n");
466
467 let original_prompt = self.original_prompt.as_deref().unwrap_or("");
469 let retry_prompt =
470 self.generate_retry_prompt(original_prompt, raw_output, &validation_errors);
471
472 let prompt_len = retry_prompt.len();
473
474 debug!(
475 task_id = %task_id,
476 retry = retry_num,
477 prompt_len,
478 "Layer 3: calling LLM with retry prompt"
479 );
480
481 self.log.emit(EventKind::ProviderCalled {
483 task_id: Arc::clone(task_id),
484 provider: self
485 .provider_name
486 .clone()
487 .unwrap_or_else(|| "unknown".to_string()),
488 model: self
489 .model_name
490 .clone()
491 .unwrap_or_else(|| "unknown".to_string()),
492 prompt_len,
493 });
494
495 let infer_start = Instant::now();
497 let new_output = match infer_fn(retry_prompt).await {
498 Ok(output) => {
499 let elapsed = infer_start.elapsed();
500 let in_tok = estimate_tokens(prompt_len);
501 let out_tok = estimate_tokens(output.len());
502 let cost = self.estimate_cost(in_tok, out_tok);
503 self.log.emit(EventKind::ProviderResponded {
505 task_id: Arc::clone(task_id),
506 request_id: None,
507 input_tokens: in_tok,
508 output_tokens: out_tok,
509 cache_read_tokens: 0,
510 ttft_ms: Some(elapsed.as_millis() as u64),
511 finish_reason: "structured_output_retry".to_string(),
512 cost_usd: cost,
513 });
514 output
515 }
516 Err(e) => {
517 self.emit_attempt(
518 task_id,
519 3,
520 LAYER_3_NAME,
521 attempt,
522 false,
523 Some(format!("retry {}: LLM call failed: {}", retry_num, e)),
524 );
525 return (Err(e), None);
526 }
527 };
528
529 debug!(
530 task_id = %task_id,
531 retry = retry_num,
532 output_len = new_output.len(),
533 "Layer 3: received LLM response"
534 );
535
536 let json_value = match extract_json(&new_output) {
538 Ok(v) => v,
539 Err(e) => {
540 self.emit_attempt(
541 task_id,
542 3,
543 LAYER_3_NAME,
544 attempt,
545 false,
546 Some(format!("retry {}: extraction failed: {}", retry_num, e)),
547 );
548 return (
549 Err(NikaError::StructuredOutputExtractionFailed {
550 task_id: task_id.to_string(),
551 layer: LAYER_3_NAME.to_string(),
552 reason: e,
553 }),
554 Some(new_output),
555 );
556 }
557 };
558
559 match validate_schema_ref(&json_value, &SchemaRef::Inline(schema.clone())).await {
561 Ok(()) => {
562 debug!(
563 task_id = %task_id,
564 retry = retry_num,
565 "Layer 3: validation succeeded"
566 );
567 self.emit_attempt(task_id, 3, LAYER_3_NAME, attempt, true, None);
568 (Ok(json_value), Some(new_output))
569 }
570 Err(e) => {
571 self.emit_attempt(
572 task_id,
573 3,
574 LAYER_3_NAME,
575 attempt,
576 false,
577 Some(format!("retry {}: validation failed: {}", retry_num, e)),
578 );
579 (
580 Err(NikaError::StructuredOutputValidationFailed {
581 task_id: task_id.to_string(),
582 layer: LAYER_3_NAME.to_string(),
583 attempt,
584 errors: vec![e.to_string()],
585 }),
586 Some(new_output),
587 )
588 }
589 }
590 }
591
592 async fn try_layer_4(
602 &self,
603 task_id: &Arc<str>,
604 raw_output: &str,
605 schema: &Value,
606 attempt: u32,
607 ) -> Result<Value, NikaError> {
608 let infer_fn = match &self.infer_fn {
610 Some(f) => f,
611 None => {
612 debug!(
614 task_id = %task_id,
615 "Layer 4 skipped: no infer callback configured"
616 );
617 self.emit_attempt(
618 task_id,
619 4,
620 LAYER_4_NAME,
621 attempt,
622 false,
623 Some("no infer callback - Layer 4 disabled".to_string()),
624 );
625 return Err(NikaError::StructuredOutputValidationFailed {
626 task_id: task_id.to_string(),
627 layer: LAYER_4_NAME.to_string(),
628 attempt,
629 errors: vec!["Layer 4 requires infer callback".to_string()],
630 });
631 }
632 };
633
634 let repair_prompt = self.generate_repair_prompt(raw_output, schema);
636 let prompt_len = repair_prompt.len();
637
638 debug!(
639 task_id = %task_id,
640 prompt_len,
641 "Layer 4: calling repair LLM"
642 );
643
644 self.log.emit(EventKind::ProviderCalled {
646 task_id: Arc::clone(task_id),
647 provider: self
648 .provider_name
649 .clone()
650 .unwrap_or_else(|| "unknown".to_string()),
651 model: self
652 .model_name
653 .clone()
654 .unwrap_or_else(|| "unknown".to_string()),
655 prompt_len,
656 });
657
658 let infer_start = Instant::now();
660 let repaired_output = match infer_fn(repair_prompt).await {
661 Ok(output) => {
662 let elapsed = infer_start.elapsed();
663 let in_tok = estimate_tokens(prompt_len);
664 let out_tok = estimate_tokens(output.len());
665 let cost = self.estimate_cost(in_tok, out_tok);
666 self.log.emit(EventKind::ProviderResponded {
668 task_id: Arc::clone(task_id),
669 request_id: None,
670 input_tokens: in_tok,
671 output_tokens: out_tok,
672 cache_read_tokens: 0,
673 ttft_ms: Some(elapsed.as_millis() as u64),
674 finish_reason: "structured_output_repair".to_string(),
675 cost_usd: cost,
676 });
677 output
678 }
679 Err(e) => {
680 self.emit_attempt(
681 task_id,
682 4,
683 LAYER_4_NAME,
684 attempt,
685 false,
686 Some(format!("repair LLM call failed: {}", e)),
687 );
688 return Err(e);
689 }
690 };
691
692 debug!(
693 task_id = %task_id,
694 output_len = repaired_output.len(),
695 "Layer 4: received repair LLM response"
696 );
697
698 let json_value = match extract_json(&repaired_output) {
700 Ok(v) => v,
701 Err(e) => {
702 self.emit_attempt(
703 task_id,
704 4,
705 LAYER_4_NAME,
706 attempt,
707 false,
708 Some(format!("repair extraction failed: {}", e)),
709 );
710 return Err(NikaError::StructuredOutputExtractionFailed {
711 task_id: task_id.to_string(),
712 layer: LAYER_4_NAME.to_string(),
713 reason: e,
714 });
715 }
716 };
717
718 match validate_schema_ref(&json_value, &SchemaRef::Inline(schema.clone())).await {
720 Ok(()) => {
721 debug!(
722 task_id = %task_id,
723 "Layer 4: repair validation succeeded"
724 );
725 self.emit_attempt(task_id, 4, LAYER_4_NAME, attempt, true, None);
726 Ok(json_value)
727 }
728 Err(e) => {
729 self.emit_attempt(
730 task_id,
731 4,
732 LAYER_4_NAME,
733 attempt,
734 false,
735 Some(format!("repair validation failed: {}", e)),
736 );
737 Err(NikaError::StructuredOutputValidationFailed {
738 task_id: task_id.to_string(),
739 layer: LAYER_4_NAME.to_string(),
740 attempt,
741 errors: vec![e.to_string()],
742 })
743 }
744 }
745 }
746
747 fn emit_attempt(
749 &self,
750 task_id: &Arc<str>,
751 layer: u8,
752 layer_name: &str,
753 attempt: u32,
754 success: bool,
755 error: Option<String>,
756 ) {
757 self.log.emit(EventKind::StructuredOutputAttempt {
758 task_id: Arc::clone(task_id),
759 layer,
760 layer_name: layer_name.to_string(),
761 attempt,
762 success,
763 error,
764 });
765 }
766
767 fn emit_success(&self, task_id: &Arc<str>, layer: u8, layer_name: &str, total_attempts: u32) {
769 self.log.emit(EventKind::StructuredOutputSuccess {
770 task_id: Arc::clone(task_id),
771 layer,
772 layer_name: layer_name.to_string(),
773 total_attempts,
774 });
775 }
776
777 fn collect_validation_errors(&self, raw_output: &str, schema: &Value) -> Vec<String> {
779 match extract_json(raw_output) {
780 Ok(value) => {
781 let errors_str = format_validation_errors(&value, schema);
782 errors_str.lines().map(|s| s.to_string()).collect()
783 }
784 Err(e) => vec![format!("JSON extraction failed: {}", e)],
785 }
786 }
787
788 pub fn generate_retry_prompt(
792 &self,
793 original_prompt: &str,
794 invalid_output: &str,
795 validation_errors: &str,
796 ) -> String {
797 format!(
798 r#"{original_prompt}
799
800Your previous response was invalid:
801```
802{invalid_output}
803```
804
805Validation errors:
806{validation_errors}
807
808Please provide a corrected response that matches the required JSON schema."#
809 )
810 }
811
812 pub fn generate_repair_prompt(&self, invalid_output: &str, schema: &Value) -> String {
816 let schema_str =
817 serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string());
818
819 format!(
820 r#"You are a JSON repair assistant. Fix the following invalid JSON to match the schema.
821
822Invalid JSON:
823```
824{invalid_output}
825```
826
827Required schema:
828```json
829{schema_str}
830```
831
832Respond with ONLY the corrected JSON, no explanation."#
833 )
834 }
835}
836
837pub async fn validate_structured_output(
849 task_id: &str,
850 output: &str,
851 spec: &StructuredOutputSpec,
852 log: &EventLog,
853) -> Result<Value, NikaError> {
854 let task_id: Arc<str> = Arc::from(task_id);
855
856 let json_value = extract_json(output).map_err(|e| {
858 log.emit(EventKind::StructuredOutputAttempt {
859 task_id: Arc::clone(&task_id),
860 layer: 2,
861 layer_name: LAYER_2_NAME.to_string(),
862 attempt: 1,
863 success: false,
864 error: Some(e.clone()),
865 });
866 NikaError::StructuredOutputExtractionFailed {
867 task_id: task_id.to_string(),
868 layer: LAYER_2_NAME.to_string(),
869 reason: e,
870 }
871 })?;
872
873 let effective_schema = if let Some(ref example_ref) = spec.from_example {
875 let example_value = match example_ref {
876 SchemaRef::Inline(v) => v.clone(),
877 SchemaRef::File(path) => {
878 let content =
879 tokio::fs::read_to_string(path)
880 .await
881 .map_err(|e| NikaError::SchemaFailed {
882 details: format!("Failed to read example '{}': {}", path, e),
883 })?;
884 serde_json::from_str(&content).map_err(|e| NikaError::SchemaFailed {
885 details: format!("Invalid JSON in example '{}': {}", path, e),
886 })?
887 }
888 };
889 if spec.strict == Some(true) {
890 SchemaRef::Inline(crate::ast::structured::json_to_schema_strict(
891 &example_value,
892 ))
893 } else {
894 SchemaRef::Inline(crate::ast::structured::json_to_schema(&example_value))
895 }
896 } else {
897 match spec.schema.clone() {
898 Some(schema) => schema,
899 None => {
900 return Err(NikaError::SchemaFailed {
901 details: "No schema or from_example defined".to_string(),
902 });
903 }
904 }
905 };
906
907 validate_schema_ref(&json_value, &effective_schema)
909 .await
910 .map_err(|e| {
911 log.emit(EventKind::StructuredOutputAttempt {
912 task_id: Arc::clone(&task_id),
913 layer: 2,
914 layer_name: LAYER_2_NAME.to_string(),
915 attempt: 1,
916 success: false,
917 error: Some(e.to_string()),
918 });
919 NikaError::StructuredOutputValidationFailed {
920 task_id: task_id.to_string(),
921 layer: LAYER_2_NAME.to_string(),
922 attempt: 1,
923 errors: vec![e.to_string()],
924 }
925 })?;
926
927 log.emit(EventKind::StructuredOutputSuccess {
928 task_id: Arc::clone(&task_id),
929 layer: 2,
930 layer_name: LAYER_2_NAME.to_string(),
931 total_attempts: 1,
932 });
933
934 Ok(json_value)
935}
936
937#[cfg(test)]
938mod tests {
939 use super::*;
940 use std::io::Write;
941 use tempfile::NamedTempFile;
942
943 fn create_test_log() -> Arc<EventLog> {
944 Arc::new(EventLog::new())
945 }
946
947 fn create_user_schema() -> Value {
948 serde_json::json!({
949 "type": "object",
950 "properties": {
951 "name": { "type": "string" },
952 "age": { "type": "integer", "minimum": 0 }
953 },
954 "required": ["name", "age"]
955 })
956 }
957
958 #[tokio::test]
963 async fn layer2_valid_json_passes() {
964 let log = create_test_log();
965 let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
966 let mut engine = StructuredOutputEngine::new(spec, log.clone());
967
968 let result = engine
969 .validate("test-task", r#"{"name": "Alice", "age": 30}"#)
970 .await;
971
972 assert!(result.is_ok());
973 let r = result.unwrap();
974 assert_eq!(r.layer, 2);
975 assert_eq!(r.layer_name, "extract_validate");
976 assert_eq!(r.value["name"], "Alice");
977 }
978
979 #[tokio::test]
980 async fn layer2_markdown_wrapped_json_passes() {
981 let log = create_test_log();
982 let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
983 let mut engine = StructuredOutputEngine::new(spec, log.clone());
984
985 let output = r#"Here's the result:
986```json
987{"name": "Bob", "age": 25}
988```
989Hope this helps!"#;
990
991 let result = engine.validate("test-task", output).await;
992
993 assert!(result.is_ok());
994 let r = result.unwrap();
995 assert_eq!(r.value["name"], "Bob");
996 assert_eq!(r.value["age"], 25);
997 }
998
999 #[tokio::test]
1000 async fn layer2_invalid_json_fails() {
1001 let log = create_test_log();
1002 let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1003 let mut engine = StructuredOutputEngine::new(spec, log.clone());
1004
1005 let result = engine.validate("test-task", r#"{"name": "Charlie"}"#).await;
1007
1008 assert!(result.is_err());
1009 let err = result.unwrap_err();
1010 assert!(matches!(
1011 err,
1012 NikaError::StructuredOutputAllLayersFailed { .. }
1013 ));
1014 }
1015
1016 #[tokio::test]
1017 async fn layer2_malformed_json_fails() {
1018 let log = create_test_log();
1019 let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1020 let mut engine = StructuredOutputEngine::new(spec, log.clone());
1021
1022 let result = engine.validate("test-task", "not json at all").await;
1023
1024 assert!(result.is_err());
1025 }
1026
1027 #[tokio::test]
1032 async fn load_schema_from_file() {
1033 let log = create_test_log();
1034
1035 let mut schema_file = NamedTempFile::new().unwrap();
1036 writeln!(
1037 schema_file,
1038 r#"{{"type": "object", "properties": {{"x": {{"type": "number"}}}}}}"#
1039 )
1040 .unwrap();
1041 let path = schema_file.path().to_string_lossy().to_string();
1042
1043 let spec = StructuredOutputSpec::with_file_schema(&path);
1044 let mut engine = StructuredOutputEngine::new(spec, log);
1045
1046 let schema = engine.load_schema().await.unwrap();
1047 assert_eq!(schema["type"], "object");
1048 }
1049
1050 #[tokio::test]
1051 async fn load_schema_file_not_found() {
1052 let log = create_test_log();
1053 let spec = StructuredOutputSpec::with_file_schema("/nonexistent/schema.json");
1054 let mut engine = StructuredOutputEngine::new(spec, log);
1055
1056 let result = engine.load_schema().await;
1057 assert!(result.is_err());
1058 }
1059
1060 #[tokio::test]
1065 async fn load_schema_from_example_inline() {
1066 let log = create_test_log();
1067 let spec = StructuredOutputSpec::with_example_inline(serde_json::json!({
1068 "name": "alice",
1069 "score": 42
1070 }));
1071 let mut engine = StructuredOutputEngine::new(spec, log);
1072 let schema = engine.load_schema().await.unwrap();
1073 assert_eq!(schema["type"], "object");
1074 assert_eq!(schema["properties"]["name"]["type"], "string");
1075 assert_eq!(schema["properties"]["score"]["type"], "integer");
1076 }
1077
1078 #[tokio::test]
1079 async fn load_schema_from_example_file() {
1080 let mut example_file = NamedTempFile::new().unwrap();
1081 writeln!(example_file, r#"{{"title":"hello","count":1}}"#).unwrap();
1082 let path = example_file.path().to_string_lossy().to_string();
1083
1084 let spec = StructuredOutputSpec::with_example_file(&path);
1085 let mut engine = StructuredOutputEngine::new(spec, create_test_log());
1086 let schema = engine.load_schema().await.unwrap();
1087 assert_eq!(schema["type"], "object");
1088 assert_eq!(schema["properties"]["title"]["type"], "string");
1089 assert_eq!(schema["properties"]["count"]["type"], "integer");
1090 }
1091
1092 #[tokio::test]
1093 async fn load_schema_from_example_file_not_found() {
1094 let spec = StructuredOutputSpec::with_example_file("/nonexistent/example.json");
1095 let mut engine = StructuredOutputEngine::new(spec, create_test_log());
1096 let result = engine.load_schema().await;
1097 assert!(result.is_err());
1098 let err = result.unwrap_err().to_string();
1099 assert!(err.contains("Failed to read example"), "got: {err}");
1100 }
1101
1102 #[tokio::test]
1103 async fn validate_with_example_inline_passes_valid_json() {
1104 let spec = StructuredOutputSpec::with_example_inline(serde_json::json!({
1105 "name": "x",
1106 "score": 0
1107 }));
1108 let mut engine = StructuredOutputEngine::new(spec, create_test_log());
1109 let result = engine
1110 .validate("t1", r#"{"name": "bob", "score": 99}"#)
1111 .await;
1112 assert!(result.is_ok(), "expected ok, got: {:?}", result);
1113 }
1114
1115 #[tokio::test]
1116 async fn validate_with_example_inline_fails_wrong_type() {
1117 let spec = StructuredOutputSpec::with_example_inline(serde_json::json!({
1118 "name": "x",
1119 "score": 0
1120 }));
1121 let mut engine = StructuredOutputEngine::new(spec, create_test_log());
1123 let result = engine
1124 .validate("t2", r#"{"name": "bob", "score": "not-a-number"}"#)
1125 .await;
1126 assert!(result.is_err(), "expected validation failure on wrong type");
1127 }
1128
1129 #[tokio::test]
1130 async fn validate_structured_output_from_example_inline_passes() {
1131 let log = create_test_log();
1132 let spec = StructuredOutputSpec::with_example_inline(serde_json::json!({
1133 "name": "x",
1134 "score": 0
1135 }));
1136 let result =
1137 validate_structured_output("t3", r#"{"name":"alice","score":42}"#, &spec, &log).await;
1138 assert!(result.is_ok(), "expected ok, got: {:?}", result);
1139 }
1140
1141 #[tokio::test]
1142 async fn validate_structured_output_from_example_inline_rejects_invalid() {
1143 let log = create_test_log();
1144 let spec = StructuredOutputSpec::with_example_inline(serde_json::json!({
1145 "name": "x",
1146 "score": 0
1147 }));
1148 let result =
1150 validate_structured_output("t4", r#"{"anything":"goes","random":true}"#, &spec, &log)
1151 .await;
1152 assert!(
1153 result.is_err(),
1154 "validate_structured_output must reject missing required fields"
1155 );
1156 }
1157
1158 #[tokio::test]
1163 async fn events_emitted_on_success() {
1164 let log = create_test_log();
1165 let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1166 let mut engine = StructuredOutputEngine::new(spec, log.clone());
1167
1168 let _ = engine
1169 .validate("task-1", r#"{"name": "Test", "age": 20}"#)
1170 .await;
1171
1172 let events = log.events();
1173 assert!(!events.is_empty());
1174
1175 let has_attempt = events.iter().any(|e| {
1177 matches!(
1178 &e.kind,
1179 EventKind::StructuredOutputAttempt { success: true, .. }
1180 )
1181 });
1182 let has_success = events
1183 .iter()
1184 .any(|e| matches!(&e.kind, EventKind::StructuredOutputSuccess { .. }));
1185
1186 assert!(has_attempt);
1187 assert!(has_success);
1188 }
1189
1190 #[tokio::test]
1191 async fn events_emitted_on_failure() {
1192 let log = create_test_log();
1193 let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1194 let mut engine = StructuredOutputEngine::new(spec, log.clone());
1195
1196 let _ = engine.validate("task-2", "invalid").await;
1197
1198 let events = log.events();
1199 assert!(!events.is_empty());
1200
1201 let has_failed_attempt = events.iter().any(|e| {
1203 matches!(
1204 &e.kind,
1205 EventKind::StructuredOutputAttempt { success: false, .. }
1206 )
1207 });
1208 assert!(has_failed_attempt);
1209 }
1210
1211 #[tokio::test]
1216 async fn layers_can_be_disabled() {
1217 let log = create_test_log();
1218 let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1219 spec.enable_retry = Some(false);
1220 spec.enable_repair = Some(false);
1221
1222 let mut engine = StructuredOutputEngine::new(spec, log.clone());
1223
1224 let result = engine
1226 .validate("task-3", r#"{"name": "Only name, no age"}"#)
1227 .await;
1228
1229 assert!(result.is_err());
1230
1231 let events = log.events();
1233 let attempt_count = events
1234 .iter()
1235 .filter(|e| matches!(&e.kind, EventKind::StructuredOutputAttempt { .. }))
1236 .count();
1237 assert_eq!(attempt_count, 1, "Only Layer 2 should have attempted");
1238 }
1239
1240 #[test]
1245 fn generate_retry_prompt_includes_context() {
1246 let log = create_test_log();
1247 let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1248 let engine = StructuredOutputEngine::new(spec, log);
1249
1250 let prompt = engine.generate_retry_prompt(
1251 "Generate a user object",
1252 r#"{"name": "Test"}"#,
1253 "missing required field: age",
1254 );
1255
1256 assert!(prompt.contains("Generate a user object"));
1257 assert!(prompt.contains(r#"{"name": "Test"}"#));
1258 assert!(prompt.contains("missing required field: age"));
1259 }
1260
1261 #[test]
1262 fn generate_repair_prompt_includes_schema() {
1263 let log = create_test_log();
1264 let schema = create_user_schema();
1265 let spec = StructuredOutputSpec::with_inline_schema(schema.clone());
1266 let engine = StructuredOutputEngine::new(spec, log);
1267
1268 let prompt = engine.generate_repair_prompt(r#"{"broken": true}"#, &schema);
1269
1270 assert!(prompt.contains(r#"{"broken": true}"#));
1271 assert!(prompt.contains("name"));
1272 assert!(prompt.contains("age"));
1273 }
1274
1275 #[tokio::test]
1280 async fn standalone_validation_works() {
1281 let log = EventLog::new();
1282 let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1283
1284 let result = validate_structured_output(
1285 "task-4",
1286 r#"{"name": "Standalone", "age": 42}"#,
1287 &spec,
1288 &log,
1289 )
1290 .await;
1291
1292 assert!(result.is_ok());
1293 let value = result.unwrap();
1294 assert_eq!(value["name"], "Standalone");
1295 }
1296
1297 #[tokio::test]
1298 async fn standalone_validation_fails_on_invalid() {
1299 let log = EventLog::new();
1300 let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1301
1302 let result =
1303 validate_structured_output("task-5", r#"{"invalid": true}"#, &spec, &log).await;
1304
1305 assert!(result.is_err());
1306 }
1307
1308 #[tokio::test]
1313 async fn handles_unicode_content() {
1314 let log = create_test_log();
1315 let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1316 let mut engine = StructuredOutputEngine::new(spec, log);
1317
1318 let result = engine
1319 .validate("task-unicode", r#"{"name": "日本語テスト", "age": 25}"#)
1320 .await;
1321
1322 assert!(result.is_ok());
1323 assert_eq!(result.unwrap().value["name"], "日本語テスト");
1324 }
1325
1326 #[tokio::test]
1327 async fn handles_nested_objects() {
1328 let log = create_test_log();
1329 let schema = serde_json::json!({
1330 "type": "object",
1331 "properties": {
1332 "user": {
1333 "type": "object",
1334 "properties": {
1335 "name": { "type": "string" }
1336 },
1337 "required": ["name"]
1338 }
1339 },
1340 "required": ["user"]
1341 });
1342 let spec = StructuredOutputSpec::with_inline_schema(schema);
1343 let mut engine = StructuredOutputEngine::new(spec, log);
1344
1345 let result = engine
1346 .validate("task-nested", r#"{"user": {"name": "Nested User"}}"#)
1347 .await;
1348
1349 assert!(result.is_ok());
1350 }
1351
1352 #[tokio::test]
1353 async fn handles_arrays() {
1354 let log = create_test_log();
1355 let schema = serde_json::json!({
1356 "type": "array",
1357 "items": {
1358 "type": "object",
1359 "properties": {
1360 "id": { "type": "integer" }
1361 },
1362 "required": ["id"]
1363 }
1364 });
1365 let spec = StructuredOutputSpec::with_inline_schema(schema);
1366 let mut engine = StructuredOutputEngine::new(spec, log);
1367
1368 let result = engine
1369 .validate("task-array", r#"[{"id": 1}, {"id": 2}, {"id": 3}]"#)
1370 .await;
1371
1372 assert!(result.is_ok());
1373 let arr = result.unwrap().value;
1374 assert!(arr.is_array());
1375 assert_eq!(arr.as_array().unwrap().len(), 3);
1376 }
1377
1378 use std::sync::atomic::{AtomicU32, Ordering};
1383
1384 #[tokio::test]
1385 async fn layer3_actually_retries_llm() {
1386 let call_count = Arc::new(AtomicU32::new(0));
1387 let call_count_clone = call_count.clone();
1388
1389 let callback: InferCallback = Arc::new(move |_prompt: String| {
1391 let count = call_count_clone.clone();
1392 Box::pin(async move {
1393 let n = count.fetch_add(1, Ordering::SeqCst);
1394 if n == 0 {
1395 Ok(r#"{"name": "Alice", "age": 30}"#.to_string())
1397 } else {
1398 Ok(r#"{"name": "Bob", "age": 25}"#.to_string())
1400 }
1401 })
1402 });
1403
1404 let log = create_test_log();
1405 let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1406 spec.enable_retry = Some(true);
1407 spec.max_retries = Some(3);
1408 spec.enable_repair = Some(false); let mut engine = StructuredOutputEngine::new(spec, log.clone())
1411 .with_infer_callback(callback)
1412 .with_original_prompt("Generate a user object".to_string());
1413
1414 let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
1416
1417 assert!(result.is_ok(), "Should succeed after Layer 3 retry");
1418 let r = result.unwrap();
1419 assert_eq!(r.layer, 3, "Should succeed at Layer 3");
1420 assert_eq!(r.layer_name, "retry_with_feedback");
1421 assert_eq!(r.value["name"], "Alice");
1422 assert!(
1423 call_count.load(Ordering::SeqCst) >= 1,
1424 "Should have called LLM at least once"
1425 );
1426 }
1427
1428 #[tokio::test]
1429 async fn layer3_skipped_without_callback() {
1430 let log = create_test_log();
1431 let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1432 spec.enable_retry = Some(true);
1433 spec.max_retries = Some(3);
1434 spec.enable_repair = Some(false);
1435
1436 let mut engine = StructuredOutputEngine::new(spec, log.clone());
1438
1439 let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
1440
1441 assert!(result.is_err(), "Should fail without callback");
1442
1443 let events = log.events();
1445 let layer3_attempts = events.iter().filter(|e| {
1446 matches!(
1447 &e.kind,
1448 EventKind::StructuredOutputAttempt {
1449 layer: 3,
1450 success: false,
1451 error: Some(err),
1452 ..
1453 } if err.contains("no infer callback")
1454 )
1455 });
1456 assert!(
1457 layer3_attempts.count() > 0,
1458 "Should have Layer 3 attempt events showing no callback"
1459 );
1460 }
1461
1462 #[tokio::test]
1467 async fn layer4_actually_repairs_json() {
1468 let call_count = Arc::new(AtomicU32::new(0));
1469 let call_count_clone = call_count.clone();
1470
1471 let callback: InferCallback = Arc::new(move |prompt: String| {
1473 let count = call_count_clone.clone();
1474 Box::pin(async move {
1475 count.fetch_add(1, Ordering::SeqCst);
1476 assert!(
1478 prompt.contains("repair") || prompt.contains("schema"),
1479 "Should receive repair prompt"
1480 );
1481 Ok(r#"{"name": "Repaired", "age": 25}"#.to_string())
1483 })
1484 });
1485
1486 let log = create_test_log();
1487 let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1488 spec.enable_retry = Some(false); spec.enable_repair = Some(true);
1490
1491 let mut engine =
1492 StructuredOutputEngine::new(spec, log.clone()).with_infer_callback(callback);
1493
1494 let result = engine.validate("test-task", "totally broken json").await;
1495
1496 assert!(result.is_ok(), "Should succeed after Layer 4 repair");
1497 let r = result.unwrap();
1498 assert_eq!(r.layer, 4, "Should succeed at Layer 4");
1499 assert_eq!(r.layer_name, "llm_repair");
1500 assert_eq!(r.value["name"], "Repaired");
1501 assert!(
1502 call_count.load(Ordering::SeqCst) >= 1,
1503 "Should have called repair LLM"
1504 );
1505 }
1506
1507 #[tokio::test]
1508 async fn layer4_skipped_without_callback() {
1509 let log = create_test_log();
1510 let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1511 spec.enable_retry = Some(false);
1512 spec.enable_repair = Some(true);
1513
1514 let mut engine = StructuredOutputEngine::new(spec, log.clone());
1516
1517 let result = engine.validate("test-task", "broken json").await;
1518
1519 assert!(result.is_err(), "Should fail without callback");
1520
1521 let events = log.events();
1523 let layer4_attempts = events.iter().filter(|e| {
1524 matches!(
1525 &e.kind,
1526 EventKind::StructuredOutputAttempt {
1527 layer: 4,
1528 success: false,
1529 error: Some(err),
1530 ..
1531 } if err.contains("no infer callback")
1532 )
1533 });
1534 assert!(
1535 layer4_attempts.count() > 0,
1536 "Should have Layer 4 attempt event showing no callback"
1537 );
1538 }
1539
1540 #[tokio::test]
1545 async fn max_retries_is_respected() {
1546 let call_count = Arc::new(AtomicU32::new(0));
1547 let call_count_clone = call_count.clone();
1548
1549 let callback: InferCallback = Arc::new(move |_prompt: String| {
1551 let count = call_count_clone.clone();
1552 Box::pin(async move {
1553 count.fetch_add(1, Ordering::SeqCst);
1554 Ok(r#"{"still_invalid": true}"#.to_string())
1556 })
1557 });
1558
1559 let log = create_test_log();
1560 let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1561 spec.max_retries = Some(3);
1562 spec.enable_retry = Some(true);
1563 spec.enable_repair = Some(false); let mut engine =
1566 StructuredOutputEngine::new(spec, log.clone()).with_infer_callback(callback);
1567
1568 let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
1569
1570 assert!(result.is_err(), "Should fail after max retries");
1571 assert_eq!(
1572 call_count.load(Ordering::SeqCst),
1573 3,
1574 "Should have retried exactly max_retries times"
1575 );
1576 }
1577
1578 #[tokio::test]
1579 async fn layer3_layer4_chain_works() {
1580 let call_count = Arc::new(AtomicU32::new(0));
1581 let call_count_clone = call_count.clone();
1582
1583 let callback: InferCallback = Arc::new(move |prompt: String| {
1588 let count = call_count_clone.clone();
1589 Box::pin(async move {
1590 let n = count.fetch_add(1, Ordering::SeqCst);
1591 if prompt.contains("JSON repair assistant") {
1592 Ok(r#"{"name": "Repaired", "age": 42}"#.to_string())
1594 } else {
1595 Ok(format!(
1597 r#"{{"retry_attempt": {}, "still_invalid": true}}"#,
1598 n
1599 ))
1600 }
1601 })
1602 });
1603
1604 let log = create_test_log();
1605 let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1606 spec.max_retries = Some(2);
1607 spec.enable_retry = Some(true);
1608 spec.enable_repair = Some(true);
1609
1610 let mut engine = StructuredOutputEngine::new(spec, log.clone())
1611 .with_infer_callback(callback)
1612 .with_original_prompt("Generate user".to_string());
1613
1614 let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
1615
1616 assert!(result.is_ok(), "Should succeed after Layer 4 repair");
1617 let r = result.unwrap();
1618 assert_eq!(r.layer, 4, "Should succeed at Layer 4");
1619 assert_eq!(r.value["name"], "Repaired");
1620 assert_eq!(
1622 call_count.load(Ordering::SeqCst),
1623 3,
1624 "Should have made 2 retry calls + 1 repair call"
1625 );
1626 }
1627
1628 #[tokio::test]
1629 async fn original_prompt_included_in_retry() {
1630 let captured_prompt = Arc::new(std::sync::Mutex::new(String::new()));
1631 let captured_prompt_clone = captured_prompt.clone();
1632
1633 let callback: InferCallback = Arc::new(move |prompt: String| {
1634 let captured = captured_prompt_clone.clone();
1635 Box::pin(async move {
1636 *captured.lock().unwrap() = prompt.clone();
1637 Ok(r#"{"name": "Test", "age": 30}"#.to_string())
1639 })
1640 });
1641
1642 let log = create_test_log();
1643 let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1644 spec.enable_retry = Some(true);
1645 spec.max_retries = Some(1);
1646 spec.enable_repair = Some(false);
1647
1648 let mut engine = StructuredOutputEngine::new(spec, log.clone())
1649 .with_infer_callback(callback)
1650 .with_original_prompt("Generate a user object for testing".to_string());
1651
1652 let _ = engine.validate("test-task", r#"{"invalid": true}"#).await;
1653
1654 let prompt = captured_prompt.lock().unwrap().clone();
1655 assert!(
1656 prompt.contains("Generate a user object for testing"),
1657 "Retry prompt should include original prompt"
1658 );
1659 assert!(
1660 prompt.contains("invalid"),
1661 "Retry prompt should include the invalid output"
1662 );
1663 }
1664}