use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, error, info};
#[derive(Debug, Clone)]
pub struct OpenAIProvider {
base: BaseProvider,
organization: Option<String>,
pricing_cache: HashMap<String, ModelPricing>,
}
impl OpenAIProvider {
pub async fn new(config: &ProviderConfig) -> Result<Self> {
let base = BaseProvider::new(config)?;
let base_url = config
.base_url
.clone()
.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
let provider = Self {
base: BaseProvider { base_url, ..base },
organization: config.organization.clone(),
pricing_cache: Self::initialize_pricing_cache(),
};
provider.validate_config().await?;
info!("OpenAI provider '{}' initialized successfully", config.name);
Ok(provider)
}
async fn validate_config(&self) -> Result<()> {
if self.base.api_key.is_empty() {
return Err(
ProviderError::Authentication("OpenAI API key is required".to_string()).into(),
);
}
match self.health_check().await {
Ok(_) => {
debug!("OpenAI provider configuration validated successfully");
Ok(())
}
Err(e) => {
error!("OpenAI provider configuration validation failed: {}", e);
Err(e)
}
}
}
fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
let mut cache = HashMap::new();
cache.insert(
"gpt-4".to_string(),
ModelPricing {
model: "gpt-4".to_string(),
input_cost_per_1k: 0.03,
output_cost_per_1k: 0.06,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"gpt-4-turbo".to_string(),
ModelPricing {
model: "gpt-4-turbo".to_string(),
input_cost_per_1k: 0.01,
output_cost_per_1k: 0.03,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"gpt-3.5-turbo".to_string(),
ModelPricing {
model: "gpt-3.5-turbo".to_string(),
input_cost_per_1k: 0.0015,
output_cost_per_1k: 0.002,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"text-embedding-ada-002".to_string(),
ModelPricing {
model: "text-embedding-ada-002".to_string(),
input_cost_per_1k: 0.0001,
output_cost_per_1k: 0.0,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"dall-e-3".to_string(),
ModelPricing {
model: "dall-e-3".to_string(),
input_cost_per_1k: 0.04, output_cost_per_1k: 0.0,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache
}
fn create_headers(&self) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", self.base.api_key).parse().unwrap(),
);
if let Some(org) = &self.organization {
headers.insert("OpenAI-Organization", org.parse().unwrap());
}
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers
}
async fn make_openai_request(
&self,
endpoint: &str,
body: serde_json::Value,
) -> Result<reqwest::Response> {
let url = format!(
"{}/{}",
self.base.base_url.trim_end_matches('/'),
endpoint.trim_start_matches('/')
);
let response = self
.base
.client
.post(&url)
.headers(self.create_headers())
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 => ProviderError::Authentication(error_text),
429 => ProviderError::RateLimit(error_text),
404 => ProviderError::ModelNotFound(error_text),
400 => ProviderError::InvalidRequest(error_text),
503 => ProviderError::Unavailable(error_text),
_ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
}
.into());
}
Ok(response)
}
}
#[async_trait]
impl Provider for OpenAIProvider {
fn name(&self) -> &str {
&self.base.name
}
fn provider_type(&self) -> ProviderType {
ProviderType::OpenAI
}
async fn supports_model(&self, model: &str) -> bool {
self.base.is_model_supported(model)
|| model.starts_with("gpt-")
|| model.starts_with("text-")
|| model.starts_with("dall-e")
}
async fn supports_images(&self) -> bool {
true
}
async fn supports_embeddings(&self) -> bool {
true
}
async fn supports_streaming(&self) -> bool {
true
}
async fn list_models(&self) -> Result<Vec<Model>> {
debug!("Listing OpenAI models");
let response = self
.base
.client
.get(format!("{}/models", self.base.base_url))
.headers(self.create_headers())
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
let models_response: serde_json::Value = self.base.parse_json_response(response).await?;
let mut models = Vec::new();
if let Some(data) = models_response.get("data").and_then(|d| d.as_array()) {
for model_data in data {
if let Some(id) = model_data.get("id").and_then(|i| i.as_str()) {
models.push(Model {
id: id.to_string(),
object: "model".to_string(),
created: model_data
.get("created")
.and_then(|c| c.as_i64())
.unwrap_or(0) as u64,
owned_by: "openai".to_string(),
});
}
}
}
Ok(models)
}
async fn health_check(&self) -> Result<()> {
debug!("Performing OpenAI health check");
let response = self
.base
.client
.get(format!("{}/models", self.base.base_url))
.headers(self.create_headers())
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if response.status().is_success() {
Ok(())
} else {
Err(ProviderError::Unavailable(format!(
"Health check failed with status: {}",
response.status()
))
.into())
}
}
async fn chat_completion(
&self,
request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<ChatCompletionResponse> {
debug!("OpenAI chat completion for model: {}", request.model);
let body = json!({
"model": request.model,
"messages": request.messages,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
"n": request.n,
"stream": request.stream,
"stop": request.stop,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"logit_bias": request.logit_bias,
"user": request.user
});
if request.stream.unwrap_or(false) {
return Err(ProviderError::InvalidRequest(
"Streaming requests should use chat_completion_stream method".to_string(),
)
.into());
}
let response = self.make_openai_request("chat/completions", body).await?;
let chat_response: ChatCompletionResponse = self.base.parse_json_response(response).await?;
Ok(chat_response)
}
async fn chat_completion_stream(
&self,
request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin + 'static>> {
debug!(
"OpenAI streaming chat completion for model: {}",
request.model
);
let mut stream_request = request.clone();
stream_request.stream = Some(true);
let body = json!({
"model": stream_request.model,
"messages": stream_request.messages,
"max_tokens": stream_request.max_tokens,
"temperature": stream_request.temperature,
"top_p": stream_request.top_p,
"n": stream_request.n,
"stream": true,
"stop": stream_request.stop,
"presence_penalty": stream_request.presence_penalty,
"frequency_penalty": stream_request.frequency_penalty,
"logit_bias": stream_request.logit_bias,
"user": stream_request.user
});
let url = format!(
"{}/{}",
self.base.base_url.trim_end_matches('/'),
"chat/completions"
);
let response = self
.base
.client
.post(&url)
.headers(self.create_headers())
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 => ProviderError::Authentication(error_text),
429 => ProviderError::RateLimit(error_text),
404 => ProviderError::ModelNotFound(error_text),
400 => ProviderError::InvalidRequest(error_text),
503 => ProviderError::Unavailable(error_text),
_ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
}
.into());
}
let stream = crate::core::streaming::providers::OpenAIStreaming::create_stream(response);
Ok(Box::new(stream))
}
async fn completion(
&self,
request: CompletionRequest,
_context: RequestContext,
) -> Result<CompletionResponse> {
debug!("OpenAI completion for model: {}", request.model);
let body = json!({
"model": request.model,
"prompt": request.prompt,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
"n": request.n,
"stream": request.stream,
"stop": request.stop,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"logit_bias": request.logit_bias,
"user": request.user
});
let response = self.make_openai_request("completions", body).await?;
let completion_response: CompletionResponse =
self.base.parse_json_response(response).await?;
Ok(completion_response)
}
async fn embedding(
&self,
request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse> {
debug!("OpenAI embedding for model: {}", request.model);
let body = json!({
"model": request.model,
"input": request.input,
"user": request.user
});
let response = self.make_openai_request("embeddings", body).await?;
let embedding_response: EmbeddingResponse = self.base.parse_json_response(response).await?;
Ok(embedding_response)
}
async fn image_generation(
&self,
request: ImageGenerationRequest,
_context: RequestContext,
) -> Result<ImageGenerationResponse> {
debug!("OpenAI image generation");
let body = json!({
"model": request.model.unwrap_or_else(|| "dall-e-3".to_string()),
"prompt": request.prompt,
"n": request.n,
"size": request.size,
"response_format": request.response_format,
"user": request.user
});
let response = self.make_openai_request("images/generations", body).await?;
let image_response: ImageGenerationResponse =
self.base.parse_json_response(response).await?;
Ok(image_response)
}
async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
if let Some(pricing) = self.pricing_cache.get(model) {
Ok(pricing.clone())
} else {
Ok(ModelPricing {
model: model.to_string(),
input_cost_per_1k: 0.01, output_cost_per_1k: 0.03,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
})
}
}
async fn calculate_cost(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64> {
let pricing = self.get_model_pricing(model).await?;
let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
Ok(input_cost + output_cost)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ProviderConfig;
fn create_test_config() -> ProviderConfig {
ProviderConfig {
name: "test-openai".to_string(),
provider_type: "openai".to_string(),
api_key: "test-key".to_string(),
base_url: Some("https://api.openai.com/v1".to_string()),
models: vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()],
timeout: 30,
max_retries: 3,
organization: None,
api_version: None,
project: None,
weight: 1.0,
rpm: 1000,
tpm: 10000,
enabled: true,
max_concurrent_requests: 10,
retry: crate::config::RetryConfig::default(),
health_check: crate::config::HealthCheckConfig::default(),
settings: std::collections::HashMap::new(),
tags: vec![],
}
}
#[tokio::test]
async fn test_openai_provider_creation() {
let config = create_test_config();
assert!(OpenAIProvider::new(&config).await.is_err()); }
#[tokio::test]
async fn test_model_support() {
let config = create_test_config();
if let Ok(provider) = OpenAIProvider::new(&config).await {
assert!(provider.supports_model("gpt-4").await);
assert!(provider.supports_model("gpt-3.5-turbo").await);
assert!(!provider.supports_model("claude-3").await);
}
}
#[test]
fn test_pricing_cache() {
let cache = OpenAIProvider::initialize_pricing_cache();
assert!(cache.contains_key("gpt-4"));
assert!(cache.contains_key("gpt-3.5-turbo"));
assert!(cache.contains_key("text-embedding-ada-002"));
}
#[tokio::test]
async fn test_cost_calculation() {
let config = create_test_config();
if let Ok(provider) = OpenAIProvider::new(&config).await {
let cost = provider.calculate_cost("gpt-4", 1000, 500).await.unwrap();
assert!((cost - 0.06).abs() < 0.001);
}
}
}