Skip to main content

agent_sdk_providers/
structured.rs

1//! Schema-validated structured output.
2//!
3//! This module implements a provider-agnostic runner that constrains a model's
4//! final answer to a JSON Schema, validates the output, and bounded-re-prompts
5//! on mismatch before failing with a typed error.
6//!
7//! # How it works
8//!
9//! 1. The caller supplies a [`ChatRequest`] whose
10//!    [`response_format`](agent_sdk_foundation::llm::ChatRequest::response_format) is
11//!    set, plus a [`StructuredConfig`] bounding the retries.
12//! 2. The runner inspects the provider's
13//!    [`structured_output_support`](crate::LlmProvider::structured_output_support):
14//!    - [`Native`](crate::StructuredOutputSupport::Native) — the provider
15//!      already mapped `response_format` onto its wire request (`OpenAI` JSON
16//!      mode, Gemini `responseSchema`). The structured value is parsed from the
17//!      assistant's text output.
18//!    - [`ToolForcing`](crate::StructuredOutputSupport::ToolForcing) — the
19//!      runner injects a single forced "respond" tool whose `input_schema` is
20//!      the output schema (the Anthropic fallback) and reads the structured
21//!      value from that tool call's input.
22//! 3. The candidate value is validated against the schema with `jsonschema`.
23//!    On success it is returned. On failure the runner appends the model's
24//!    output plus a corrective user message describing the validation errors
25//!    and retries, up to [`StructuredConfig::max_retries`] times. Exhausting
26//!    the budget yields [`StructuredOutputError::RetriesExhausted`].
27//!
28//! This mirrors the Claude SDK's `output_format` +
29//! `error_max_structured_output_retries` behaviour.
30
31use agent_sdk_foundation::llm::{
32    ChatOutcome, ChatRequest, ChatResponse, ContentBlock, Message, ResponseFormat, Tool, ToolChoice,
33};
34use agent_sdk_foundation::types::ToolTier;
35
36use crate::provider::{LlmProvider, StructuredOutputSupport};
37
38/// The forced tool name used for the tool-forcing fallback (Anthropic).
39const RESPOND_TOOL_NAME: &str = "respond";
40
41/// Bounds for the structured-output re-prompt loop.
42#[derive(Debug, Clone, Copy)]
43pub struct StructuredConfig {
44    /// Maximum number of *re-prompts* after the first attempt. A value of `2`
45    /// means up to three model calls total (1 initial + 2 retries) before the
46    /// runner gives up with [`StructuredOutputError::RetriesExhausted`].
47    ///
48    /// Mirrors the Claude SDK `error_max_structured_output_retries`.
49    pub max_retries: u32,
50}
51
52impl Default for StructuredConfig {
53    fn default() -> Self {
54        Self { max_retries: 2 }
55    }
56}
57
58/// A successfully validated structured output and the response that produced it.
59#[derive(Debug, Clone)]
60pub struct StructuredOutput {
61    /// The validated JSON value, guaranteed to satisfy the requested schema.
62    pub value: serde_json::Value,
63    /// The full provider response that produced [`value`](Self::value), so
64    /// callers can still read usage, stop reason, and any leading text.
65    pub response: ChatResponse,
66    /// Number of re-prompts performed before the value validated (0 when the
67    /// first attempt already satisfied the schema).
68    pub retries: u32,
69}
70
71/// Errors from the structured-output runner.
72///
73/// These are *typed* terminal outcomes — the runner never panics on a model
74/// that fails to produce schema-valid output.
75#[derive(Debug, thiserror::Error)]
76pub enum StructuredOutputError {
77    /// The request did not carry a
78    /// [`response_format`](agent_sdk_foundation::llm::ChatRequest::response_format).
79    #[error("structured output requested without a response_format on the request")]
80    MissingResponseFormat,
81
82    /// The schema supplied in the response format is not a valid JSON Schema.
83    #[error("invalid output JSON schema: {0}")]
84    InvalidSchema(String),
85
86    /// The model produced no extractable structured value (no JSON text /
87    /// no forced tool call).
88    #[error("model produced no structured output to validate")]
89    NoStructuredOutput,
90
91    /// The provider returned a non-success outcome (rate limit, server error,
92    /// invalid request).
93    #[error("provider returned a non-success outcome: {0}")]
94    ProviderOutcome(String),
95
96    /// The re-prompt budget was exhausted and the latest output still failed
97    /// schema validation. Carries the final validation errors and the last
98    /// candidate value for diagnostics.
99    #[error(
100        "structured output failed schema validation after {attempts} attempt(s); last errors: {errors}"
101    )]
102    RetriesExhausted {
103        /// Total number of model calls made (initial + retries).
104        attempts: u32,
105        /// Human-readable concatenation of the final validation errors.
106        errors: String,
107        /// The last candidate value the model produced, if any.
108        last_value: Option<serde_json::Value>,
109    },
110
111    /// A transport-level error bubbled up from the provider.
112    #[error(transparent)]
113    Transport(#[from] anyhow::Error),
114}
115
116/// Run a bounded, schema-validated structured-output exchange against `provider`.
117///
118/// The `request`'s
119/// [`response_format`](agent_sdk_foundation::llm::ChatRequest::response_format) must
120/// be set; on success the returned [`StructuredOutput::value`] is guaranteed to
121/// satisfy that schema.
122///
123/// # Errors
124///
125/// Returns a [`StructuredOutputError`] when the request is missing a response
126/// format, the schema is invalid, the provider errors, or the model fails to
127/// produce schema-valid output within [`StructuredConfig::max_retries`].
128pub async fn run_structured(
129    provider: &dyn LlmProvider,
130    mut request: ChatRequest,
131    config: StructuredConfig,
132) -> Result<StructuredOutput, StructuredOutputError> {
133    let response_format = request
134        .response_format
135        .clone()
136        .ok_or(StructuredOutputError::MissingResponseFormat)?;
137
138    // Compile the validator once; reuse it across every retry.
139    let validator = jsonschema::validator_for(&response_format.schema)
140        .map_err(|e| StructuredOutputError::InvalidSchema(e.to_string()))?;
141
142    let support = provider.structured_output_support();
143    if matches!(support, StructuredOutputSupport::ToolForcing) {
144        apply_tool_forcing(&mut request, &response_format);
145    }
146
147    let max_attempts = config.max_retries.saturating_add(1);
148    let mut last_value: Option<serde_json::Value> = None;
149    let mut last_errors = String::new();
150
151    for attempt in 0..max_attempts {
152        let outcome = provider.chat(request.clone()).await?;
153        let response = match outcome {
154            ChatOutcome::Success(response) => response,
155            ChatOutcome::RateLimited => {
156                return Err(StructuredOutputError::ProviderOutcome(
157                    "rate limited".to_owned(),
158                ));
159            }
160            ChatOutcome::InvalidRequest(msg) => {
161                return Err(StructuredOutputError::ProviderOutcome(format!(
162                    "invalid request: {msg}"
163                )));
164            }
165            ChatOutcome::ServerError(msg) => {
166                return Err(StructuredOutputError::ProviderOutcome(format!(
167                    "server error: {msg}"
168                )));
169            }
170            // `ChatOutcome` is `#[non_exhaustive]`; an unrecognized outcome is
171            // surfaced as a provider failure rather than silently retried.
172            _ => {
173                return Err(StructuredOutputError::ProviderOutcome(
174                    "unrecognized provider outcome".to_owned(),
175                ));
176            }
177        };
178
179        let candidate = extract_candidate(&response, support);
180        let Some(value) = candidate else {
181            // No structured value at all. On the final attempt this is a hard
182            // failure; otherwise re-prompt asking for the structured answer.
183            if attempt + 1 >= max_attempts {
184                return Err(StructuredOutputError::NoStructuredOutput);
185            }
186            append_correction(
187                &mut request,
188                &response,
189                "Your previous reply did not contain a structured answer. \
190                 Respond with a single JSON value that satisfies the requested schema.",
191            );
192            "missing structured output".clone_into(&mut last_errors);
193            continue;
194        };
195
196        let errors: Vec<String> = validator
197            .iter_errors(&value)
198            .map(|error| format!("at `{}`: {error}", error.instance_path()))
199            .collect();
200
201        if errors.is_empty() {
202            return Ok(StructuredOutput {
203                value,
204                response,
205                retries: attempt,
206            });
207        }
208
209        last_errors = errors.join("; ");
210        last_value = Some(value);
211
212        if attempt + 1 < max_attempts {
213            let correction = format!(
214                "Your previous JSON output did not satisfy the schema. \
215                 Fix these validation errors and resend the full JSON value: {last_errors}"
216            );
217            append_correction(&mut request, &response, &correction);
218        }
219    }
220
221    Err(StructuredOutputError::RetriesExhausted {
222        attempts: max_attempts,
223        errors: last_errors,
224        last_value,
225    })
226}
227
228/// Inject the forced "respond" tool for providers without native JSON mode.
229fn apply_tool_forcing(request: &mut ChatRequest, response_format: &ResponseFormat) {
230    let respond_tool = Tool {
231        name: RESPOND_TOOL_NAME.to_owned(),
232        description: format!(
233            "Return the final answer as structured data named `{}`. \
234             You MUST call this tool exactly once with arguments matching the schema.",
235            response_format.name
236        ),
237        input_schema: response_format.schema.clone(),
238        display_name: "Structured response".to_owned(),
239        tier: ToolTier::Observe,
240    };
241
242    match request.tools {
243        Some(ref mut tools) => {
244            tools.retain(|t| t.name != RESPOND_TOOL_NAME);
245            tools.push(respond_tool);
246        }
247        None => request.tools = Some(vec![respond_tool]),
248    }
249    request.tool_choice = Some(ToolChoice::Tool(RESPOND_TOOL_NAME.to_owned()));
250}
251
252/// Pull the candidate structured value out of a response according to how the
253/// provider satisfied the request.
254fn extract_candidate(
255    response: &ChatResponse,
256    support: StructuredOutputSupport,
257) -> Option<serde_json::Value> {
258    match support {
259        StructuredOutputSupport::ToolForcing => {
260            response.content.iter().find_map(|block| match block {
261                ContentBlock::ToolUse { name, input, .. } if name == RESPOND_TOOL_NAME => {
262                    Some(input.clone())
263                }
264                _ => None,
265            })
266        }
267        StructuredOutputSupport::Native => {
268            let text = response.first_text()?;
269            parse_json_text(text)
270        }
271    }
272}
273
274/// Parse a JSON value from model text output.
275///
276/// Native JSON mode returns a bare JSON document, but models occasionally wrap
277/// it in a fenced code block, so this strips a leading/trailing markdown fence
278/// before parsing.
279fn parse_json_text(text: &str) -> Option<serde_json::Value> {
280    let trimmed = text.trim();
281    let unfenced = strip_code_fence(trimmed);
282    serde_json::from_str(unfenced).ok()
283}
284
285/// Strip a surrounding ```` ```json ... ``` ```` (or plain ```` ``` ````) fence.
286fn strip_code_fence(text: &str) -> &str {
287    let Some(rest) = text.strip_prefix("```") else {
288        return text;
289    };
290    // Drop an optional language tag on the opening fence line.
291    let rest = rest.split_once('\n').map_or(rest, |(_, body)| body);
292    rest.strip_suffix("```")
293        .map_or(text, |inner| inner.trim_end_matches('`').trim())
294}
295
296/// Append the assistant's previous output plus a corrective user message so the
297/// next attempt sees the validation feedback.
298fn append_correction(request: &mut ChatRequest, previous: &ChatResponse, correction: &str) {
299    request
300        .messages
301        .push(Message::assistant_with_content(previous.content.clone()));
302    request.messages.push(Message::user(correction));
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    use std::sync::Mutex;
310    use std::sync::atomic::{AtomicUsize, Ordering};
311
312    use agent_sdk_foundation::llm::{StopReason, Usage};
313    use anyhow::Result;
314    use async_trait::async_trait;
315
316    use crate::streaming::StreamBox;
317
318    /// A scripted provider: replays a fixed queue of [`ChatOutcome`]s and
319    /// reports a configurable [`StructuredOutputSupport`]. It also records every
320    /// request it receives so tests can assert on the re-prompt history and on
321    /// the tool-forcing injection.
322    struct ScriptedProvider {
323        provider_name: &'static str,
324        model: String,
325        support: StructuredOutputSupport,
326        outcomes: Mutex<std::collections::VecDeque<ChatOutcome>>,
327        seen_requests: Mutex<Vec<ChatRequest>>,
328        calls: AtomicUsize,
329    }
330
331    impl ScriptedProvider {
332        fn new(
333            provider_name: &'static str,
334            support: StructuredOutputSupport,
335            outcomes: Vec<ChatOutcome>,
336        ) -> Self {
337            Self {
338                provider_name,
339                model: "scripted-model".to_owned(),
340                support,
341                outcomes: Mutex::new(outcomes.into()),
342                seen_requests: Mutex::new(Vec::new()),
343                calls: AtomicUsize::new(0),
344            }
345        }
346
347        fn call_count(&self) -> usize {
348            self.calls.load(Ordering::SeqCst)
349        }
350    }
351
352    #[async_trait]
353    impl LlmProvider for ScriptedProvider {
354        async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
355            self.calls.fetch_add(1, Ordering::SeqCst);
356            self.seen_requests
357                .lock()
358                .expect("seen_requests lock")
359                .push(request);
360            let outcome = self
361                .outcomes
362                .lock()
363                .expect("outcomes lock")
364                .pop_front()
365                .expect("ScriptedProvider: ran out of scripted outcomes");
366            Ok(outcome)
367        }
368
369        fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
370            Box::pin(async_stream::stream! {
371                yield Err(anyhow::anyhow!("streaming not used in structured tests"));
372            })
373        }
374
375        fn model(&self) -> &str {
376            &self.model
377        }
378
379        fn provider(&self) -> &'static str {
380            self.provider_name
381        }
382
383        fn structured_output_support(&self) -> StructuredOutputSupport {
384            self.support
385        }
386    }
387
388    fn person_schema() -> serde_json::Value {
389        serde_json::json!({
390            "type": "object",
391            "properties": {
392                "name": { "type": "string" },
393                "age": { "type": "integer", "minimum": 0 }
394            },
395            "required": ["name", "age"],
396            "additionalProperties": false
397        })
398    }
399
400    fn request_with_format() -> ChatRequest {
401        ChatRequest {
402            system: String::new(),
403            messages: vec![Message::user("Describe a person.")],
404            tools: None,
405            max_tokens: 256,
406            max_tokens_explicit: true,
407            session_id: None,
408            cached_content: None,
409            thinking: None,
410            tool_choice: None,
411            response_format: Some(ResponseFormat::new("person", person_schema())),
412        }
413    }
414
415    fn success(content: Vec<ContentBlock>) -> ChatOutcome {
416        ChatOutcome::Success(ChatResponse {
417            id: "resp".to_owned(),
418            content,
419            model: "scripted-model".to_owned(),
420            stop_reason: Some(StopReason::EndTurn),
421            usage: Usage {
422                input_tokens: 1,
423                output_tokens: 1,
424                cached_input_tokens: 0,
425                cache_creation_input_tokens: 0,
426            },
427        })
428    }
429
430    fn text_block(text: &str) -> Vec<ContentBlock> {
431        vec![ContentBlock::Text {
432            text: text.to_owned(),
433        }]
434    }
435
436    fn respond_tool_block(input: serde_json::Value) -> Vec<ContentBlock> {
437        vec![ContentBlock::ToolUse {
438            id: "call_1".to_owned(),
439            name: RESPOND_TOOL_NAME.to_owned(),
440            input,
441            thought_signature: None,
442        }]
443    }
444
445    // ── Happy path: native (OpenAI / Gemini) ──────────────────────────
446
447    #[tokio::test]
448    async fn native_happy_path_validates_json_text() -> Result<()> {
449        let provider = ScriptedProvider::new(
450            "openai",
451            StructuredOutputSupport::Native,
452            vec![success(text_block(r#"{"name": "Ada", "age": 36}"#))],
453        );
454
455        let out = run_structured(
456            &provider,
457            request_with_format(),
458            StructuredConfig::default(),
459        )
460        .await?;
461
462        assert_eq!(out.value["name"], "Ada");
463        assert_eq!(out.value["age"], 36);
464        assert_eq!(out.retries, 0);
465        assert_eq!(provider.call_count(), 1);
466        Ok(())
467    }
468
469    #[tokio::test]
470    async fn native_happy_path_strips_markdown_fence() -> Result<()> {
471        let provider = ScriptedProvider::new(
472            "gemini",
473            StructuredOutputSupport::Native,
474            vec![success(text_block(
475                "```json\n{\"name\": \"Grace\", \"age\": 45}\n```",
476            ))],
477        );
478
479        let out = run_structured(
480            &provider,
481            request_with_format(),
482            StructuredConfig::default(),
483        )
484        .await?;
485
486        assert_eq!(out.value["name"], "Grace");
487        Ok(())
488    }
489
490    // ── Happy path: tool-forcing fallback (Anthropic) ─────────────────
491
492    #[tokio::test]
493    async fn tool_forcing_happy_path_reads_tool_input() -> Result<()> {
494        let provider = ScriptedProvider::new(
495            "anthropic",
496            StructuredOutputSupport::ToolForcing,
497            vec![success(respond_tool_block(
498                serde_json::json!({"name": "Linus", "age": 54}),
499            ))],
500        );
501
502        let out = run_structured(
503            &provider,
504            request_with_format(),
505            StructuredConfig::default(),
506        )
507        .await?;
508
509        assert_eq!(out.value["name"], "Linus");
510        assert_eq!(out.retries, 0);
511
512        // The runner must have injected the forced respond tool.
513        let (has_respond_tool, forces_respond) = {
514            let seen = provider.seen_requests.lock().expect("seen lock");
515            let tools = seen[0].tools.as_ref().expect("tools injected");
516            (
517                tools.iter().any(|t| t.name == RESPOND_TOOL_NAME),
518                matches!(
519                    seen[0].tool_choice,
520                    Some(ToolChoice::Tool(ref n)) if n == RESPOND_TOOL_NAME
521                ),
522            )
523        };
524        assert!(has_respond_tool);
525        assert!(forces_respond);
526        Ok(())
527    }
528
529    // ── Mismatch → retry → success ────────────────────────────────────
530
531    #[tokio::test]
532    async fn mismatch_then_retry_succeeds() -> Result<()> {
533        let provider = ScriptedProvider::new(
534            "openai",
535            StructuredOutputSupport::Native,
536            vec![
537                // First attempt: `age` is a string, violating the schema.
538                success(text_block(r#"{"name": "Ada", "age": "old"}"#)),
539                // Retry: corrected.
540                success(text_block(r#"{"name": "Ada", "age": 36}"#)),
541            ],
542        );
543
544        let out = run_structured(
545            &provider,
546            request_with_format(),
547            StructuredConfig { max_retries: 2 },
548        )
549        .await?;
550
551        assert_eq!(out.value["age"], 36);
552        assert_eq!(out.retries, 1);
553        assert_eq!(provider.call_count(), 2);
554
555        // The corrective re-prompt must have appended the prior answer + a
556        // user correction message.
557        let grew = {
558            let seen = provider.seen_requests.lock().expect("seen lock");
559            seen[1].messages.len() > seen[0].messages.len()
560        };
561        assert!(grew);
562        Ok(())
563    }
564
565    // ── Retry exhaustion → typed error ────────────────────────────────
566
567    #[tokio::test]
568    async fn retry_exhaustion_yields_typed_error() -> Result<()> {
569        let provider = ScriptedProvider::new(
570            "anthropic",
571            StructuredOutputSupport::ToolForcing,
572            vec![
573                success(respond_tool_block(serde_json::json!({"name": "x"}))),
574                success(respond_tool_block(serde_json::json!({"name": "y"}))),
575                success(respond_tool_block(serde_json::json!({"name": "z"}))),
576            ],
577        );
578
579        let err = run_structured(
580            &provider,
581            request_with_format(),
582            StructuredConfig { max_retries: 2 },
583        )
584        .await
585        .expect_err("schema never satisfied");
586
587        match err {
588            StructuredOutputError::RetriesExhausted {
589                attempts,
590                last_value,
591                ..
592            } => {
593                assert_eq!(attempts, 3, "1 initial + 2 retries");
594                assert_eq!(
595                    last_value.as_ref().and_then(|v| v["name"].as_str()),
596                    Some("z")
597                );
598            }
599            other => panic!("expected RetriesExhausted, got {other:?}"),
600        }
601        // initial + 2 retries == 3 calls.
602        assert_eq!(provider.call_count(), 3);
603        Ok(())
604    }
605
606    #[tokio::test]
607    async fn zero_retries_fails_after_single_attempt() -> Result<()> {
608        let provider = ScriptedProvider::new(
609            "openai",
610            StructuredOutputSupport::Native,
611            vec![success(text_block(r#"{"name": "Ada"}"#))],
612        );
613
614        let err = run_structured(
615            &provider,
616            request_with_format(),
617            StructuredConfig { max_retries: 0 },
618        )
619        .await
620        .expect_err("missing required `age`");
621
622        assert!(matches!(
623            err,
624            StructuredOutputError::RetriesExhausted { attempts: 1, .. }
625        ));
626        assert_eq!(provider.call_count(), 1);
627        Ok(())
628    }
629
630    // ── Error surfaces ────────────────────────────────────────────────
631
632    #[tokio::test]
633    async fn missing_response_format_is_typed_error() {
634        let provider = ScriptedProvider::new(
635            "openai",
636            StructuredOutputSupport::Native,
637            vec![success(text_block("{}"))],
638        );
639        let mut req = request_with_format();
640        req.response_format = None;
641
642        let err = run_structured(&provider, req, StructuredConfig::default())
643            .await
644            .expect_err("no response format");
645        assert!(matches!(err, StructuredOutputError::MissingResponseFormat));
646    }
647
648    #[tokio::test]
649    async fn invalid_schema_is_typed_error() {
650        let provider = ScriptedProvider::new(
651            "openai",
652            StructuredOutputSupport::Native,
653            vec![success(text_block("{}"))],
654        );
655        let mut req = request_with_format();
656        // `type` must be a string/array, not a number — an invalid schema.
657        req.response_format = Some(ResponseFormat::new("bad", serde_json::json!({"type": 123})));
658
659        let err = run_structured(&provider, req, StructuredConfig::default())
660            .await
661            .expect_err("invalid schema");
662        assert!(matches!(err, StructuredOutputError::InvalidSchema(_)));
663    }
664
665    #[tokio::test]
666    async fn provider_rate_limit_surfaces_as_typed_error() {
667        let provider = ScriptedProvider::new(
668            "openai",
669            StructuredOutputSupport::Native,
670            vec![ChatOutcome::RateLimited],
671        );
672
673        let err = run_structured(
674            &provider,
675            request_with_format(),
676            StructuredConfig::default(),
677        )
678        .await
679        .expect_err("rate limited");
680        assert!(matches!(err, StructuredOutputError::ProviderOutcome(_)));
681    }
682
683    #[tokio::test]
684    async fn no_structured_output_on_final_attempt_errors() {
685        // Native provider returns non-JSON prose every time.
686        let provider = ScriptedProvider::new(
687            "openai",
688            StructuredOutputSupport::Native,
689            vec![
690                success(text_block("I cannot do that.")),
691                success(text_block("Still prose, sorry.")),
692            ],
693        );
694
695        let err = run_structured(
696            &provider,
697            request_with_format(),
698            StructuredConfig { max_retries: 1 },
699        )
700        .await
701        .expect_err("never produced JSON");
702        assert!(matches!(err, StructuredOutputError::NoStructuredOutput));
703        assert_eq!(provider.call_count(), 2);
704    }
705}