Skip to main content

nika_engine/runtime/
structured_output.rs

1//! Structured Output Engine
2//!
3//! 5-layer defense system for ~99.99% JSON Schema compliance:
4//!
5//! - **Layer 0**: Tool Injection (DynamicSubmitTool via `submit_tool` module)
6//!   - Handled OUTSIDE the engine, in `executor/verbs.rs` (infer) and
7//!     `rig_agent_loop` (agent). Uses `tool_choice: Required` to force
8//!     provider-side schema enforcement.
9//! - **Layer 1**: rig Extractor (Rust types with JsonSchema via schemars — future)
10//! - **Layer 2**: Extract + Validate (extract JSON from output, validate against schema)
11//! - **Layer 3**: Retry with Feedback (re-prompt with validation errors)
12//! - **Layer 4**: LLM Repair (separate call to fix invalid JSON)
13//!
14//! Layer 0 is non-blocking: if tool injection fails (native provider, timeout),
15//! execution falls through to streaming + Layers 2-4.
16//!
17//! Each layer emits `StructuredOutputAttempt` events for observability.
18//! Success emits `StructuredOutputSuccess` with total attempt count.
19//!
20//! ## Usage
21//!
22//! ```rust,ignore
23//! use nika::runtime::StructuredOutputEngine;
24//! use nika::ast::StructuredOutputSpec;
25//!
26//! let spec = StructuredOutputSpec::with_file_schema("./schema.json");
27//! let engine = StructuredOutputEngine::new(spec, event_log.clone());
28//!
29//! // Validate raw output (Layer 2 only without callback)
30//! let result = engine.validate("task-1", raw_output).await?;
31//!
32//! // With inference callback for full Layer 3 & 4 support
33//! let callback: InferCallback = Arc::new(move |prompt: String| {
34//!     let provider = provider.clone();
35//!     Box::pin(async move {
36//!         provider.infer(&prompt, None).await
37//!             .map_err(|e| NikaError::ProviderApiError { message: e.to_string() })
38//!     })
39//! });
40//! let engine = engine.with_infer_callback(callback);
41//! ```
42
43use std::future::Future;
44use std::pin::Pin;
45use std::sync::Arc;
46use std::time::Instant;
47
48use serde_json::Value;
49use tracing::debug;
50
51use crate::ast::output::SchemaRef;
52use crate::ast::StructuredOutputSpec;
53use crate::error::NikaError;
54use crate::event::{EventKind, EventLog};
55
56use super::output::{extract_json, format_validation_errors, validate_schema_ref};
57
58/// Callback type for LLM inference during retry/repair (Layers 3 & 4)
59///
60/// This callback is invoked when the engine needs to re-call the LLM:
61/// - Layer 3: Retry with validation error feedback
62/// - Layer 4: Repair call to fix invalid JSON
63///
64/// The callback receives the prompt and returns the LLM response.
65pub type InferCallback = Arc<
66    dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<String, NikaError>> + Send>> + Send + Sync,
67>;
68
69/// Layer names for event tracking
70const LAYER_2_NAME: &str = "extract_validate";
71const LAYER_3_NAME: &str = "retry_with_feedback";
72const LAYER_4_NAME: &str = "llm_repair";
73
74/// Estimate token count from character length (chars / 4 heuristic)
75fn estimate_tokens(char_len: usize) -> u64 {
76    char_len.div_ceil(4) as u64
77}
78
79/// Result of structured output validation
80#[derive(Debug, Clone)]
81pub struct StructuredOutputResult {
82    /// The validated JSON value
83    pub value: Value,
84    /// Which layer succeeded (1-4)
85    pub layer: u8,
86    /// Layer name
87    pub layer_name: String,
88    /// Total attempts across all layers
89    pub total_attempts: u32,
90}
91
92/// Post-processing structured output validation engine (Layers 2-4)
93///
94/// Attempts validation through multiple layers until success or exhaustion.
95/// All attempts are tracked via events for observability.
96///
97/// Layer 0 (DynamicSubmitTool injection) is handled externally in `executor/verbs.rs`
98/// and `rig_agent_loop/mod.rs` BEFORE this engine is invoked.
99///
100/// ## Layers (this engine)
101///
102/// - **Layer 2**: Extract + Validate - extracts JSON from raw output and validates against schema
103/// - **Layer 3**: Retry with Feedback - re-calls LLM with validation errors (requires `infer_fn`)
104/// - **Layer 4**: LLM Repair - calls repair model to fix invalid JSON (requires `infer_fn`)
105///
106/// Without `infer_fn`, only Layer 2 is functional. Layers 3 & 4 will emit warnings
107/// and gracefully skip to the next layer.
108pub struct StructuredOutputEngine {
109    /// Structured output specification (schema + layer config)
110    spec: StructuredOutputSpec,
111    /// Event log for observability
112    log: Arc<EventLog>,
113    /// Cached compiled schema (for validation speed, Arc for cheap cloning)
114    compiled_schema: Option<Arc<Value>>,
115    /// Cached example value when `from_example` is a file.
116    ///
117    /// Set during `load_schema()` so that `build_json_schema_instruction` can
118    /// inject the file-based example into the LLM prompt without async I/O.
119    cached_example: Option<Value>,
120    /// Callback for LLM inference in Layer 3 & 4
121    ///
122    /// When set, enables actual LLM retries and repairs instead of just re-validation.
123    infer_fn: Option<InferCallback>,
124    /// Original prompt for retry context
125    ///
126    /// Used by Layer 3 to construct the retry prompt with full context.
127    original_prompt: Option<String>,
128    /// Provider name for telemetry (e.g., "anthropic")
129    provider_name: Option<String>,
130    /// Model name for telemetry (e.g., "claude-3-haiku-20240307")
131    model_name: Option<String>,
132}
133
134impl StructuredOutputEngine {
135    /// Create a new engine with the given spec and event log
136    pub fn new(spec: StructuredOutputSpec, log: Arc<EventLog>) -> Self {
137        Self {
138            spec,
139            log,
140            compiled_schema: None,
141            cached_example: None,
142            infer_fn: None,
143            original_prompt: None,
144            provider_name: None,
145            model_name: None,
146        }
147    }
148
149    /// Set the inference callback for Layer 3 & 4
150    ///
151    /// This enables actual LLM retries and repairs. Without this callback,
152    /// only Layer 2 validation is functional.
153    ///
154    /// # Example
155    ///
156    /// ```rust,ignore
157    /// let callback: InferCallback = Arc::new(move |prompt: String| {
158    ///     let provider = provider.clone();
159    ///     Box::pin(async move {
160    ///         provider.infer(&prompt, None).await
161    ///             .map_err(|e| NikaError::ProviderApiError { message: e.to_string() })
162    ///     })
163    /// });
164    /// let engine = engine.with_infer_callback(callback);
165    /// ```
166    pub fn with_infer_callback(mut self, callback: InferCallback) -> Self {
167        self.infer_fn = Some(callback);
168        self
169    }
170
171    /// Set the original prompt for retry context
172    ///
173    /// Used by Layer 3 to construct the retry prompt with full context.
174    pub fn with_original_prompt(mut self, prompt: String) -> Self {
175        self.original_prompt = Some(prompt);
176        self
177    }
178
179    /// Set provider and model names for telemetry on Layer 3/4 LLM calls
180    pub fn with_provider_context(mut self, provider: String, model: String) -> Self {
181        self.provider_name = Some(provider);
182        self.model_name = Some(model);
183        self
184    }
185
186    /// Estimate cost using provider/model context (returns 0.0 if unknown)
187    fn estimate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
188        let provider = self.provider_name.as_deref().unwrap_or("");
189        let model = self.model_name.as_deref().unwrap_or("");
190        crate::provider::cost::ProviderKind::parse(provider)
191            .map(|pk| crate::provider::cost::calculate_cost(pk, model, input_tokens, output_tokens))
192            .unwrap_or(0.0)
193    }
194
195    /// Load and cache the schema for validation.
196    /// Returns an `Arc<Value>` for cheap cloning across async boundaries.
197    pub async fn load_schema(&mut self) -> Result<Arc<Value>, NikaError> {
198        if self.compiled_schema.is_none() {
199            let schema = if let Some(ref example_ref) = self.spec.from_example {
200                // from_example: load example then derive JSON Schema from it
201                let example_value = match example_ref {
202                    SchemaRef::Inline(v) => v.clone(),
203                    SchemaRef::File(path) => {
204                        let content = tokio::fs::read_to_string(path).await.map_err(|e| {
205                            NikaError::SchemaFailed {
206                                details: format!("Failed to read example '{}': {}", path, e),
207                            }
208                        })?;
209                        let parsed: Value = serde_json::from_str(&content).map_err(|e| {
210                            NikaError::SchemaFailed {
211                                details: format!("Invalid JSON in example '{}': {}", path, e),
212                            }
213                        })?;
214                        // Cache example for prompt injection
215                        self.cached_example = Some(parsed.clone());
216                        parsed
217                    }
218                };
219                if self.spec.strict == Some(true) {
220                    crate::ast::structured::json_to_schema_strict(&example_value)
221                } else {
222                    crate::ast::structured::json_to_schema(&example_value)
223                }
224            } else {
225                // Standard: load schema directly
226                match self.spec.schema.as_ref() {
227                    Some(SchemaRef::Inline(v)) => v.clone(),
228                    Some(SchemaRef::File(path)) => {
229                        let content = tokio::fs::read_to_string(path).await.map_err(|e| {
230                            NikaError::SchemaFailed {
231                                details: format!("Failed to read schema '{}': {}", path, e),
232                            }
233                        })?;
234                        serde_json::from_str(&content).map_err(|e| NikaError::SchemaFailed {
235                            details: format!("Invalid JSON in schema '{}': {}", path, e),
236                        })?
237                    }
238                    None => {
239                        return Err(NikaError::SchemaFailed {
240                            details: "No schema or from_example defined".to_string(),
241                        });
242                    }
243                }
244            };
245            self.compiled_schema = Some(Arc::new(schema));
246        }
247        self.compiled_schema
248            .clone()
249            .ok_or_else(|| NikaError::SchemaFailed {
250                details: "Schema compilation produced None (internal error)".to_string(),
251            })
252    }
253
254    /// Get the raw schema reference from the spec.
255    ///
256    /// Returns `None` when `from_example` is set (schema derived at runtime).
257    /// Use `load_schema()` to get the resolved, effective schema.
258    pub fn schema(&self) -> Option<&SchemaRef> {
259        self.spec.schema.as_ref()
260    }
261
262    /// Get the cached example value (set during `load_schema()` for file-based examples).
263    ///
264    /// Returns `Some` after `load_schema()` has been called when `from_example` is a file.
265    /// For inline examples, returns `None` (the value is in `spec.from_example` directly).
266    pub fn cached_example(&self) -> Option<&Value> {
267        self.cached_example.as_ref()
268    }
269
270    /// Validate raw output through the 4-layer defense system
271    ///
272    /// Returns the validated JSON value and metadata about which layer succeeded.
273    pub async fn validate(
274        &mut self,
275        task_id: &str,
276        raw_output: &str,
277    ) -> Result<StructuredOutputResult, NikaError> {
278        let task_id: Arc<str> = Arc::from(task_id);
279        let mut total_attempts: u32 = 0;
280
281        // Load schema for validation (Arc clone is cheap)
282        let schema = self.load_schema().await?;
283
284        // Layer 1: rig Extractor (skip for now - requires compile-time types)
285        // In future: use rig's Extractor with schemars-derived types
286        // For now, we rely on Layers 2-4 which work with runtime schemas
287
288        // Layer 2: Extract + Validate
289        // Extract JSON from the raw output and validate against the schema.
290        // This always runs — it's the core post-processing validation step.
291        {
292            total_attempts += 1;
293            let layer_result = self
294                .try_layer_2(&task_id, raw_output, &schema, total_attempts)
295                .await;
296
297            if let Ok(value) = layer_result {
298                self.emit_success(&task_id, 2, LAYER_2_NAME, total_attempts);
299                return Ok(StructuredOutputResult {
300                    value,
301                    layer: 2,
302                    layer_name: LAYER_2_NAME.to_string(),
303                    total_attempts,
304                });
305            }
306        }
307
308        // Layer 3: Retry with Feedback
309        // Track the latest LLM output so each retry sees the most recent attempt,
310        // not the stale original. try_layer_3 returns (Result, Option<latest_output>).
311        let mut current_output = raw_output.to_string();
312        if self.spec.enable_retry_or_default() {
313            let max_retries = self.spec.max_retries_or_default();
314            for retry in 1..=max_retries {
315                total_attempts += 1;
316                let (layer_result, llm_output) = self
317                    .try_layer_3(&task_id, &current_output, &schema, retry, total_attempts)
318                    .await;
319
320                // Update current_output with LLM's latest response for next retry
321                if let Some(output) = llm_output {
322                    current_output = output;
323                }
324
325                if let Ok(value) = layer_result {
326                    self.emit_success(&task_id, 3, LAYER_3_NAME, total_attempts);
327                    return Ok(StructuredOutputResult {
328                        value,
329                        layer: 3,
330                        layer_name: LAYER_3_NAME.to_string(),
331                        total_attempts,
332                    });
333                }
334            }
335        }
336
337        // Layer 4: LLM Repair — uses latest output (from Layer 3 retries if any)
338        if self.spec.enable_repair_or_default() {
339            total_attempts += 1;
340            let layer_result = self
341                .try_layer_4(&task_id, &current_output, &schema, total_attempts)
342                .await;
343
344            if let Ok(value) = layer_result {
345                self.emit_success(&task_id, 4, LAYER_4_NAME, total_attempts);
346                return Ok(StructuredOutputResult {
347                    value,
348                    layer: 4,
349                    layer_name: LAYER_4_NAME.to_string(),
350                    total_attempts,
351                });
352            }
353        }
354
355        // All layers failed — use latest output for error reporting
356        let errors = self.collect_validation_errors(&current_output, &schema);
357        Err(NikaError::StructuredOutputAllLayersFailed {
358            task_id: task_id.to_string(),
359            attempts: total_attempts,
360            final_errors: errors,
361        })
362    }
363
364    /// Layer 2: Provider-Native validation
365    ///
366    /// Extracts JSON from raw output and validates against schema.
367    /// The provider should have already been configured with tool_use/response_format.
368    async fn try_layer_2(
369        &self,
370        task_id: &Arc<str>,
371        raw_output: &str,
372        schema: &Value,
373        attempt: u32,
374    ) -> Result<Value, NikaError> {
375        // Extract JSON from potentially markdown-wrapped output
376        let json_value = match extract_json(raw_output) {
377            Ok(v) => v,
378            Err(e) => {
379                self.emit_attempt(task_id, 2, LAYER_2_NAME, attempt, false, Some(e.clone()));
380                return Err(NikaError::StructuredOutputExtractionFailed {
381                    task_id: task_id.to_string(),
382                    layer: LAYER_2_NAME.to_string(),
383                    reason: e,
384                });
385            }
386        };
387
388        // Validate against schema
389        match validate_schema_ref(&json_value, &SchemaRef::Inline(schema.clone())).await {
390            Ok(()) => {
391                self.emit_attempt(task_id, 2, LAYER_2_NAME, attempt, true, None);
392                Ok(json_value)
393            }
394            Err(e) => {
395                self.emit_attempt(
396                    task_id,
397                    2,
398                    LAYER_2_NAME,
399                    attempt,
400                    false,
401                    Some(e.to_string()),
402                );
403                Err(NikaError::StructuredOutputValidationFailed {
404                    task_id: task_id.to_string(),
405                    layer: LAYER_2_NAME.to_string(),
406                    attempt,
407                    errors: vec![e.to_string()],
408                })
409            }
410        }
411    }
412
413    /// Layer 3: Retry with Feedback
414    ///
415    /// Re-calls the LLM with validation error feedback to get corrected output.
416    /// Requires `infer_fn` callback to be set via `with_infer_callback()`.
417    ///
418    /// Without `infer_fn`, this layer is skipped with a warning.
419    /// Returns `(Result<Value>, Option<raw_llm_output>)` so the caller can
420    /// track the latest LLM response for subsequent retries.
421    async fn try_layer_3(
422        &self,
423        task_id: &Arc<str>,
424        raw_output: &str,
425        schema: &Value,
426        retry_num: u8,
427        attempt: u32,
428    ) -> (Result<Value, NikaError>, Option<String>) {
429        // Check if we have an inference callback
430        let infer_fn = match &self.infer_fn {
431            Some(f) => f,
432            None => {
433                // No callback - Layer 3 is disabled
434                debug!(
435                    task_id = %task_id,
436                    retry = retry_num,
437                    "Layer 3 skipped: no infer callback configured"
438                );
439                self.emit_attempt(
440                    task_id,
441                    3,
442                    LAYER_3_NAME,
443                    attempt,
444                    false,
445                    Some(format!(
446                        "retry {}: no infer callback - Layer 3 disabled",
447                        retry_num
448                    )),
449                );
450                return (
451                    Err(NikaError::StructuredOutputValidationFailed {
452                        task_id: task_id.to_string(),
453                        layer: LAYER_3_NAME.to_string(),
454                        attempt,
455                        errors: vec!["Layer 3 requires infer callback".to_string()],
456                    }),
457                    None,
458                );
459            }
460        };
461
462        // Collect validation errors from the raw output
463        let validation_errors = self
464            .collect_validation_errors(raw_output, schema)
465            .join("\n");
466
467        // Generate retry prompt with feedback
468        let original_prompt = self.original_prompt.as_deref().unwrap_or("");
469        let retry_prompt =
470            self.generate_retry_prompt(original_prompt, raw_output, &validation_errors);
471
472        let prompt_len = retry_prompt.len();
473
474        debug!(
475            task_id = %task_id,
476            retry = retry_num,
477            prompt_len,
478            "Layer 3: calling LLM with retry prompt"
479        );
480
481        // EMIT: ProviderCalled before the LLM retry call
482        self.log.emit(EventKind::ProviderCalled {
483            task_id: Arc::clone(task_id),
484            provider: self
485                .provider_name
486                .clone()
487                .unwrap_or_else(|| "unknown".to_string()),
488            model: self
489                .model_name
490                .clone()
491                .unwrap_or_else(|| "unknown".to_string()),
492            prompt_len,
493        });
494
495        // Actually call the LLM with the retry prompt
496        let infer_start = Instant::now();
497        let new_output = match infer_fn(retry_prompt).await {
498            Ok(output) => {
499                let elapsed = infer_start.elapsed();
500                let in_tok = estimate_tokens(prompt_len);
501                let out_tok = estimate_tokens(output.len());
502                let cost = self.estimate_cost(in_tok, out_tok);
503                // EMIT: ProviderResponded after successful LLM retry call
504                self.log.emit(EventKind::ProviderResponded {
505                    task_id: Arc::clone(task_id),
506                    request_id: None,
507                    input_tokens: in_tok,
508                    output_tokens: out_tok,
509                    cache_read_tokens: 0,
510                    ttft_ms: Some(elapsed.as_millis() as u64),
511                    finish_reason: "structured_output_retry".to_string(),
512                    cost_usd: cost,
513                });
514                output
515            }
516            Err(e) => {
517                self.emit_attempt(
518                    task_id,
519                    3,
520                    LAYER_3_NAME,
521                    attempt,
522                    false,
523                    Some(format!("retry {}: LLM call failed: {}", retry_num, e)),
524                );
525                return (Err(e), None);
526            }
527        };
528
529        debug!(
530            task_id = %task_id,
531            retry = retry_num,
532            output_len = new_output.len(),
533            "Layer 3: received LLM response"
534        );
535
536        // Extract JSON from the new output
537        let json_value = match extract_json(&new_output) {
538            Ok(v) => v,
539            Err(e) => {
540                self.emit_attempt(
541                    task_id,
542                    3,
543                    LAYER_3_NAME,
544                    attempt,
545                    false,
546                    Some(format!("retry {}: extraction failed: {}", retry_num, e)),
547                );
548                return (
549                    Err(NikaError::StructuredOutputExtractionFailed {
550                        task_id: task_id.to_string(),
551                        layer: LAYER_3_NAME.to_string(),
552                        reason: e,
553                    }),
554                    Some(new_output),
555                );
556            }
557        };
558
559        // Validate the new output against schema
560        match validate_schema_ref(&json_value, &SchemaRef::Inline(schema.clone())).await {
561            Ok(()) => {
562                debug!(
563                    task_id = %task_id,
564                    retry = retry_num,
565                    "Layer 3: validation succeeded"
566                );
567                self.emit_attempt(task_id, 3, LAYER_3_NAME, attempt, true, None);
568                (Ok(json_value), Some(new_output))
569            }
570            Err(e) => {
571                self.emit_attempt(
572                    task_id,
573                    3,
574                    LAYER_3_NAME,
575                    attempt,
576                    false,
577                    Some(format!("retry {}: validation failed: {}", retry_num, e)),
578                );
579                (
580                    Err(NikaError::StructuredOutputValidationFailed {
581                        task_id: task_id.to_string(),
582                        layer: LAYER_3_NAME.to_string(),
583                        attempt,
584                        errors: vec![e.to_string()],
585                    }),
586                    Some(new_output),
587                )
588            }
589        }
590    }
591
592    /// Layer 4: LLM Repair
593    ///
594    /// Calls a repair LLM to fix invalid JSON.
595    /// Requires `infer_fn` callback to be set via `with_infer_callback()`.
596    ///
597    /// The repair prompt includes the invalid output and schema, asking the LLM
598    /// to return only the corrected JSON.
599    ///
600    /// Without `infer_fn`, this layer is skipped with a warning.
601    async fn try_layer_4(
602        &self,
603        task_id: &Arc<str>,
604        raw_output: &str,
605        schema: &Value,
606        attempt: u32,
607    ) -> Result<Value, NikaError> {
608        // Check if we have an inference callback
609        let infer_fn = match &self.infer_fn {
610            Some(f) => f,
611            None => {
612                // No callback - Layer 4 is disabled
613                debug!(
614                    task_id = %task_id,
615                    "Layer 4 skipped: no infer callback configured"
616                );
617                self.emit_attempt(
618                    task_id,
619                    4,
620                    LAYER_4_NAME,
621                    attempt,
622                    false,
623                    Some("no infer callback - Layer 4 disabled".to_string()),
624                );
625                return Err(NikaError::StructuredOutputValidationFailed {
626                    task_id: task_id.to_string(),
627                    layer: LAYER_4_NAME.to_string(),
628                    attempt,
629                    errors: vec!["Layer 4 requires infer callback".to_string()],
630                });
631            }
632        };
633
634        // Generate repair prompt
635        let repair_prompt = self.generate_repair_prompt(raw_output, schema);
636        let prompt_len = repair_prompt.len();
637
638        debug!(
639            task_id = %task_id,
640            prompt_len,
641            "Layer 4: calling repair LLM"
642        );
643
644        // EMIT: ProviderCalled before the LLM repair call
645        self.log.emit(EventKind::ProviderCalled {
646            task_id: Arc::clone(task_id),
647            provider: self
648                .provider_name
649                .clone()
650                .unwrap_or_else(|| "unknown".to_string()),
651            model: self
652                .model_name
653                .clone()
654                .unwrap_or_else(|| "unknown".to_string()),
655            prompt_len,
656        });
657
658        // Call the LLM to repair the JSON
659        let infer_start = Instant::now();
660        let repaired_output = match infer_fn(repair_prompt).await {
661            Ok(output) => {
662                let elapsed = infer_start.elapsed();
663                let in_tok = estimate_tokens(prompt_len);
664                let out_tok = estimate_tokens(output.len());
665                let cost = self.estimate_cost(in_tok, out_tok);
666                // EMIT: ProviderResponded after successful LLM repair call
667                self.log.emit(EventKind::ProviderResponded {
668                    task_id: Arc::clone(task_id),
669                    request_id: None,
670                    input_tokens: in_tok,
671                    output_tokens: out_tok,
672                    cache_read_tokens: 0,
673                    ttft_ms: Some(elapsed.as_millis() as u64),
674                    finish_reason: "structured_output_repair".to_string(),
675                    cost_usd: cost,
676                });
677                output
678            }
679            Err(e) => {
680                self.emit_attempt(
681                    task_id,
682                    4,
683                    LAYER_4_NAME,
684                    attempt,
685                    false,
686                    Some(format!("repair LLM call failed: {}", e)),
687                );
688                return Err(e);
689            }
690        };
691
692        debug!(
693            task_id = %task_id,
694            output_len = repaired_output.len(),
695            "Layer 4: received repair LLM response"
696        );
697
698        // Extract JSON from the repaired output
699        let json_value = match extract_json(&repaired_output) {
700            Ok(v) => v,
701            Err(e) => {
702                self.emit_attempt(
703                    task_id,
704                    4,
705                    LAYER_4_NAME,
706                    attempt,
707                    false,
708                    Some(format!("repair extraction failed: {}", e)),
709                );
710                return Err(NikaError::StructuredOutputExtractionFailed {
711                    task_id: task_id.to_string(),
712                    layer: LAYER_4_NAME.to_string(),
713                    reason: e,
714                });
715            }
716        };
717
718        // Validate the repaired output against schema
719        match validate_schema_ref(&json_value, &SchemaRef::Inline(schema.clone())).await {
720            Ok(()) => {
721                debug!(
722                    task_id = %task_id,
723                    "Layer 4: repair validation succeeded"
724                );
725                self.emit_attempt(task_id, 4, LAYER_4_NAME, attempt, true, None);
726                Ok(json_value)
727            }
728            Err(e) => {
729                self.emit_attempt(
730                    task_id,
731                    4,
732                    LAYER_4_NAME,
733                    attempt,
734                    false,
735                    Some(format!("repair validation failed: {}", e)),
736                );
737                Err(NikaError::StructuredOutputValidationFailed {
738                    task_id: task_id.to_string(),
739                    layer: LAYER_4_NAME.to_string(),
740                    attempt,
741                    errors: vec![e.to_string()],
742                })
743            }
744        }
745    }
746
747    /// Emit a StructuredOutputAttempt event
748    fn emit_attempt(
749        &self,
750        task_id: &Arc<str>,
751        layer: u8,
752        layer_name: &str,
753        attempt: u32,
754        success: bool,
755        error: Option<String>,
756    ) {
757        self.log.emit(EventKind::StructuredOutputAttempt {
758            task_id: Arc::clone(task_id),
759            layer,
760            layer_name: layer_name.to_string(),
761            attempt,
762            success,
763            error,
764        });
765    }
766
767    /// Emit a StructuredOutputSuccess event
768    fn emit_success(&self, task_id: &Arc<str>, layer: u8, layer_name: &str, total_attempts: u32) {
769        self.log.emit(EventKind::StructuredOutputSuccess {
770            task_id: Arc::clone(task_id),
771            layer,
772            layer_name: layer_name.to_string(),
773            total_attempts,
774        });
775    }
776
777    /// Collect validation errors for the final failure message
778    fn collect_validation_errors(&self, raw_output: &str, schema: &Value) -> Vec<String> {
779        match extract_json(raw_output) {
780            Ok(value) => {
781                let errors_str = format_validation_errors(&value, schema);
782                errors_str.lines().map(|s| s.to_string()).collect()
783            }
784            Err(e) => vec![format!("JSON extraction failed: {}", e)],
785        }
786    }
787
788    /// Generate a retry prompt with validation feedback
789    ///
790    /// Used by Layer 3 to construct the re-prompt with error context.
791    pub fn generate_retry_prompt(
792        &self,
793        original_prompt: &str,
794        invalid_output: &str,
795        validation_errors: &str,
796    ) -> String {
797        format!(
798            r#"{original_prompt}
799
800Your previous response was invalid:
801```
802{invalid_output}
803```
804
805Validation errors:
806{validation_errors}
807
808Please provide a corrected response that matches the required JSON schema."#
809        )
810    }
811
812    /// Generate a repair prompt for Layer 4
813    ///
814    /// Used by the executor to construct the repair LLM call.
815    pub fn generate_repair_prompt(&self, invalid_output: &str, schema: &Value) -> String {
816        let schema_str =
817            serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string());
818
819        format!(
820            r#"You are a JSON repair assistant. Fix the following invalid JSON to match the schema.
821
822Invalid JSON:
823```
824{invalid_output}
825```
826
827Required schema:
828```json
829{schema_str}
830```
831
832Respond with ONLY the corrected JSON, no explanation."#
833        )
834    }
835}
836
837// ═══════════════════════════════════════════════════════════════════════════
838// STANDALONE VALIDATION FUNCTIONS
839// ═══════════════════════════════════════════════════════════════════════════
840
841/// Quick validation without the full engine (for simple cases)
842///
843/// Validates output against a schema without retry or repair.
844/// Useful for one-shot validation in exec: or fetch: tasks.
845///
846/// Correctly handles `from_example`: derives the JSON Schema at call time.
847/// For inline examples this is synchronous; for file-based examples it reads the file.
848pub async fn validate_structured_output(
849    task_id: &str,
850    output: &str,
851    spec: &StructuredOutputSpec,
852    log: &EventLog,
853) -> Result<Value, NikaError> {
854    let task_id: Arc<str> = Arc::from(task_id);
855
856    // Extract JSON
857    let json_value = extract_json(output).map_err(|e| {
858        log.emit(EventKind::StructuredOutputAttempt {
859            task_id: Arc::clone(&task_id),
860            layer: 2,
861            layer_name: LAYER_2_NAME.to_string(),
862            attempt: 1,
863            success: false,
864            error: Some(e.clone()),
865        });
866        NikaError::StructuredOutputExtractionFailed {
867            task_id: task_id.to_string(),
868            layer: LAYER_2_NAME.to_string(),
869            reason: e,
870        }
871    })?;
872
873    // Resolve the effective schema — honours from_example (derives schema at runtime)
874    let effective_schema = if let Some(ref example_ref) = spec.from_example {
875        let example_value = match example_ref {
876            SchemaRef::Inline(v) => v.clone(),
877            SchemaRef::File(path) => {
878                let content =
879                    tokio::fs::read_to_string(path)
880                        .await
881                        .map_err(|e| NikaError::SchemaFailed {
882                            details: format!("Failed to read example '{}': {}", path, e),
883                        })?;
884                serde_json::from_str(&content).map_err(|e| NikaError::SchemaFailed {
885                    details: format!("Invalid JSON in example '{}': {}", path, e),
886                })?
887            }
888        };
889        if spec.strict == Some(true) {
890            SchemaRef::Inline(crate::ast::structured::json_to_schema_strict(
891                &example_value,
892            ))
893        } else {
894            SchemaRef::Inline(crate::ast::structured::json_to_schema(&example_value))
895        }
896    } else {
897        match spec.schema.clone() {
898            Some(schema) => schema,
899            None => {
900                return Err(NikaError::SchemaFailed {
901                    details: "No schema or from_example defined".to_string(),
902                });
903            }
904        }
905    };
906
907    // Validate
908    validate_schema_ref(&json_value, &effective_schema)
909        .await
910        .map_err(|e| {
911            log.emit(EventKind::StructuredOutputAttempt {
912                task_id: Arc::clone(&task_id),
913                layer: 2,
914                layer_name: LAYER_2_NAME.to_string(),
915                attempt: 1,
916                success: false,
917                error: Some(e.to_string()),
918            });
919            NikaError::StructuredOutputValidationFailed {
920                task_id: task_id.to_string(),
921                layer: LAYER_2_NAME.to_string(),
922                attempt: 1,
923                errors: vec![e.to_string()],
924            }
925        })?;
926
927    log.emit(EventKind::StructuredOutputSuccess {
928        task_id: Arc::clone(&task_id),
929        layer: 2,
930        layer_name: LAYER_2_NAME.to_string(),
931        total_attempts: 1,
932    });
933
934    Ok(json_value)
935}
936
937#[cfg(test)]
938mod tests {
939    use super::*;
940    use std::io::Write;
941    use tempfile::NamedTempFile;
942
943    fn create_test_log() -> Arc<EventLog> {
944        Arc::new(EventLog::new())
945    }
946
947    fn create_user_schema() -> Value {
948        serde_json::json!({
949            "type": "object",
950            "properties": {
951                "name": { "type": "string" },
952                "age": { "type": "integer", "minimum": 0 }
953            },
954            "required": ["name", "age"]
955        })
956    }
957
958    // ═══════════════════════════════════════════════════════════════════════════
959    // LAYER 2 TESTS (Provider-Native)
960    // ═══════════════════════════════════════════════════════════════════════════
961
962    #[tokio::test]
963    async fn layer2_valid_json_passes() {
964        let log = create_test_log();
965        let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
966        let mut engine = StructuredOutputEngine::new(spec, log.clone());
967
968        let result = engine
969            .validate("test-task", r#"{"name": "Alice", "age": 30}"#)
970            .await;
971
972        assert!(result.is_ok());
973        let r = result.unwrap();
974        assert_eq!(r.layer, 2);
975        assert_eq!(r.layer_name, "extract_validate");
976        assert_eq!(r.value["name"], "Alice");
977    }
978
979    #[tokio::test]
980    async fn layer2_markdown_wrapped_json_passes() {
981        let log = create_test_log();
982        let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
983        let mut engine = StructuredOutputEngine::new(spec, log.clone());
984
985        let output = r#"Here's the result:
986```json
987{"name": "Bob", "age": 25}
988```
989Hope this helps!"#;
990
991        let result = engine.validate("test-task", output).await;
992
993        assert!(result.is_ok());
994        let r = result.unwrap();
995        assert_eq!(r.value["name"], "Bob");
996        assert_eq!(r.value["age"], 25);
997    }
998
999    #[tokio::test]
1000    async fn layer2_invalid_json_fails() {
1001        let log = create_test_log();
1002        let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1003        let mut engine = StructuredOutputEngine::new(spec, log.clone());
1004
1005        // Missing required 'age' field
1006        let result = engine.validate("test-task", r#"{"name": "Charlie"}"#).await;
1007
1008        assert!(result.is_err());
1009        let err = result.unwrap_err();
1010        assert!(matches!(
1011            err,
1012            NikaError::StructuredOutputAllLayersFailed { .. }
1013        ));
1014    }
1015
1016    #[tokio::test]
1017    async fn layer2_malformed_json_fails() {
1018        let log = create_test_log();
1019        let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1020        let mut engine = StructuredOutputEngine::new(spec, log.clone());
1021
1022        let result = engine.validate("test-task", "not json at all").await;
1023
1024        assert!(result.is_err());
1025    }
1026
1027    // ═══════════════════════════════════════════════════════════════════════════
1028    // SCHEMA LOADING TESTS
1029    // ═══════════════════════════════════════════════════════════════════════════
1030
1031    #[tokio::test]
1032    async fn load_schema_from_file() {
1033        let log = create_test_log();
1034
1035        let mut schema_file = NamedTempFile::new().unwrap();
1036        writeln!(
1037            schema_file,
1038            r#"{{"type": "object", "properties": {{"x": {{"type": "number"}}}}}}"#
1039        )
1040        .unwrap();
1041        let path = schema_file.path().to_string_lossy().to_string();
1042
1043        let spec = StructuredOutputSpec::with_file_schema(&path);
1044        let mut engine = StructuredOutputEngine::new(spec, log);
1045
1046        let schema = engine.load_schema().await.unwrap();
1047        assert_eq!(schema["type"], "object");
1048    }
1049
1050    #[tokio::test]
1051    async fn load_schema_file_not_found() {
1052        let log = create_test_log();
1053        let spec = StructuredOutputSpec::with_file_schema("/nonexistent/schema.json");
1054        let mut engine = StructuredOutputEngine::new(spec, log);
1055
1056        let result = engine.load_schema().await;
1057        assert!(result.is_err());
1058    }
1059
1060    // ═══════════════════════════════════════════════════════════════════════════
1061    // FROM_EXAMPLE TESTS (engine layer)
1062    // ═══════════════════════════════════════════════════════════════════════════
1063
1064    #[tokio::test]
1065    async fn load_schema_from_example_inline() {
1066        let log = create_test_log();
1067        let spec = StructuredOutputSpec::with_example_inline(serde_json::json!({
1068            "name": "alice",
1069            "score": 42
1070        }));
1071        let mut engine = StructuredOutputEngine::new(spec, log);
1072        let schema = engine.load_schema().await.unwrap();
1073        assert_eq!(schema["type"], "object");
1074        assert_eq!(schema["properties"]["name"]["type"], "string");
1075        assert_eq!(schema["properties"]["score"]["type"], "integer");
1076    }
1077
1078    #[tokio::test]
1079    async fn load_schema_from_example_file() {
1080        let mut example_file = NamedTempFile::new().unwrap();
1081        writeln!(example_file, r#"{{"title":"hello","count":1}}"#).unwrap();
1082        let path = example_file.path().to_string_lossy().to_string();
1083
1084        let spec = StructuredOutputSpec::with_example_file(&path);
1085        let mut engine = StructuredOutputEngine::new(spec, create_test_log());
1086        let schema = engine.load_schema().await.unwrap();
1087        assert_eq!(schema["type"], "object");
1088        assert_eq!(schema["properties"]["title"]["type"], "string");
1089        assert_eq!(schema["properties"]["count"]["type"], "integer");
1090    }
1091
1092    #[tokio::test]
1093    async fn load_schema_from_example_file_not_found() {
1094        let spec = StructuredOutputSpec::with_example_file("/nonexistent/example.json");
1095        let mut engine = StructuredOutputEngine::new(spec, create_test_log());
1096        let result = engine.load_schema().await;
1097        assert!(result.is_err());
1098        let err = result.unwrap_err().to_string();
1099        assert!(err.contains("Failed to read example"), "got: {err}");
1100    }
1101
1102    #[tokio::test]
1103    async fn validate_with_example_inline_passes_valid_json() {
1104        let spec = StructuredOutputSpec::with_example_inline(serde_json::json!({
1105            "name": "x",
1106            "score": 0
1107        }));
1108        let mut engine = StructuredOutputEngine::new(spec, create_test_log());
1109        let result = engine
1110            .validate("t1", r#"{"name": "bob", "score": 99}"#)
1111            .await;
1112        assert!(result.is_ok(), "expected ok, got: {:?}", result);
1113    }
1114
1115    #[tokio::test]
1116    async fn validate_with_example_inline_fails_wrong_type() {
1117        let spec = StructuredOutputSpec::with_example_inline(serde_json::json!({
1118            "name": "x",
1119            "score": 0
1120        }));
1121        // score should be integer — send string instead
1122        let mut engine = StructuredOutputEngine::new(spec, create_test_log());
1123        let result = engine
1124            .validate("t2", r#"{"name": "bob", "score": "not-a-number"}"#)
1125            .await;
1126        assert!(result.is_err(), "expected validation failure on wrong type");
1127    }
1128
1129    #[tokio::test]
1130    async fn validate_structured_output_from_example_inline_passes() {
1131        let log = create_test_log();
1132        let spec = StructuredOutputSpec::with_example_inline(serde_json::json!({
1133            "name": "x",
1134            "score": 0
1135        }));
1136        let result =
1137            validate_structured_output("t3", r#"{"name":"alice","score":42}"#, &spec, &log).await;
1138        assert!(result.is_ok(), "expected ok, got: {:?}", result);
1139    }
1140
1141    #[tokio::test]
1142    async fn validate_structured_output_from_example_inline_rejects_invalid() {
1143        let log = create_test_log();
1144        let spec = StructuredOutputSpec::with_example_inline(serde_json::json!({
1145            "name": "x",
1146            "score": 0
1147        }));
1148        // This used to silently pass with spec.schema ({}) — now it must fail
1149        let result =
1150            validate_structured_output("t4", r#"{"anything":"goes","random":true}"#, &spec, &log)
1151                .await;
1152        assert!(
1153            result.is_err(),
1154            "validate_structured_output must reject missing required fields"
1155        );
1156    }
1157
1158    // ═══════════════════════════════════════════════════════════════════════════
1159    // EVENT EMISSION TESTS
1160    // ═══════════════════════════════════════════════════════════════════════════
1161
1162    #[tokio::test]
1163    async fn events_emitted_on_success() {
1164        let log = create_test_log();
1165        let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1166        let mut engine = StructuredOutputEngine::new(spec, log.clone());
1167
1168        let _ = engine
1169            .validate("task-1", r#"{"name": "Test", "age": 20}"#)
1170            .await;
1171
1172        let events = log.events();
1173        assert!(!events.is_empty());
1174
1175        // Should have attempt + success events
1176        let has_attempt = events.iter().any(|e| {
1177            matches!(
1178                &e.kind,
1179                EventKind::StructuredOutputAttempt { success: true, .. }
1180            )
1181        });
1182        let has_success = events
1183            .iter()
1184            .any(|e| matches!(&e.kind, EventKind::StructuredOutputSuccess { .. }));
1185
1186        assert!(has_attempt);
1187        assert!(has_success);
1188    }
1189
1190    #[tokio::test]
1191    async fn events_emitted_on_failure() {
1192        let log = create_test_log();
1193        let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1194        let mut engine = StructuredOutputEngine::new(spec, log.clone());
1195
1196        let _ = engine.validate("task-2", "invalid").await;
1197
1198        let events = log.events();
1199        assert!(!events.is_empty());
1200
1201        // Should have failed attempt events
1202        let has_failed_attempt = events.iter().any(|e| {
1203            matches!(
1204                &e.kind,
1205                EventKind::StructuredOutputAttempt { success: false, .. }
1206            )
1207        });
1208        assert!(has_failed_attempt);
1209    }
1210
1211    // ═══════════════════════════════════════════════════════════════════════════
1212    // LAYER TOGGLE TESTS
1213    // ═══════════════════════════════════════════════════════════════════════════
1214
1215    #[tokio::test]
1216    async fn layers_can_be_disabled() {
1217        let log = create_test_log();
1218        let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1219        spec.enable_retry = Some(false);
1220        spec.enable_repair = Some(false);
1221
1222        let mut engine = StructuredOutputEngine::new(spec, log.clone());
1223
1224        // Invalid JSON should fail fast with only Layer 2 enabled
1225        let result = engine
1226            .validate("task-3", r#"{"name": "Only name, no age"}"#)
1227            .await;
1228
1229        assert!(result.is_err());
1230
1231        // Check attempt count - should be just 1 (Layer 2 only)
1232        let events = log.events();
1233        let attempt_count = events
1234            .iter()
1235            .filter(|e| matches!(&e.kind, EventKind::StructuredOutputAttempt { .. }))
1236            .count();
1237        assert_eq!(attempt_count, 1, "Only Layer 2 should have attempted");
1238    }
1239
1240    // ═══════════════════════════════════════════════════════════════════════════
1241    // RETRY PROMPT GENERATION TESTS
1242    // ═══════════════════════════════════════════════════════════════════════════
1243
1244    #[test]
1245    fn generate_retry_prompt_includes_context() {
1246        let log = create_test_log();
1247        let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1248        let engine = StructuredOutputEngine::new(spec, log);
1249
1250        let prompt = engine.generate_retry_prompt(
1251            "Generate a user object",
1252            r#"{"name": "Test"}"#,
1253            "missing required field: age",
1254        );
1255
1256        assert!(prompt.contains("Generate a user object"));
1257        assert!(prompt.contains(r#"{"name": "Test"}"#));
1258        assert!(prompt.contains("missing required field: age"));
1259    }
1260
1261    #[test]
1262    fn generate_repair_prompt_includes_schema() {
1263        let log = create_test_log();
1264        let schema = create_user_schema();
1265        let spec = StructuredOutputSpec::with_inline_schema(schema.clone());
1266        let engine = StructuredOutputEngine::new(spec, log);
1267
1268        let prompt = engine.generate_repair_prompt(r#"{"broken": true}"#, &schema);
1269
1270        assert!(prompt.contains(r#"{"broken": true}"#));
1271        assert!(prompt.contains("name"));
1272        assert!(prompt.contains("age"));
1273    }
1274
1275    // ═══════════════════════════════════════════════════════════════════════════
1276    // STANDALONE VALIDATION TESTS
1277    // ═══════════════════════════════════════════════════════════════════════════
1278
1279    #[tokio::test]
1280    async fn standalone_validation_works() {
1281        let log = EventLog::new();
1282        let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1283
1284        let result = validate_structured_output(
1285            "task-4",
1286            r#"{"name": "Standalone", "age": 42}"#,
1287            &spec,
1288            &log,
1289        )
1290        .await;
1291
1292        assert!(result.is_ok());
1293        let value = result.unwrap();
1294        assert_eq!(value["name"], "Standalone");
1295    }
1296
1297    #[tokio::test]
1298    async fn standalone_validation_fails_on_invalid() {
1299        let log = EventLog::new();
1300        let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1301
1302        let result =
1303            validate_structured_output("task-5", r#"{"invalid": true}"#, &spec, &log).await;
1304
1305        assert!(result.is_err());
1306    }
1307
1308    // ═══════════════════════════════════════════════════════════════════════════
1309    // EDGE CASES
1310    // ═══════════════════════════════════════════════════════════════════════════
1311
1312    #[tokio::test]
1313    async fn handles_unicode_content() {
1314        let log = create_test_log();
1315        let spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1316        let mut engine = StructuredOutputEngine::new(spec, log);
1317
1318        let result = engine
1319            .validate("task-unicode", r#"{"name": "日本語テスト", "age": 25}"#)
1320            .await;
1321
1322        assert!(result.is_ok());
1323        assert_eq!(result.unwrap().value["name"], "日本語テスト");
1324    }
1325
1326    #[tokio::test]
1327    async fn handles_nested_objects() {
1328        let log = create_test_log();
1329        let schema = serde_json::json!({
1330            "type": "object",
1331            "properties": {
1332                "user": {
1333                    "type": "object",
1334                    "properties": {
1335                        "name": { "type": "string" }
1336                    },
1337                    "required": ["name"]
1338                }
1339            },
1340            "required": ["user"]
1341        });
1342        let spec = StructuredOutputSpec::with_inline_schema(schema);
1343        let mut engine = StructuredOutputEngine::new(spec, log);
1344
1345        let result = engine
1346            .validate("task-nested", r#"{"user": {"name": "Nested User"}}"#)
1347            .await;
1348
1349        assert!(result.is_ok());
1350    }
1351
1352    #[tokio::test]
1353    async fn handles_arrays() {
1354        let log = create_test_log();
1355        let schema = serde_json::json!({
1356            "type": "array",
1357            "items": {
1358                "type": "object",
1359                "properties": {
1360                    "id": { "type": "integer" }
1361                },
1362                "required": ["id"]
1363            }
1364        });
1365        let spec = StructuredOutputSpec::with_inline_schema(schema);
1366        let mut engine = StructuredOutputEngine::new(spec, log);
1367
1368        let result = engine
1369            .validate("task-array", r#"[{"id": 1}, {"id": 2}, {"id": 3}]"#)
1370            .await;
1371
1372        assert!(result.is_ok());
1373        let arr = result.unwrap().value;
1374        assert!(arr.is_array());
1375        assert_eq!(arr.as_array().unwrap().len(), 3);
1376    }
1377
1378    // ═══════════════════════════════════════════════════════════════════════════
1379    // LAYER 3 TESTS
1380    // ═══════════════════════════════════════════════════════════════════════════
1381
1382    use std::sync::atomic::{AtomicU32, Ordering};
1383
1384    #[tokio::test]
1385    async fn layer3_actually_retries_llm() {
1386        let call_count = Arc::new(AtomicU32::new(0));
1387        let call_count_clone = call_count.clone();
1388
1389        // Mock callback that returns valid JSON on second call
1390        let callback: InferCallback = Arc::new(move |_prompt: String| {
1391            let count = call_count_clone.clone();
1392            Box::pin(async move {
1393                let n = count.fetch_add(1, Ordering::SeqCst);
1394                if n == 0 {
1395                    // First call from Layer 3 retry: return valid JSON
1396                    Ok(r#"{"name": "Alice", "age": 30}"#.to_string())
1397                } else {
1398                    // Shouldn't be called more than once if first retry succeeds
1399                    Ok(r#"{"name": "Bob", "age": 25}"#.to_string())
1400                }
1401            })
1402        });
1403
1404        let log = create_test_log();
1405        let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1406        spec.enable_retry = Some(true);
1407        spec.max_retries = Some(3);
1408        spec.enable_repair = Some(false); // Disable Layer 4
1409
1410        let mut engine = StructuredOutputEngine::new(spec, log.clone())
1411            .with_infer_callback(callback)
1412            .with_original_prompt("Generate a user object".to_string());
1413
1414        // Invalid JSON should trigger Layer 3 retry
1415        let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
1416
1417        assert!(result.is_ok(), "Should succeed after Layer 3 retry");
1418        let r = result.unwrap();
1419        assert_eq!(r.layer, 3, "Should succeed at Layer 3");
1420        assert_eq!(r.layer_name, "retry_with_feedback");
1421        assert_eq!(r.value["name"], "Alice");
1422        assert!(
1423            call_count.load(Ordering::SeqCst) >= 1,
1424            "Should have called LLM at least once"
1425        );
1426    }
1427
1428    #[tokio::test]
1429    async fn layer3_skipped_without_callback() {
1430        let log = create_test_log();
1431        let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1432        spec.enable_retry = Some(true);
1433        spec.max_retries = Some(3);
1434        spec.enable_repair = Some(false);
1435
1436        // No callback - Layer 3 should be skipped
1437        let mut engine = StructuredOutputEngine::new(spec, log.clone());
1438
1439        let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
1440
1441        assert!(result.is_err(), "Should fail without callback");
1442
1443        // Check that Layer 3 attempts were made but failed due to no callback
1444        let events = log.events();
1445        let layer3_attempts = events.iter().filter(|e| {
1446            matches!(
1447                &e.kind,
1448                EventKind::StructuredOutputAttempt {
1449                    layer: 3,
1450                    success: false,
1451                    error: Some(err),
1452                    ..
1453                } if err.contains("no infer callback")
1454            )
1455        });
1456        assert!(
1457            layer3_attempts.count() > 0,
1458            "Should have Layer 3 attempt events showing no callback"
1459        );
1460    }
1461
1462    // ═══════════════════════════════════════════════════════════════════════════
1463    // LAYER 4 TESTS
1464    // ═══════════════════════════════════════════════════════════════════════════
1465
1466    #[tokio::test]
1467    async fn layer4_actually_repairs_json() {
1468        let call_count = Arc::new(AtomicU32::new(0));
1469        let call_count_clone = call_count.clone();
1470
1471        // Mock callback that returns repaired JSON
1472        let callback: InferCallback = Arc::new(move |prompt: String| {
1473            let count = call_count_clone.clone();
1474            Box::pin(async move {
1475                count.fetch_add(1, Ordering::SeqCst);
1476                // Verify we received a repair prompt
1477                assert!(
1478                    prompt.contains("repair") || prompt.contains("schema"),
1479                    "Should receive repair prompt"
1480                );
1481                // Return valid JSON
1482                Ok(r#"{"name": "Repaired", "age": 25}"#.to_string())
1483            })
1484        });
1485
1486        let log = create_test_log();
1487        let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1488        spec.enable_retry = Some(false); // Skip Layer 3
1489        spec.enable_repair = Some(true);
1490
1491        let mut engine =
1492            StructuredOutputEngine::new(spec, log.clone()).with_infer_callback(callback);
1493
1494        let result = engine.validate("test-task", "totally broken json").await;
1495
1496        assert!(result.is_ok(), "Should succeed after Layer 4 repair");
1497        let r = result.unwrap();
1498        assert_eq!(r.layer, 4, "Should succeed at Layer 4");
1499        assert_eq!(r.layer_name, "llm_repair");
1500        assert_eq!(r.value["name"], "Repaired");
1501        assert!(
1502            call_count.load(Ordering::SeqCst) >= 1,
1503            "Should have called repair LLM"
1504        );
1505    }
1506
1507    #[tokio::test]
1508    async fn layer4_skipped_without_callback() {
1509        let log = create_test_log();
1510        let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1511        spec.enable_retry = Some(false);
1512        spec.enable_repair = Some(true);
1513
1514        // No callback - Layer 4 should be skipped
1515        let mut engine = StructuredOutputEngine::new(spec, log.clone());
1516
1517        let result = engine.validate("test-task", "broken json").await;
1518
1519        assert!(result.is_err(), "Should fail without callback");
1520
1521        // Check that Layer 4 attempt was made but failed due to no callback
1522        let events = log.events();
1523        let layer4_attempts = events.iter().filter(|e| {
1524            matches!(
1525                &e.kind,
1526                EventKind::StructuredOutputAttempt {
1527                    layer: 4,
1528                    success: false,
1529                    error: Some(err),
1530                    ..
1531                } if err.contains("no infer callback")
1532            )
1533        });
1534        assert!(
1535            layer4_attempts.count() > 0,
1536            "Should have Layer 4 attempt event showing no callback"
1537        );
1538    }
1539
1540    // ═══════════════════════════════════════════════════════════════════════════
1541    // MAX_RETRIES TESTS
1542    // ═══════════════════════════════════════════════════════════════════════════
1543
1544    #[tokio::test]
1545    async fn max_retries_is_respected() {
1546        let call_count = Arc::new(AtomicU32::new(0));
1547        let call_count_clone = call_count.clone();
1548
1549        // Mock callback that always returns invalid JSON
1550        let callback: InferCallback = Arc::new(move |_prompt: String| {
1551            let count = call_count_clone.clone();
1552            Box::pin(async move {
1553                count.fetch_add(1, Ordering::SeqCst);
1554                // Always return invalid JSON (missing age)
1555                Ok(r#"{"still_invalid": true}"#.to_string())
1556            })
1557        });
1558
1559        let log = create_test_log();
1560        let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1561        spec.max_retries = Some(3);
1562        spec.enable_retry = Some(true);
1563        spec.enable_repair = Some(false); // Skip Layer 4
1564
1565        let mut engine =
1566            StructuredOutputEngine::new(spec, log.clone()).with_infer_callback(callback);
1567
1568        let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
1569
1570        assert!(result.is_err(), "Should fail after max retries");
1571        assert_eq!(
1572            call_count.load(Ordering::SeqCst),
1573            3,
1574            "Should have retried exactly max_retries times"
1575        );
1576    }
1577
1578    #[tokio::test]
1579    async fn layer3_layer4_chain_works() {
1580        let call_count = Arc::new(AtomicU32::new(0));
1581        let call_count_clone = call_count.clone();
1582
1583        // Mock callback:
1584        // - Layer 3 retries all fail (return invalid JSON)
1585        // - Layer 4 repair succeeds (return valid JSON)
1586        // Note: Detect Layer 4 by "JSON repair assistant" which is unique to repair prompt
1587        let callback: InferCallback = Arc::new(move |prompt: String| {
1588            let count = call_count_clone.clone();
1589            Box::pin(async move {
1590                let n = count.fetch_add(1, Ordering::SeqCst);
1591                if prompt.contains("JSON repair assistant") {
1592                    // Layer 4 repair call - succeed
1593                    Ok(r#"{"name": "Repaired", "age": 42}"#.to_string())
1594                } else {
1595                    // Layer 3 retry calls - always fail
1596                    Ok(format!(
1597                        r#"{{"retry_attempt": {}, "still_invalid": true}}"#,
1598                        n
1599                    ))
1600                }
1601            })
1602        });
1603
1604        let log = create_test_log();
1605        let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1606        spec.max_retries = Some(2);
1607        spec.enable_retry = Some(true);
1608        spec.enable_repair = Some(true);
1609
1610        let mut engine = StructuredOutputEngine::new(spec, log.clone())
1611            .with_infer_callback(callback)
1612            .with_original_prompt("Generate user".to_string());
1613
1614        let result = engine.validate("test-task", r#"{"invalid": true}"#).await;
1615
1616        assert!(result.is_ok(), "Should succeed after Layer 4 repair");
1617        let r = result.unwrap();
1618        assert_eq!(r.layer, 4, "Should succeed at Layer 4");
1619        assert_eq!(r.value["name"], "Repaired");
1620        // Should have: 2 Layer 3 retries + 1 Layer 4 repair = 3 calls
1621        assert_eq!(
1622            call_count.load(Ordering::SeqCst),
1623            3,
1624            "Should have made 2 retry calls + 1 repair call"
1625        );
1626    }
1627
1628    #[tokio::test]
1629    async fn original_prompt_included_in_retry() {
1630        let captured_prompt = Arc::new(std::sync::Mutex::new(String::new()));
1631        let captured_prompt_clone = captured_prompt.clone();
1632
1633        let callback: InferCallback = Arc::new(move |prompt: String| {
1634            let captured = captured_prompt_clone.clone();
1635            Box::pin(async move {
1636                *captured.lock().unwrap() = prompt.clone();
1637                // Return valid JSON
1638                Ok(r#"{"name": "Test", "age": 30}"#.to_string())
1639            })
1640        });
1641
1642        let log = create_test_log();
1643        let mut spec = StructuredOutputSpec::with_inline_schema(create_user_schema());
1644        spec.enable_retry = Some(true);
1645        spec.max_retries = Some(1);
1646        spec.enable_repair = Some(false);
1647
1648        let mut engine = StructuredOutputEngine::new(spec, log.clone())
1649            .with_infer_callback(callback)
1650            .with_original_prompt("Generate a user object for testing".to_string());
1651
1652        let _ = engine.validate("test-task", r#"{"invalid": true}"#).await;
1653
1654        let prompt = captured_prompt.lock().unwrap().clone();
1655        assert!(
1656            prompt.contains("Generate a user object for testing"),
1657            "Retry prompt should include original prompt"
1658        );
1659        assert!(
1660            prompt.contains("invalid"),
1661            "Retry prompt should include the invalid output"
1662        );
1663    }
1664}