use std::marker::PhantomData;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
#[cfg(feature = "blocking")]
use crate::blocking::BlockingResponsesClient;
#[cfg(all(feature = "blocking", feature = "streaming"))]
use crate::responses::BlockingStructuredJSONStream;
use crate::types::{InputItem, JSONSchemaFormat, OutputFormat, OutputFormatKind, Response};
#[cfg(test)]
use crate::types::{ContentPart, MessageRole};
#[derive(Debug, Clone)]
pub struct AttemptRecord {
pub attempt: u32,
pub raw_json: String,
pub error: StructuredErrorKind,
}
#[derive(Debug, Clone)]
pub enum StructuredErrorKind {
Decode {
message: String,
},
Validation {
issues: Vec<ValidationIssue>,
},
}
impl std::fmt::Display for StructuredErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StructuredErrorKind::Decode { message } => write!(f, "decode error: {}", message),
StructuredErrorKind::Validation { issues } => {
write!(f, "validation error: ")?;
for (i, issue) in issues.iter().enumerate() {
if i > 0 {
write!(f, "; ")?;
}
write!(f, "{}", issue)?;
}
Ok(())
}
}
}
}
#[derive(Debug, Clone)]
pub struct ValidationIssue {
pub path: Option<String>,
pub message: String,
}
impl std::fmt::Display for ValidationIssue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(path) = &self.path {
write!(f, "{}: {}", path, self.message)
} else {
write!(f, "{}", self.message)
}
}
}
#[derive(Debug)]
pub struct StructuredDecodeError {
pub raw_json: String,
pub message: String,
pub attempt: u32,
}
impl std::fmt::Display for StructuredDecodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"structured output decode error (attempt {}): {}",
self.attempt, self.message
)
}
}
impl std::error::Error for StructuredDecodeError {}
#[derive(Debug)]
pub struct StructuredExhaustedError {
pub last_raw_json: String,
pub all_attempts: Vec<AttemptRecord>,
pub final_error: StructuredErrorKind,
}
impl std::fmt::Display for StructuredExhaustedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"structured output failed after {} attempts: {}",
self.all_attempts.len(),
self.final_error
)
}
}
impl std::error::Error for StructuredExhaustedError {}
#[derive(Debug)]
pub enum StructuredError {
Decode(StructuredDecodeError),
Exhausted(StructuredExhaustedError),
Sdk(crate::Error),
}
impl std::fmt::Display for StructuredError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StructuredError::Decode(e) => write!(f, "{}", e),
StructuredError::Exhausted(e) => write!(f, "{}", e),
StructuredError::Sdk(e) => write!(f, "{}", e),
}
}
}
impl std::error::Error for StructuredError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
StructuredError::Decode(e) => Some(e),
StructuredError::Exhausted(e) => Some(e),
StructuredError::Sdk(e) => Some(e),
}
}
}
impl From<crate::Error> for StructuredError {
fn from(e: crate::Error) -> Self {
StructuredError::Sdk(e)
}
}
pub trait RetryHandler: Send + Sync {
fn on_validation_error(
&self,
attempt: u32,
raw_json: &str,
error: &StructuredErrorKind,
messages: &[InputItem],
) -> Option<Vec<InputItem>>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct DefaultRetryHandler;
impl RetryHandler for DefaultRetryHandler {
fn on_validation_error(
&self,
_attempt: u32,
_raw_json: &str,
error: &StructuredErrorKind,
_messages: &[InputItem],
) -> Option<Vec<InputItem>> {
let error_msg = format!(
"The previous response did not match the expected schema. Error: {}. \
Please provide a response that matches the schema exactly.",
error
);
Some(vec![InputItem::user(error_msg)])
}
}
#[derive(Debug, Clone)]
pub struct StructuredOptions<H: RetryHandler = DefaultRetryHandler> {
pub max_retries: u32,
pub retry_handler: H,
pub schema_name: Option<String>,
}
impl Default for StructuredOptions<DefaultRetryHandler> {
fn default() -> Self {
Self {
max_retries: 0,
retry_handler: DefaultRetryHandler,
schema_name: None,
}
}
}
impl<H: RetryHandler> StructuredOptions<H> {
pub fn with_max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
pub fn with_retry_handler<H2: RetryHandler>(self, handler: H2) -> StructuredOptions<H2> {
StructuredOptions {
max_retries: self.max_retries,
retry_handler: handler,
schema_name: self.schema_name,
}
}
pub fn with_schema_name(mut self, name: impl Into<String>) -> Self {
self.schema_name = Some(name.into());
self
}
}
#[derive(Debug, Clone)]
pub struct StructuredResult<T> {
pub value: T,
pub attempts: u32,
pub request_id: Option<String>,
}
pub(crate) enum RetryDecision<T> {
Success(StructuredResult<T>),
Retry(Vec<InputItem>),
Exhausted(StructuredExhaustedError),
}
pub(crate) struct RetryExecutor<'a, H: RetryHandler> {
options: &'a StructuredOptions<H>,
attempts: Vec<AttemptRecord>,
current_attempt: u32,
max_attempts: u32,
last_request_id: Option<String>,
}
impl<'a, H: RetryHandler> RetryExecutor<'a, H> {
pub(crate) fn new(options: &'a StructuredOptions<H>) -> Self {
Self {
options,
attempts: Vec::new(),
current_attempt: 0,
max_attempts: options.max_retries + 1,
last_request_id: None,
}
}
pub(crate) fn process_response<T: DeserializeOwned>(
&mut self,
response: &Response,
messages: &[InputItem],
) -> Result<RetryDecision<T>, StructuredError> {
self.current_attempt += 1;
self.last_request_id = response.request_id.clone();
let raw_json = extract_json_content(response)?;
match serde_json::from_str::<T>(&raw_json) {
Ok(value) => Ok(RetryDecision::Success(StructuredResult {
value,
attempts: self.current_attempt,
request_id: self.last_request_id.clone(),
})),
Err(e) => {
let error = StructuredErrorKind::Decode {
message: e.to_string(),
};
self.attempts.push(AttemptRecord {
attempt: self.current_attempt,
raw_json: raw_json.clone(),
error: error.clone(),
});
if self.current_attempt >= self.max_attempts {
return Ok(RetryDecision::Exhausted(StructuredExhaustedError {
last_raw_json: raw_json,
all_attempts: std::mem::take(&mut self.attempts),
final_error: error,
}));
}
match self.options.retry_handler.on_validation_error(
self.current_attempt,
&raw_json,
&error,
messages,
) {
Some(retry_messages) => Ok(RetryDecision::Retry(retry_messages)),
None => {
Ok(RetryDecision::Exhausted(StructuredExhaustedError {
last_raw_json: raw_json,
all_attempts: std::mem::take(&mut self.attempts),
final_error: error,
}))
}
}
}
}
}
}
pub fn output_format_from_type<T: JsonSchema>(
schema_name: Option<&str>,
) -> Result<OutputFormat, StructuredError> {
let root_schema = schemars::schema_for!(T);
let schema_value = serde_json::to_value(&root_schema)
.map_err(|e| StructuredError::Sdk(crate::Error::Serialization(e)))?;
let name = schema_name.map(|s| s.to_string()).unwrap_or_else(|| {
std::any::type_name::<T>()
.split("::")
.last()
.unwrap_or("response")
.to_string()
});
Ok(OutputFormat {
kind: OutputFormatKind::JsonSchema,
json_schema: Some(JSONSchemaFormat {
name,
description: None,
schema: schema_value,
strict: Some(true),
}),
})
}
pub struct StructuredResponseBuilder<T, H: RetryHandler = DefaultRetryHandler> {
pub(crate) inner: crate::responses::ResponseBuilder,
pub(crate) options: StructuredOptions<H>,
pub(crate) _marker: PhantomData<T>,
}
impl<T: JsonSchema + DeserializeOwned> StructuredResponseBuilder<T, DefaultRetryHandler> {
pub fn new(inner: crate::responses::ResponseBuilder) -> Self {
Self {
inner,
options: StructuredOptions::default(),
_marker: PhantomData,
}
}
}
impl<T: JsonSchema + DeserializeOwned, H: RetryHandler> StructuredResponseBuilder<T, H> {
pub fn max_retries(mut self, retries: u32) -> Self {
self.options.max_retries = retries;
self
}
pub fn retry_handler<H2: RetryHandler>(self, handler: H2) -> StructuredResponseBuilder<T, H2> {
StructuredResponseBuilder {
inner: self.inner,
options: self.options.with_retry_handler(handler),
_marker: PhantomData,
}
}
pub fn schema_name(mut self, name: impl Into<String>) -> Self {
self.options.schema_name = Some(name.into());
self
}
pub fn provider(mut self, provider: crate::identifiers::ProviderId) -> Self {
self.inner = self.inner.provider(provider);
self
}
pub fn model(mut self, model: impl Into<crate::types::Model>) -> Self {
self.inner = self.inner.model(model);
self
}
pub fn customer_id(mut self, customer_id: impl Into<String>) -> Self {
self.inner = self.inner.customer_id(customer_id);
self
}
pub fn system(mut self, content: impl Into<String>) -> Self {
self.inner = self.inner.system(content);
self
}
pub fn user(mut self, content: impl Into<String>) -> Self {
self.inner = self.inner.user(content);
self
}
pub fn assistant(mut self, content: impl Into<String>) -> Self {
self.inner = self.inner.assistant(content);
self
}
pub fn tool_result(
mut self,
tool_call_id: impl Into<String>,
content: impl Into<String>,
) -> Self {
self.inner = self.inner.tool_result(tool_call_id, content);
self
}
pub fn item(mut self, item: InputItem) -> Self {
self.inner = self.inner.item(item);
self
}
pub fn input(mut self, input: Vec<InputItem>) -> Self {
self.inner = self.inner.input(input);
self
}
pub fn max_output_tokens(mut self, max_output_tokens: u32) -> Self {
self.inner = self.inner.max_output_tokens(max_output_tokens);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.inner = self.inner.temperature(temperature);
self
}
pub fn stop(mut self, stop: Vec<String>) -> Self {
self.inner = self.inner.stop(stop);
self
}
pub fn tools(mut self, tools: Vec<crate::types::Tool>) -> Self {
self.inner = self.inner.tools(tools);
self
}
pub fn tool_choice(mut self, tool_choice: crate::types::ToolChoice) -> Self {
self.inner = self.inner.tool_choice(tool_choice);
self
}
pub fn request_id(mut self, request_id: impl Into<String>) -> Self {
self.inner = self.inner.request_id(request_id);
self
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.inner = self.inner.header(key, value);
self
}
pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
self.inner = self.inner.timeout(timeout);
self
}
pub fn retry(mut self, retry: crate::RetryConfig) -> Self {
self.inner = self.inner.retry(retry);
self
}
pub async fn send(
self,
client: &crate::client::ResponsesClient,
) -> Result<StructuredResult<T>, StructuredError> {
client
.create_structured::<T, H>(self.inner, self.options)
.await
}
#[cfg(feature = "blocking")]
pub fn send_blocking(
self,
client: &BlockingResponsesClient,
) -> Result<StructuredResult<T>, StructuredError> {
client.create_structured::<T, H>(self.inner, self.options)
}
#[cfg(feature = "streaming")]
pub async fn stream(
self,
client: &crate::client::ResponsesClient,
) -> Result<crate::responses::StructuredJSONStream<T>, StructuredError> {
let output_format = output_format_from_type::<T>(self.options.schema_name.as_deref())?;
self.inner
.output_format(output_format)
.stream_json(client)
.await
.map_err(StructuredError::Sdk)
}
#[cfg(all(feature = "blocking", feature = "streaming"))]
pub fn stream_blocking(
self,
client: &BlockingResponsesClient,
) -> Result<BlockingStructuredJSONStream<T>, StructuredError> {
let output_format = output_format_from_type::<T>(self.options.schema_name.as_deref())?;
self.inner
.output_format(output_format)
.stream_json_blocking(client)
.map_err(StructuredError::Sdk)
}
}
fn extract_json_content(response: &Response) -> Result<String, StructuredError> {
let content = response.text();
if content.trim().is_empty() {
return Err(StructuredError::Sdk(crate::Error::Transport(
crate::errors::TransportError {
kind: crate::errors::TransportErrorKind::EmptyResponse,
message: "response contained no content".to_string(),
source: None,
retries: None,
},
)));
}
Ok(content)
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(dead_code)]
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
struct TestPerson {
name: String,
age: u32,
}
#[test]
fn test_output_format_from_type() {
let format = output_format_from_type::<TestPerson>(None).unwrap();
assert_eq!(format.kind, OutputFormatKind::JsonSchema);
assert!(format.json_schema.is_some());
let schema = format.json_schema.unwrap();
assert_eq!(schema.name, "TestPerson");
assert_eq!(schema.strict, Some(true));
}
#[test]
fn test_output_format_custom_name() {
let format = output_format_from_type::<TestPerson>(Some("person_info")).unwrap();
let schema = format.json_schema.unwrap();
assert_eq!(schema.name, "person_info");
}
#[test]
fn test_structured_error_kind_display() {
let decode_error = StructuredErrorKind::Decode {
message: "expected string".to_string(),
};
assert!(decode_error.to_string().contains("decode error"));
let validation_error = StructuredErrorKind::Validation {
issues: vec![ValidationIssue {
path: Some("person.age".to_string()),
message: "expected integer".to_string(),
}],
};
assert!(validation_error.to_string().contains("person.age"));
}
#[test]
fn test_default_retry_handler() {
let handler = DefaultRetryHandler;
let error = StructuredErrorKind::Decode {
message: "parse error".to_string(),
};
let messages = handler.on_validation_error(1, "{}", &error, &[]);
assert!(messages.is_some());
let msgs = messages.unwrap();
assert_eq!(msgs.len(), 1);
match &msgs[0] {
InputItem::Message { role, content, .. } => {
assert_eq!(*role, MessageRole::User);
let text = content
.iter()
.filter_map(|p| match p {
ContentPart::Text { text } => Some(text.as_str()),
ContentPart::File { .. } => None,
})
.collect::<String>();
assert!(text.contains("schema"));
}
}
}
}