use std::time::{Duration, Instant};
use async_trait::async_trait;
use llmkit_core::{
pricing, ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse,
LlmError, LlmProvider, LlmResult,
};
use crate::types::{ApiError, ChatCompletionResponse, EmbeddingsResponse};
use crate::{chat, embed, stream};
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
const DEFAULT_MODEL: &str = "gpt-4o-mini";
const DEFAULT_EMBED_MODEL: &str = "text-embedding-3-small";
#[derive(Clone)]
pub struct OpenAiProvider {
http: reqwest::Client,
api_key: String,
base_url: String,
model: String,
embed_model: String,
}
impl OpenAiProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
http: reqwest::Client::new(),
api_key: api_key.into(),
base_url: DEFAULT_BASE_URL.to_string(),
model: DEFAULT_MODEL.to_string(),
embed_model: DEFAULT_EMBED_MODEL.to_string(),
}
}
pub fn from_env() -> LlmResult<Self> {
let key = std::env::var("OPENAI_API_KEY")
.map_err(|_| LlmError::Auth("OPENAI_API_KEY not set".into()))?;
Ok(Self::new(key))
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn embed_model(mut self, model: impl Into<String>) -> Self {
self.embed_model = model.into();
self
}
pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub fn with_client(mut self, client: reqwest::Client) -> Self {
self.http = client;
self
}
fn resolved_model(&self, req: &ChatRequest) -> String {
req.model.clone().unwrap_or_else(|| self.model.clone())
}
}
#[async_trait]
impl LlmProvider for OpenAiProvider {
async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
let model = self.resolved_model(&req);
let body = chat::build_request(&req, model, false);
let start = Instant::now();
let resp = self
.http
.post(format!("{}/chat/completions", self.base_url))
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await
.map_err(map_reqwest_err)?;
let resp = check_status(resp).await?;
let parsed: ChatCompletionResponse = resp.json().await.map_err(map_reqwest_err)?;
let mut out = chat::map_response(parsed, start.elapsed().as_millis() as u64)?;
out.cost = pricing::pricing_for(&out.model).map(|p| p.cost_for(out.usage));
Ok(out)
}
async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
let model = self.resolved_model(&req);
let body = chat::build_request(&req, model, true);
let resp = self
.http
.post(format!("{}/chat/completions", self.base_url))
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await
.map_err(map_reqwest_err)?;
let resp = check_status(resp).await?;
Ok(stream::parse(resp))
}
async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
let model = req.model.clone().unwrap_or_else(|| self.embed_model.clone());
let body = embed::build_request(req.input, model);
let resp = self
.http
.post(format!("{}/embeddings", self.base_url))
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await
.map_err(map_reqwest_err)?;
let resp = check_status(resp).await?;
let parsed: EmbeddingsResponse = resp.json().await.map_err(map_reqwest_err)?;
Ok(embed::map_response(parsed))
}
fn name(&self) -> &'static str {
"openai"
}
fn model(&self) -> &str {
&self.model
}
fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
let model = self.resolved_model(req);
let pricing = pricing::pricing_for(&model)?;
let prompt_chars: usize = req.messages.iter().filter_map(|m| m.content.as_text()).map(|t| t.len()).sum();
let prompt_tokens = (prompt_chars / 4) as u32;
let completion_tokens = req.max_tokens.unwrap_or(256);
Some(pricing.cost_for(llmkit_core::TokenUsage::new(prompt_tokens, completion_tokens)))
}
}
pub(crate) fn map_reqwest_err(e: reqwest::Error) -> LlmError {
if e.is_timeout() {
LlmError::Timeout
} else if e.is_decode() {
LlmError::Serialization(e.to_string())
} else {
LlmError::Transport(e.to_string())
}
}
pub(crate) async fn check_status(resp: reqwest::Response) -> LlmResult<reqwest::Response> {
let status = resp.status();
if status.is_success() {
return Ok(resp);
}
let retry_after = resp
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_secs);
let body = resp.text().await.unwrap_or_default();
let message = serde_json::from_str::<ApiError>(&body)
.map(|e| e.error.message)
.unwrap_or(body);
Err(match status.as_u16() {
401 | 403 => LlmError::Auth(message),
429 => LlmError::RateLimited { retry_after, message },
400 | 404 | 422 => LlmError::InvalidRequest(message),
code => LlmError::Provider { status: code, message },
})
}