1use agent_sdk_foundation::llm::{
32 ChatOutcome, ChatRequest, ChatResponse, ContentBlock, Message, ResponseFormat, Tool, ToolChoice,
33};
34use agent_sdk_foundation::types::ToolTier;
35
36use crate::provider::{LlmProvider, StructuredOutputSupport};
37
38const RESPOND_TOOL_NAME: &str = "respond";
40
41#[derive(Debug, Clone, Copy)]
43pub struct StructuredConfig {
44 pub max_retries: u32,
50}
51
52impl Default for StructuredConfig {
53 fn default() -> Self {
54 Self { max_retries: 2 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct StructuredOutput {
61 pub value: serde_json::Value,
63 pub response: ChatResponse,
66 pub retries: u32,
69}
70
71#[derive(Debug, thiserror::Error)]
76pub enum StructuredOutputError {
77 #[error("structured output requested without a response_format on the request")]
80 MissingResponseFormat,
81
82 #[error("invalid output JSON schema: {0}")]
84 InvalidSchema(String),
85
86 #[error("model produced no structured output to validate")]
89 NoStructuredOutput,
90
91 #[error("provider returned a non-success outcome: {0}")]
94 ProviderOutcome(String),
95
96 #[error(
100 "structured output failed schema validation after {attempts} attempt(s); last errors: {errors}"
101 )]
102 RetriesExhausted {
103 attempts: u32,
105 errors: String,
107 last_value: Option<serde_json::Value>,
109 },
110
111 #[error(transparent)]
113 Transport(#[from] anyhow::Error),
114}
115
116pub async fn run_structured(
129 provider: &dyn LlmProvider,
130 mut request: ChatRequest,
131 config: StructuredConfig,
132) -> Result<StructuredOutput, StructuredOutputError> {
133 let response_format = request
134 .response_format
135 .clone()
136 .ok_or(StructuredOutputError::MissingResponseFormat)?;
137
138 let validator = jsonschema::validator_for(&response_format.schema)
140 .map_err(|e| StructuredOutputError::InvalidSchema(e.to_string()))?;
141
142 let support = provider.structured_output_support();
143 if matches!(support, StructuredOutputSupport::ToolForcing) {
144 apply_tool_forcing(&mut request, &response_format);
145 }
146
147 let max_attempts = config.max_retries.saturating_add(1);
148 let mut last_value: Option<serde_json::Value> = None;
149 let mut last_errors = String::new();
150
151 for attempt in 0..max_attempts {
152 let attempt_request = if attempt + 1 == max_attempts {
155 std::mem::replace(&mut request, ChatRequest::new(String::new(), Vec::new()))
156 } else {
157 request.clone()
158 };
159 let outcome = provider.chat(attempt_request).await?;
160 let response = match outcome {
161 ChatOutcome::Success(response) => response,
162 ChatOutcome::RateLimited => {
163 return Err(StructuredOutputError::ProviderOutcome(
164 "rate limited".to_owned(),
165 ));
166 }
167 ChatOutcome::InvalidRequest(msg) => {
168 return Err(StructuredOutputError::ProviderOutcome(format!(
169 "invalid request: {msg}"
170 )));
171 }
172 ChatOutcome::ServerError(msg) => {
173 return Err(StructuredOutputError::ProviderOutcome(format!(
174 "server error: {msg}"
175 )));
176 }
177 _ => {
180 return Err(StructuredOutputError::ProviderOutcome(
181 "unrecognized provider outcome".to_owned(),
182 ));
183 }
184 };
185
186 let candidate = extract_candidate(&response, support);
187 let Some(value) = candidate else {
188 if attempt + 1 >= max_attempts {
191 return Err(StructuredOutputError::NoStructuredOutput);
192 }
193 append_correction(
194 &mut request,
195 &response,
196 support,
197 "Your previous reply did not contain a structured answer. \
198 Respond with a single JSON value that satisfies the requested schema.",
199 );
200 "missing structured output".clone_into(&mut last_errors);
201 continue;
202 };
203
204 let errors: Vec<String> = validator
205 .iter_errors(&value)
206 .map(|error| format!("at `{}`: {error}", error.instance_path()))
207 .collect();
208
209 if errors.is_empty() {
210 return Ok(StructuredOutput {
211 value,
212 response,
213 retries: attempt,
214 });
215 }
216
217 last_errors = errors.join("; ");
218 last_value = Some(value);
219
220 if attempt + 1 < max_attempts {
221 let correction = format!(
222 "Your previous JSON output did not satisfy the schema. \
223 Fix these validation errors and resend the full JSON value: {last_errors}"
224 );
225 append_correction(&mut request, &response, support, &correction);
226 }
227 }
228
229 Err(StructuredOutputError::RetriesExhausted {
230 attempts: max_attempts,
231 errors: last_errors,
232 last_value,
233 })
234}
235
236fn apply_tool_forcing(request: &mut ChatRequest, response_format: &ResponseFormat) {
238 let respond_tool = Tool {
239 name: RESPOND_TOOL_NAME.to_owned(),
240 description: format!(
241 "Return the final answer as structured data named `{}`. \
242 You MUST call this tool exactly once with arguments matching the schema.",
243 response_format.name
244 ),
245 input_schema: response_format.schema.clone(),
246 display_name: "Structured response".to_owned(),
247 tier: ToolTier::Observe,
248 };
249
250 match request.tools {
251 Some(ref mut tools) => {
252 tools.retain(|t| t.name != RESPOND_TOOL_NAME);
253 tools.push(respond_tool);
254 }
255 None => request.tools = Some(vec![respond_tool]),
256 }
257 request.tool_choice = Some(ToolChoice::Tool(RESPOND_TOOL_NAME.to_owned()));
258}
259
260fn extract_candidate(
263 response: &ChatResponse,
264 support: StructuredOutputSupport,
265) -> Option<serde_json::Value> {
266 match support {
267 StructuredOutputSupport::ToolForcing => {
268 response.content.iter().find_map(|block| match block {
269 ContentBlock::ToolUse { name, input, .. } if name == RESPOND_TOOL_NAME => {
270 Some(input.clone())
271 }
272 _ => None,
273 })
274 }
275 StructuredOutputSupport::Native => {
276 let text = response.first_text()?;
277 parse_json_text(text)
278 }
279 }
280}
281
282fn parse_json_text(text: &str) -> Option<serde_json::Value> {
288 let trimmed = text.trim();
289 let unfenced = strip_code_fence(trimmed);
290 serde_json::from_str(unfenced).ok()
291}
292
293fn strip_code_fence(text: &str) -> &str {
295 let Some(rest) = text.strip_prefix("```") else {
296 return text;
297 };
298 let rest = rest.split_once('\n').map_or(rest, |(_, body)| body);
300 rest.strip_suffix("```")
301 .map_or(text, |inner| inner.trim_end_matches('`').trim())
302}
303
304fn append_correction(
315 request: &mut ChatRequest,
316 previous: &ChatResponse,
317 support: StructuredOutputSupport,
318 correction: &str,
319) {
320 request
321 .messages
322 .push(Message::assistant_with_content(previous.content.clone()));
323
324 let respond_tool_use_id = if matches!(support, StructuredOutputSupport::ToolForcing) {
325 previous.content.iter().find_map(|block| match block {
326 ContentBlock::ToolUse { id, name, .. } if name == RESPOND_TOOL_NAME => Some(id.clone()),
327 _ => None,
328 })
329 } else {
330 None
331 };
332
333 match respond_tool_use_id {
334 Some(tool_use_id) => {
335 request
336 .messages
337 .push(Message::tool_result(tool_use_id, correction, true));
338 }
339 None => request.messages.push(Message::user(correction)),
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 use std::sync::Mutex;
348 use std::sync::atomic::{AtomicUsize, Ordering};
349
350 use agent_sdk_foundation::llm::{StopReason, Usage};
351 use anyhow::Result;
352 use async_trait::async_trait;
353
354 use crate::streaming::StreamBox;
355
356 struct ScriptedProvider {
361 provider_name: &'static str,
362 model: String,
363 support: StructuredOutputSupport,
364 outcomes: Mutex<std::collections::VecDeque<ChatOutcome>>,
365 seen_requests: Mutex<Vec<ChatRequest>>,
366 calls: AtomicUsize,
367 }
368
369 impl ScriptedProvider {
370 fn new(
371 provider_name: &'static str,
372 support: StructuredOutputSupport,
373 outcomes: Vec<ChatOutcome>,
374 ) -> Self {
375 Self {
376 provider_name,
377 model: "scripted-model".to_owned(),
378 support,
379 outcomes: Mutex::new(outcomes.into()),
380 seen_requests: Mutex::new(Vec::new()),
381 calls: AtomicUsize::new(0),
382 }
383 }
384
385 fn call_count(&self) -> usize {
386 self.calls.load(Ordering::SeqCst)
387 }
388 }
389
390 #[async_trait]
391 impl LlmProvider for ScriptedProvider {
392 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
393 self.calls.fetch_add(1, Ordering::SeqCst);
394 self.seen_requests
395 .lock()
396 .expect("seen_requests lock")
397 .push(request);
398 let outcome = self
399 .outcomes
400 .lock()
401 .expect("outcomes lock")
402 .pop_front()
403 .expect("ScriptedProvider: ran out of scripted outcomes");
404 Ok(outcome)
405 }
406
407 fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
408 Box::pin(async_stream::stream! {
409 yield Err(anyhow::anyhow!("streaming not used in structured tests"));
410 })
411 }
412
413 fn model(&self) -> &str {
414 &self.model
415 }
416
417 fn provider(&self) -> &'static str {
418 self.provider_name
419 }
420
421 fn structured_output_support(&self) -> StructuredOutputSupport {
422 self.support
423 }
424 }
425
426 fn person_schema() -> serde_json::Value {
427 serde_json::json!({
428 "type": "object",
429 "properties": {
430 "name": { "type": "string" },
431 "age": { "type": "integer", "minimum": 0 }
432 },
433 "required": ["name", "age"],
434 "additionalProperties": false
435 })
436 }
437
438 fn request_with_format() -> ChatRequest {
439 ChatRequest {
440 system: String::new(),
441 messages: vec![Message::user("Describe a person.")],
442 tools: None,
443 max_tokens: 256,
444 max_tokens_explicit: true,
445 session_id: None,
446 cached_content: None,
447 thinking: None,
448 tool_choice: None,
449 response_format: Some(ResponseFormat::new("person", person_schema())),
450 }
451 }
452
453 fn success(content: Vec<ContentBlock>) -> ChatOutcome {
454 ChatOutcome::Success(ChatResponse {
455 id: "resp".to_owned(),
456 content,
457 model: "scripted-model".to_owned(),
458 stop_reason: Some(StopReason::EndTurn),
459 usage: Usage {
460 input_tokens: 1,
461 output_tokens: 1,
462 cached_input_tokens: 0,
463 cache_creation_input_tokens: 0,
464 },
465 })
466 }
467
468 fn text_block(text: &str) -> Vec<ContentBlock> {
469 vec![ContentBlock::Text {
470 text: text.to_owned(),
471 }]
472 }
473
474 fn respond_tool_block(input: serde_json::Value) -> Vec<ContentBlock> {
475 vec![ContentBlock::ToolUse {
476 id: "call_1".to_owned(),
477 name: RESPOND_TOOL_NAME.to_owned(),
478 input,
479 thought_signature: None,
480 }]
481 }
482
483 #[tokio::test]
486 async fn native_happy_path_validates_json_text() -> Result<()> {
487 let provider = ScriptedProvider::new(
488 "openai",
489 StructuredOutputSupport::Native,
490 vec![success(text_block(r#"{"name": "Ada", "age": 36}"#))],
491 );
492
493 let out = run_structured(
494 &provider,
495 request_with_format(),
496 StructuredConfig::default(),
497 )
498 .await?;
499
500 assert_eq!(out.value["name"], "Ada");
501 assert_eq!(out.value["age"], 36);
502 assert_eq!(out.retries, 0);
503 assert_eq!(provider.call_count(), 1);
504 Ok(())
505 }
506
507 #[tokio::test]
508 async fn native_happy_path_strips_markdown_fence() -> Result<()> {
509 let provider = ScriptedProvider::new(
510 "gemini",
511 StructuredOutputSupport::Native,
512 vec![success(text_block(
513 "```json\n{\"name\": \"Grace\", \"age\": 45}\n```",
514 ))],
515 );
516
517 let out = run_structured(
518 &provider,
519 request_with_format(),
520 StructuredConfig::default(),
521 )
522 .await?;
523
524 assert_eq!(out.value["name"], "Grace");
525 Ok(())
526 }
527
528 #[tokio::test]
531 async fn tool_forcing_happy_path_reads_tool_input() -> Result<()> {
532 let provider = ScriptedProvider::new(
533 "anthropic",
534 StructuredOutputSupport::ToolForcing,
535 vec![success(respond_tool_block(
536 serde_json::json!({"name": "Linus", "age": 54}),
537 ))],
538 );
539
540 let out = run_structured(
541 &provider,
542 request_with_format(),
543 StructuredConfig::default(),
544 )
545 .await?;
546
547 assert_eq!(out.value["name"], "Linus");
548 assert_eq!(out.retries, 0);
549
550 let (has_respond_tool, forces_respond) = {
552 let seen = provider.seen_requests.lock().expect("seen lock");
553 let tools = seen[0].tools.as_ref().expect("tools injected");
554 (
555 tools.iter().any(|t| t.name == RESPOND_TOOL_NAME),
556 matches!(
557 seen[0].tool_choice,
558 Some(ToolChoice::Tool(ref n)) if n == RESPOND_TOOL_NAME
559 ),
560 )
561 };
562 assert!(has_respond_tool);
563 assert!(forces_respond);
564 Ok(())
565 }
566
567 #[tokio::test]
570 async fn mismatch_then_retry_succeeds() -> Result<()> {
571 let provider = ScriptedProvider::new(
572 "openai",
573 StructuredOutputSupport::Native,
574 vec![
575 success(text_block(r#"{"name": "Ada", "age": "old"}"#)),
577 success(text_block(r#"{"name": "Ada", "age": 36}"#)),
579 ],
580 );
581
582 let out = run_structured(
583 &provider,
584 request_with_format(),
585 StructuredConfig { max_retries: 2 },
586 )
587 .await?;
588
589 assert_eq!(out.value["age"], 36);
590 assert_eq!(out.retries, 1);
591 assert_eq!(provider.call_count(), 2);
592
593 let grew = {
596 let seen = provider.seen_requests.lock().expect("seen lock");
597 seen[1].messages.len() > seen[0].messages.len()
598 };
599 assert!(grew);
600 Ok(())
601 }
602
603 #[tokio::test]
604 async fn tool_forcing_retry_appends_tool_result_for_forced_tool_use() -> Result<()> {
605 use agent_sdk_foundation::llm::Content;
606
607 let provider = ScriptedProvider::new(
608 "anthropic",
609 StructuredOutputSupport::ToolForcing,
610 vec![
611 success(respond_tool_block(serde_json::json!({"name": "x"}))),
613 success(respond_tool_block(
615 serde_json::json!({"name": "x", "age": 1}),
616 )),
617 ],
618 );
619
620 let out = run_structured(
621 &provider,
622 request_with_format(),
623 StructuredConfig { max_retries: 1 },
624 )
625 .await?;
626 assert_eq!(out.retries, 1);
627
628 let seen = provider.seen_requests.lock().expect("seen lock");
632 let retry = &seen[1];
633
634 let assistant_tool_use_id = retry
635 .messages
636 .iter()
637 .find_map(|m| match &m.content {
638 Content::Blocks(blocks) => blocks.iter().find_map(|b| match b {
639 ContentBlock::ToolUse { id, name, .. } if name == RESPOND_TOOL_NAME => {
640 Some(id.clone())
641 }
642 _ => None,
643 }),
644 Content::Text(_) => None,
645 })
646 .expect("assistant respond tool_use present in retry");
647
648 let has_matching_result = retry.messages.iter().any(|m| match &m.content {
649 Content::Blocks(blocks) => blocks.iter().any(|b| {
650 matches!(
651 b,
652 ContentBlock::ToolResult { tool_use_id, .. }
653 if *tool_use_id == assistant_tool_use_id
654 )
655 }),
656 Content::Text(_) => false,
657 });
658 drop(seen);
659 assert!(
660 has_matching_result,
661 "retry must carry a tool_result for the forced tool_use id"
662 );
663 Ok(())
664 }
665
666 #[tokio::test]
669 async fn retry_exhaustion_yields_typed_error() -> Result<()> {
670 let provider = ScriptedProvider::new(
671 "anthropic",
672 StructuredOutputSupport::ToolForcing,
673 vec![
674 success(respond_tool_block(serde_json::json!({"name": "x"}))),
675 success(respond_tool_block(serde_json::json!({"name": "y"}))),
676 success(respond_tool_block(serde_json::json!({"name": "z"}))),
677 ],
678 );
679
680 let err = run_structured(
681 &provider,
682 request_with_format(),
683 StructuredConfig { max_retries: 2 },
684 )
685 .await
686 .expect_err("schema never satisfied");
687
688 match err {
689 StructuredOutputError::RetriesExhausted {
690 attempts,
691 last_value,
692 ..
693 } => {
694 assert_eq!(attempts, 3, "1 initial + 2 retries");
695 assert_eq!(
696 last_value.as_ref().and_then(|v| v["name"].as_str()),
697 Some("z")
698 );
699 }
700 other => panic!("expected RetriesExhausted, got {other:?}"),
701 }
702 assert_eq!(provider.call_count(), 3);
704 Ok(())
705 }
706
707 #[tokio::test]
708 async fn zero_retries_fails_after_single_attempt() -> Result<()> {
709 let provider = ScriptedProvider::new(
710 "openai",
711 StructuredOutputSupport::Native,
712 vec![success(text_block(r#"{"name": "Ada"}"#))],
713 );
714
715 let err = run_structured(
716 &provider,
717 request_with_format(),
718 StructuredConfig { max_retries: 0 },
719 )
720 .await
721 .expect_err("missing required `age`");
722
723 assert!(matches!(
724 err,
725 StructuredOutputError::RetriesExhausted { attempts: 1, .. }
726 ));
727 assert_eq!(provider.call_count(), 1);
728 Ok(())
729 }
730
731 #[tokio::test]
734 async fn missing_response_format_is_typed_error() {
735 let provider = ScriptedProvider::new(
736 "openai",
737 StructuredOutputSupport::Native,
738 vec![success(text_block("{}"))],
739 );
740 let mut req = request_with_format();
741 req.response_format = None;
742
743 let err = run_structured(&provider, req, StructuredConfig::default())
744 .await
745 .expect_err("no response format");
746 assert!(matches!(err, StructuredOutputError::MissingResponseFormat));
747 }
748
749 #[tokio::test]
750 async fn invalid_schema_is_typed_error() {
751 let provider = ScriptedProvider::new(
752 "openai",
753 StructuredOutputSupport::Native,
754 vec![success(text_block("{}"))],
755 );
756 let mut req = request_with_format();
757 req.response_format = Some(ResponseFormat::new("bad", serde_json::json!({"type": 123})));
759
760 let err = run_structured(&provider, req, StructuredConfig::default())
761 .await
762 .expect_err("invalid schema");
763 assert!(matches!(err, StructuredOutputError::InvalidSchema(_)));
764 }
765
766 #[tokio::test]
767 async fn provider_rate_limit_surfaces_as_typed_error() {
768 let provider = ScriptedProvider::new(
769 "openai",
770 StructuredOutputSupport::Native,
771 vec![ChatOutcome::RateLimited],
772 );
773
774 let err = run_structured(
775 &provider,
776 request_with_format(),
777 StructuredConfig::default(),
778 )
779 .await
780 .expect_err("rate limited");
781 assert!(matches!(err, StructuredOutputError::ProviderOutcome(_)));
782 }
783
784 #[tokio::test]
785 async fn no_structured_output_on_final_attempt_errors() {
786 let provider = ScriptedProvider::new(
788 "openai",
789 StructuredOutputSupport::Native,
790 vec![
791 success(text_block("I cannot do that.")),
792 success(text_block("Still prose, sorry.")),
793 ],
794 );
795
796 let err = run_structured(
797 &provider,
798 request_with_format(),
799 StructuredConfig { max_retries: 1 },
800 )
801 .await
802 .expect_err("never produced JSON");
803 assert!(matches!(err, StructuredOutputError::NoStructuredOutput));
804 assert_eq!(provider.call_count(), 2);
805 }
806}