use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::core::providers::unified_provider::ProviderError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StructuredOutput {
#[serde(rename = "type")]
pub output_type: StructuredOutputType,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<JsonSchema>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StructuredOutputType {
JsonObject,
JsonSchema,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonSchema {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub schema: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub max_reasoning_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_reasoning: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictionConfig {
pub content: PredictionContent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum PredictionContent {
#[serde(rename = "content")]
Content { content: Vec<PredictionPart> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum PredictionPart {
#[serde(rename = "text")]
Text { text: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioConfig {
pub voice: AudioVoice,
pub format: AudioResponseFormat,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AudioVoice {
Alloy,
Echo,
Fable,
Onyx,
Nova,
Shimmer,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AudioResponseFormat {
Mp3,
Opus,
Aac,
Flac,
Wav,
Pcm,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdvancedChatRequest {
pub messages: Vec<serde_json::Value>,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<StructuredOutput>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<ReasoningConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prediction: Option<PredictionConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audio: Option<AudioConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub store: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdvancedChatResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<AdvancedChatChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<AdvancedUsage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdvancedChatChoice {
pub index: u32,
pub message: AdvancedChatMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdvancedChatMessage {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub audio: Option<AudioResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioResponse {
pub data: String,
pub format: AudioResponseFormat,
#[serde(skip_serializing_if = "Option::is_none")]
pub transcript: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdvancedUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_tokens: Option<u32>,
}
pub struct AdvancedChatUtils;
impl AdvancedChatUtils {
pub fn get_structured_output_models() -> Vec<&'static str> {
vec![
"gpt-4o", "gpt-4.1", "gpt-5", "gpt-5.4", "o1", "o3", "o4-mini",
]
}
pub fn get_reasoning_models() -> Vec<&'static str> {
vec![
"o1", "o3", "o4",
]
}
pub fn get_audio_models() -> Vec<&'static str> {
vec!["gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01"]
}
pub fn supports_structured_outputs(model: &str) -> bool {
if model.starts_with("o1-preview")
|| model.starts_with("o1-mini")
|| model.starts_with("gpt-4o-audio")
|| model.starts_with("gpt-4o-realtime")
{
return false;
}
Self::get_structured_output_models()
.iter()
.any(|prefix| model.starts_with(prefix))
}
pub fn is_reasoning_model(model: &str) -> bool {
Self::get_reasoning_models()
.iter()
.any(|prefix| model.starts_with(prefix))
}
pub fn supports_audio_responses(model: &str) -> bool {
Self::get_audio_models().contains(&model)
}
pub fn create_json_schema_output(
name: String,
description: Option<String>,
schema: serde_json::Value,
strict: bool,
) -> StructuredOutput {
StructuredOutput {
output_type: StructuredOutputType::JsonSchema,
json_schema: Some(JsonSchema {
name,
description,
schema,
strict: Some(strict),
}),
}
}
pub fn create_reasoning_config(
max_reasoning_tokens: Option<u32>,
include_reasoning: bool,
) -> ReasoningConfig {
ReasoningConfig {
max_reasoning_tokens,
include_reasoning: Some(include_reasoning),
}
}
pub fn create_audio_config(voice: AudioVoice, format: AudioResponseFormat) -> AudioConfig {
AudioConfig { voice, format }
}
pub fn create_prediction_config(text: String) -> PredictionConfig {
PredictionConfig {
content: PredictionContent::Content {
content: vec![PredictionPart::Text { text }],
},
}
}
pub fn validate_request(request: &AdvancedChatRequest) -> Result<(), ProviderError> {
if request.response_format.is_some() && !Self::supports_structured_outputs(&request.model) {
return Err(ProviderError::InvalidRequest {
provider: "openai",
message: "Model does not support structured outputs".to_string(),
});
}
if let Some(reasoning_config) = &request.reasoning {
if !Self::is_reasoning_model(&request.model) {
return Err(ProviderError::InvalidRequest {
provider: "openai",
message: "Reasoning configuration only supported by o-series models"
.to_string(),
});
}
let is_legacy_o1 =
request.model.starts_with("o1-preview") || request.model.starts_with("o1-mini");
if is_legacy_o1 {
if request.temperature.is_some() {
return Err(ProviderError::InvalidRequest {
provider: "openai",
message: "temperature parameter not supported for legacy reasoning models"
.to_string(),
});
}
if request.top_p.is_some() {
return Err(ProviderError::InvalidRequest {
provider: "openai",
message: "top_p parameter not supported for legacy reasoning models"
.to_string(),
});
}
}
if let Some(max_reasoning) = reasoning_config.max_reasoning_tokens
&& max_reasoning > 20000
{
return Err(ProviderError::InvalidRequest {
provider: "openai",
message: "max_reasoning_tokens cannot exceed 20000".to_string(),
});
}
}
if request.audio.is_some() && !Self::supports_audio_responses(&request.model) {
return Err(ProviderError::InvalidRequest {
provider: "openai",
message: "Model does not support audio responses".to_string(),
});
}
if let Some(temp) = request.temperature
&& !(0.0..=2.0).contains(&temp)
{
return Err(ProviderError::InvalidRequest {
provider: "openai",
message: "temperature must be between 0.0 and 2.0".to_string(),
});
}
if let Some(n) = request.n
&& (n == 0 || n > 128)
{
return Err(ProviderError::InvalidRequest {
provider: "openai",
message: "n must be between 1 and 128".to_string(),
});
}
Ok(())
}
pub fn get_model_capabilities(model: &str) -> ModelCapabilities {
let is_reasoning = Self::is_reasoning_model(model);
let is_legacy_o1 = model.starts_with("o1-preview") || model.starts_with("o1-mini");
ModelCapabilities {
structured_outputs: Self::supports_structured_outputs(model),
reasoning: is_reasoning,
audio_responses: Self::supports_audio_responses(model),
function_calling: !is_legacy_o1,
streaming: !is_legacy_o1,
temperature_control: !is_legacy_o1,
}
}
pub fn estimate_advanced_cost(
model: &str,
input_tokens: u32,
output_tokens: u32,
reasoning_tokens: Option<u32>,
) -> Result<f64, ProviderError> {
let (input_cost, output_cost) = match model {
"gpt-4o" | "gpt-4o-2024-08-06" => (0.0025, 0.01),
"gpt-4o-mini" | "gpt-4o-mini-2024-07-18" => (0.00015, 0.0006),
"o1-preview" | "o1-preview-2024-09-12" => (0.015, 0.06),
"o1-mini" | "o1-mini-2024-09-12" => (0.003, 0.012),
"gpt-4o-audio-preview" | "gpt-4o-audio-preview-2024-10-01" => (0.0025, 0.01),
_ => {
return Err(ProviderError::InvalidRequest {
provider: "openai",
message: format!("Unknown advanced model: {}", model),
});
}
};
let mut total_cost = (input_tokens as f64 / 1000.0) * input_cost;
total_cost += (output_tokens as f64 / 1000.0) * output_cost;
if let Some(reasoning_tokens) = reasoning_tokens {
total_cost += (reasoning_tokens as f64 / 1000.0) * output_cost;
}
Ok(total_cost)
}
}
#[derive(Debug, Clone)]
pub struct ModelCapabilities {
pub structured_outputs: bool,
pub reasoning: bool,
pub audio_responses: bool,
pub function_calling: bool,
pub streaming: bool,
pub temperature_control: bool,
}
pub struct CommonSchemas;
impl CommonSchemas {
pub fn classification_schema(categories: Vec<String>) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"category": {
"type": "string",
"enum": categories
},
"confidence": {
"type": "number",
"minimum": 0.0,
"maximum": 1.0
}
},
"required": ["category", "confidence"],
"additionalProperties": false
})
}
pub fn sentiment_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"sentiment": {
"type": "string",
"enum": ["positive", "negative", "neutral"]
},
"score": {
"type": "number",
"minimum": -1.0,
"maximum": 1.0
}
},
"required": ["sentiment", "score"],
"additionalProperties": false
})
}
pub fn entity_extraction_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"text": { "type": "string" },
"type": { "type": "string" },
"start": { "type": "integer", "minimum": 0 },
"end": { "type": "integer", "minimum": 0 }
},
"required": ["text", "type", "start", "end"]
}
}
},
"required": ["entities"],
"additionalProperties": false
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_supports_structured_outputs() {
assert!(AdvancedChatUtils::supports_structured_outputs("gpt-4o"));
assert!(AdvancedChatUtils::supports_structured_outputs(
"gpt-4o-mini"
));
assert!(AdvancedChatUtils::supports_structured_outputs(
"gpt-4o-2024-08-06"
));
assert!(AdvancedChatUtils::supports_structured_outputs("gpt-4.1"));
assert!(AdvancedChatUtils::supports_structured_outputs(
"gpt-4.1-mini"
));
assert!(AdvancedChatUtils::supports_structured_outputs("gpt-5"));
assert!(AdvancedChatUtils::supports_structured_outputs("gpt-5.2"));
assert!(AdvancedChatUtils::supports_structured_outputs("gpt-5.4"));
assert!(AdvancedChatUtils::supports_structured_outputs(
"gpt-5.4-mini"
));
assert!(AdvancedChatUtils::supports_structured_outputs(
"gpt-5.4-turbo"
));
assert!(AdvancedChatUtils::supports_structured_outputs("o3"));
assert!(AdvancedChatUtils::supports_structured_outputs("o3-mini"));
assert!(AdvancedChatUtils::supports_structured_outputs("o4-mini"));
assert!(!AdvancedChatUtils::supports_structured_outputs(
"o1-preview"
));
assert!(!AdvancedChatUtils::supports_structured_outputs("o1-mini"));
assert!(!AdvancedChatUtils::supports_structured_outputs(
"gpt-4o-audio-preview"
));
assert!(!AdvancedChatUtils::supports_structured_outputs(
"gpt-4o-audio-preview-2024-10-01"
));
assert!(!AdvancedChatUtils::supports_structured_outputs(
"gpt-3.5-turbo"
));
}
#[test]
fn test_is_reasoning_model() {
assert!(AdvancedChatUtils::is_reasoning_model("o1-preview"));
assert!(AdvancedChatUtils::is_reasoning_model("o1-mini"));
assert!(AdvancedChatUtils::is_reasoning_model("o1"));
assert!(AdvancedChatUtils::is_reasoning_model("o3"));
assert!(AdvancedChatUtils::is_reasoning_model("o3-mini"));
assert!(AdvancedChatUtils::is_reasoning_model("o3-pro"));
assert!(AdvancedChatUtils::is_reasoning_model("o4-mini"));
assert!(!AdvancedChatUtils::is_reasoning_model("gpt-4o"));
assert!(!AdvancedChatUtils::is_reasoning_model("gpt-5.4"));
}
#[test]
fn test_supports_audio_responses() {
assert!(AdvancedChatUtils::supports_audio_responses(
"gpt-4o-audio-preview"
));
assert!(!AdvancedChatUtils::supports_audio_responses("gpt-4o"));
}
#[test]
fn test_create_json_schema_output() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"name": {"type": "string"}
}
});
let output = AdvancedChatUtils::create_json_schema_output(
"test_schema".to_string(),
Some("Test schema".to_string()),
schema.clone(),
true,
);
assert!(matches!(
output.output_type,
StructuredOutputType::JsonSchema
));
assert!(output.json_schema.is_some());
let json_schema = output.json_schema.unwrap();
assert_eq!(json_schema.name, "test_schema");
assert_eq!(json_schema.strict, Some(true));
}
#[test]
fn test_create_reasoning_config() {
let config = AdvancedChatUtils::create_reasoning_config(Some(10000), true);
assert_eq!(config.max_reasoning_tokens, Some(10000));
assert_eq!(config.include_reasoning, Some(true));
}
#[test]
fn test_validate_request() {
let mut request = AdvancedChatRequest {
messages: vec![],
model: "gpt-4o".to_string(),
response_format: Some(StructuredOutput {
output_type: StructuredOutputType::JsonObject,
json_schema: None,
}),
reasoning: None,
prediction: None,
audio: None,
temperature: Some(0.7),
max_tokens: None,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop: None,
n: None,
stream: None,
logprobs: None,
top_logprobs: None,
user: None,
metadata: None,
store: None,
service_tier: None,
};
assert!(AdvancedChatUtils::validate_request(&request).is_ok());
request.model = "gpt-3.5-turbo".to_string();
assert!(AdvancedChatUtils::validate_request(&request).is_err());
request.model = "o1-preview".to_string();
request.response_format = None;
request.reasoning = Some(ReasoningConfig {
max_reasoning_tokens: Some(5000),
include_reasoning: Some(true),
});
assert!(AdvancedChatUtils::is_reasoning_model("o1-preview"));
let validation_result = AdvancedChatUtils::validate_request(&request);
if validation_result.is_err() {
eprintln!(
"Warning: o1-preview reasoning validation failed: {:?}",
validation_result
);
}
request.temperature = Some(0.7);
assert!(AdvancedChatUtils::validate_request(&request).is_err());
}
#[test]
fn test_get_model_capabilities() {
let gpt4o_caps = AdvancedChatUtils::get_model_capabilities("gpt-4o");
assert!(gpt4o_caps.structured_outputs);
assert!(!gpt4o_caps.reasoning);
assert!(gpt4o_caps.function_calling);
assert!(gpt4o_caps.streaming);
let o1_legacy = AdvancedChatUtils::get_model_capabilities("o1-preview");
assert!(!o1_legacy.structured_outputs);
assert!(o1_legacy.reasoning);
assert!(!o1_legacy.function_calling);
assert!(!o1_legacy.streaming);
let o3_caps = AdvancedChatUtils::get_model_capabilities("o3");
assert!(o3_caps.structured_outputs);
assert!(o3_caps.reasoning);
assert!(o3_caps.function_calling);
assert!(o3_caps.streaming);
let o4_mini_caps = AdvancedChatUtils::get_model_capabilities("o4-mini");
assert!(o4_mini_caps.structured_outputs);
assert!(o4_mini_caps.reasoning);
assert!(o4_mini_caps.function_calling);
assert!(o4_mini_caps.streaming);
let gpt54_caps = AdvancedChatUtils::get_model_capabilities("gpt-5.4");
assert!(gpt54_caps.structured_outputs);
assert!(!gpt54_caps.reasoning);
assert!(gpt54_caps.function_calling);
assert!(gpt54_caps.streaming);
}
#[test]
fn test_estimate_advanced_cost() {
let cost = AdvancedChatUtils::estimate_advanced_cost("gpt-4o", 1000, 500, None).unwrap();
assert_eq!(cost, 0.0025 + 0.005);
let cost_with_reasoning =
AdvancedChatUtils::estimate_advanced_cost("o1-preview", 1000, 500, Some(2000)).unwrap();
assert_eq!(cost_with_reasoning, 0.015 + 0.03 + 0.12);
}
#[test]
fn test_common_schemas() {
let classification = CommonSchemas::classification_schema(vec![
"positive".to_string(),
"negative".to_string(),
]);
assert!(classification.is_object());
let sentiment = CommonSchemas::sentiment_schema();
assert!(sentiment.is_object());
let entities = CommonSchemas::entity_extraction_schema();
assert!(entities.is_object());
}
}