use crate::core::types::{
Content, GenerateOptions, Message, Prompt, Role, StreamPart, ToolDefinition, Usage,
};
use crate::core::{LanguageModel, Result};
use futures::stream::BoxStream;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum OutputMode {
#[default]
Json,
Tool,
}
#[derive(Debug, Clone)]
pub struct ObjectGenerateOptions {
pub model_id: String,
pub schema: serde_json::Value,
pub schema_name: Option<String>,
pub schema_description: Option<String>,
pub mode: OutputMode,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub system: Option<String>,
pub max_retries: u32,
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ObjectGenerateResult {
pub object: serde_json::Value,
pub raw_text: String,
pub usage: Usage,
pub finish_reason: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ObjectStreamPart {
TextDelta { delta: String },
Partial { object: serde_json::Value },
Final {
object: serde_json::Value,
usage: Usage,
},
Error { message: String },
}
pub async fn generate_object(
model: &dyn LanguageModel,
prompt_text: &str,
options: ObjectGenerateOptions,
) -> Result<ObjectGenerateResult> {
let max_retries = options.max_retries;
let mut last_error: Option<String> = None;
for attempt in 0..=max_retries {
let result = generate_object_once(model, prompt_text, &options, last_error.as_deref()).await;
match result {
Ok(gen_result) => {
match validate_schema(&gen_result.object, &options.schema) {
Ok(()) => return Ok(gen_result),
Err(validation_errors) => {
if attempt == max_retries {
return Err(crate::core::error::ProviderError::InvalidResponse(
format!(
"Structured output failed schema validation after {} retries: {}",
max_retries, validation_errors
),
));
}
last_error = Some(format!(
"Your JSON output did not match the required schema. Errors: {validation_errors}. Please fix and try again."
));
}
}
}
Err(e) => {
if attempt == max_retries {
return Err(e);
}
last_error = Some(format!(
"Failed to produce valid JSON: {e}. Please respond with valid JSON only."
));
}
}
}
Err(crate::core::error::ProviderError::InvalidResponse(
"Structured output generation exhausted all retries".to_string(),
))
}
async fn generate_object_once(
model: &dyn LanguageModel,
prompt_text: &str,
options: &ObjectGenerateOptions,
retry_context: Option<&str>,
) -> Result<ObjectGenerateResult> {
let mut messages = Vec::new();
let schema_instruction = format!(
"You MUST respond with valid JSON that conforms to this JSON Schema:\n```json\n{}\n```\nRespond ONLY with the JSON object, no markdown fences, no extra text.",
serde_json::to_string_pretty(&options.schema).unwrap_or_default()
);
let system_text = if let Some(ref sys) = options.system {
format!("{sys}\n\n{schema_instruction}")
} else {
schema_instruction
};
messages.push(Message {
role: Role::System,
content: vec![Content::Text {
text: system_text,
}],
});
messages.push(Message {
role: Role::User,
content: vec![Content::Text {
text: prompt_text.to_string(),
}],
});
if let Some(context) = retry_context {
messages.push(Message {
role: Role::User,
content: vec![Content::Text {
text: context.to_string(),
}],
});
}
let prompt = Prompt { messages };
match options.mode {
OutputMode::Json => {
let tool_name = options
.schema_name
.clone()
.unwrap_or_else(|| "json_output".to_string());
let tool_desc = options
.schema_description
.clone()
.unwrap_or_else(|| "Generate a structured JSON object".to_string());
let gen_options = GenerateOptions {
model_id: options.model_id.clone(),
max_tokens: options.max_tokens,
temperature: options.temperature,
top_p: None,
stop_sequences: None,
tools: None,
response_format: Some(serde_json::json!({
"type": "json_schema",
"json_schema": {
"name": tool_name,
"description": tool_desc,
"schema": options.schema.clone(),
"strict": options.strict.unwrap_or(true)
}
})),
reasoning_format: None,
reasoning_effort: None,
tool_choice: None,
parallel_tool_calls: None,
extra_headers: None,
server_tools: None,
include_citations: None,
include_tool_outputs: None,
};
let result = model.generate(prompt, gen_options).await?;
let object = parse_json_from_text(&result.text)?;
Ok(ObjectGenerateResult {
object,
raw_text: result.text,
usage: result.usage,
finish_reason: result.finish_reason,
})
}
OutputMode::Tool => {
let tool_name = options
.schema_name
.clone()
.unwrap_or_else(|| "json_output".to_string());
let tool_desc = options
.schema_description
.clone()
.unwrap_or_else(|| "Generate a structured JSON object".to_string());
let tool = ToolDefinition {
name: tool_name.clone(),
description: tool_desc,
parameters: options.schema.clone(),
};
let gen_options = GenerateOptions {
model_id: options.model_id.clone(),
max_tokens: options.max_tokens,
temperature: options.temperature,
top_p: None,
stop_sequences: None,
tools: Some(vec![tool]),
response_format: None, reasoning_format: None, reasoning_effort: None,
tool_choice: None, parallel_tool_calls: None, extra_headers: None,
server_tools: None, include_citations: None, include_tool_outputs: None,
};
let result = model.generate(prompt, gen_options).await?;
let object = if let Some(tc) = result
.tool_calls
.iter()
.find(|tc| tc.name == tool_name)
{
tc.arguments.clone()
} else {
parse_json_from_text(&result.text)?
};
Ok(ObjectGenerateResult {
object,
raw_text: result.text,
usage: result.usage,
finish_reason: result.finish_reason,
})
}
}
}
pub async fn stream_object(
model: &dyn LanguageModel,
prompt_text: &str,
options: ObjectGenerateOptions,
) -> Result<BoxStream<'static, ObjectStreamPart>> {
let mut messages = Vec::new();
let schema_instruction = format!(
"You MUST respond with valid JSON that conforms to this JSON Schema:\n```json\n{}\n```\nRespond ONLY with the JSON object, no markdown fences, no extra text.",
serde_json::to_string_pretty(&options.schema).unwrap_or_default()
);
let system_text = if let Some(ref sys) = options.system {
format!("{sys}\n\n{schema_instruction}")
} else {
schema_instruction
};
messages.push(Message {
role: Role::System,
content: vec![Content::Text {
text: system_text,
}],
});
messages.push(Message {
role: Role::User,
content: vec![Content::Text {
text: prompt_text.to_string(),
}],
});
let prompt = Prompt { messages };
let tool_name = options
.schema_name
.clone()
.unwrap_or_else(|| "json_output".to_string());
let tool_desc = options
.schema_description
.clone()
.unwrap_or_else(|| "Generate a structured JSON object".to_string());
let gen_options = match options.mode {
OutputMode::Json => GenerateOptions {
model_id: options.model_id.clone(),
max_tokens: options.max_tokens,
temperature: options.temperature,
top_p: None,
stop_sequences: None,
tools: None,
response_format: Some(serde_json::json!({
"type": "json_schema",
"json_schema": {
"name": tool_name.clone(),
"description": tool_desc.clone(),
"schema": options.schema.clone(),
"strict": options.strict.unwrap_or(true)
}
})),
reasoning_format: None,
reasoning_effort: None,
tool_choice: None,
parallel_tool_calls: None,
extra_headers: None,
server_tools: None,
include_citations: None,
include_tool_outputs: None,
},
OutputMode::Tool => {
let tool = ToolDefinition {
name: tool_name.clone(),
description: tool_desc.clone(),
parameters: options.schema.clone(),
};
GenerateOptions {
model_id: options.model_id.clone(),
max_tokens: options.max_tokens,
temperature: options.temperature,
top_p: None,
stop_sequences: None,
tools: Some(vec![tool]),
response_format: None, reasoning_format: None, reasoning_effort: None,
tool_choice: None, parallel_tool_calls: None, extra_headers: None,
server_tools: None, include_citations: None, include_tool_outputs: None,
}
}
};
let mut inner_stream = model.generate_stream(prompt, gen_options).await?;
let stream = async_stream::stream! {
let mut accumulated = String::new();
let mut last_usage = Usage { prompt_tokens: 0, completion_tokens: 0, cache_hit_tokens: None, cache_miss_tokens: None };
let mut chunk_count: u32 = 0;
while let Some(part) = inner_stream.next().await {
match part {
StreamPart::TextDelta { delta } => {
if matches!(options.mode, OutputMode::Json) {
accumulated.push_str(&delta);
chunk_count += 1;
yield ObjectStreamPart::TextDelta { delta };
if chunk_count.is_multiple_of(5) {
if let Ok(partial) = try_parse_partial_json(&accumulated) {
yield ObjectStreamPart::Partial { object: partial };
}
}
}
}
StreamPart::ToolCallDelta { arguments_delta, .. } => {
if matches!(options.mode, OutputMode::Tool) {
if let Some(delta) = arguments_delta {
accumulated.push_str(&delta);
chunk_count += 1;
yield ObjectStreamPart::TextDelta { delta: delta.clone() };
if chunk_count.is_multiple_of(5) {
if let Ok(partial) = try_parse_partial_json(&accumulated) {
yield ObjectStreamPart::Partial { object: partial };
}
}
}
}
}
StreamPart::ExecutedTool { .. } => {
}
StreamPart::Usage { usage } => {
last_usage = usage;
}
StreamPart::Finish { .. } => {
match parse_json_from_text(&accumulated) {
Ok(object) => {
yield ObjectStreamPart::Final {
object,
usage: last_usage.clone(),
};
}
Err(e) => {
yield ObjectStreamPart::Error {
message: format!("Failed to parse final JSON: {e}"),
};
}
}
}
StreamPart::Error { message } => {
yield ObjectStreamPart::Error { message };
}
StreamPart::ReasoningDelta { .. } => {
}
StreamPart::Citation { .. } => {
}
}
}
if !accumulated.is_empty() {
if let Ok(object) = parse_json_from_text(&accumulated) {
yield ObjectStreamPart::Final {
object,
usage: last_usage,
};
}
}
};
Ok(Box::pin(stream))
}
fn validate_schema(
value: &serde_json::Value,
schema: &serde_json::Value,
) -> std::result::Result<(), String> {
let validator = jsonschema::validator_for(schema).map_err(|e| {
format!("Invalid JSON Schema: {e}")
})?;
let errors: Vec<String> = validator
.iter_errors(value)
.map(|e| e.to_string())
.collect();
if errors.is_empty() {
Ok(())
} else {
Err(errors.join("; "))
}
}
fn parse_json_from_text(text: &str) -> Result<serde_json::Value> {
let trimmed = text.trim();
let json_str = if trimmed.starts_with("```json") {
trimmed
.strip_prefix("```json")
.and_then(|s| s.strip_suffix("```"))
.unwrap_or(trimmed)
.trim()
} else if trimmed.starts_with("```") {
trimmed
.strip_prefix("```")
.and_then(|s| s.strip_suffix("```"))
.unwrap_or(trimmed)
.trim()
} else {
trimmed
};
serde_json::from_str(json_str).map_err(|e| {
crate::core::error::ProviderError::InvalidResponse(format!(
"Failed to parse structured output as JSON: {e}\nRaw text: {json_str}"
))
})
}
fn try_parse_partial_json(text: &str) -> std::result::Result<serde_json::Value, ()> {
let trimmed = text.trim();
if let Ok(v) = serde_json::from_str::<serde_json::Value>(trimmed) {
return Ok(v);
}
let mut open_braces: i32 = 0;
let mut open_brackets: i32 = 0;
let mut in_string = false;
let mut escape_next = false;
for ch in trimmed.chars() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' if in_string => escape_next = true,
'"' => in_string = !in_string,
'{' if !in_string => open_braces += 1,
'}' if !in_string => open_braces -= 1,
'[' if !in_string => open_brackets += 1,
']' if !in_string => open_brackets -= 1,
_ => {}
}
}
if open_braces <= 0 && open_brackets <= 0 {
return Err(());
}
let mut patched = trimmed.to_string();
if let Some(stripped) = patched.strip_suffix(',') {
patched = stripped.to_string();
}
for _ in 0..open_brackets {
patched.push(']');
}
for _ in 0..open_braces {
patched.push('}');
}
serde_json::from_str::<serde_json::Value>(&patched).map_err(|_| ())
}
impl Default for ObjectGenerateOptions {
fn default() -> Self {
Self {
model_id: String::new(),
schema: serde_json::Value::Null,
schema_name: None,
schema_description: None,
mode: OutputMode::Json,
max_tokens: None,
temperature: None,
system: None,
max_retries: 2,
strict: None,
}
}
}