use std::sync::Arc;
use std::time::Duration;
use reqwest::Client;
use serde::Deserialize;
use serde_json::Value;
use super::config::OpenAIConfig;
use crate::chat::ChatRequest;
use crate::error::Result;
use crate::llms::LlmError;
#[derive(Debug, Clone, Deserialize)]
struct OpenAIErrorResponse {
pub error: OpenAIError,
}
#[derive(Debug, Clone, Deserialize)]
struct OpenAIError {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
pub code: Option<String>,
}
#[derive(Debug, Clone)]
pub struct OpenAI {
pub(crate) config: Arc<OpenAIConfig>,
pub(crate) client: Client,
}
impl OpenAI {
pub fn new(config: OpenAIConfig) -> Result<Self> {
if config.api_key.is_empty() {
return Err(LlmError::auth("openai", "API key is required").into());
}
let mut builder = Client::builder();
if let Some(timeout) = config.timeout_secs {
builder = builder.timeout(Duration::from_secs(timeout));
}
let client = builder
.build()
.map_err(|e| LlmError::internal(format!("Failed to create HTTP client: {e}")))?;
Ok(Self {
config: Arc::new(config),
client,
})
}
pub fn from_env() -> Result<Self> {
let config = OpenAIConfig::from_env()?;
Self::new(config)
}
#[must_use]
pub fn api_key(&self) -> &str {
&self.config.api_key
}
#[must_use]
pub fn base_url(&self) -> &str {
&self.config.base_url
}
#[must_use]
pub fn model(&self) -> &str {
&self.config.model
}
pub(crate) fn chat_url(&self) -> String {
format!("{}/chat/completions", self.config.base_url)
}
pub(crate) fn speech_url(&self) -> String {
format!("{}/audio/speech", self.config.base_url)
}
pub(crate) fn transcriptions_url(&self) -> String {
format!("{}/audio/transcriptions", self.config.base_url)
}
pub(crate) fn embeddings_url(&self) -> String {
format!("{}/embeddings", self.config.base_url)
}
pub(crate) fn build_request(&self, url: &str) -> reqwest::RequestBuilder {
let mut req = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json");
if let Some(org) = &self.config.organization {
req = req.header("OpenAI-Organization", org);
}
req
}
pub(crate) fn build_multipart_request(&self, url: &str) -> reqwest::RequestBuilder {
let mut req = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.config.api_key));
if let Some(org) = &self.config.organization {
req = req.header("OpenAI-Organization", org);
}
req
}
pub(crate) fn build_chat_body(
&self,
request: &ChatRequest,
streaming: bool,
) -> Result<Value> {
let mut body = serde_json::to_value(request)
.map_err(|e| LlmError::internal(format!("Failed to serialize request: {e}")))?;
if request.model.is_empty() {
body["model"] = Value::String(self.config.model.clone());
}
if streaming {
body["stream"] = Value::Bool(true);
body["stream_options"] = serde_json::json!({"include_usage": true});
}
Ok(body)
}
pub(crate) fn parse_error(status: u16, body: &str) -> LlmError {
if let Ok(error_response) = serde_json::from_str::<OpenAIErrorResponse>(body) {
let error = error_response.error;
let code = error.code.unwrap_or_else(|| error.error_type.clone());
return match status {
401 => LlmError::auth("openai", error.message),
429 => LlmError::rate_limited("openai"),
400 if error.message.contains("context_length") => {
let (used, max) = parse_context_length_tokens(&error.message);
LlmError::context_exceeded(used, max)
}
_ => LlmError::provider_code("openai", code, error.message),
};
}
LlmError::http_status(status, body.to_owned())
}
}
fn parse_context_length_tokens(message: &str) -> (usize, usize) {
let mut max = 0usize;
let mut used = 0usize;
if let Some(pos) = message.find("maximum context length is ") {
let after = &message[pos + "maximum context length is ".len()..];
if let Some(end) = after.find(|c: char| !c.is_ascii_digit()) {
max = after[..end].parse().unwrap_or(0);
}
}
if let Some(pos) = message.find("resulted in ") {
let after = &message[pos + "resulted in ".len()..];
if let Some(end) = after.find(|c: char| !c.is_ascii_digit()) {
used = after[..end].parse().unwrap_or(0);
}
}
(used, max)
}