use std::pin::Pin;
use agent_sdk_foundation::llm::{
ChatOutcome, ChatRequest, ChatResponse, ContentBlock, Message, ResponseFormat, Tool,
ToolChoice, Usage,
};
use agent_sdk_foundation::types::ToolTier;
use futures::{Stream, StreamExt};
use crate::provider::{LlmProvider, StructuredOutputSupport};
use crate::streaming::{StreamAccumulator, StreamDelta, StreamErrorKind};
const RESPOND_TOOL_NAME: &str = "respond";
#[derive(Debug, Clone, Copy)]
pub struct StructuredConfig {
pub max_retries: u32,
}
impl Default for StructuredConfig {
fn default() -> Self {
Self { max_retries: 2 }
}
}
#[derive(Debug, Clone)]
pub struct StructuredOutput {
pub value: serde_json::Value,
pub response: ChatResponse,
pub retries: u32,
}
#[derive(Debug, thiserror::Error)]
pub enum StructuredOutputError {
#[error("structured output requested without a response_format on the request")]
MissingResponseFormat,
#[error("invalid output JSON schema: {0}")]
InvalidSchema(String),
#[error("model produced no structured output to validate")]
NoStructuredOutput,
#[error("provider returned a non-success outcome: {0}")]
ProviderOutcome(String),
#[error(
"structured output failed schema validation after {attempts} attempt(s); last errors: {errors}"
)]
RetriesExhausted {
attempts: u32,
errors: String,
last_value: Option<serde_json::Value>,
},
#[error(transparent)]
Transport(#[from] anyhow::Error),
}
pub async fn run_structured(
provider: &dyn LlmProvider,
mut request: ChatRequest,
config: StructuredConfig,
) -> Result<StructuredOutput, StructuredOutputError> {
let response_format = request
.response_format
.clone()
.ok_or(StructuredOutputError::MissingResponseFormat)?;
let validator = jsonschema::validator_for(&response_format.schema)
.map_err(|e| StructuredOutputError::InvalidSchema(e.to_string()))?;
let support = provider.structured_output_support();
if matches!(support, StructuredOutputSupport::ToolForcing) {
apply_tool_forcing(&mut request, &response_format);
}
let max_attempts = config.max_retries.saturating_add(1);
let mut last_value: Option<serde_json::Value> = None;
let mut last_errors = String::new();
for attempt in 0..max_attempts {
let attempt_request = if attempt + 1 == max_attempts {
std::mem::replace(&mut request, ChatRequest::new(String::new(), Vec::new()))
} else {
request.clone()
};
let outcome = provider.chat(attempt_request).await?;
let response = match outcome {
ChatOutcome::Success(response) => response,
ChatOutcome::RateLimited(_) => {
return Err(StructuredOutputError::ProviderOutcome(
"rate limited".to_owned(),
));
}
ChatOutcome::InvalidRequest(msg) => {
return Err(StructuredOutputError::ProviderOutcome(format!(
"invalid request: {msg}"
)));
}
ChatOutcome::ServerError(msg) => {
return Err(StructuredOutputError::ProviderOutcome(format!(
"server error: {msg}"
)));
}
_ => {
return Err(StructuredOutputError::ProviderOutcome(
"unrecognized provider outcome".to_owned(),
));
}
};
let candidate = extract_candidate(&response, support);
let Some(value) = candidate else {
if attempt + 1 >= max_attempts {
return Err(StructuredOutputError::NoStructuredOutput);
}
append_correction(
&mut request,
&response,
support,
"Your previous reply did not contain a structured answer. \
Respond with a single JSON value that satisfies the requested schema.",
);
"missing structured output".clone_into(&mut last_errors);
continue;
};
let errors = collect_schema_errors(&validator, &value);
if errors.is_empty() {
return Ok(StructuredOutput {
value,
response,
retries: attempt,
});
}
last_errors = errors.join("; ");
last_value = Some(value);
if attempt + 1 < max_attempts {
let correction = format!(
"Your previous JSON output did not satisfy the schema. \
Fix these validation errors and resend the full JSON value: {last_errors}"
);
append_correction(&mut request, &response, support, &correction);
}
}
Err(StructuredOutputError::RetriesExhausted {
attempts: max_attempts,
errors: last_errors,
last_value,
})
}
#[derive(Debug, Clone)]
pub enum StructuredStreamUpdate {
Partial(serde_json::Value),
Final(StructuredOutput),
}
pub type StructuredStream<'a> =
Pin<Box<dyn Stream<Item = Result<StructuredStreamUpdate, StructuredOutputError>> + Send + 'a>>;
pub fn run_structured_stream(
provider: &dyn LlmProvider,
request: ChatRequest,
config: StructuredConfig,
) -> StructuredStream<'_> {
Box::pin(async_stream::stream! {
let mut request = request;
let Some(response_format) = request.response_format.clone() else {
yield Err(StructuredOutputError::MissingResponseFormat);
return;
};
let validator = match jsonschema::validator_for(&response_format.schema) {
Ok(validator) => validator,
Err(e) => {
yield Err(StructuredOutputError::InvalidSchema(e.to_string()));
return;
}
};
let support = provider.structured_output_support();
if matches!(support, StructuredOutputSupport::ToolForcing) {
apply_tool_forcing(&mut request, &response_format);
}
let max_attempts = config.max_retries.saturating_add(1);
let model = provider.model().to_owned();
let mut last_value: Option<serde_json::Value> = None;
let mut last_errors = String::new();
for attempt in 0..max_attempts {
let response = if attempt == 0 {
let mut attempt_stream =
Box::pin(stream_first_attempt(provider, request.clone(), support, model.clone()));
let mut completed: Option<ChatResponse> = None;
while let Some(item) = attempt_stream.next().await {
match item {
StreamAttemptItem::Partial(value) => {
yield Ok(StructuredStreamUpdate::Partial(value));
}
StreamAttemptItem::Complete(response) => completed = Some(response),
StreamAttemptItem::Failed(error) => {
yield Err(error);
return;
}
}
}
match completed {
Some(response) => response,
None => return,
}
} else {
match provider.chat(request.clone()).await {
Ok(ChatOutcome::Success(response)) => response,
Ok(other) => {
yield Err(non_success_outcome_error(&other));
return;
}
Err(e) => {
yield Err(StructuredOutputError::Transport(e));
return;
}
}
};
let Some(value) = extract_candidate(&response, support) else {
if attempt + 1 >= max_attempts {
yield Err(StructuredOutputError::NoStructuredOutput);
return;
}
append_correction(
&mut request,
&response,
support,
"Your previous reply did not contain a structured answer. \
Respond with a single JSON value that satisfies the requested schema.",
);
"missing structured output".clone_into(&mut last_errors);
continue;
};
let errors = collect_schema_errors(&validator, &value);
if errors.is_empty() {
yield Ok(StructuredStreamUpdate::Final(StructuredOutput {
value,
response,
retries: attempt,
}));
return;
}
last_errors = errors.join("; ");
last_value = Some(value);
if attempt + 1 < max_attempts {
let correction = format!(
"Your previous JSON output did not satisfy the schema. \
Fix these validation errors and resend the full JSON value: {last_errors}"
);
append_correction(&mut request, &response, support, &correction);
}
}
yield Err(StructuredOutputError::RetriesExhausted {
attempts: max_attempts,
errors: last_errors,
last_value,
});
})
}
enum StreamAttemptItem {
Partial(serde_json::Value),
Complete(ChatResponse),
Failed(StructuredOutputError),
}
fn stream_first_attempt(
provider: &dyn LlmProvider,
request: ChatRequest,
support: StructuredOutputSupport,
model: String,
) -> impl Stream<Item = StreamAttemptItem> + Send + '_ {
async_stream::stream! {
let mut accumulator = StreamAccumulator::new();
let mut partial_buf = String::new();
let mut respond_tool_ids: std::collections::HashSet<String> =
std::collections::HashSet::new();
let mut last_partial: Option<serde_json::Value> = None;
let mut stream_error: Option<(String, StreamErrorKind)> = None;
let mut stream = provider.chat_stream(request);
while let Some(item) = stream.next().await {
let delta = match item {
Ok(delta) => delta,
Err(e) => {
yield StreamAttemptItem::Failed(StructuredOutputError::Transport(e));
return;
}
};
accumulate_partial_buffer(&delta, support, &mut partial_buf, &mut respond_tool_ids);
if let StreamDelta::Error { message, kind } = &delta {
stream_error = Some((message.clone(), *kind));
}
accumulator.apply(&delta);
if let Some(value) = partial_from_buffer(&partial_buf)
&& last_partial.as_ref() != Some(&value)
{
last_partial = Some(value.clone());
yield StreamAttemptItem::Partial(value);
}
}
if let Some((message, kind)) = stream_error {
yield StreamAttemptItem::Failed(stream_error_to_outcome(&message, kind));
return;
}
yield StreamAttemptItem::Complete(build_streamed_response(accumulator, model));
}
}
fn accumulate_partial_buffer(
delta: &StreamDelta,
support: StructuredOutputSupport,
buffer: &mut String,
respond_tool_ids: &mut std::collections::HashSet<String>,
) {
match (support, delta) {
(StructuredOutputSupport::Native, StreamDelta::TextDelta { delta, .. }) => {
buffer.push_str(delta);
}
(StructuredOutputSupport::ToolForcing, StreamDelta::ToolUseStart { id, name, .. })
if name == RESPOND_TOOL_NAME =>
{
respond_tool_ids.insert(id.clone());
}
(StructuredOutputSupport::ToolForcing, StreamDelta::ToolInputDelta { id, delta, .. })
if respond_tool_ids.contains(id) =>
{
buffer.push_str(delta);
}
_ => {}
}
}
fn stream_error_to_outcome(message: &str, kind: StreamErrorKind) -> StructuredOutputError {
let label = match kind {
StreamErrorKind::RateLimited => "rate limited".to_owned(),
StreamErrorKind::InvalidRequest => format!("invalid request: {message}"),
_ => format!("server error: {message}"),
};
StructuredOutputError::ProviderOutcome(label)
}
fn non_success_outcome_error(outcome: &ChatOutcome) -> StructuredOutputError {
let label = match outcome {
ChatOutcome::RateLimited(_) => "rate limited".to_owned(),
ChatOutcome::InvalidRequest(msg) => format!("invalid request: {msg}"),
ChatOutcome::ServerError(msg) => format!("server error: {msg}"),
_ => "unrecognized provider outcome".to_owned(),
};
StructuredOutputError::ProviderOutcome(label)
}
fn build_streamed_response(mut accumulator: StreamAccumulator, model: String) -> ChatResponse {
let usage = accumulator.take_usage().unwrap_or(Usage {
input_tokens: 0,
output_tokens: 0,
cached_input_tokens: 0,
cache_creation_input_tokens: 0,
});
let stop_reason = accumulator.take_stop_reason();
ChatResponse {
id: String::new(),
content: accumulator.into_content_blocks(),
model,
stop_reason,
usage,
}
}
fn partial_from_buffer(buffer: &str) -> Option<serde_json::Value> {
let trimmed = buffer.trim_start();
let body = trimmed
.strip_prefix("```")
.and_then(|rest| rest.split_once('\n').map(|(_, body)| body))
.unwrap_or(trimmed)
.trim();
if body.is_empty() {
return None;
}
let repaired = repair_partial_json(body);
serde_json::from_str::<serde_json::Value>(&repaired)
.ok()
.filter(|value| value.is_object() || value.is_array())
}
fn repair_partial_json(buffer: &str) -> String {
let mut in_string = false;
let mut escape = false;
let mut stack: Vec<char> = Vec::new();
for ch in buffer.chars() {
if in_string {
if escape {
escape = false;
} else if ch == '\\' {
escape = true;
} else if ch == '"' {
in_string = false;
}
continue;
}
match ch {
'"' => in_string = true,
'{' => stack.push('}'),
'[' => stack.push(']'),
'}' | ']' => {
stack.pop();
}
_ => {}
}
}
let mut out = buffer.to_owned();
if escape {
out.pop();
}
if in_string {
out.push('"');
}
out.truncate(out.trim_end().len());
if out.ends_with(',') {
out.pop();
out.truncate(out.trim_end().len());
} else if out.ends_with(':') {
out.push_str(" null");
}
for closer in stack.iter().rev() {
out.push(*closer);
}
out
}
fn collect_schema_errors(
validator: &jsonschema::Validator,
value: &serde_json::Value,
) -> Vec<String> {
validator
.iter_errors(value)
.map(|error| format!("at `{}`: {error}", error.instance_path()))
.collect()
}
fn apply_tool_forcing(request: &mut ChatRequest, response_format: &ResponseFormat) {
let respond_tool = Tool {
name: RESPOND_TOOL_NAME.to_owned(),
description: format!(
"Return the final answer as structured data named `{}`. \
You MUST call this tool exactly once with arguments matching the schema.",
response_format.name
),
input_schema: response_format.schema.clone(),
display_name: "Structured response".to_owned(),
tier: ToolTier::Observe,
};
match request.tools {
Some(ref mut tools) => {
tools.retain(|t| t.name != RESPOND_TOOL_NAME);
tools.push(respond_tool);
}
None => request.tools = Some(vec![respond_tool]),
}
request.tool_choice = Some(ToolChoice::Tool(RESPOND_TOOL_NAME.to_owned()));
}
fn extract_candidate(
response: &ChatResponse,
support: StructuredOutputSupport,
) -> Option<serde_json::Value> {
match support {
StructuredOutputSupport::ToolForcing => {
response.content.iter().find_map(|block| match block {
ContentBlock::ToolUse { name, input, .. } if name == RESPOND_TOOL_NAME => {
Some(input.clone())
}
_ => None,
})
}
StructuredOutputSupport::Native => {
let text = response.first_text()?;
parse_json_text(text)
}
}
}
fn parse_json_text(text: &str) -> Option<serde_json::Value> {
let trimmed = text.trim();
let unfenced = strip_code_fence(trimmed);
serde_json::from_str(unfenced).ok()
}
fn strip_code_fence(text: &str) -> &str {
let Some(rest) = text.strip_prefix("```") else {
return text;
};
let rest = rest.split_once('\n').map_or(rest, |(_, body)| body);
rest.strip_suffix("```")
.map_or(text, |inner| inner.trim_end_matches('`').trim())
}
fn append_correction(
request: &mut ChatRequest,
previous: &ChatResponse,
support: StructuredOutputSupport,
correction: &str,
) {
request
.messages
.push(Message::assistant_with_content(previous.content.clone()));
let respond_tool_use_id = if matches!(support, StructuredOutputSupport::ToolForcing) {
previous.content.iter().find_map(|block| match block {
ContentBlock::ToolUse { id, name, .. } if name == RESPOND_TOOL_NAME => Some(id.clone()),
_ => None,
})
} else {
None
};
match respond_tool_use_id {
Some(tool_use_id) => {
request
.messages
.push(Message::tool_result(tool_use_id, correction, true));
}
None => request.messages.push(Message::user(correction)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use agent_sdk_foundation::llm::{StopReason, Usage};
use anyhow::Result;
use async_trait::async_trait;
use crate::streaming::StreamBox;
struct ScriptedProvider {
provider_name: &'static str,
model: String,
support: StructuredOutputSupport,
outcomes: Mutex<std::collections::VecDeque<ChatOutcome>>,
seen_requests: Mutex<Vec<ChatRequest>>,
calls: AtomicUsize,
}
impl ScriptedProvider {
fn new(
provider_name: &'static str,
support: StructuredOutputSupport,
outcomes: Vec<ChatOutcome>,
) -> Self {
Self {
provider_name,
model: "scripted-model".to_owned(),
support,
outcomes: Mutex::new(outcomes.into()),
seen_requests: Mutex::new(Vec::new()),
calls: AtomicUsize::new(0),
}
}
fn call_count(&self) -> usize {
self.calls.load(Ordering::SeqCst)
}
}
#[async_trait]
impl LlmProvider for ScriptedProvider {
async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
self.calls.fetch_add(1, Ordering::SeqCst);
self.seen_requests
.lock()
.expect("seen_requests lock")
.push(request);
let outcome = self
.outcomes
.lock()
.expect("outcomes lock")
.pop_front()
.expect("ScriptedProvider: ran out of scripted outcomes");
Ok(outcome)
}
fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
Box::pin(async_stream::stream! {
yield Err(anyhow::anyhow!("streaming not used in structured tests"));
})
}
fn model(&self) -> &str {
&self.model
}
fn provider(&self) -> &'static str {
self.provider_name
}
fn structured_output_support(&self) -> StructuredOutputSupport {
self.support
}
}
fn person_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"name": { "type": "string" },
"age": { "type": "integer", "minimum": 0 }
},
"required": ["name", "age"],
"additionalProperties": false
})
}
fn request_with_format() -> ChatRequest {
ChatRequest {
system: String::new(),
messages: vec![Message::user("Describe a person.")],
tools: None,
max_tokens: 256,
max_tokens_explicit: true,
session_id: None,
cached_content: None,
thinking: None,
tool_choice: None,
response_format: Some(ResponseFormat::new("person", person_schema())),
cache: None,
}
}
fn success(content: Vec<ContentBlock>) -> ChatOutcome {
ChatOutcome::Success(ChatResponse {
id: "resp".to_owned(),
content,
model: "scripted-model".to_owned(),
stop_reason: Some(StopReason::EndTurn),
usage: Usage {
input_tokens: 1,
output_tokens: 1,
cached_input_tokens: 0,
cache_creation_input_tokens: 0,
},
})
}
fn text_block(text: &str) -> Vec<ContentBlock> {
vec![ContentBlock::Text {
text: text.to_owned(),
}]
}
fn respond_tool_block(input: serde_json::Value) -> Vec<ContentBlock> {
vec![ContentBlock::ToolUse {
id: "call_1".to_owned(),
name: RESPOND_TOOL_NAME.to_owned(),
input,
thought_signature: None,
}]
}
#[tokio::test]
async fn native_happy_path_validates_json_text() -> Result<()> {
let provider = ScriptedProvider::new(
"openai",
StructuredOutputSupport::Native,
vec![success(text_block(r#"{"name": "Ada", "age": 36}"#))],
);
let out = run_structured(
&provider,
request_with_format(),
StructuredConfig::default(),
)
.await?;
assert_eq!(out.value["name"], "Ada");
assert_eq!(out.value["age"], 36);
assert_eq!(out.retries, 0);
assert_eq!(provider.call_count(), 1);
Ok(())
}
#[tokio::test]
async fn native_happy_path_strips_markdown_fence() -> Result<()> {
let provider = ScriptedProvider::new(
"gemini",
StructuredOutputSupport::Native,
vec![success(text_block(
"```json\n{\"name\": \"Grace\", \"age\": 45}\n```",
))],
);
let out = run_structured(
&provider,
request_with_format(),
StructuredConfig::default(),
)
.await?;
assert_eq!(out.value["name"], "Grace");
Ok(())
}
#[tokio::test]
async fn tool_forcing_happy_path_reads_tool_input() -> Result<()> {
let provider = ScriptedProvider::new(
"anthropic",
StructuredOutputSupport::ToolForcing,
vec![success(respond_tool_block(
serde_json::json!({"name": "Linus", "age": 54}),
))],
);
let out = run_structured(
&provider,
request_with_format(),
StructuredConfig::default(),
)
.await?;
assert_eq!(out.value["name"], "Linus");
assert_eq!(out.retries, 0);
let (has_respond_tool, forces_respond) = {
let seen = provider.seen_requests.lock().expect("seen lock");
let tools = seen[0].tools.as_ref().expect("tools injected");
(
tools.iter().any(|t| t.name == RESPOND_TOOL_NAME),
matches!(
seen[0].tool_choice,
Some(ToolChoice::Tool(ref n)) if n == RESPOND_TOOL_NAME
),
)
};
assert!(has_respond_tool);
assert!(forces_respond);
Ok(())
}
#[tokio::test]
async fn mismatch_then_retry_succeeds() -> Result<()> {
let provider = ScriptedProvider::new(
"openai",
StructuredOutputSupport::Native,
vec![
success(text_block(r#"{"name": "Ada", "age": "old"}"#)),
success(text_block(r#"{"name": "Ada", "age": 36}"#)),
],
);
let out = run_structured(
&provider,
request_with_format(),
StructuredConfig { max_retries: 2 },
)
.await?;
assert_eq!(out.value["age"], 36);
assert_eq!(out.retries, 1);
assert_eq!(provider.call_count(), 2);
let grew = {
let seen = provider.seen_requests.lock().expect("seen lock");
seen[1].messages.len() > seen[0].messages.len()
};
assert!(grew);
Ok(())
}
#[tokio::test]
async fn tool_forcing_retry_appends_tool_result_for_forced_tool_use() -> Result<()> {
use agent_sdk_foundation::llm::Content;
let provider = ScriptedProvider::new(
"anthropic",
StructuredOutputSupport::ToolForcing,
vec![
success(respond_tool_block(serde_json::json!({"name": "x"}))),
success(respond_tool_block(
serde_json::json!({"name": "x", "age": 1}),
)),
],
);
let out = run_structured(
&provider,
request_with_format(),
StructuredConfig { max_retries: 1 },
)
.await?;
assert_eq!(out.retries, 1);
let seen = provider.seen_requests.lock().expect("seen lock");
let retry = &seen[1];
let assistant_tool_use_id = retry
.messages
.iter()
.find_map(|m| match &m.content {
Content::Blocks(blocks) => blocks.iter().find_map(|b| match b {
ContentBlock::ToolUse { id, name, .. } if name == RESPOND_TOOL_NAME => {
Some(id.clone())
}
_ => None,
}),
Content::Text(_) => None,
})
.expect("assistant respond tool_use present in retry");
let has_matching_result = retry.messages.iter().any(|m| match &m.content {
Content::Blocks(blocks) => blocks.iter().any(|b| {
matches!(
b,
ContentBlock::ToolResult { tool_use_id, .. }
if *tool_use_id == assistant_tool_use_id
)
}),
Content::Text(_) => false,
});
drop(seen);
assert!(
has_matching_result,
"retry must carry a tool_result for the forced tool_use id"
);
Ok(())
}
#[tokio::test]
async fn retry_exhaustion_yields_typed_error() -> Result<()> {
let provider = ScriptedProvider::new(
"anthropic",
StructuredOutputSupport::ToolForcing,
vec![
success(respond_tool_block(serde_json::json!({"name": "x"}))),
success(respond_tool_block(serde_json::json!({"name": "y"}))),
success(respond_tool_block(serde_json::json!({"name": "z"}))),
],
);
let err = run_structured(
&provider,
request_with_format(),
StructuredConfig { max_retries: 2 },
)
.await
.expect_err("schema never satisfied");
match err {
StructuredOutputError::RetriesExhausted {
attempts,
last_value,
..
} => {
assert_eq!(attempts, 3, "1 initial + 2 retries");
assert_eq!(
last_value.as_ref().and_then(|v| v["name"].as_str()),
Some("z")
);
}
other => panic!("expected RetriesExhausted, got {other:?}"),
}
assert_eq!(provider.call_count(), 3);
Ok(())
}
#[tokio::test]
async fn zero_retries_fails_after_single_attempt() -> Result<()> {
let provider = ScriptedProvider::new(
"openai",
StructuredOutputSupport::Native,
vec![success(text_block(r#"{"name": "Ada"}"#))],
);
let err = run_structured(
&provider,
request_with_format(),
StructuredConfig { max_retries: 0 },
)
.await
.expect_err("missing required `age`");
assert!(matches!(
err,
StructuredOutputError::RetriesExhausted { attempts: 1, .. }
));
assert_eq!(provider.call_count(), 1);
Ok(())
}
#[tokio::test]
async fn missing_response_format_is_typed_error() {
let provider = ScriptedProvider::new(
"openai",
StructuredOutputSupport::Native,
vec![success(text_block("{}"))],
);
let mut req = request_with_format();
req.response_format = None;
let err = run_structured(&provider, req, StructuredConfig::default())
.await
.expect_err("no response format");
assert!(matches!(err, StructuredOutputError::MissingResponseFormat));
}
#[tokio::test]
async fn invalid_schema_is_typed_error() {
let provider = ScriptedProvider::new(
"openai",
StructuredOutputSupport::Native,
vec![success(text_block("{}"))],
);
let mut req = request_with_format();
req.response_format = Some(ResponseFormat::new("bad", serde_json::json!({"type": 123})));
let err = run_structured(&provider, req, StructuredConfig::default())
.await
.expect_err("invalid schema");
assert!(matches!(err, StructuredOutputError::InvalidSchema(_)));
}
#[tokio::test]
async fn provider_rate_limit_surfaces_as_typed_error() {
let provider = ScriptedProvider::new(
"openai",
StructuredOutputSupport::Native,
vec![ChatOutcome::RateLimited(None)],
);
let err = run_structured(
&provider,
request_with_format(),
StructuredConfig::default(),
)
.await
.expect_err("rate limited");
assert!(matches!(err, StructuredOutputError::ProviderOutcome(_)));
}
#[tokio::test]
async fn no_structured_output_on_final_attempt_errors() {
let provider = ScriptedProvider::new(
"openai",
StructuredOutputSupport::Native,
vec![
success(text_block("I cannot do that.")),
success(text_block("Still prose, sorry.")),
],
);
let err = run_structured(
&provider,
request_with_format(),
StructuredConfig { max_retries: 1 },
)
.await
.expect_err("never produced JSON");
assert!(matches!(err, StructuredOutputError::NoStructuredOutput));
assert_eq!(provider.call_count(), 2);
}
struct StreamingProvider {
provider_name: &'static str,
model: String,
support: StructuredOutputSupport,
deltas: Mutex<Vec<StreamDelta>>,
}
impl StreamingProvider {
fn new(
provider_name: &'static str,
support: StructuredOutputSupport,
deltas: Vec<StreamDelta>,
) -> Self {
Self {
provider_name,
model: "scripted-model".to_owned(),
support,
deltas: Mutex::new(deltas),
}
}
}
#[async_trait]
impl LlmProvider for StreamingProvider {
async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
Ok(ChatOutcome::ServerError("chat() not used".to_owned()))
}
fn chat_stream(&self, _request: ChatRequest) -> StreamBox<'_> {
let deltas = self.deltas.lock().map(|d| d.clone()).unwrap_or_default();
Box::pin(async_stream::stream! {
for delta in deltas {
yield Ok(delta);
}
})
}
fn model(&self) -> &str {
&self.model
}
fn provider(&self) -> &'static str {
self.provider_name
}
fn structured_output_support(&self) -> StructuredOutputSupport {
self.support
}
}
async fn drive_stream(
mut stream: StructuredStream<'_>,
) -> Result<(Vec<serde_json::Value>, Option<StructuredOutput>)> {
let mut partials = Vec::new();
let mut final_out = None;
while let Some(update) = stream.next().await {
match update? {
StructuredStreamUpdate::Partial(value) => partials.push(value),
StructuredStreamUpdate::Final(out) => final_out = Some(out),
}
}
Ok((partials, final_out))
}
#[tokio::test]
async fn streaming_native_emits_partials_then_validated_final() -> Result<()> {
let provider = StreamingProvider::new(
"openai",
StructuredOutputSupport::Native,
vec![
StreamDelta::TextDelta {
delta: r#"{"name": "Ada""#.to_owned(),
block_index: 0,
},
StreamDelta::TextDelta {
delta: r#", "age": 36}"#.to_owned(),
block_index: 0,
},
StreamDelta::Done {
stop_reason: Some(StopReason::EndTurn),
},
],
);
let stream = run_structured_stream(
&provider,
request_with_format(),
StructuredConfig::default(),
);
let (partials, final_out) = drive_stream(stream).await?;
assert!(!partials.is_empty(), "expected at least one partial");
assert_eq!(partials[0]["name"], "Ada");
let final_out = final_out.expect("a validated final value");
assert_eq!(final_out.value["name"], "Ada");
assert_eq!(final_out.value["age"], 36);
assert_eq!(final_out.retries, 0);
Ok(())
}
#[tokio::test]
async fn streaming_tool_forcing_reads_partial_tool_input() -> Result<()> {
let provider = StreamingProvider::new(
"anthropic",
StructuredOutputSupport::ToolForcing,
vec![
StreamDelta::ToolUseStart {
id: "call_1".to_owned(),
name: RESPOND_TOOL_NAME.to_owned(),
block_index: 0,
thought_signature: None,
},
StreamDelta::ToolInputDelta {
id: "call_1".to_owned(),
delta: r#"{"name": "Linus""#.to_owned(),
block_index: 0,
},
StreamDelta::ToolInputDelta {
id: "call_1".to_owned(),
delta: r#", "age": 54}"#.to_owned(),
block_index: 0,
},
StreamDelta::Done {
stop_reason: Some(StopReason::ToolUse),
},
],
);
let stream = run_structured_stream(
&provider,
request_with_format(),
StructuredConfig::default(),
);
let (partials, final_out) = drive_stream(stream).await?;
assert_eq!(partials[0]["name"], "Linus");
let final_out = final_out.expect("a validated final value");
assert_eq!(final_out.value["age"], 54);
Ok(())
}
#[tokio::test]
async fn streaming_missing_response_format_errors() {
let provider =
StreamingProvider::new("openai", StructuredOutputSupport::Native, Vec::new());
let mut req = request_with_format();
req.response_format = None;
let mut stream = run_structured_stream(&provider, req, StructuredConfig::default());
let first = stream.next().await.expect("one item");
assert!(matches!(
first,
Err(StructuredOutputError::MissingResponseFormat)
));
}
#[test]
fn partial_from_buffer_repairs_incomplete_json() {
assert_eq!(
partial_from_buffer(r#"{"name": "Ada""#).map(|v| v["name"].clone()),
Some(serde_json::json!("Ada"))
);
assert_eq!(
partial_from_buffer(r#"{"a": 1,"#),
Some(serde_json::json!({"a": 1}))
);
assert_eq!(
partial_from_buffer(r#"{"a":"#),
Some(serde_json::json!({"a": null}))
);
assert!(partial_from_buffer("").is_none());
assert!(partial_from_buffer("not json").is_none());
}
}