mod request;
mod response;
mod types;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::StreamExt;
use reqwest::{Client, StatusCode};
use std::time::Duration;
use crate::llm::{
BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
};
use types::*;
const OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
const CHAT_COMPLETIONS_PATH: &str = "/chat/completions";
#[derive(Builder, Clone)]
#[builder(pattern = "owned", build_fn(skip))]
pub struct ChatOpenAI {
#[builder(setter(into))]
pub(super) model: String,
pub(super) api_key: String,
#[builder(setter(into, strip_option), default = "None")]
pub(super) base_url: Option<String>,
#[builder(default = "0.2")]
pub(super) temperature: f32,
#[builder(default = "Some(4096)")]
pub(super) max_completion_tokens: Option<u64>,
#[builder(setter(skip))]
pub(super) client: Client,
#[builder(setter(skip))]
pub(super) context_window: u64,
}
impl ChatOpenAI {
pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| LlmError::Config("OPENAI_API_KEY not set".into()))?;
let base_url = std::env::var("OPENAI_BASE_URL").ok();
let mut builder = Self::builder().model(model).api_key(api_key);
if let Some(url) = base_url {
builder = builder.base_url(url);
}
builder.build()
}
pub fn builder() -> ChatOpenAIBuilder {
ChatOpenAIBuilder::default()
}
fn is_reasoning_model(&self) -> bool {
let model_lower = self.model.to_lowercase();
model_lower.starts_with("o1")
|| model_lower.starts_with("o3")
|| model_lower.starts_with("o4")
|| model_lower.starts_with("gpt-5")
}
fn api_url(&self) -> String {
let base = self.base_url.as_deref().unwrap_or(OPENAI_BASE_URL);
format!("{}{}", base.trim_end_matches('/'), CHAT_COMPLETIONS_PATH)
}
fn map_error_status(status: StatusCode, body: String) -> LlmError {
match status {
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => LlmError::Auth(body),
StatusCode::NOT_FOUND => LlmError::ModelNotFound(body),
StatusCode::TOO_MANY_REQUESTS => LlmError::RateLimit,
_ => LlmError::Api(format!("OpenAI API error ({}): {}", status, body)),
}
}
fn build_client() -> Client {
Client::builder()
.timeout(Duration::from_secs(120))
.build()
.expect("Failed to create HTTP client")
}
fn get_context_window(model: &str) -> u64 {
let model_lower = model.to_lowercase();
if model_lower.contains("gpt-4o") || model_lower.contains("gpt-4-turbo") {
128_000
}
else if model_lower.starts_with("gpt-4") {
8_192
}
else if model_lower.starts_with("gpt-3.5") {
16_385
}
else if model_lower.starts_with("o1")
|| model_lower.starts_with("o3")
|| model_lower.starts_with("o4")
{
200_000
}
else {
128_000
}
}
}
impl ChatOpenAIBuilder {
pub fn build(&self) -> Result<ChatOpenAI, LlmError> {
let model = self
.model
.clone()
.ok_or_else(|| LlmError::Config("model is required".into()))?;
let api_key = self
.api_key
.clone()
.ok_or_else(|| LlmError::Config("api_key is required".into()))?;
Ok(ChatOpenAI {
context_window: ChatOpenAI::get_context_window(&model),
client: ChatOpenAI::build_client(),
model,
api_key,
base_url: self.base_url.clone().flatten(),
temperature: self.temperature.unwrap_or(0.2),
max_completion_tokens: self.max_completion_tokens.flatten(),
})
}
}
#[async_trait]
impl BaseChatModel for ChatOpenAI {
fn model(&self) -> &str {
&self.model
}
fn provider(&self) -> &str {
"openai"
}
fn context_window(&self) -> Option<u64> {
Some(self.context_window)
}
async fn invoke(
&self,
messages: Vec<Message>,
tools: Option<Vec<ToolDefinition>>,
tool_choice: Option<ToolChoice>,
) -> Result<ChatCompletion, LlmError> {
let request = self.build_request(messages, tools, tool_choice, false)?;
let response = self
.client
.post(self.api_url())
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(Self::map_error_status(status, body));
}
let body = response.text().await?;
tracing::debug!("OpenAI raw response: {}", body);
if body.trim_start().starts_with("data:") {
return self.parse_sse_as_completion(&body);
}
let completion: OpenAIResponse = serde_json::from_str(&body).map_err(|e| {
LlmError::Api(format!(
"Failed to parse response: {}\nBody: {}",
e,
&body[..body.len().min(500)]
))
})?;
Ok(self.parse_response(completion))
}
async fn invoke_stream(
&self,
messages: Vec<Message>,
tools: Option<Vec<ToolDefinition>>,
tool_choice: Option<ToolChoice>,
) -> Result<ChatStream, LlmError> {
let request = self.build_request(messages, tools, tool_choice, true)?;
let response = self
.client
.post(self.api_url())
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(Self::map_error_status(status, body));
}
let stream = response.bytes_stream().filter_map(|result| async move {
match result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
Self::parse_stream_chunk(&text)
}
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
}
});
Ok(Box::pin(stream))
}
fn supports_vision(&self) -> bool {
let model_lower = self.model.to_lowercase();
model_lower.contains("gpt-4o")
|| model_lower.contains("gpt-4-turbo")
|| model_lower.contains("gpt-4-vision")
|| model_lower.contains("gpt-4.1")
}
}