mod request;
mod response;
mod types;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::StreamExt;
use reqwest::{Client, StatusCode};
use std::time::Duration;
use crate::llm::{
BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
};
use types::*;
#[derive(Builder, Clone)]
#[builder(pattern = "owned", build_fn(skip))]
pub struct ChatOpenAICompatible {
#[builder(setter(into))]
pub(super) model: String,
#[builder(setter(into), default = "None")]
pub(super) api_key: Option<String>,
#[builder(setter(into))]
pub(super) base_url: String,
#[builder(setter(into))]
pub(super) provider: String,
#[builder(default = "0.2")]
pub(super) temperature: f32,
#[builder(default = "Some(4096)")]
pub(super) max_completion_tokens: Option<u64>,
#[builder(setter(skip))]
pub(super) client: Client,
#[builder(setter(skip))]
pub(super) context_window: u64,
#[builder(default = "true")]
pub(super) use_bearer_auth: bool,
}
pub(crate) struct OpenAICompatibleProviderConfig<'a> {
pub provider: &'a str,
pub default_base_url: &'a str,
pub api_key_env: Option<&'a str>,
pub base_url_env: Option<&'a str>,
pub use_bearer_auth: bool,
pub default_temperature: f32,
}
impl ChatOpenAICompatible {
pub fn builder() -> ChatOpenAICompatibleBuilder {
ChatOpenAICompatibleBuilder::default()
}
pub(crate) fn build_provider(
config: OpenAICompatibleProviderConfig<'_>,
model: Option<String>,
api_key: Option<String>,
base_url: Option<String>,
temperature: Option<f32>,
max_tokens: Option<u64>,
) -> Result<Self, LlmError> {
let model = model.ok_or_else(|| LlmError::Config("model is required".into()))?;
let api_key = match (api_key, config.api_key_env) {
(Some(key), _) => Some(key),
(None, Some(env_var)) => std::env::var(env_var).ok(),
(None, None) => None,
};
if api_key.is_none() && config.api_key_env.is_some() {
return Err(LlmError::Config(format!(
"{} not set",
config.api_key_env.unwrap_or_default()
)));
}
let base_url = base_url
.or_else(|| {
config
.base_url_env
.and_then(|env_var| std::env::var(env_var).ok())
})
.unwrap_or_else(|| config.default_base_url.to_string());
ChatOpenAICompatible::builder()
.model(model)
.base_url(base_url)
.provider(config.provider)
.api_key(api_key)
.use_bearer_auth(config.use_bearer_auth)
.temperature(temperature.unwrap_or(config.default_temperature))
.max_completion_tokens(max_tokens)
.build()
}
fn map_error_status(status: StatusCode, body: String) -> LlmError {
match status {
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => LlmError::Auth(body),
StatusCode::NOT_FOUND => LlmError::ModelNotFound(body),
StatusCode::TOO_MANY_REQUESTS => LlmError::RateLimit,
_ => LlmError::Api(format!("API error ({}): {}", status, body)),
}
}
fn build_client() -> Client {
Client::builder()
.timeout(Duration::from_secs(120))
.build()
.expect("Failed to create HTTP client")
}
fn default_context_window() -> u64 {
128_000
}
fn api_url(&self) -> String {
format!("{}/chat/completions", self.base_url.trim_end_matches('/'))
}
}
impl ChatOpenAICompatibleBuilder {
pub fn build(&self) -> Result<ChatOpenAICompatible, LlmError> {
let model = self
.model
.clone()
.ok_or_else(|| LlmError::Config("model is required".into()))?;
let base_url = self
.base_url
.clone()
.ok_or_else(|| LlmError::Config("base_url is required".into()))?;
let provider = self
.provider
.clone()
.ok_or_else(|| LlmError::Config("provider is required".into()))?;
Ok(ChatOpenAICompatible {
client: ChatOpenAICompatible::build_client(),
context_window: ChatOpenAICompatible::default_context_window(),
model,
api_key: self.api_key.clone().flatten(),
base_url,
provider,
temperature: self.temperature.unwrap_or(0.2),
max_completion_tokens: self.max_completion_tokens.flatten(),
use_bearer_auth: self.use_bearer_auth.unwrap_or(true),
})
}
}
#[async_trait]
impl BaseChatModel for ChatOpenAICompatible {
fn model(&self) -> &str {
&self.model
}
fn provider(&self) -> &str {
&self.provider
}
fn context_window(&self) -> Option<u64> {
Some(self.context_window)
}
async fn invoke(
&self,
messages: Vec<Message>,
tools: Option<Vec<ToolDefinition>>,
tool_choice: Option<ToolChoice>,
) -> Result<ChatCompletion, LlmError> {
let request = self.build_request(messages, tools, tool_choice, false)?;
let mut req = self
.client
.post(self.api_url())
.header("Content-Type", "application/json");
if let Some(ref api_key) = self.api_key {
if self.use_bearer_auth {
req = req.header("Authorization", format!("Bearer {}", api_key));
} else {
req = req.header("Authorization", api_key.clone());
}
}
let response = req.json(&request).send().await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(Self::map_error_status(
status,
format!("{}: {}", self.provider, body),
));
}
let completion: OpenAICompatibleResponse = response.json().await?;
Ok(Self::parse_response(completion))
}
async fn invoke_stream(
&self,
messages: Vec<Message>,
tools: Option<Vec<ToolDefinition>>,
tool_choice: Option<ToolChoice>,
) -> Result<ChatStream, LlmError> {
let request = self.build_request(messages, tools, tool_choice, true)?;
let mut req = self
.client
.post(self.api_url())
.header("Content-Type", "application/json");
if let Some(ref api_key) = self.api_key {
if self.use_bearer_auth {
req = req.header("Authorization", format!("Bearer {}", api_key));
} else {
req = req.header("Authorization", api_key.clone());
}
}
let response = req.json(&request).send().await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(Self::map_error_status(
status,
format!("{}: {}", self.provider, body),
));
}
let stream = response.bytes_stream().filter_map(|result| async move {
match result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
Self::parse_stream_chunk(&text)
}
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
}
});
Ok(Box::pin(stream))
}
fn supports_vision(&self) -> bool {
true
}
}