use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use crate::core::providers::base::BaseConfig;
use crate::core::traits::provider::ProviderConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIConfig {
#[serde(flatten)]
pub base: BaseConfig,
pub organization: Option<String>,
pub project: Option<String>,
pub model_mappings: HashMap<String, String>,
pub features: OpenAIFeatures,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIFeatures {
pub o_series_optimizations: bool,
pub gpt5_features: bool,
pub audio_models: bool,
pub image_generation: bool,
pub audio_transcription: bool,
pub fine_tuning: bool,
pub vector_stores: bool,
pub realtime_audio: bool,
}
impl Default for OpenAIFeatures {
fn default() -> Self {
Self {
o_series_optimizations: true,
gpt5_features: false, audio_models: true,
image_generation: true,
audio_transcription: true,
fine_tuning: false, vector_stores: false, realtime_audio: false, }
}
}
impl Default for OpenAIConfig {
fn default() -> Self {
Self {
base: BaseConfig {
api_key: None,
api_base: Some("https://api.openai.com/v1".to_string()),
timeout: 60, max_retries: 3,
headers: HashMap::new(),
organization: None,
api_version: None,
},
organization: None,
project: None,
model_mappings: HashMap::new(),
features: OpenAIFeatures::default(),
}
}
}
impl OpenAIConfig {
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
config.base.api_key = Some(api_key);
}
if let Ok(org) = std::env::var("OPENAI_ORG_ID") {
config.organization = Some(org);
}
if let Ok(project) = std::env::var("OPENAI_PROJECT_ID") {
config.project = Some(project);
}
if let Ok(base_url) = std::env::var("OPENAI_API_BASE") {
config.base.api_base = Some(base_url);
}
if let Ok(timeout_str) = std::env::var("OPENAI_TIMEOUT")
&& let Ok(timeout) = timeout_str.parse::<u64>()
{
config.base.timeout = timeout;
}
config
}
pub fn validate(&self) -> Result<(), String> {
self.base.validate("openai")?;
if let Some(ref api_key) = self.base.api_key
&& !api_key.starts_with("sk-")
&& !api_key.starts_with("sk-proj-")
{
return Err("OpenAI API key must start with 'sk-' or 'sk-proj-'".to_string());
}
if let Some(ref org) = self.organization
&& org.is_empty()
{
return Err("Organization ID cannot be empty".to_string());
}
if let Some(ref project) = self.project
&& project.is_empty()
{
return Err("Project ID cannot be empty".to_string());
}
Ok(())
}
pub fn get_api_base(&self) -> String {
self.base
.api_base
.as_ref()
.unwrap_or(&"https://api.openai.com/v1".to_string())
.clone()
}
pub fn is_feature_enabled(&self, feature: OpenAIFeature) -> bool {
match feature {
OpenAIFeature::OSeriesOptimizations => self.features.o_series_optimizations,
OpenAIFeature::GPT5Features => self.features.gpt5_features,
OpenAIFeature::AudioModels => self.features.audio_models,
OpenAIFeature::ImageGeneration => self.features.image_generation,
OpenAIFeature::AudioTranscription => self.features.audio_transcription,
OpenAIFeature::FineTuning => self.features.fine_tuning,
OpenAIFeature::VectorStores => self.features.vector_stores,
OpenAIFeature::RealtimeAudio => self.features.realtime_audio,
}
}
pub fn get_model_mapping(&self, model: &str) -> String {
self.model_mappings
.get(model)
.unwrap_or(&model.to_string())
.clone()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum OpenAIFeature {
OSeriesOptimizations,
GPT5Features,
AudioModels,
ImageGeneration,
AudioTranscription,
FineTuning,
VectorStores,
RealtimeAudio,
}
impl ProviderConfig for OpenAIConfig {
fn validate(&self) -> Result<(), String> {
self.validate()
}
fn api_key(&self) -> Option<&str> {
self.base.api_key.as_deref()
}
fn api_base(&self) -> Option<&str> {
self.base.api_base.as_deref()
}
fn timeout(&self) -> Duration {
Duration::from_secs(self.base.timeout)
}
fn max_retries(&self) -> u32 {
self.base.max_retries
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = OpenAIConfig::default();
assert_eq!(config.get_api_base(), "https://api.openai.com/v1");
assert!(config.features.image_generation);
}
#[test]
fn test_config_validation() {
let mut config = OpenAIConfig::default();
assert!(config.validate().is_err());
config.base.api_key = Some("sk-test123".to_string());
assert!(config.validate().is_ok());
config.base.api_key = Some("invalid-key".to_string());
assert!(config.validate().is_err());
}
#[test]
fn test_feature_flags() {
let config = OpenAIConfig::default();
assert!(config.is_feature_enabled(OpenAIFeature::ImageGeneration));
assert!(!config.is_feature_enabled(OpenAIFeature::GPT5Features));
assert!(!config.is_feature_enabled(OpenAIFeature::RealtimeAudio));
}
#[test]
fn test_model_mapping() {
let mut config = OpenAIConfig::default();
config
.model_mappings
.insert("gpt-4".to_string(), "gpt-4-0613".to_string());
assert_eq!(config.get_model_mapping("gpt-4"), "gpt-4-0613");
assert_eq!(config.get_model_mapping("gpt-3.5-turbo"), "gpt-3.5-turbo");
}
}