1use std::pin::Pin;
32
33use agent_sdk_foundation::llm::{
34 ChatOutcome, ChatRequest, ChatResponse, ContentBlock, Message, ResponseFormat, Tool,
35 ToolChoice, Usage,
36};
37use agent_sdk_foundation::types::ToolTier;
38use futures::{Stream, StreamExt};
39
40use crate::provider::{LlmProvider, StructuredOutputSupport};
41use crate::streaming::{StreamAccumulator, StreamDelta, StreamErrorKind};
42
43const RESPOND_TOOL_NAME: &str = "respond";
45
46#[derive(Debug, Clone, Copy)]
48pub struct StructuredConfig {
49 pub max_retries: u32,
55}
56
57impl Default for StructuredConfig {
58 fn default() -> Self {
59 Self { max_retries: 2 }
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct StructuredOutput {
66 pub value: serde_json::Value,
68 pub response: ChatResponse,
71 pub retries: u32,
74}
75
76#[derive(Debug, thiserror::Error)]
81pub enum StructuredOutputError {
82 #[error("structured output requested without a response_format on the request")]
85 MissingResponseFormat,
86
87 #[error("invalid output JSON schema: {0}")]
89 InvalidSchema(String),
90
91 #[error("model produced no structured output to validate")]
94 NoStructuredOutput,
95
96 #[error("provider returned a non-success outcome: {0}")]
99 ProviderOutcome(String),
100
101 #[error(
105 "structured output failed schema validation after {attempts} attempt(s); last errors: {errors}"
106 )]
107 RetriesExhausted {
108 attempts: u32,
110 errors: String,
112 last_value: Option<serde_json::Value>,
114 },
115
116 #[error(transparent)]
118 Transport(#[from] anyhow::Error),
119}
120
121pub async fn run_structured(
134 provider: &dyn LlmProvider,
135 mut request: ChatRequest,
136 config: StructuredConfig,
137) -> Result<StructuredOutput, StructuredOutputError> {
138 let response_format = request
139 .response_format
140 .clone()
141 .ok_or(StructuredOutputError::MissingResponseFormat)?;
142
143 let validator = jsonschema::validator_for(&response_format.schema)
145 .map_err(|e| StructuredOutputError::InvalidSchema(e.to_string()))?;
146
147 let support = provider.structured_output_support();
148 if matches!(support, StructuredOutputSupport::ToolForcing) {
149 apply_tool_forcing(&mut request, &response_format);
150 }
151
152 let max_attempts = config.max_retries.saturating_add(1);
153 let mut last_value: Option<serde_json::Value> = None;
154 let mut last_errors = String::new();
155
156 for attempt in 0..max_attempts {
157 let attempt_request = if attempt + 1 == max_attempts {
160 std::mem::replace(&mut request, ChatRequest::new(String::new(), Vec::new()))
161 } else {
162 request.clone()
163 };
164 let outcome = provider.chat(attempt_request).await?;
165 let response = match outcome {
166 ChatOutcome::Success(response) => response,
167 ChatOutcome::RateLimited(_) => {
168 return Err(StructuredOutputError::ProviderOutcome(
169 "rate limited".to_owned(),
170 ));
171 }
172 ChatOutcome::InvalidRequest(msg) => {
173 return Err(StructuredOutputError::ProviderOutcome(format!(
174 "invalid request: {msg}"
175 )));
176 }
177 ChatOutcome::ServerError(msg) => {
178 return Err(StructuredOutputError::ProviderOutcome(format!(
179 "server error: {msg}"
180 )));
181 }
182 _ => {
185 return Err(StructuredOutputError::ProviderOutcome(
186 "unrecognized provider outcome".to_owned(),
187 ));
188 }
189 };
190
191 let candidate = extract_candidate(&response, support);
192 let Some(value) = candidate else {
193 if attempt + 1 >= max_attempts {
196 return Err(StructuredOutputError::NoStructuredOutput);
197 }
198 append_correction(
199 &mut request,
200 &response,
201 support,
202 "Your previous reply did not contain a structured answer. \
203 Respond with a single JSON value that satisfies the requested schema.",
204 );
205 "missing structured output".clone_into(&mut last_errors);
206 continue;
207 };
208
209 let errors = collect_schema_errors(&validator, &value);
210
211 if errors.is_empty() {
212 return Ok(StructuredOutput {
213 value,
214 response,
215 retries: attempt,
216 });
217 }
218
219 last_errors = errors.join("; ");
220 last_value = Some(value);
221
222 if attempt + 1 < max_attempts {
223 let correction = format!(
224 "Your previous JSON output did not satisfy the schema. \
225 Fix these validation errors and resend the full JSON value: {last_errors}"
226 );
227 append_correction(&mut request, &response, support, &correction);
228 }
229 }
230
231 Err(StructuredOutputError::RetriesExhausted {
232 attempts: max_attempts,
233 errors: last_errors,
234 last_value,
235 })
236}
237
238#[derive(Debug, Clone)]
240pub enum StructuredStreamUpdate {
241 Partial(serde_json::Value),
247 Final(StructuredOutput),
250}
251
252pub type StructuredStream<'a> =
254 Pin<Box<dyn Stream<Item = Result<StructuredStreamUpdate, StructuredOutputError>> + Send + 'a>>;
255
256pub fn run_structured_stream(
274 provider: &dyn LlmProvider,
275 request: ChatRequest,
276 config: StructuredConfig,
277) -> StructuredStream<'_> {
278 Box::pin(async_stream::stream! {
279 let mut request = request;
280 let Some(response_format) = request.response_format.clone() else {
281 yield Err(StructuredOutputError::MissingResponseFormat);
282 return;
283 };
284 let validator = match jsonschema::validator_for(&response_format.schema) {
285 Ok(validator) => validator,
286 Err(e) => {
287 yield Err(StructuredOutputError::InvalidSchema(e.to_string()));
288 return;
289 }
290 };
291
292 let support = provider.structured_output_support();
293 if matches!(support, StructuredOutputSupport::ToolForcing) {
294 apply_tool_forcing(&mut request, &response_format);
295 }
296
297 let max_attempts = config.max_retries.saturating_add(1);
298 let model = provider.model().to_owned();
299 let mut last_value: Option<serde_json::Value> = None;
300 let mut last_errors = String::new();
301
302 for attempt in 0..max_attempts {
303 let response = if attempt == 0 {
306 let mut attempt_stream =
307 Box::pin(stream_first_attempt(provider, request.clone(), support, model.clone()));
308 let mut completed: Option<ChatResponse> = None;
309 while let Some(item) = attempt_stream.next().await {
310 match item {
311 StreamAttemptItem::Partial(value) => {
312 yield Ok(StructuredStreamUpdate::Partial(value));
313 }
314 StreamAttemptItem::Complete(response) => completed = Some(response),
315 StreamAttemptItem::Failed(error) => {
316 yield Err(error);
317 return;
318 }
319 }
320 }
321 match completed {
324 Some(response) => response,
325 None => return,
326 }
327 } else {
328 match provider.chat(request.clone()).await {
329 Ok(ChatOutcome::Success(response)) => response,
330 Ok(other) => {
331 yield Err(non_success_outcome_error(&other));
332 return;
333 }
334 Err(e) => {
335 yield Err(StructuredOutputError::Transport(e));
336 return;
337 }
338 }
339 };
340
341 let Some(value) = extract_candidate(&response, support) else {
342 if attempt + 1 >= max_attempts {
343 yield Err(StructuredOutputError::NoStructuredOutput);
344 return;
345 }
346 append_correction(
347 &mut request,
348 &response,
349 support,
350 "Your previous reply did not contain a structured answer. \
351 Respond with a single JSON value that satisfies the requested schema.",
352 );
353 "missing structured output".clone_into(&mut last_errors);
354 continue;
355 };
356
357 let errors = collect_schema_errors(&validator, &value);
358
359 if errors.is_empty() {
360 yield Ok(StructuredStreamUpdate::Final(StructuredOutput {
361 value,
362 response,
363 retries: attempt,
364 }));
365 return;
366 }
367
368 last_errors = errors.join("; ");
369 last_value = Some(value);
370
371 if attempt + 1 < max_attempts {
372 let correction = format!(
373 "Your previous JSON output did not satisfy the schema. \
374 Fix these validation errors and resend the full JSON value: {last_errors}"
375 );
376 append_correction(&mut request, &response, support, &correction);
377 }
378 }
379
380 yield Err(StructuredOutputError::RetriesExhausted {
381 attempts: max_attempts,
382 errors: last_errors,
383 last_value,
384 });
385 })
386}
387
388enum StreamAttemptItem {
391 Partial(serde_json::Value),
392 Complete(ChatResponse),
393 Failed(StructuredOutputError),
394}
395
396fn stream_first_attempt(
401 provider: &dyn LlmProvider,
402 request: ChatRequest,
403 support: StructuredOutputSupport,
404 model: String,
405) -> impl Stream<Item = StreamAttemptItem> + Send + '_ {
406 async_stream::stream! {
407 let mut accumulator = StreamAccumulator::new();
408 let mut partial_buf = String::new();
409 let mut respond_tool_ids: std::collections::HashSet<String> =
410 std::collections::HashSet::new();
411 let mut last_partial: Option<serde_json::Value> = None;
412 let mut stream_error: Option<(String, StreamErrorKind)> = None;
413
414 let mut stream = provider.chat_stream(request);
415 while let Some(item) = stream.next().await {
416 let delta = match item {
417 Ok(delta) => delta,
418 Err(e) => {
419 yield StreamAttemptItem::Failed(StructuredOutputError::Transport(e));
420 return;
421 }
422 };
423
424 accumulate_partial_buffer(&delta, support, &mut partial_buf, &mut respond_tool_ids);
425 if let StreamDelta::Error { message, kind } = &delta {
426 stream_error = Some((message.clone(), *kind));
427 }
428 accumulator.apply(&delta);
429
430 if let Some(value) = partial_from_buffer(&partial_buf)
431 && last_partial.as_ref() != Some(&value)
432 {
433 last_partial = Some(value.clone());
434 yield StreamAttemptItem::Partial(value);
435 }
436 }
437
438 if let Some((message, kind)) = stream_error {
439 yield StreamAttemptItem::Failed(stream_error_to_outcome(&message, kind));
440 return;
441 }
442
443 yield StreamAttemptItem::Complete(build_streamed_response(accumulator, model));
444 }
445}
446
447fn accumulate_partial_buffer(
452 delta: &StreamDelta,
453 support: StructuredOutputSupport,
454 buffer: &mut String,
455 respond_tool_ids: &mut std::collections::HashSet<String>,
456) {
457 match (support, delta) {
458 (StructuredOutputSupport::Native, StreamDelta::TextDelta { delta, .. }) => {
459 buffer.push_str(delta);
460 }
461 (StructuredOutputSupport::ToolForcing, StreamDelta::ToolUseStart { id, name, .. })
462 if name == RESPOND_TOOL_NAME =>
463 {
464 respond_tool_ids.insert(id.clone());
465 }
466 (StructuredOutputSupport::ToolForcing, StreamDelta::ToolInputDelta { id, delta, .. })
467 if respond_tool_ids.contains(id) =>
468 {
469 buffer.push_str(delta);
470 }
471 _ => {}
472 }
473}
474
475fn stream_error_to_outcome(message: &str, kind: StreamErrorKind) -> StructuredOutputError {
477 let label = match kind {
478 StreamErrorKind::RateLimited => "rate limited".to_owned(),
479 StreamErrorKind::InvalidRequest => format!("invalid request: {message}"),
480 _ => format!("server error: {message}"),
481 };
482 StructuredOutputError::ProviderOutcome(label)
483}
484
485fn non_success_outcome_error(outcome: &ChatOutcome) -> StructuredOutputError {
487 let label = match outcome {
488 ChatOutcome::RateLimited(_) => "rate limited".to_owned(),
489 ChatOutcome::InvalidRequest(msg) => format!("invalid request: {msg}"),
490 ChatOutcome::ServerError(msg) => format!("server error: {msg}"),
491 _ => "unrecognized provider outcome".to_owned(),
492 };
493 StructuredOutputError::ProviderOutcome(label)
494}
495
496fn build_streamed_response(mut accumulator: StreamAccumulator, model: String) -> ChatResponse {
498 let usage = accumulator.take_usage().unwrap_or(Usage {
499 input_tokens: 0,
500 output_tokens: 0,
501 cached_input_tokens: 0,
502 cache_creation_input_tokens: 0,
503 });
504 let stop_reason = accumulator.take_stop_reason();
505 ChatResponse {
506 id: String::new(),
507 content: accumulator.into_content_blocks(),
508 model,
509 stop_reason,
510 usage,
511 }
512}
513
514fn partial_from_buffer(buffer: &str) -> Option<serde_json::Value> {
520 let trimmed = buffer.trim_start();
521 let body = trimmed
523 .strip_prefix("```")
524 .and_then(|rest| rest.split_once('\n').map(|(_, body)| body))
525 .unwrap_or(trimmed)
526 .trim();
527 if body.is_empty() {
528 return None;
529 }
530 let repaired = repair_partial_json(body);
531 serde_json::from_str::<serde_json::Value>(&repaired)
532 .ok()
533 .filter(|value| value.is_object() || value.is_array())
534}
535
536fn repair_partial_json(buffer: &str) -> String {
543 let mut in_string = false;
544 let mut escape = false;
545 let mut stack: Vec<char> = Vec::new();
546
547 for ch in buffer.chars() {
548 if in_string {
549 if escape {
550 escape = false;
551 } else if ch == '\\' {
552 escape = true;
553 } else if ch == '"' {
554 in_string = false;
555 }
556 continue;
557 }
558 match ch {
559 '"' => in_string = true,
560 '{' => stack.push('}'),
561 '[' => stack.push(']'),
562 '}' | ']' => {
563 stack.pop();
564 }
565 _ => {}
566 }
567 }
568
569 let mut out = buffer.to_owned();
570 if escape {
571 out.pop();
572 }
573 if in_string {
574 out.push('"');
575 }
576 out.truncate(out.trim_end().len());
577 if out.ends_with(',') {
578 out.pop();
579 out.truncate(out.trim_end().len());
580 } else if out.ends_with(':') {
581 out.push_str(" null");
582 }
583 for closer in stack.iter().rev() {
584 out.push(*closer);
585 }
586 out
587}
588
589fn collect_schema_errors(
592 validator: &jsonschema::Validator,
593 value: &serde_json::Value,
594) -> Vec<String> {
595 validator
596 .iter_errors(value)
597 .map(|error| format!("at `{}`: {error}", error.instance_path()))
598 .collect()
599}
600
601fn apply_tool_forcing(request: &mut ChatRequest, response_format: &ResponseFormat) {
603 let respond_tool = Tool {
604 name: RESPOND_TOOL_NAME.to_owned(),
605 description: format!(
606 "Return the final answer as structured data named `{}`. \
607 You MUST call this tool exactly once with arguments matching the schema.",
608 response_format.name
609 ),
610 input_schema: response_format.schema.clone(),
611 display_name: "Structured response".to_owned(),
612 tier: ToolTier::Observe,
613 };
614
615 match request.tools {
616 Some(ref mut tools) => {
617 tools.retain(|t| t.name != RESPOND_TOOL_NAME);
618 tools.push(respond_tool);
619 }
620 None => request.tools = Some(vec![respond_tool]),
621 }
622 request.tool_choice = Some(ToolChoice::Tool(RESPOND_TOOL_NAME.to_owned()));
623}
624
625fn extract_candidate(
628 response: &ChatResponse,
629 support: StructuredOutputSupport,
630) -> Option<serde_json::Value> {
631 match support {
632 StructuredOutputSupport::ToolForcing => {
633 response.content.iter().find_map(|block| match block {
634 ContentBlock::ToolUse { name, input, .. } if name == RESPOND_TOOL_NAME => {
635 Some(input.clone())
636 }
637 _ => None,
638 })
639 }
640 StructuredOutputSupport::Native => {
641 let text = response.first_text()?;
642 parse_json_text(text)
643 }
644 }
645}
646
647fn parse_json_text(text: &str) -> Option<serde_json::Value> {
653 let trimmed = text.trim();
654 let unfenced = strip_code_fence(trimmed);
655 serde_json::from_str(unfenced).ok()
656}
657
658fn strip_code_fence(text: &str) -> &str {
660 let Some(rest) = text.strip_prefix("```") else {
661 return text;
662 };
663 let rest = rest.split_once('\n').map_or(rest, |(_, body)| body);
665 rest.strip_suffix("```")
666 .map_or(text, |inner| inner.trim_end_matches('`').trim())
667}
668
669fn append_correction(
680 request: &mut ChatRequest,
681 previous: &ChatResponse,
682 support: StructuredOutputSupport,
683 correction: &str,
684) {
685 request
686 .messages
687 .push(Message::assistant_with_content(previous.content.clone()));
688
689 let respond_tool_use_id = if matches!(support, StructuredOutputSupport::ToolForcing) {
690 previous.content.iter().find_map(|block| match block {
691 ContentBlock::ToolUse { id, name, .. } if name == RESPOND_TOOL_NAME => Some(id.clone()),
692 _ => None,
693 })
694 } else {
695 None
696 };
697
698 match respond_tool_use_id {
699 Some(tool_use_id) => {
700 request
701 .messages
702 .push(Message::tool_result(tool_use_id, correction, true));
703 }
704 None => request.messages.push(Message::user(correction)),
705 }
706}
707
708#[cfg(test)]
709mod tests {
710 use super::*;
711
712 use std::sync::Mutex;
713 use std::sync::atomic::{AtomicUsize, Ordering};
714
715 use agent_sdk_foundation::llm::{StopReason, Usage};
716 use anyhow::Result;
717 use async_trait::async_trait;
718
719 use crate::streaming::StreamBox;
720
721 struct ScriptedProvider {
726 provider_name: &'static str,
727 model: String,
728 support: StructuredOutputSupport,
729 outcomes: Mutex<std::collections::VecDeque<ChatOutcome>>,
730 seen_requests: Mutex<Vec<ChatRequest>>,
731 calls: AtomicUsize,
732 }
733
734 impl ScriptedProvider {
735 fn new(
736 provider_name: &'static str,
737 support: StructuredOutputSupport,
738 outcomes: Vec<ChatOutcome>,
739 ) -> Self {
740 Self {
741 provider_name,
742 model: "scripted-model".to_owned(),
743 support,
744 outcomes: Mutex::new(outcomes.into()),
745 seen_requests: Mutex::new(Vec::new()),
746 calls: AtomicUsize::new(0),
747 }
748 }
749
750 fn call_count(&self) -> usize {
751 self.calls.load(Ordering::SeqCst)
752 }
753 }
754
755 #[async_trait]
756 impl LlmProvider for ScriptedProvider {
757 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
758 self.calls.fetch_add(1, Ordering::SeqCst);
759 self.seen_requests
760 .lock()
761 .expect("seen_requests lock")
762 .push(request);
763 let outcome = self
764 .outcomes
765 .lock()
766 .expect("outcomes lock")
767 .pop_front()
768 .expect("ScriptedProvider: ran out of scripted outcomes");
769 Ok(outcome)
770 }
771
772 fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
773 Box::pin(async_stream::stream! {
774 yield Err(anyhow::anyhow!("streaming not used in structured tests"));
775 })
776 }
777
778 fn model(&self) -> &str {
779 &self.model
780 }
781
782 fn provider(&self) -> &'static str {
783 self.provider_name
784 }
785
786 fn structured_output_support(&self) -> StructuredOutputSupport {
787 self.support
788 }
789 }
790
791 fn person_schema() -> serde_json::Value {
792 serde_json::json!({
793 "type": "object",
794 "properties": {
795 "name": { "type": "string" },
796 "age": { "type": "integer", "minimum": 0 }
797 },
798 "required": ["name", "age"],
799 "additionalProperties": false
800 })
801 }
802
803 fn request_with_format() -> ChatRequest {
804 ChatRequest {
805 system: String::new(),
806 messages: vec![Message::user("Describe a person.")],
807 tools: None,
808 max_tokens: 256,
809 max_tokens_explicit: true,
810 session_id: None,
811 cached_content: None,
812 thinking: None,
813 tool_choice: None,
814 response_format: Some(ResponseFormat::new("person", person_schema())),
815 cache: None,
816 }
817 }
818
819 fn success(content: Vec<ContentBlock>) -> ChatOutcome {
820 ChatOutcome::Success(ChatResponse {
821 id: "resp".to_owned(),
822 content,
823 model: "scripted-model".to_owned(),
824 stop_reason: Some(StopReason::EndTurn),
825 usage: Usage {
826 input_tokens: 1,
827 output_tokens: 1,
828 cached_input_tokens: 0,
829 cache_creation_input_tokens: 0,
830 },
831 })
832 }
833
834 fn text_block(text: &str) -> Vec<ContentBlock> {
835 vec![ContentBlock::Text {
836 text: text.to_owned(),
837 }]
838 }
839
840 fn respond_tool_block(input: serde_json::Value) -> Vec<ContentBlock> {
841 vec![ContentBlock::ToolUse {
842 id: "call_1".to_owned(),
843 name: RESPOND_TOOL_NAME.to_owned(),
844 input,
845 thought_signature: None,
846 }]
847 }
848
849 #[tokio::test]
852 async fn native_happy_path_validates_json_text() -> Result<()> {
853 let provider = ScriptedProvider::new(
854 "openai",
855 StructuredOutputSupport::Native,
856 vec![success(text_block(r#"{"name": "Ada", "age": 36}"#))],
857 );
858
859 let out = run_structured(
860 &provider,
861 request_with_format(),
862 StructuredConfig::default(),
863 )
864 .await?;
865
866 assert_eq!(out.value["name"], "Ada");
867 assert_eq!(out.value["age"], 36);
868 assert_eq!(out.retries, 0);
869 assert_eq!(provider.call_count(), 1);
870 Ok(())
871 }
872
873 #[tokio::test]
874 async fn native_happy_path_strips_markdown_fence() -> Result<()> {
875 let provider = ScriptedProvider::new(
876 "gemini",
877 StructuredOutputSupport::Native,
878 vec![success(text_block(
879 "```json\n{\"name\": \"Grace\", \"age\": 45}\n```",
880 ))],
881 );
882
883 let out = run_structured(
884 &provider,
885 request_with_format(),
886 StructuredConfig::default(),
887 )
888 .await?;
889
890 assert_eq!(out.value["name"], "Grace");
891 Ok(())
892 }
893
894 #[tokio::test]
897 async fn tool_forcing_happy_path_reads_tool_input() -> Result<()> {
898 let provider = ScriptedProvider::new(
899 "anthropic",
900 StructuredOutputSupport::ToolForcing,
901 vec![success(respond_tool_block(
902 serde_json::json!({"name": "Linus", "age": 54}),
903 ))],
904 );
905
906 let out = run_structured(
907 &provider,
908 request_with_format(),
909 StructuredConfig::default(),
910 )
911 .await?;
912
913 assert_eq!(out.value["name"], "Linus");
914 assert_eq!(out.retries, 0);
915
916 let (has_respond_tool, forces_respond) = {
918 let seen = provider.seen_requests.lock().expect("seen lock");
919 let tools = seen[0].tools.as_ref().expect("tools injected");
920 (
921 tools.iter().any(|t| t.name == RESPOND_TOOL_NAME),
922 matches!(
923 seen[0].tool_choice,
924 Some(ToolChoice::Tool(ref n)) if n == RESPOND_TOOL_NAME
925 ),
926 )
927 };
928 assert!(has_respond_tool);
929 assert!(forces_respond);
930 Ok(())
931 }
932
933 #[tokio::test]
936 async fn mismatch_then_retry_succeeds() -> Result<()> {
937 let provider = ScriptedProvider::new(
938 "openai",
939 StructuredOutputSupport::Native,
940 vec![
941 success(text_block(r#"{"name": "Ada", "age": "old"}"#)),
943 success(text_block(r#"{"name": "Ada", "age": 36}"#)),
945 ],
946 );
947
948 let out = run_structured(
949 &provider,
950 request_with_format(),
951 StructuredConfig { max_retries: 2 },
952 )
953 .await?;
954
955 assert_eq!(out.value["age"], 36);
956 assert_eq!(out.retries, 1);
957 assert_eq!(provider.call_count(), 2);
958
959 let grew = {
962 let seen = provider.seen_requests.lock().expect("seen lock");
963 seen[1].messages.len() > seen[0].messages.len()
964 };
965 assert!(grew);
966 Ok(())
967 }
968
969 #[tokio::test]
970 async fn tool_forcing_retry_appends_tool_result_for_forced_tool_use() -> Result<()> {
971 use agent_sdk_foundation::llm::Content;
972
973 let provider = ScriptedProvider::new(
974 "anthropic",
975 StructuredOutputSupport::ToolForcing,
976 vec![
977 success(respond_tool_block(serde_json::json!({"name": "x"}))),
979 success(respond_tool_block(
981 serde_json::json!({"name": "x", "age": 1}),
982 )),
983 ],
984 );
985
986 let out = run_structured(
987 &provider,
988 request_with_format(),
989 StructuredConfig { max_retries: 1 },
990 )
991 .await?;
992 assert_eq!(out.retries, 1);
993
994 let seen = provider.seen_requests.lock().expect("seen lock");
998 let retry = &seen[1];
999
1000 let assistant_tool_use_id = retry
1001 .messages
1002 .iter()
1003 .find_map(|m| match &m.content {
1004 Content::Blocks(blocks) => blocks.iter().find_map(|b| match b {
1005 ContentBlock::ToolUse { id, name, .. } if name == RESPOND_TOOL_NAME => {
1006 Some(id.clone())
1007 }
1008 _ => None,
1009 }),
1010 Content::Text(_) => None,
1011 })
1012 .expect("assistant respond tool_use present in retry");
1013
1014 let has_matching_result = retry.messages.iter().any(|m| match &m.content {
1015 Content::Blocks(blocks) => blocks.iter().any(|b| {
1016 matches!(
1017 b,
1018 ContentBlock::ToolResult { tool_use_id, .. }
1019 if *tool_use_id == assistant_tool_use_id
1020 )
1021 }),
1022 Content::Text(_) => false,
1023 });
1024 drop(seen);
1025 assert!(
1026 has_matching_result,
1027 "retry must carry a tool_result for the forced tool_use id"
1028 );
1029 Ok(())
1030 }
1031
1032 #[tokio::test]
1035 async fn retry_exhaustion_yields_typed_error() -> Result<()> {
1036 let provider = ScriptedProvider::new(
1037 "anthropic",
1038 StructuredOutputSupport::ToolForcing,
1039 vec![
1040 success(respond_tool_block(serde_json::json!({"name": "x"}))),
1041 success(respond_tool_block(serde_json::json!({"name": "y"}))),
1042 success(respond_tool_block(serde_json::json!({"name": "z"}))),
1043 ],
1044 );
1045
1046 let err = run_structured(
1047 &provider,
1048 request_with_format(),
1049 StructuredConfig { max_retries: 2 },
1050 )
1051 .await
1052 .expect_err("schema never satisfied");
1053
1054 match err {
1055 StructuredOutputError::RetriesExhausted {
1056 attempts,
1057 last_value,
1058 ..
1059 } => {
1060 assert_eq!(attempts, 3, "1 initial + 2 retries");
1061 assert_eq!(
1062 last_value.as_ref().and_then(|v| v["name"].as_str()),
1063 Some("z")
1064 );
1065 }
1066 other => panic!("expected RetriesExhausted, got {other:?}"),
1067 }
1068 assert_eq!(provider.call_count(), 3);
1070 Ok(())
1071 }
1072
1073 #[tokio::test]
1074 async fn zero_retries_fails_after_single_attempt() -> Result<()> {
1075 let provider = ScriptedProvider::new(
1076 "openai",
1077 StructuredOutputSupport::Native,
1078 vec![success(text_block(r#"{"name": "Ada"}"#))],
1079 );
1080
1081 let err = run_structured(
1082 &provider,
1083 request_with_format(),
1084 StructuredConfig { max_retries: 0 },
1085 )
1086 .await
1087 .expect_err("missing required `age`");
1088
1089 assert!(matches!(
1090 err,
1091 StructuredOutputError::RetriesExhausted { attempts: 1, .. }
1092 ));
1093 assert_eq!(provider.call_count(), 1);
1094 Ok(())
1095 }
1096
1097 #[tokio::test]
1100 async fn missing_response_format_is_typed_error() {
1101 let provider = ScriptedProvider::new(
1102 "openai",
1103 StructuredOutputSupport::Native,
1104 vec![success(text_block("{}"))],
1105 );
1106 let mut req = request_with_format();
1107 req.response_format = None;
1108
1109 let err = run_structured(&provider, req, StructuredConfig::default())
1110 .await
1111 .expect_err("no response format");
1112 assert!(matches!(err, StructuredOutputError::MissingResponseFormat));
1113 }
1114
1115 #[tokio::test]
1116 async fn invalid_schema_is_typed_error() {
1117 let provider = ScriptedProvider::new(
1118 "openai",
1119 StructuredOutputSupport::Native,
1120 vec![success(text_block("{}"))],
1121 );
1122 let mut req = request_with_format();
1123 req.response_format = Some(ResponseFormat::new("bad", serde_json::json!({"type": 123})));
1125
1126 let err = run_structured(&provider, req, StructuredConfig::default())
1127 .await
1128 .expect_err("invalid schema");
1129 assert!(matches!(err, StructuredOutputError::InvalidSchema(_)));
1130 }
1131
1132 #[tokio::test]
1133 async fn provider_rate_limit_surfaces_as_typed_error() {
1134 let provider = ScriptedProvider::new(
1135 "openai",
1136 StructuredOutputSupport::Native,
1137 vec![ChatOutcome::RateLimited(None)],
1138 );
1139
1140 let err = run_structured(
1141 &provider,
1142 request_with_format(),
1143 StructuredConfig::default(),
1144 )
1145 .await
1146 .expect_err("rate limited");
1147 assert!(matches!(err, StructuredOutputError::ProviderOutcome(_)));
1148 }
1149
1150 #[tokio::test]
1151 async fn no_structured_output_on_final_attempt_errors() {
1152 let provider = ScriptedProvider::new(
1154 "openai",
1155 StructuredOutputSupport::Native,
1156 vec![
1157 success(text_block("I cannot do that.")),
1158 success(text_block("Still prose, sorry.")),
1159 ],
1160 );
1161
1162 let err = run_structured(
1163 &provider,
1164 request_with_format(),
1165 StructuredConfig { max_retries: 1 },
1166 )
1167 .await
1168 .expect_err("never produced JSON");
1169 assert!(matches!(err, StructuredOutputError::NoStructuredOutput));
1170 assert_eq!(provider.call_count(), 2);
1171 }
1172
1173 struct StreamingProvider {
1179 provider_name: &'static str,
1180 model: String,
1181 support: StructuredOutputSupport,
1182 deltas: Mutex<Vec<StreamDelta>>,
1183 }
1184
1185 impl StreamingProvider {
1186 fn new(
1187 provider_name: &'static str,
1188 support: StructuredOutputSupport,
1189 deltas: Vec<StreamDelta>,
1190 ) -> Self {
1191 Self {
1192 provider_name,
1193 model: "scripted-model".to_owned(),
1194 support,
1195 deltas: Mutex::new(deltas),
1196 }
1197 }
1198 }
1199
1200 #[async_trait]
1201 impl LlmProvider for StreamingProvider {
1202 async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
1203 Ok(ChatOutcome::ServerError("chat() not used".to_owned()))
1204 }
1205
1206 fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
1207 let deltas = self.deltas.lock().map(|d| d.clone()).unwrap_or_default();
1208 Box::pin(async_stream::stream! {
1209 for delta in deltas {
1210 yield Ok(delta);
1211 }
1212 })
1213 }
1214
1215 fn model(&self) -> &str {
1216 &self.model
1217 }
1218
1219 fn provider(&self) -> &'static str {
1220 self.provider_name
1221 }
1222
1223 fn structured_output_support(&self) -> StructuredOutputSupport {
1224 self.support
1225 }
1226 }
1227
1228 async fn drive_stream(
1229 mut stream: StructuredStream<'_>,
1230 ) -> Result<(Vec<serde_json::Value>, Option<StructuredOutput>)> {
1231 let mut partials = Vec::new();
1232 let mut final_out = None;
1233 while let Some(update) = stream.next().await {
1234 match update? {
1235 StructuredStreamUpdate::Partial(value) => partials.push(value),
1236 StructuredStreamUpdate::Final(out) => final_out = Some(out),
1237 }
1238 }
1239 Ok((partials, final_out))
1240 }
1241
1242 #[tokio::test]
1243 async fn streaming_native_emits_partials_then_validated_final() -> Result<()> {
1244 let provider = StreamingProvider::new(
1245 "openai",
1246 StructuredOutputSupport::Native,
1247 vec![
1248 StreamDelta::TextDelta {
1249 delta: r#"{"name": "Ada""#.to_owned(),
1250 block_index: 0,
1251 },
1252 StreamDelta::TextDelta {
1253 delta: r#", "age": 36}"#.to_owned(),
1254 block_index: 0,
1255 },
1256 StreamDelta::Done {
1257 stop_reason: Some(StopReason::EndTurn),
1258 },
1259 ],
1260 );
1261
1262 let stream = run_structured_stream(
1263 &provider,
1264 request_with_format(),
1265 StructuredConfig::default(),
1266 );
1267 let (partials, final_out) = drive_stream(stream).await?;
1268
1269 assert!(!partials.is_empty(), "expected at least one partial");
1270 assert_eq!(partials[0]["name"], "Ada");
1272 let final_out = final_out.expect("a validated final value");
1273 assert_eq!(final_out.value["name"], "Ada");
1274 assert_eq!(final_out.value["age"], 36);
1275 assert_eq!(final_out.retries, 0);
1276 Ok(())
1277 }
1278
1279 #[tokio::test]
1280 async fn streaming_tool_forcing_reads_partial_tool_input() -> Result<()> {
1281 let provider = StreamingProvider::new(
1282 "anthropic",
1283 StructuredOutputSupport::ToolForcing,
1284 vec![
1285 StreamDelta::ToolUseStart {
1286 id: "call_1".to_owned(),
1287 name: RESPOND_TOOL_NAME.to_owned(),
1288 block_index: 0,
1289 thought_signature: None,
1290 },
1291 StreamDelta::ToolInputDelta {
1292 id: "call_1".to_owned(),
1293 delta: r#"{"name": "Linus""#.to_owned(),
1294 block_index: 0,
1295 },
1296 StreamDelta::ToolInputDelta {
1297 id: "call_1".to_owned(),
1298 delta: r#", "age": 54}"#.to_owned(),
1299 block_index: 0,
1300 },
1301 StreamDelta::Done {
1302 stop_reason: Some(StopReason::ToolUse),
1303 },
1304 ],
1305 );
1306
1307 let stream = run_structured_stream(
1308 &provider,
1309 request_with_format(),
1310 StructuredConfig::default(),
1311 );
1312 let (partials, final_out) = drive_stream(stream).await?;
1313
1314 assert_eq!(partials[0]["name"], "Linus");
1315 let final_out = final_out.expect("a validated final value");
1316 assert_eq!(final_out.value["age"], 54);
1317 Ok(())
1318 }
1319
1320 #[tokio::test]
1321 async fn streaming_missing_response_format_errors() {
1322 let provider =
1323 StreamingProvider::new("openai", StructuredOutputSupport::Native, Vec::new());
1324 let mut req = request_with_format();
1325 req.response_format = None;
1326
1327 let mut stream = run_structured_stream(&provider, req, StructuredConfig::default());
1328 let first = stream.next().await.expect("one item");
1329 assert!(matches!(
1330 first,
1331 Err(StructuredOutputError::MissingResponseFormat)
1332 ));
1333 }
1334
1335 #[test]
1336 fn partial_from_buffer_repairs_incomplete_json() {
1337 assert_eq!(
1338 partial_from_buffer(r#"{"name": "Ada""#).map(|v| v["name"].clone()),
1339 Some(serde_json::json!("Ada"))
1340 );
1341 assert_eq!(
1342 partial_from_buffer(r#"{"a": 1,"#),
1343 Some(serde_json::json!({"a": 1}))
1344 );
1345 assert_eq!(
1346 partial_from_buffer(r#"{"a":"#),
1347 Some(serde_json::json!({"a": null}))
1348 );
1349 assert!(partial_from_buffer("").is_none());
1350 assert!(partial_from_buffer("not json").is_none());
1351 }
1352}