use crate::core::types::{
Content, GenerateOptions, Message, Prompt, Role, ToolDefinition, Usage,
};
use crate::core::{LanguageModel, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum OutputMode {
#[default]
Json,
Tool,
}
#[derive(Debug, Clone)]
pub struct ObjectGenerateOptions {
pub model_id: String,
pub schema: serde_json::Value,
pub schema_name: Option<String>,
pub schema_description: Option<String>,
pub mode: OutputMode,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub system: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ObjectGenerateResult {
pub object: serde_json::Value,
pub raw_text: String,
pub usage: Usage,
pub finish_reason: String,
}
pub async fn generate_object(
model: &dyn LanguageModel,
prompt_text: &str,
options: ObjectGenerateOptions,
) -> Result<ObjectGenerateResult> {
let mut messages = Vec::new();
let schema_instruction = format!(
"You MUST respond with valid JSON that conforms to this JSON Schema:\n```json\n{}\n```\nRespond ONLY with the JSON object, no markdown fences, no extra text.",
serde_json::to_string_pretty(&options.schema).unwrap_or_default()
);
let system_text = if let Some(ref sys) = options.system {
format!("{sys}\n\n{schema_instruction}")
} else {
schema_instruction
};
messages.push(Message {
role: Role::System,
content: vec![Content::Text {
text: system_text,
}],
});
messages.push(Message {
role: Role::User,
content: vec![Content::Text {
text: prompt_text.to_string(),
}],
});
let prompt = Prompt { messages };
match options.mode {
OutputMode::Json => {
let gen_options = GenerateOptions {
model_id: options.model_id,
max_tokens: options.max_tokens,
temperature: options.temperature,
top_p: None,
stop_sequences: None,
tools: None,
};
let result = model.generate(prompt, gen_options).await?;
let object = parse_json_from_text(&result.text)?;
Ok(ObjectGenerateResult {
object,
raw_text: result.text,
usage: result.usage,
finish_reason: result.finish_reason,
})
}
OutputMode::Tool => {
let tool_name = options
.schema_name
.unwrap_or_else(|| "json_output".to_string());
let tool_desc = options
.schema_description
.unwrap_or_else(|| "Generate a structured JSON object".to_string());
let tool = ToolDefinition {
name: tool_name.clone(),
description: tool_desc,
parameters: options.schema.clone(),
};
let gen_options = GenerateOptions {
model_id: options.model_id,
max_tokens: options.max_tokens,
temperature: options.temperature,
top_p: None,
stop_sequences: None,
tools: Some(vec![tool]),
};
let result = model.generate(prompt, gen_options).await?;
let object = if let Some(tc) = result
.tool_calls
.iter()
.find(|tc| tc.name == tool_name)
{
tc.arguments.clone()
} else {
parse_json_from_text(&result.text)?
};
Ok(ObjectGenerateResult {
object,
raw_text: result.text,
usage: result.usage,
finish_reason: result.finish_reason,
})
}
}
}
fn parse_json_from_text(text: &str) -> Result<serde_json::Value> {
let trimmed = text.trim();
let json_str = if trimmed.starts_with("```json") {
trimmed
.strip_prefix("```json")
.and_then(|s| s.strip_suffix("```"))
.unwrap_or(trimmed)
.trim()
} else if trimmed.starts_with("```") {
trimmed
.strip_prefix("```")
.and_then(|s| s.strip_suffix("```"))
.unwrap_or(trimmed)
.trim()
} else {
trimmed
};
serde_json::from_str(json_str).map_err(|e| {
crate::core::error::ProviderError::InvalidResponse(format!(
"Failed to parse structured output as JSON: {e}\nRaw text: {json_str}"
))
})
}
impl Default for ObjectGenerateOptions {
fn default() -> Self {
Self {
model_id: String::new(),
schema: serde_json::Value::Null,
schema_name: None,
schema_description: None,
mode: OutputMode::Json,
max_tokens: None,
temperature: None,
system: None,
}
}
}