1use agent_sdk_foundation::llm::{
32 ChatOutcome, ChatRequest, ChatResponse, ContentBlock, Message, ResponseFormat, Tool, ToolChoice,
33};
34use agent_sdk_foundation::types::ToolTier;
35
36use crate::provider::{LlmProvider, StructuredOutputSupport};
37
38const RESPOND_TOOL_NAME: &str = "respond";
40
41#[derive(Debug, Clone, Copy)]
43pub struct StructuredConfig {
44 pub max_retries: u32,
50}
51
52impl Default for StructuredConfig {
53 fn default() -> Self {
54 Self { max_retries: 2 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct StructuredOutput {
61 pub value: serde_json::Value,
63 pub response: ChatResponse,
66 pub retries: u32,
69}
70
71#[derive(Debug, thiserror::Error)]
76pub enum StructuredOutputError {
77 #[error("structured output requested without a response_format on the request")]
80 MissingResponseFormat,
81
82 #[error("invalid output JSON schema: {0}")]
84 InvalidSchema(String),
85
86 #[error("model produced no structured output to validate")]
89 NoStructuredOutput,
90
91 #[error("provider returned a non-success outcome: {0}")]
94 ProviderOutcome(String),
95
96 #[error(
100 "structured output failed schema validation after {attempts} attempt(s); last errors: {errors}"
101 )]
102 RetriesExhausted {
103 attempts: u32,
105 errors: String,
107 last_value: Option<serde_json::Value>,
109 },
110
111 #[error(transparent)]
113 Transport(#[from] anyhow::Error),
114}
115
116pub async fn run_structured(
129 provider: &dyn LlmProvider,
130 mut request: ChatRequest,
131 config: StructuredConfig,
132) -> Result<StructuredOutput, StructuredOutputError> {
133 let response_format = request
134 .response_format
135 .clone()
136 .ok_or(StructuredOutputError::MissingResponseFormat)?;
137
138 let validator = jsonschema::validator_for(&response_format.schema)
140 .map_err(|e| StructuredOutputError::InvalidSchema(e.to_string()))?;
141
142 let support = provider.structured_output_support();
143 if matches!(support, StructuredOutputSupport::ToolForcing) {
144 apply_tool_forcing(&mut request, &response_format);
145 }
146
147 let max_attempts = config.max_retries.saturating_add(1);
148 let mut last_value: Option<serde_json::Value> = None;
149 let mut last_errors = String::new();
150
151 for attempt in 0..max_attempts {
152 let outcome = provider.chat(request.clone()).await?;
153 let response = match outcome {
154 ChatOutcome::Success(response) => response,
155 ChatOutcome::RateLimited => {
156 return Err(StructuredOutputError::ProviderOutcome(
157 "rate limited".to_owned(),
158 ));
159 }
160 ChatOutcome::InvalidRequest(msg) => {
161 return Err(StructuredOutputError::ProviderOutcome(format!(
162 "invalid request: {msg}"
163 )));
164 }
165 ChatOutcome::ServerError(msg) => {
166 return Err(StructuredOutputError::ProviderOutcome(format!(
167 "server error: {msg}"
168 )));
169 }
170 _ => {
173 return Err(StructuredOutputError::ProviderOutcome(
174 "unrecognized provider outcome".to_owned(),
175 ));
176 }
177 };
178
179 let candidate = extract_candidate(&response, support);
180 let Some(value) = candidate else {
181 if attempt + 1 >= max_attempts {
184 return Err(StructuredOutputError::NoStructuredOutput);
185 }
186 append_correction(
187 &mut request,
188 &response,
189 "Your previous reply did not contain a structured answer. \
190 Respond with a single JSON value that satisfies the requested schema.",
191 );
192 "missing structured output".clone_into(&mut last_errors);
193 continue;
194 };
195
196 let errors: Vec<String> = validator
197 .iter_errors(&value)
198 .map(|error| format!("at `{}`: {error}", error.instance_path()))
199 .collect();
200
201 if errors.is_empty() {
202 return Ok(StructuredOutput {
203 value,
204 response,
205 retries: attempt,
206 });
207 }
208
209 last_errors = errors.join("; ");
210 last_value = Some(value);
211
212 if attempt + 1 < max_attempts {
213 let correction = format!(
214 "Your previous JSON output did not satisfy the schema. \
215 Fix these validation errors and resend the full JSON value: {last_errors}"
216 );
217 append_correction(&mut request, &response, &correction);
218 }
219 }
220
221 Err(StructuredOutputError::RetriesExhausted {
222 attempts: max_attempts,
223 errors: last_errors,
224 last_value,
225 })
226}
227
228fn apply_tool_forcing(request: &mut ChatRequest, response_format: &ResponseFormat) {
230 let respond_tool = Tool {
231 name: RESPOND_TOOL_NAME.to_owned(),
232 description: format!(
233 "Return the final answer as structured data named `{}`. \
234 You MUST call this tool exactly once with arguments matching the schema.",
235 response_format.name
236 ),
237 input_schema: response_format.schema.clone(),
238 display_name: "Structured response".to_owned(),
239 tier: ToolTier::Observe,
240 };
241
242 match request.tools {
243 Some(ref mut tools) => {
244 tools.retain(|t| t.name != RESPOND_TOOL_NAME);
245 tools.push(respond_tool);
246 }
247 None => request.tools = Some(vec![respond_tool]),
248 }
249 request.tool_choice = Some(ToolChoice::Tool(RESPOND_TOOL_NAME.to_owned()));
250}
251
252fn extract_candidate(
255 response: &ChatResponse,
256 support: StructuredOutputSupport,
257) -> Option<serde_json::Value> {
258 match support {
259 StructuredOutputSupport::ToolForcing => {
260 response.content.iter().find_map(|block| match block {
261 ContentBlock::ToolUse { name, input, .. } if name == RESPOND_TOOL_NAME => {
262 Some(input.clone())
263 }
264 _ => None,
265 })
266 }
267 StructuredOutputSupport::Native => {
268 let text = response.first_text()?;
269 parse_json_text(text)
270 }
271 }
272}
273
274fn parse_json_text(text: &str) -> Option<serde_json::Value> {
280 let trimmed = text.trim();
281 let unfenced = strip_code_fence(trimmed);
282 serde_json::from_str(unfenced).ok()
283}
284
285fn strip_code_fence(text: &str) -> &str {
287 let Some(rest) = text.strip_prefix("```") else {
288 return text;
289 };
290 let rest = rest.split_once('\n').map_or(rest, |(_, body)| body);
292 rest.strip_suffix("```")
293 .map_or(text, |inner| inner.trim_end_matches('`').trim())
294}
295
296fn append_correction(request: &mut ChatRequest, previous: &ChatResponse, correction: &str) {
299 request
300 .messages
301 .push(Message::assistant_with_content(previous.content.clone()));
302 request.messages.push(Message::user(correction));
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 use std::sync::Mutex;
310 use std::sync::atomic::{AtomicUsize, Ordering};
311
312 use agent_sdk_foundation::llm::{StopReason, Usage};
313 use anyhow::Result;
314 use async_trait::async_trait;
315
316 use crate::streaming::StreamBox;
317
318 struct ScriptedProvider {
323 provider_name: &'static str,
324 model: String,
325 support: StructuredOutputSupport,
326 outcomes: Mutex<std::collections::VecDeque<ChatOutcome>>,
327 seen_requests: Mutex<Vec<ChatRequest>>,
328 calls: AtomicUsize,
329 }
330
331 impl ScriptedProvider {
332 fn new(
333 provider_name: &'static str,
334 support: StructuredOutputSupport,
335 outcomes: Vec<ChatOutcome>,
336 ) -> Self {
337 Self {
338 provider_name,
339 model: "scripted-model".to_owned(),
340 support,
341 outcomes: Mutex::new(outcomes.into()),
342 seen_requests: Mutex::new(Vec::new()),
343 calls: AtomicUsize::new(0),
344 }
345 }
346
347 fn call_count(&self) -> usize {
348 self.calls.load(Ordering::SeqCst)
349 }
350 }
351
352 #[async_trait]
353 impl LlmProvider for ScriptedProvider {
354 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
355 self.calls.fetch_add(1, Ordering::SeqCst);
356 self.seen_requests
357 .lock()
358 .expect("seen_requests lock")
359 .push(request);
360 let outcome = self
361 .outcomes
362 .lock()
363 .expect("outcomes lock")
364 .pop_front()
365 .expect("ScriptedProvider: ran out of scripted outcomes");
366 Ok(outcome)
367 }
368
369 fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
370 Box::pin(async_stream::stream! {
371 yield Err(anyhow::anyhow!("streaming not used in structured tests"));
372 })
373 }
374
375 fn model(&self) -> &str {
376 &self.model
377 }
378
379 fn provider(&self) -> &'static str {
380 self.provider_name
381 }
382
383 fn structured_output_support(&self) -> StructuredOutputSupport {
384 self.support
385 }
386 }
387
388 fn person_schema() -> serde_json::Value {
389 serde_json::json!({
390 "type": "object",
391 "properties": {
392 "name": { "type": "string" },
393 "age": { "type": "integer", "minimum": 0 }
394 },
395 "required": ["name", "age"],
396 "additionalProperties": false
397 })
398 }
399
400 fn request_with_format() -> ChatRequest {
401 ChatRequest {
402 system: String::new(),
403 messages: vec![Message::user("Describe a person.")],
404 tools: None,
405 max_tokens: 256,
406 max_tokens_explicit: true,
407 session_id: None,
408 cached_content: None,
409 thinking: None,
410 tool_choice: None,
411 response_format: Some(ResponseFormat::new("person", person_schema())),
412 }
413 }
414
415 fn success(content: Vec<ContentBlock>) -> ChatOutcome {
416 ChatOutcome::Success(ChatResponse {
417 id: "resp".to_owned(),
418 content,
419 model: "scripted-model".to_owned(),
420 stop_reason: Some(StopReason::EndTurn),
421 usage: Usage {
422 input_tokens: 1,
423 output_tokens: 1,
424 cached_input_tokens: 0,
425 cache_creation_input_tokens: 0,
426 },
427 })
428 }
429
430 fn text_block(text: &str) -> Vec<ContentBlock> {
431 vec![ContentBlock::Text {
432 text: text.to_owned(),
433 }]
434 }
435
436 fn respond_tool_block(input: serde_json::Value) -> Vec<ContentBlock> {
437 vec![ContentBlock::ToolUse {
438 id: "call_1".to_owned(),
439 name: RESPOND_TOOL_NAME.to_owned(),
440 input,
441 thought_signature: None,
442 }]
443 }
444
445 #[tokio::test]
448 async fn native_happy_path_validates_json_text() -> Result<()> {
449 let provider = ScriptedProvider::new(
450 "openai",
451 StructuredOutputSupport::Native,
452 vec![success(text_block(r#"{"name": "Ada", "age": 36}"#))],
453 );
454
455 let out = run_structured(
456 &provider,
457 request_with_format(),
458 StructuredConfig::default(),
459 )
460 .await?;
461
462 assert_eq!(out.value["name"], "Ada");
463 assert_eq!(out.value["age"], 36);
464 assert_eq!(out.retries, 0);
465 assert_eq!(provider.call_count(), 1);
466 Ok(())
467 }
468
469 #[tokio::test]
470 async fn native_happy_path_strips_markdown_fence() -> Result<()> {
471 let provider = ScriptedProvider::new(
472 "gemini",
473 StructuredOutputSupport::Native,
474 vec![success(text_block(
475 "```json\n{\"name\": \"Grace\", \"age\": 45}\n```",
476 ))],
477 );
478
479 let out = run_structured(
480 &provider,
481 request_with_format(),
482 StructuredConfig::default(),
483 )
484 .await?;
485
486 assert_eq!(out.value["name"], "Grace");
487 Ok(())
488 }
489
490 #[tokio::test]
493 async fn tool_forcing_happy_path_reads_tool_input() -> Result<()> {
494 let provider = ScriptedProvider::new(
495 "anthropic",
496 StructuredOutputSupport::ToolForcing,
497 vec![success(respond_tool_block(
498 serde_json::json!({"name": "Linus", "age": 54}),
499 ))],
500 );
501
502 let out = run_structured(
503 &provider,
504 request_with_format(),
505 StructuredConfig::default(),
506 )
507 .await?;
508
509 assert_eq!(out.value["name"], "Linus");
510 assert_eq!(out.retries, 0);
511
512 let (has_respond_tool, forces_respond) = {
514 let seen = provider.seen_requests.lock().expect("seen lock");
515 let tools = seen[0].tools.as_ref().expect("tools injected");
516 (
517 tools.iter().any(|t| t.name == RESPOND_TOOL_NAME),
518 matches!(
519 seen[0].tool_choice,
520 Some(ToolChoice::Tool(ref n)) if n == RESPOND_TOOL_NAME
521 ),
522 )
523 };
524 assert!(has_respond_tool);
525 assert!(forces_respond);
526 Ok(())
527 }
528
529 #[tokio::test]
532 async fn mismatch_then_retry_succeeds() -> Result<()> {
533 let provider = ScriptedProvider::new(
534 "openai",
535 StructuredOutputSupport::Native,
536 vec![
537 success(text_block(r#"{"name": "Ada", "age": "old"}"#)),
539 success(text_block(r#"{"name": "Ada", "age": 36}"#)),
541 ],
542 );
543
544 let out = run_structured(
545 &provider,
546 request_with_format(),
547 StructuredConfig { max_retries: 2 },
548 )
549 .await?;
550
551 assert_eq!(out.value["age"], 36);
552 assert_eq!(out.retries, 1);
553 assert_eq!(provider.call_count(), 2);
554
555 let grew = {
558 let seen = provider.seen_requests.lock().expect("seen lock");
559 seen[1].messages.len() > seen[0].messages.len()
560 };
561 assert!(grew);
562 Ok(())
563 }
564
565 #[tokio::test]
568 async fn retry_exhaustion_yields_typed_error() -> Result<()> {
569 let provider = ScriptedProvider::new(
570 "anthropic",
571 StructuredOutputSupport::ToolForcing,
572 vec![
573 success(respond_tool_block(serde_json::json!({"name": "x"}))),
574 success(respond_tool_block(serde_json::json!({"name": "y"}))),
575 success(respond_tool_block(serde_json::json!({"name": "z"}))),
576 ],
577 );
578
579 let err = run_structured(
580 &provider,
581 request_with_format(),
582 StructuredConfig { max_retries: 2 },
583 )
584 .await
585 .expect_err("schema never satisfied");
586
587 match err {
588 StructuredOutputError::RetriesExhausted {
589 attempts,
590 last_value,
591 ..
592 } => {
593 assert_eq!(attempts, 3, "1 initial + 2 retries");
594 assert_eq!(
595 last_value.as_ref().and_then(|v| v["name"].as_str()),
596 Some("z")
597 );
598 }
599 other => panic!("expected RetriesExhausted, got {other:?}"),
600 }
601 assert_eq!(provider.call_count(), 3);
603 Ok(())
604 }
605
606 #[tokio::test]
607 async fn zero_retries_fails_after_single_attempt() -> Result<()> {
608 let provider = ScriptedProvider::new(
609 "openai",
610 StructuredOutputSupport::Native,
611 vec![success(text_block(r#"{"name": "Ada"}"#))],
612 );
613
614 let err = run_structured(
615 &provider,
616 request_with_format(),
617 StructuredConfig { max_retries: 0 },
618 )
619 .await
620 .expect_err("missing required `age`");
621
622 assert!(matches!(
623 err,
624 StructuredOutputError::RetriesExhausted { attempts: 1, .. }
625 ));
626 assert_eq!(provider.call_count(), 1);
627 Ok(())
628 }
629
630 #[tokio::test]
633 async fn missing_response_format_is_typed_error() {
634 let provider = ScriptedProvider::new(
635 "openai",
636 StructuredOutputSupport::Native,
637 vec![success(text_block("{}"))],
638 );
639 let mut req = request_with_format();
640 req.response_format = None;
641
642 let err = run_structured(&provider, req, StructuredConfig::default())
643 .await
644 .expect_err("no response format");
645 assert!(matches!(err, StructuredOutputError::MissingResponseFormat));
646 }
647
648 #[tokio::test]
649 async fn invalid_schema_is_typed_error() {
650 let provider = ScriptedProvider::new(
651 "openai",
652 StructuredOutputSupport::Native,
653 vec![success(text_block("{}"))],
654 );
655 let mut req = request_with_format();
656 req.response_format = Some(ResponseFormat::new("bad", serde_json::json!({"type": 123})));
658
659 let err = run_structured(&provider, req, StructuredConfig::default())
660 .await
661 .expect_err("invalid schema");
662 assert!(matches!(err, StructuredOutputError::InvalidSchema(_)));
663 }
664
665 #[tokio::test]
666 async fn provider_rate_limit_surfaces_as_typed_error() {
667 let provider = ScriptedProvider::new(
668 "openai",
669 StructuredOutputSupport::Native,
670 vec![ChatOutcome::RateLimited],
671 );
672
673 let err = run_structured(
674 &provider,
675 request_with_format(),
676 StructuredConfig::default(),
677 )
678 .await
679 .expect_err("rate limited");
680 assert!(matches!(err, StructuredOutputError::ProviderOutcome(_)));
681 }
682
683 #[tokio::test]
684 async fn no_structured_output_on_final_attempt_errors() {
685 let provider = ScriptedProvider::new(
687 "openai",
688 StructuredOutputSupport::Native,
689 vec![
690 success(text_block("I cannot do that.")),
691 success(text_block("Still prose, sorry.")),
692 ],
693 );
694
695 let err = run_structured(
696 &provider,
697 request_with_format(),
698 StructuredConfig { max_retries: 1 },
699 )
700 .await
701 .expect_err("never produced JSON");
702 assert!(matches!(err, StructuredOutputError::NoStructuredOutput));
703 assert_eq!(provider.call_count(), 2);
704 }
705}