Skip to main content

agent_sdk_providers/
structured.rs

1//! Schema-validated structured output.
2//!
3//! This module implements a provider-agnostic runner that constrains a model's
4//! final answer to a JSON Schema, validates the output, and bounded-re-prompts
5//! on mismatch before failing with a typed error.
6//!
7//! # How it works
8//!
9//! 1. The caller supplies a [`ChatRequest`] whose
10//!    [`response_format`](agent_sdk_foundation::llm::ChatRequest::response_format) is
11//!    set, plus a [`StructuredConfig`] bounding the retries.
12//! 2. The runner inspects the provider's
13//!    [`structured_output_support`](crate::LlmProvider::structured_output_support):
14//!    - [`Native`](crate::StructuredOutputSupport::Native) — the provider
15//!      already mapped `response_format` onto its wire request (`OpenAI` JSON
16//!      mode, Gemini `responseSchema`). The structured value is parsed from the
17//!      assistant's text output.
18//!    - [`ToolForcing`](crate::StructuredOutputSupport::ToolForcing) — the
19//!      runner injects a single forced "respond" tool whose `input_schema` is
20//!      the output schema (the Anthropic fallback) and reads the structured
21//!      value from that tool call's input.
22//! 3. The candidate value is validated against the schema with `jsonschema`.
23//!    On success it is returned. On failure the runner appends the model's
24//!    output plus a corrective user message describing the validation errors
25//!    and retries, up to [`StructuredConfig::max_retries`] times. Exhausting
26//!    the budget yields [`StructuredOutputError::RetriesExhausted`].
27//!
28//! This mirrors the Claude SDK's `output_format` +
29//! `error_max_structured_output_retries` behaviour.
30
31use std::pin::Pin;
32
33use agent_sdk_foundation::llm::{
34    ChatOutcome, ChatRequest, ChatResponse, ContentBlock, Message, ResponseFormat, Tool,
35    ToolChoice, Usage,
36};
37use agent_sdk_foundation::types::ToolTier;
38use futures::{Stream, StreamExt};
39
40use crate::provider::{LlmProvider, StructuredOutputSupport};
41use crate::streaming::{StreamAccumulator, StreamDelta, StreamErrorKind};
42
43/// The forced tool name used for the tool-forcing fallback (Anthropic).
44const RESPOND_TOOL_NAME: &str = "respond";
45
46/// Bounds for the structured-output re-prompt loop.
47#[derive(Debug, Clone, Copy)]
48pub struct StructuredConfig {
49    /// Maximum number of *re-prompts* after the first attempt. A value of `2`
50    /// means up to three model calls total (1 initial + 2 retries) before the
51    /// runner gives up with [`StructuredOutputError::RetriesExhausted`].
52    ///
53    /// Mirrors the Claude SDK `error_max_structured_output_retries`.
54    pub max_retries: u32,
55}
56
57impl Default for StructuredConfig {
58    fn default() -> Self {
59        Self { max_retries: 2 }
60    }
61}
62
63/// A successfully validated structured output and the response that produced it.
64#[derive(Debug, Clone)]
65pub struct StructuredOutput {
66    /// The validated JSON value, guaranteed to satisfy the requested schema.
67    pub value: serde_json::Value,
68    /// The full provider response that produced [`value`](Self::value), so
69    /// callers can still read usage, stop reason, and any leading text.
70    pub response: ChatResponse,
71    /// Number of re-prompts performed before the value validated (0 when the
72    /// first attempt already satisfied the schema).
73    pub retries: u32,
74}
75
76/// Errors from the structured-output runner.
77///
78/// These are *typed* terminal outcomes — the runner never panics on a model
79/// that fails to produce schema-valid output.
80#[derive(Debug, thiserror::Error)]
81pub enum StructuredOutputError {
82    /// The request did not carry a
83    /// [`response_format`](agent_sdk_foundation::llm::ChatRequest::response_format).
84    #[error("structured output requested without a response_format on the request")]
85    MissingResponseFormat,
86
87    /// The schema supplied in the response format is not a valid JSON Schema.
88    #[error("invalid output JSON schema: {0}")]
89    InvalidSchema(String),
90
91    /// The model produced no extractable structured value (no JSON text /
92    /// no forced tool call).
93    #[error("model produced no structured output to validate")]
94    NoStructuredOutput,
95
96    /// The provider returned a non-success outcome (rate limit, server error,
97    /// invalid request).
98    #[error("provider returned a non-success outcome: {0}")]
99    ProviderOutcome(String),
100
101    /// The re-prompt budget was exhausted and the latest output still failed
102    /// schema validation. Carries the final validation errors and the last
103    /// candidate value for diagnostics.
104    #[error(
105        "structured output failed schema validation after {attempts} attempt(s); last errors: {errors}"
106    )]
107    RetriesExhausted {
108        /// Total number of model calls made (initial + retries).
109        attempts: u32,
110        /// Human-readable concatenation of the final validation errors.
111        errors: String,
112        /// The last candidate value the model produced, if any.
113        last_value: Option<serde_json::Value>,
114    },
115
116    /// A transport-level error bubbled up from the provider.
117    #[error(transparent)]
118    Transport(#[from] anyhow::Error),
119}
120
121/// Run a bounded, schema-validated structured-output exchange against `provider`.
122///
123/// The `request`'s
124/// [`response_format`](agent_sdk_foundation::llm::ChatRequest::response_format) must
125/// be set; on success the returned [`StructuredOutput::value`] is guaranteed to
126/// satisfy that schema.
127///
128/// # Errors
129///
130/// Returns a [`StructuredOutputError`] when the request is missing a response
131/// format, the schema is invalid, the provider errors, or the model fails to
132/// produce schema-valid output within [`StructuredConfig::max_retries`].
133pub async fn run_structured(
134    provider: &dyn LlmProvider,
135    mut request: ChatRequest,
136    config: StructuredConfig,
137) -> Result<StructuredOutput, StructuredOutputError> {
138    let response_format = request
139        .response_format
140        .clone()
141        .ok_or(StructuredOutputError::MissingResponseFormat)?;
142
143    // Compile the validator once; reuse it across every retry.
144    let validator = jsonschema::validator_for(&response_format.schema)
145        .map_err(|e| StructuredOutputError::InvalidSchema(e.to_string()))?;
146
147    let support = provider.structured_output_support();
148    if matches!(support, StructuredOutputSupport::ToolForcing) {
149        apply_tool_forcing(&mut request, &response_format);
150    }
151
152    let max_attempts = config.max_retries.saturating_add(1);
153    let mut last_value: Option<serde_json::Value> = None;
154    let mut last_errors = String::new();
155
156    for attempt in 0..max_attempts {
157        // Clone only when a retry may still follow; on the final attempt move the
158        // request in (no deep clone of the message history + attachments).
159        let attempt_request = if attempt + 1 == max_attempts {
160            std::mem::replace(&mut request, ChatRequest::new(String::new(), Vec::new()))
161        } else {
162            request.clone()
163        };
164        let outcome = provider.chat(attempt_request).await?;
165        let response = match outcome {
166            ChatOutcome::Success(response) => response,
167            ChatOutcome::RateLimited(_) => {
168                return Err(StructuredOutputError::ProviderOutcome(
169                    "rate limited".to_owned(),
170                ));
171            }
172            ChatOutcome::InvalidRequest(msg) => {
173                return Err(StructuredOutputError::ProviderOutcome(format!(
174                    "invalid request: {msg}"
175                )));
176            }
177            ChatOutcome::ServerError(msg) => {
178                return Err(StructuredOutputError::ProviderOutcome(format!(
179                    "server error: {msg}"
180                )));
181            }
182            // `ChatOutcome` is `#[non_exhaustive]`; an unrecognized outcome is
183            // surfaced as a provider failure rather than silently retried.
184            _ => {
185                return Err(StructuredOutputError::ProviderOutcome(
186                    "unrecognized provider outcome".to_owned(),
187                ));
188            }
189        };
190
191        let candidate = extract_candidate(&response, support);
192        let Some(value) = candidate else {
193            // No structured value at all. On the final attempt this is a hard
194            // failure; otherwise re-prompt asking for the structured answer.
195            if attempt + 1 >= max_attempts {
196                return Err(StructuredOutputError::NoStructuredOutput);
197            }
198            append_correction(
199                &mut request,
200                &response,
201                support,
202                "Your previous reply did not contain a structured answer. \
203                 Respond with a single JSON value that satisfies the requested schema.",
204            );
205            "missing structured output".clone_into(&mut last_errors);
206            continue;
207        };
208
209        let errors = collect_schema_errors(&validator, &value);
210
211        if errors.is_empty() {
212            return Ok(StructuredOutput {
213                value,
214                response,
215                retries: attempt,
216            });
217        }
218
219        last_errors = errors.join("; ");
220        last_value = Some(value);
221
222        if attempt + 1 < max_attempts {
223            let correction = format!(
224                "Your previous JSON output did not satisfy the schema. \
225                 Fix these validation errors and resend the full JSON value: {last_errors}"
226            );
227            append_correction(&mut request, &response, support, &correction);
228        }
229    }
230
231    Err(StructuredOutputError::RetriesExhausted {
232        attempts: max_attempts,
233        errors: last_errors,
234        last_value,
235    })
236}
237
238/// An incremental update emitted by [`run_structured_stream`].
239#[derive(Debug, Clone)]
240pub enum StructuredStreamUpdate {
241    /// A best-effort, not-yet-validated object parsed from the partial
242    /// response as it streams in. Successive `Partial`s grow toward the final
243    /// value; a consumer can render them as a live preview. Because the
244    /// underlying JSON is incomplete, a partial may contain truncated string
245    /// values and is never schema-validated.
246    Partial(serde_json::Value),
247    /// The final, schema-validated structured output. Exactly one `Final` is
248    /// emitted on success and it is always the last item in the stream.
249    Final(StructuredOutput),
250}
251
252/// A stream of [`StructuredStreamUpdate`]s produced by [`run_structured_stream`].
253pub type StructuredStream<'a> =
254    Pin<Box<dyn Stream<Item = Result<StructuredStreamUpdate, StructuredOutputError>> + Send + 'a>>;
255
256/// Streaming counterpart to [`run_structured`].
257///
258/// Drives the same bounded, schema-validated exchange, but streams the first
259/// attempt so callers see the structured object build incrementally
260/// ([`StructuredStreamUpdate::Partial`]) before the validated
261/// [`StructuredStreamUpdate::Final`] arrives. If the first (streamed) attempt
262/// fails schema validation, the bounded re-prompt loop continues
263/// non-streaming, reusing the exact retry machinery of [`run_structured`].
264///
265/// This is additive: existing [`run_structured`] callers are unaffected.
266///
267/// # Errors
268///
269/// The stream yields a single [`StructuredOutputError`] (and then ends) when
270/// the request lacks a response format, the schema is invalid, the provider
271/// errors, or the model never produces schema-valid output within
272/// [`StructuredConfig::max_retries`].
273pub fn run_structured_stream(
274    provider: &dyn LlmProvider,
275    request: ChatRequest,
276    config: StructuredConfig,
277) -> StructuredStream<'_> {
278    Box::pin(async_stream::stream! {
279        let mut request = request;
280        let Some(response_format) = request.response_format.clone() else {
281            yield Err(StructuredOutputError::MissingResponseFormat);
282            return;
283        };
284        let validator = match jsonschema::validator_for(&response_format.schema) {
285            Ok(validator) => validator,
286            Err(e) => {
287                yield Err(StructuredOutputError::InvalidSchema(e.to_string()));
288                return;
289            }
290        };
291
292        let support = provider.structured_output_support();
293        if matches!(support, StructuredOutputSupport::ToolForcing) {
294            apply_tool_forcing(&mut request, &response_format);
295        }
296
297        let max_attempts = config.max_retries.saturating_add(1);
298        let model = provider.model().to_owned();
299        let mut last_value: Option<serde_json::Value> = None;
300        let mut last_errors = String::new();
301
302        for attempt in 0..max_attempts {
303            // The first attempt streams (emitting partials); retries reuse the
304            // non-streaming path so the re-prompt machinery is shared verbatim.
305            let response = if attempt == 0 {
306                let mut attempt_stream =
307                    Box::pin(stream_first_attempt(provider, request.clone(), support, model.clone()));
308                let mut completed: Option<ChatResponse> = None;
309                while let Some(item) = attempt_stream.next().await {
310                    match item {
311                        StreamAttemptItem::Partial(value) => {
312                            yield Ok(StructuredStreamUpdate::Partial(value));
313                        }
314                        StreamAttemptItem::Complete(response) => completed = Some(response),
315                        StreamAttemptItem::Failed(error) => {
316                            yield Err(error);
317                            return;
318                        }
319                    }
320                }
321                // `stream_first_attempt` always terminates with `Complete` or
322                // `Failed`; the `None` arm is an unreachable safety net.
323                match completed {
324                    Some(response) => response,
325                    None => return,
326                }
327            } else {
328                match provider.chat(request.clone()).await {
329                    Ok(ChatOutcome::Success(response)) => response,
330                    Ok(other) => {
331                        yield Err(non_success_outcome_error(&other));
332                        return;
333                    }
334                    Err(e) => {
335                        yield Err(StructuredOutputError::Transport(e));
336                        return;
337                    }
338                }
339            };
340
341            let Some(value) = extract_candidate(&response, support) else {
342                if attempt + 1 >= max_attempts {
343                    yield Err(StructuredOutputError::NoStructuredOutput);
344                    return;
345                }
346                append_correction(
347                    &mut request,
348                    &response,
349                    support,
350                    "Your previous reply did not contain a structured answer. \
351                     Respond with a single JSON value that satisfies the requested schema.",
352                );
353                "missing structured output".clone_into(&mut last_errors);
354                continue;
355            };
356
357            let errors = collect_schema_errors(&validator, &value);
358
359            if errors.is_empty() {
360                yield Ok(StructuredStreamUpdate::Final(StructuredOutput {
361                    value,
362                    response,
363                    retries: attempt,
364                }));
365                return;
366            }
367
368            last_errors = errors.join("; ");
369            last_value = Some(value);
370
371            if attempt + 1 < max_attempts {
372                let correction = format!(
373                    "Your previous JSON output did not satisfy the schema. \
374                     Fix these validation errors and resend the full JSON value: {last_errors}"
375                );
376                append_correction(&mut request, &response, support, &correction);
377            }
378        }
379
380        yield Err(StructuredOutputError::RetriesExhausted {
381            attempts: max_attempts,
382            errors: last_errors,
383            last_value,
384        });
385    })
386}
387
388/// An item produced by [`stream_first_attempt`]: an incremental partial, the
389/// terminal accumulated response, or a typed failure.
390enum StreamAttemptItem {
391    Partial(serde_json::Value),
392    Complete(ChatResponse),
393    Failed(StructuredOutputError),
394}
395
396/// Stream the first structured-output attempt, emitting de-duplicated partial
397/// objects as the response builds and finishing with the accumulated response
398/// (or a typed failure). Factored out of [`run_structured_stream`] so the
399/// orchestration loop stays small.
400fn stream_first_attempt(
401    provider: &dyn LlmProvider,
402    request: ChatRequest,
403    support: StructuredOutputSupport,
404    model: String,
405) -> impl Stream<Item = StreamAttemptItem> + Send + '_ {
406    async_stream::stream! {
407        let mut accumulator = StreamAccumulator::new();
408        let mut partial_buf = String::new();
409        let mut respond_tool_ids: std::collections::HashSet<String> =
410            std::collections::HashSet::new();
411        let mut last_partial: Option<serde_json::Value> = None;
412        let mut stream_error: Option<(String, StreamErrorKind)> = None;
413
414        let mut stream = provider.chat_stream(request);
415        while let Some(item) = stream.next().await {
416            let delta = match item {
417                Ok(delta) => delta,
418                Err(e) => {
419                    yield StreamAttemptItem::Failed(StructuredOutputError::Transport(e));
420                    return;
421                }
422            };
423
424            accumulate_partial_buffer(&delta, support, &mut partial_buf, &mut respond_tool_ids);
425            if let StreamDelta::Error { message, kind } = &delta {
426                stream_error = Some((message.clone(), *kind));
427            }
428            accumulator.apply(&delta);
429
430            if let Some(value) = partial_from_buffer(&partial_buf)
431                && last_partial.as_ref() != Some(&value)
432            {
433                last_partial = Some(value.clone());
434                yield StreamAttemptItem::Partial(value);
435            }
436        }
437
438        if let Some((message, kind)) = stream_error {
439            yield StreamAttemptItem::Failed(stream_error_to_outcome(&message, kind));
440            return;
441        }
442
443        yield StreamAttemptItem::Complete(build_streamed_response(accumulator, model));
444    }
445}
446
447/// Append the relevant part of a streamed delta to the running partial buffer
448/// so [`partial_from_buffer`] can re-parse it. For native providers this is the
449/// model's text output; for tool-forcing it is the forced `respond` tool's
450/// input JSON.
451fn accumulate_partial_buffer(
452    delta: &StreamDelta,
453    support: StructuredOutputSupport,
454    buffer: &mut String,
455    respond_tool_ids: &mut std::collections::HashSet<String>,
456) {
457    match (support, delta) {
458        (StructuredOutputSupport::Native, StreamDelta::TextDelta { delta, .. }) => {
459            buffer.push_str(delta);
460        }
461        (StructuredOutputSupport::ToolForcing, StreamDelta::ToolUseStart { id, name, .. })
462            if name == RESPOND_TOOL_NAME =>
463        {
464            respond_tool_ids.insert(id.clone());
465        }
466        (StructuredOutputSupport::ToolForcing, StreamDelta::ToolInputDelta { id, delta, .. })
467            if respond_tool_ids.contains(id) =>
468        {
469            buffer.push_str(delta);
470        }
471        _ => {}
472    }
473}
474
475/// Map a recorded streaming error into the structured-output error surface.
476fn stream_error_to_outcome(message: &str, kind: StreamErrorKind) -> StructuredOutputError {
477    let label = match kind {
478        StreamErrorKind::RateLimited => "rate limited".to_owned(),
479        StreamErrorKind::InvalidRequest => format!("invalid request: {message}"),
480        _ => format!("server error: {message}"),
481    };
482    StructuredOutputError::ProviderOutcome(label)
483}
484
485/// Map a non-success [`ChatOutcome`] from a retry attempt into a typed error.
486fn non_success_outcome_error(outcome: &ChatOutcome) -> StructuredOutputError {
487    let label = match outcome {
488        ChatOutcome::RateLimited(_) => "rate limited".to_owned(),
489        ChatOutcome::InvalidRequest(msg) => format!("invalid request: {msg}"),
490        ChatOutcome::ServerError(msg) => format!("server error: {msg}"),
491        _ => "unrecognized provider outcome".to_owned(),
492    };
493    StructuredOutputError::ProviderOutcome(label)
494}
495
496/// Materialize a [`ChatResponse`] from a fully-consumed stream accumulator.
497fn build_streamed_response(mut accumulator: StreamAccumulator, model: String) -> ChatResponse {
498    let usage = accumulator.take_usage().unwrap_or(Usage {
499        input_tokens: 0,
500        output_tokens: 0,
501        cached_input_tokens: 0,
502        cache_creation_input_tokens: 0,
503    });
504    let stop_reason = accumulator.take_stop_reason();
505    ChatResponse {
506        id: String::new(),
507        content: accumulator.into_content_blocks(),
508        model,
509        stop_reason,
510        usage,
511    }
512}
513
514/// Best-effort parse of a partial JSON object/array from an in-flight buffer.
515///
516/// Returns the repaired value only when the buffer (after closing any open
517/// containers) parses to an object or array; otherwise `None` so the caller
518/// simply waits for more data.
519fn partial_from_buffer(buffer: &str) -> Option<serde_json::Value> {
520    let trimmed = buffer.trim_start();
521    // Tolerate a streamed leading ```json fence.
522    let body = trimmed
523        .strip_prefix("```")
524        .and_then(|rest| rest.split_once('\n').map(|(_, body)| body))
525        .unwrap_or(trimmed)
526        .trim();
527    if body.is_empty() {
528        return None;
529    }
530    let repaired = repair_partial_json(body);
531    serde_json::from_str::<serde_json::Value>(&repaired)
532        .ok()
533        .filter(|value| value.is_object() || value.is_array())
534}
535
536/// Close any open strings/containers in a partial JSON fragment so it parses.
537///
538/// Truncated string *values* are kept (closed with `"`); a dangling separator
539/// (`,`) is dropped and a dangling key (`"k":`) is completed with `null`. The
540/// result is not guaranteed to parse (e.g. a half-typed key), in which case the
541/// caller discards it.
542fn repair_partial_json(buffer: &str) -> String {
543    let mut in_string = false;
544    let mut escape = false;
545    let mut stack: Vec<char> = Vec::new();
546
547    for ch in buffer.chars() {
548        if in_string {
549            if escape {
550                escape = false;
551            } else if ch == '\\' {
552                escape = true;
553            } else if ch == '"' {
554                in_string = false;
555            }
556            continue;
557        }
558        match ch {
559            '"' => in_string = true,
560            '{' => stack.push('}'),
561            '[' => stack.push(']'),
562            '}' | ']' => {
563                stack.pop();
564            }
565            _ => {}
566        }
567    }
568
569    let mut out = buffer.to_owned();
570    if escape {
571        out.pop();
572    }
573    if in_string {
574        out.push('"');
575    }
576    out.truncate(out.trim_end().len());
577    if out.ends_with(',') {
578        out.pop();
579        out.truncate(out.trim_end().len());
580    } else if out.ends_with(':') {
581        out.push_str(" null");
582    }
583    for closer in stack.iter().rev() {
584        out.push(*closer);
585    }
586    out
587}
588
589/// Collect human-readable schema-validation errors for a candidate value
590/// (empty when it satisfies the schema).
591fn collect_schema_errors(
592    validator: &jsonschema::Validator,
593    value: &serde_json::Value,
594) -> Vec<String> {
595    validator
596        .iter_errors(value)
597        .map(|error| format!("at `{}`: {error}", error.instance_path()))
598        .collect()
599}
600
601/// Inject the forced "respond" tool for providers without native JSON mode.
602fn apply_tool_forcing(request: &mut ChatRequest, response_format: &ResponseFormat) {
603    let respond_tool = Tool {
604        name: RESPOND_TOOL_NAME.to_owned(),
605        description: format!(
606            "Return the final answer as structured data named `{}`. \
607             You MUST call this tool exactly once with arguments matching the schema.",
608            response_format.name
609        ),
610        input_schema: response_format.schema.clone(),
611        display_name: "Structured response".to_owned(),
612        tier: ToolTier::Observe,
613    };
614
615    match request.tools {
616        Some(ref mut tools) => {
617            tools.retain(|t| t.name != RESPOND_TOOL_NAME);
618            tools.push(respond_tool);
619        }
620        None => request.tools = Some(vec![respond_tool]),
621    }
622    request.tool_choice = Some(ToolChoice::Tool(RESPOND_TOOL_NAME.to_owned()));
623}
624
625/// Pull the candidate structured value out of a response according to how the
626/// provider satisfied the request.
627fn extract_candidate(
628    response: &ChatResponse,
629    support: StructuredOutputSupport,
630) -> Option<serde_json::Value> {
631    match support {
632        StructuredOutputSupport::ToolForcing => {
633            response.content.iter().find_map(|block| match block {
634                ContentBlock::ToolUse { name, input, .. } if name == RESPOND_TOOL_NAME => {
635                    Some(input.clone())
636                }
637                _ => None,
638            })
639        }
640        StructuredOutputSupport::Native => {
641            let text = response.first_text()?;
642            parse_json_text(text)
643        }
644    }
645}
646
647/// Parse a JSON value from model text output.
648///
649/// Native JSON mode returns a bare JSON document, but models occasionally wrap
650/// it in a fenced code block, so this strips a leading/trailing markdown fence
651/// before parsing.
652fn parse_json_text(text: &str) -> Option<serde_json::Value> {
653    let trimmed = text.trim();
654    let unfenced = strip_code_fence(trimmed);
655    serde_json::from_str(unfenced).ok()
656}
657
658/// Strip a surrounding ```` ```json ... ``` ```` (or plain ```` ``` ````) fence.
659fn strip_code_fence(text: &str) -> &str {
660    let Some(rest) = text.strip_prefix("```") else {
661        return text;
662    };
663    // Drop an optional language tag on the opening fence line.
664    let rest = rest.split_once('\n').map_or(rest, |(_, body)| body);
665    rest.strip_suffix("```")
666        .map_or(text, |inner| inner.trim_end_matches('`').trim())
667}
668
669/// Append the assistant's previous output plus a corrective user message so the
670/// next attempt sees the validation feedback.
671///
672/// For the tool-forcing path (Anthropic), the assistant turn carries a forced
673/// `respond` `ContentBlock::ToolUse`. The Anthropic Messages API rejects any
674/// conversation where a `tool_use` is not immediately followed by a matching
675/// `tool_result` in the next user message, so the correction is delivered as a
676/// `ToolResult` for that tool-use id (carrying the validation errors) rather than
677/// as plain user text — otherwise the very first re-prompt 400s. When no forced
678/// tool call is present (or for native providers) the correction is plain text.
679fn append_correction(
680    request: &mut ChatRequest,
681    previous: &ChatResponse,
682    support: StructuredOutputSupport,
683    correction: &str,
684) {
685    request
686        .messages
687        .push(Message::assistant_with_content(previous.content.clone()));
688
689    let respond_tool_use_id = if matches!(support, StructuredOutputSupport::ToolForcing) {
690        previous.content.iter().find_map(|block| match block {
691            ContentBlock::ToolUse { id, name, .. } if name == RESPOND_TOOL_NAME => Some(id.clone()),
692            _ => None,
693        })
694    } else {
695        None
696    };
697
698    match respond_tool_use_id {
699        Some(tool_use_id) => {
700            request
701                .messages
702                .push(Message::tool_result(tool_use_id, correction, true));
703        }
704        None => request.messages.push(Message::user(correction)),
705    }
706}
707
708#[cfg(test)]
709mod tests {
710    use super::*;
711
712    use std::sync::Mutex;
713    use std::sync::atomic::{AtomicUsize, Ordering};
714
715    use agent_sdk_foundation::llm::{StopReason, Usage};
716    use anyhow::Result;
717    use async_trait::async_trait;
718
719    use crate::streaming::StreamBox;
720
721    /// A scripted provider: replays a fixed queue of [`ChatOutcome`]s and
722    /// reports a configurable [`StructuredOutputSupport`]. It also records every
723    /// request it receives so tests can assert on the re-prompt history and on
724    /// the tool-forcing injection.
725    struct ScriptedProvider {
726        provider_name: &'static str,
727        model: String,
728        support: StructuredOutputSupport,
729        outcomes: Mutex<std::collections::VecDeque<ChatOutcome>>,
730        seen_requests: Mutex<Vec<ChatRequest>>,
731        calls: AtomicUsize,
732    }
733
734    impl ScriptedProvider {
735        fn new(
736            provider_name: &'static str,
737            support: StructuredOutputSupport,
738            outcomes: Vec<ChatOutcome>,
739        ) -> Self {
740            Self {
741                provider_name,
742                model: "scripted-model".to_owned(),
743                support,
744                outcomes: Mutex::new(outcomes.into()),
745                seen_requests: Mutex::new(Vec::new()),
746                calls: AtomicUsize::new(0),
747            }
748        }
749
750        fn call_count(&self) -> usize {
751            self.calls.load(Ordering::SeqCst)
752        }
753    }
754
755    #[async_trait]
756    impl LlmProvider for ScriptedProvider {
757        async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
758            self.calls.fetch_add(1, Ordering::SeqCst);
759            self.seen_requests
760                .lock()
761                .expect("seen_requests lock")
762                .push(request);
763            let outcome = self
764                .outcomes
765                .lock()
766                .expect("outcomes lock")
767                .pop_front()
768                .expect("ScriptedProvider: ran out of scripted outcomes");
769            Ok(outcome)
770        }
771
772        fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
773            Box::pin(async_stream::stream! {
774                yield Err(anyhow::anyhow!("streaming not used in structured tests"));
775            })
776        }
777
778        fn model(&self) -> &str {
779            &self.model
780        }
781
782        fn provider(&self) -> &'static str {
783            self.provider_name
784        }
785
786        fn structured_output_support(&self) -> StructuredOutputSupport {
787            self.support
788        }
789    }
790
791    fn person_schema() -> serde_json::Value {
792        serde_json::json!({
793            "type": "object",
794            "properties": {
795                "name": { "type": "string" },
796                "age": { "type": "integer", "minimum": 0 }
797            },
798            "required": ["name", "age"],
799            "additionalProperties": false
800        })
801    }
802
803    fn request_with_format() -> ChatRequest {
804        ChatRequest {
805            system: String::new(),
806            messages: vec![Message::user("Describe a person.")],
807            tools: None,
808            max_tokens: 256,
809            max_tokens_explicit: true,
810            session_id: None,
811            cached_content: None,
812            thinking: None,
813            tool_choice: None,
814            response_format: Some(ResponseFormat::new("person", person_schema())),
815            cache: None,
816        }
817    }
818
819    fn success(content: Vec<ContentBlock>) -> ChatOutcome {
820        ChatOutcome::Success(ChatResponse {
821            id: "resp".to_owned(),
822            content,
823            model: "scripted-model".to_owned(),
824            stop_reason: Some(StopReason::EndTurn),
825            usage: Usage {
826                input_tokens: 1,
827                output_tokens: 1,
828                cached_input_tokens: 0,
829                cache_creation_input_tokens: 0,
830            },
831        })
832    }
833
834    fn text_block(text: &str) -> Vec<ContentBlock> {
835        vec![ContentBlock::Text {
836            text: text.to_owned(),
837        }]
838    }
839
840    fn respond_tool_block(input: serde_json::Value) -> Vec<ContentBlock> {
841        vec![ContentBlock::ToolUse {
842            id: "call_1".to_owned(),
843            name: RESPOND_TOOL_NAME.to_owned(),
844            input,
845            thought_signature: None,
846        }]
847    }
848
849    // ── Happy path: native (OpenAI / Gemini) ──────────────────────────
850
851    #[tokio::test]
852    async fn native_happy_path_validates_json_text() -> Result<()> {
853        let provider = ScriptedProvider::new(
854            "openai",
855            StructuredOutputSupport::Native,
856            vec![success(text_block(r#"{"name": "Ada", "age": 36}"#))],
857        );
858
859        let out = run_structured(
860            &provider,
861            request_with_format(),
862            StructuredConfig::default(),
863        )
864        .await?;
865
866        assert_eq!(out.value["name"], "Ada");
867        assert_eq!(out.value["age"], 36);
868        assert_eq!(out.retries, 0);
869        assert_eq!(provider.call_count(), 1);
870        Ok(())
871    }
872
873    #[tokio::test]
874    async fn native_happy_path_strips_markdown_fence() -> Result<()> {
875        let provider = ScriptedProvider::new(
876            "gemini",
877            StructuredOutputSupport::Native,
878            vec![success(text_block(
879                "```json\n{\"name\": \"Grace\", \"age\": 45}\n```",
880            ))],
881        );
882
883        let out = run_structured(
884            &provider,
885            request_with_format(),
886            StructuredConfig::default(),
887        )
888        .await?;
889
890        assert_eq!(out.value["name"], "Grace");
891        Ok(())
892    }
893
894    // ── Happy path: tool-forcing fallback (Anthropic) ─────────────────
895
896    #[tokio::test]
897    async fn tool_forcing_happy_path_reads_tool_input() -> Result<()> {
898        let provider = ScriptedProvider::new(
899            "anthropic",
900            StructuredOutputSupport::ToolForcing,
901            vec![success(respond_tool_block(
902                serde_json::json!({"name": "Linus", "age": 54}),
903            ))],
904        );
905
906        let out = run_structured(
907            &provider,
908            request_with_format(),
909            StructuredConfig::default(),
910        )
911        .await?;
912
913        assert_eq!(out.value["name"], "Linus");
914        assert_eq!(out.retries, 0);
915
916        // The runner must have injected the forced respond tool.
917        let (has_respond_tool, forces_respond) = {
918            let seen = provider.seen_requests.lock().expect("seen lock");
919            let tools = seen[0].tools.as_ref().expect("tools injected");
920            (
921                tools.iter().any(|t| t.name == RESPOND_TOOL_NAME),
922                matches!(
923                    seen[0].tool_choice,
924                    Some(ToolChoice::Tool(ref n)) if n == RESPOND_TOOL_NAME
925                ),
926            )
927        };
928        assert!(has_respond_tool);
929        assert!(forces_respond);
930        Ok(())
931    }
932
933    // ── Mismatch → retry → success ────────────────────────────────────
934
935    #[tokio::test]
936    async fn mismatch_then_retry_succeeds() -> Result<()> {
937        let provider = ScriptedProvider::new(
938            "openai",
939            StructuredOutputSupport::Native,
940            vec![
941                // First attempt: `age` is a string, violating the schema.
942                success(text_block(r#"{"name": "Ada", "age": "old"}"#)),
943                // Retry: corrected.
944                success(text_block(r#"{"name": "Ada", "age": 36}"#)),
945            ],
946        );
947
948        let out = run_structured(
949            &provider,
950            request_with_format(),
951            StructuredConfig { max_retries: 2 },
952        )
953        .await?;
954
955        assert_eq!(out.value["age"], 36);
956        assert_eq!(out.retries, 1);
957        assert_eq!(provider.call_count(), 2);
958
959        // The corrective re-prompt must have appended the prior answer + a
960        // user correction message.
961        let grew = {
962            let seen = provider.seen_requests.lock().expect("seen lock");
963            seen[1].messages.len() > seen[0].messages.len()
964        };
965        assert!(grew);
966        Ok(())
967    }
968
969    #[tokio::test]
970    async fn tool_forcing_retry_appends_tool_result_for_forced_tool_use() -> Result<()> {
971        use agent_sdk_foundation::llm::Content;
972
973        let provider = ScriptedProvider::new(
974            "anthropic",
975            StructuredOutputSupport::ToolForcing,
976            vec![
977                // First respond: invalid (missing required `age`).
978                success(respond_tool_block(serde_json::json!({"name": "x"}))),
979                // Retry: valid.
980                success(respond_tool_block(
981                    serde_json::json!({"name": "x", "age": 1}),
982                )),
983            ],
984        );
985
986        let out = run_structured(
987            &provider,
988            request_with_format(),
989            StructuredConfig { max_retries: 1 },
990        )
991        .await?;
992        assert_eq!(out.retries, 1);
993
994        // The retry request must be a valid Anthropic conversation: the appended
995        // assistant `respond` tool_use must be answered by a user tool_result with
996        // a matching tool_use_id — not a bare user text message (which 400s).
997        let seen = provider.seen_requests.lock().expect("seen lock");
998        let retry = &seen[1];
999
1000        let assistant_tool_use_id = retry
1001            .messages
1002            .iter()
1003            .find_map(|m| match &m.content {
1004                Content::Blocks(blocks) => blocks.iter().find_map(|b| match b {
1005                    ContentBlock::ToolUse { id, name, .. } if name == RESPOND_TOOL_NAME => {
1006                        Some(id.clone())
1007                    }
1008                    _ => None,
1009                }),
1010                Content::Text(_) => None,
1011            })
1012            .expect("assistant respond tool_use present in retry");
1013
1014        let has_matching_result = retry.messages.iter().any(|m| match &m.content {
1015            Content::Blocks(blocks) => blocks.iter().any(|b| {
1016                matches!(
1017                    b,
1018                    ContentBlock::ToolResult { tool_use_id, .. }
1019                        if *tool_use_id == assistant_tool_use_id
1020                )
1021            }),
1022            Content::Text(_) => false,
1023        });
1024        drop(seen);
1025        assert!(
1026            has_matching_result,
1027            "retry must carry a tool_result for the forced tool_use id"
1028        );
1029        Ok(())
1030    }
1031
1032    // ── Retry exhaustion → typed error ────────────────────────────────
1033
1034    #[tokio::test]
1035    async fn retry_exhaustion_yields_typed_error() -> Result<()> {
1036        let provider = ScriptedProvider::new(
1037            "anthropic",
1038            StructuredOutputSupport::ToolForcing,
1039            vec![
1040                success(respond_tool_block(serde_json::json!({"name": "x"}))),
1041                success(respond_tool_block(serde_json::json!({"name": "y"}))),
1042                success(respond_tool_block(serde_json::json!({"name": "z"}))),
1043            ],
1044        );
1045
1046        let err = run_structured(
1047            &provider,
1048            request_with_format(),
1049            StructuredConfig { max_retries: 2 },
1050        )
1051        .await
1052        .expect_err("schema never satisfied");
1053
1054        match err {
1055            StructuredOutputError::RetriesExhausted {
1056                attempts,
1057                last_value,
1058                ..
1059            } => {
1060                assert_eq!(attempts, 3, "1 initial + 2 retries");
1061                assert_eq!(
1062                    last_value.as_ref().and_then(|v| v["name"].as_str()),
1063                    Some("z")
1064                );
1065            }
1066            other => panic!("expected RetriesExhausted, got {other:?}"),
1067        }
1068        // initial + 2 retries == 3 calls.
1069        assert_eq!(provider.call_count(), 3);
1070        Ok(())
1071    }
1072
1073    #[tokio::test]
1074    async fn zero_retries_fails_after_single_attempt() -> Result<()> {
1075        let provider = ScriptedProvider::new(
1076            "openai",
1077            StructuredOutputSupport::Native,
1078            vec![success(text_block(r#"{"name": "Ada"}"#))],
1079        );
1080
1081        let err = run_structured(
1082            &provider,
1083            request_with_format(),
1084            StructuredConfig { max_retries: 0 },
1085        )
1086        .await
1087        .expect_err("missing required `age`");
1088
1089        assert!(matches!(
1090            err,
1091            StructuredOutputError::RetriesExhausted { attempts: 1, .. }
1092        ));
1093        assert_eq!(provider.call_count(), 1);
1094        Ok(())
1095    }
1096
1097    // ── Error surfaces ────────────────────────────────────────────────
1098
1099    #[tokio::test]
1100    async fn missing_response_format_is_typed_error() {
1101        let provider = ScriptedProvider::new(
1102            "openai",
1103            StructuredOutputSupport::Native,
1104            vec![success(text_block("{}"))],
1105        );
1106        let mut req = request_with_format();
1107        req.response_format = None;
1108
1109        let err = run_structured(&provider, req, StructuredConfig::default())
1110            .await
1111            .expect_err("no response format");
1112        assert!(matches!(err, StructuredOutputError::MissingResponseFormat));
1113    }
1114
1115    #[tokio::test]
1116    async fn invalid_schema_is_typed_error() {
1117        let provider = ScriptedProvider::new(
1118            "openai",
1119            StructuredOutputSupport::Native,
1120            vec![success(text_block("{}"))],
1121        );
1122        let mut req = request_with_format();
1123        // `type` must be a string/array, not a number — an invalid schema.
1124        req.response_format = Some(ResponseFormat::new("bad", serde_json::json!({"type": 123})));
1125
1126        let err = run_structured(&provider, req, StructuredConfig::default())
1127            .await
1128            .expect_err("invalid schema");
1129        assert!(matches!(err, StructuredOutputError::InvalidSchema(_)));
1130    }
1131
1132    #[tokio::test]
1133    async fn provider_rate_limit_surfaces_as_typed_error() {
1134        let provider = ScriptedProvider::new(
1135            "openai",
1136            StructuredOutputSupport::Native,
1137            vec![ChatOutcome::RateLimited(None)],
1138        );
1139
1140        let err = run_structured(
1141            &provider,
1142            request_with_format(),
1143            StructuredConfig::default(),
1144        )
1145        .await
1146        .expect_err("rate limited");
1147        assert!(matches!(err, StructuredOutputError::ProviderOutcome(_)));
1148    }
1149
1150    #[tokio::test]
1151    async fn no_structured_output_on_final_attempt_errors() {
1152        // Native provider returns non-JSON prose every time.
1153        let provider = ScriptedProvider::new(
1154            "openai",
1155            StructuredOutputSupport::Native,
1156            vec![
1157                success(text_block("I cannot do that.")),
1158                success(text_block("Still prose, sorry.")),
1159            ],
1160        );
1161
1162        let err = run_structured(
1163            &provider,
1164            request_with_format(),
1165            StructuredConfig { max_retries: 1 },
1166        )
1167        .await
1168        .expect_err("never produced JSON");
1169        assert!(matches!(err, StructuredOutputError::NoStructuredOutput));
1170        assert_eq!(provider.call_count(), 2);
1171    }
1172
1173    // ── Streaming structured output ───────────────────────────────────
1174
1175    /// A provider that serves a fixed list of streaming deltas from
1176    /// `chat_stream`. The non-streaming `chat` is only exercised on retries
1177    /// (none of the streaming tests below need it).
1178    struct StreamingProvider {
1179        provider_name: &'static str,
1180        model: String,
1181        support: StructuredOutputSupport,
1182        deltas: Mutex<Vec<StreamDelta>>,
1183    }
1184
1185    impl StreamingProvider {
1186        fn new(
1187            provider_name: &'static str,
1188            support: StructuredOutputSupport,
1189            deltas: Vec<StreamDelta>,
1190        ) -> Self {
1191            Self {
1192                provider_name,
1193                model: "scripted-model".to_owned(),
1194                support,
1195                deltas: Mutex::new(deltas),
1196            }
1197        }
1198    }
1199
1200    #[async_trait]
1201    impl LlmProvider for StreamingProvider {
1202        async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
1203            Ok(ChatOutcome::ServerError("chat() not used".to_owned()))
1204        }
1205
1206        fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
1207            let deltas = self.deltas.lock().map(|d| d.clone()).unwrap_or_default();
1208            Box::pin(async_stream::stream! {
1209                for delta in deltas {
1210                    yield Ok(delta);
1211                }
1212            })
1213        }
1214
1215        fn model(&self) -> &str {
1216            &self.model
1217        }
1218
1219        fn provider(&self) -> &'static str {
1220            self.provider_name
1221        }
1222
1223        fn structured_output_support(&self) -> StructuredOutputSupport {
1224            self.support
1225        }
1226    }
1227
1228    async fn drive_stream(
1229        mut stream: StructuredStream<'_>,
1230    ) -> Result<(Vec<serde_json::Value>, Option<StructuredOutput>)> {
1231        let mut partials = Vec::new();
1232        let mut final_out = None;
1233        while let Some(update) = stream.next().await {
1234            match update? {
1235                StructuredStreamUpdate::Partial(value) => partials.push(value),
1236                StructuredStreamUpdate::Final(out) => final_out = Some(out),
1237            }
1238        }
1239        Ok((partials, final_out))
1240    }
1241
1242    #[tokio::test]
1243    async fn streaming_native_emits_partials_then_validated_final() -> Result<()> {
1244        let provider = StreamingProvider::new(
1245            "openai",
1246            StructuredOutputSupport::Native,
1247            vec![
1248                StreamDelta::TextDelta {
1249                    delta: r#"{"name": "Ada""#.to_owned(),
1250                    block_index: 0,
1251                },
1252                StreamDelta::TextDelta {
1253                    delta: r#", "age": 36}"#.to_owned(),
1254                    block_index: 0,
1255                },
1256                StreamDelta::Done {
1257                    stop_reason: Some(StopReason::EndTurn),
1258                },
1259            ],
1260        );
1261
1262        let stream = run_structured_stream(
1263            &provider,
1264            request_with_format(),
1265            StructuredConfig::default(),
1266        );
1267        let (partials, final_out) = drive_stream(stream).await?;
1268
1269        assert!(!partials.is_empty(), "expected at least one partial");
1270        // The first partial sees only the name before the age streamed in.
1271        assert_eq!(partials[0]["name"], "Ada");
1272        let final_out = final_out.expect("a validated final value");
1273        assert_eq!(final_out.value["name"], "Ada");
1274        assert_eq!(final_out.value["age"], 36);
1275        assert_eq!(final_out.retries, 0);
1276        Ok(())
1277    }
1278
1279    #[tokio::test]
1280    async fn streaming_tool_forcing_reads_partial_tool_input() -> Result<()> {
1281        let provider = StreamingProvider::new(
1282            "anthropic",
1283            StructuredOutputSupport::ToolForcing,
1284            vec![
1285                StreamDelta::ToolUseStart {
1286                    id: "call_1".to_owned(),
1287                    name: RESPOND_TOOL_NAME.to_owned(),
1288                    block_index: 0,
1289                    thought_signature: None,
1290                },
1291                StreamDelta::ToolInputDelta {
1292                    id: "call_1".to_owned(),
1293                    delta: r#"{"name": "Linus""#.to_owned(),
1294                    block_index: 0,
1295                },
1296                StreamDelta::ToolInputDelta {
1297                    id: "call_1".to_owned(),
1298                    delta: r#", "age": 54}"#.to_owned(),
1299                    block_index: 0,
1300                },
1301                StreamDelta::Done {
1302                    stop_reason: Some(StopReason::ToolUse),
1303                },
1304            ],
1305        );
1306
1307        let stream = run_structured_stream(
1308            &provider,
1309            request_with_format(),
1310            StructuredConfig::default(),
1311        );
1312        let (partials, final_out) = drive_stream(stream).await?;
1313
1314        assert_eq!(partials[0]["name"], "Linus");
1315        let final_out = final_out.expect("a validated final value");
1316        assert_eq!(final_out.value["age"], 54);
1317        Ok(())
1318    }
1319
1320    #[tokio::test]
1321    async fn streaming_missing_response_format_errors() {
1322        let provider =
1323            StreamingProvider::new("openai", StructuredOutputSupport::Native, Vec::new());
1324        let mut req = request_with_format();
1325        req.response_format = None;
1326
1327        let mut stream = run_structured_stream(&provider, req, StructuredConfig::default());
1328        let first = stream.next().await.expect("one item");
1329        assert!(matches!(
1330            first,
1331            Err(StructuredOutputError::MissingResponseFormat)
1332        ));
1333    }
1334
1335    #[test]
1336    fn partial_from_buffer_repairs_incomplete_json() {
1337        assert_eq!(
1338            partial_from_buffer(r#"{"name": "Ada""#).map(|v| v["name"].clone()),
1339            Some(serde_json::json!("Ada"))
1340        );
1341        assert_eq!(
1342            partial_from_buffer(r#"{"a": 1,"#),
1343            Some(serde_json::json!({"a": 1}))
1344        );
1345        assert_eq!(
1346            partial_from_buffer(r#"{"a":"#),
1347            Some(serde_json::json!({"a": null}))
1348        );
1349        assert!(partial_from_buffer("").is_none());
1350        assert!(partial_from_buffer("not json").is_none());
1351    }
1352}