use crate::provider_api::{
FinishReason, LlmError, LlmErrorKind, LlmRequest, LlmResponse, TokenUsage,
};
use crate::secret::{SecretProvider, SecretString};
use reqwest::blocking::Client;
use serde::{Deserialize, Serialize};
#[derive(Clone)]
pub struct HttpProviderConfig {
pub api_key: SecretString,
pub model: String,
pub base_url: String,
pub client: Client,
}
#[allow(clippy::missing_fields_in_debug)]
impl std::fmt::Debug for HttpProviderConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpProviderConfig")
.field("api_key", &"[REDACTED]")
.field("model", &self.model)
.field("base_url", &self.base_url)
.finish()
}
}
impl HttpProviderConfig {
#[must_use]
pub fn new(
api_key: impl Into<String>,
model: impl Into<String>,
base_url: impl Into<String>,
) -> Self {
Self {
api_key: SecretString::new(api_key),
model: model.into(),
base_url: base_url.into(),
client: Client::new(),
}
}
pub fn from_secret_provider(
secret_provider: &dyn SecretProvider,
secret_key: &str,
model: impl Into<String>,
base_url: impl Into<String>,
) -> Result<Self, LlmError> {
let api_key = secret_provider
.get_secret(secret_key)
.map_err(|e| LlmError::auth(format!("{secret_key}: {e}")))?;
Ok(Self {
api_key,
model: model.into(),
base_url: base_url.into(),
client: Client::new(),
})
}
#[must_use]
pub fn with_client(mut self, client: Client) -> Self {
self.client = client;
self
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
#[must_use]
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: content.into(),
}
}
#[must_use]
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: content.into(),
}
}
#[must_use]
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: content.into(),
}
}
}
#[derive(Serialize, Debug)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub max_tokens: u32,
pub temperature: f64,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub stop: Vec<String>,
}
impl ChatCompletionRequest {
#[must_use]
pub fn from_llm_request(model: impl Into<String>, request: &LlmRequest) -> Self {
let mut messages = Vec::new();
if let Some(ref system) = request.system {
messages.push(ChatMessage::system(system));
}
messages.push(ChatMessage::user(&request.prompt));
Self {
model: model.into(),
messages,
max_tokens: request.max_tokens,
temperature: request.temperature,
stop: request.stop_sequences.clone(),
}
}
}
#[derive(Deserialize, Debug)]
pub struct ChatChoice {
pub message: ChatChoiceMessage,
pub finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct ChatChoiceMessage {
pub content: String,
}
#[derive(Deserialize, Debug)]
pub struct ChatUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Deserialize, Debug)]
pub struct ChatCompletionResponse {
pub model: String,
pub choices: Vec<ChatChoice>,
pub usage: ChatUsage,
}
#[must_use]
pub fn parse_finish_reason(reason: Option<&str>) -> FinishReason {
match reason {
Some("length" | "max_tokens") => FinishReason::MaxTokens,
Some("content_filter") => FinishReason::ContentFilter,
Some("stop_sequence") => FinishReason::StopSequence,
_ => FinishReason::Stop, }
}
pub fn chat_response_to_llm_response(
response: ChatCompletionResponse,
) -> Result<LlmResponse, LlmError> {
let choice = response
.choices
.first()
.ok_or_else(|| LlmError::provider("No choices in response"))?;
Ok(LlmResponse {
content: choice.message.content.clone(),
model: response.model,
finish_reason: parse_finish_reason(choice.finish_reason.as_deref()),
usage: TokenUsage {
prompt_tokens: response.usage.prompt_tokens,
completion_tokens: response.usage.completion_tokens,
total_tokens: response.usage.total_tokens,
},
})
}
pub fn make_chat_completion_request(
config: &HttpProviderConfig,
endpoint: &str,
request: &ChatCompletionRequest,
) -> Result<LlmResponse, LlmError> {
let url = format!("{}{}", config.base_url, endpoint);
let http_response = config
.client
.post(&url)
.header(
"Authorization",
format!("Bearer {}", config.api_key.expose()),
)
.header("Content-Type", "application/json")
.json(&request)
.send()
.map_err(|e| LlmError::network(format!("Request failed: {e}")))?;
let status = http_response.status();
if !status.is_success() {
return handle_openai_style_error(http_response);
}
let api_response: ChatCompletionResponse = http_response
.json()
.map_err(|e| LlmError::parse(format!("Failed to parse response: {e}")))?;
chat_response_to_llm_response(api_response)
}
#[derive(Deserialize, Debug)]
pub struct OpenAiStyleError {
pub error: OpenAiStyleErrorDetail,
}
#[derive(Deserialize, Debug)]
pub struct OpenAiStyleErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: Option<String>,
}
pub fn handle_openai_style_error(
http_response: reqwest::blocking::Response,
) -> Result<LlmResponse, LlmError> {
let http_code = http_response.status().as_u16();
let error_body: OpenAiStyleError = http_response
.json()
.map_err(|e| LlmError::parse(format!("Failed to parse error: {e}")))?;
let error_type = error_body.error.error_type.as_deref().unwrap_or("unknown");
let message = error_body.error.message;
if http_code == 429 || http_code >= 500 {
return Err(LlmError::new(
if http_code == 429 {
LlmErrorKind::RateLimit
} else {
LlmErrorKind::ProviderError
},
message,
true,
));
}
let llm_error = match error_type {
"authentication_error" => LlmError::auth(message),
"invalid_request_error" => LlmError::new(
LlmErrorKind::InvalidRequest,
message,
true, ),
"rate_limit_error" => LlmError::rate_limit(message),
_ => LlmError::provider(message),
};
Err(llm_error)
}
pub trait OpenAiCompatibleProvider {
fn config(&self) -> &HttpProviderConfig;
fn endpoint(&self) -> &str;
fn complete_openai_compatible(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError> {
let chat_request =
ChatCompletionRequest::from_llm_request(self.config().model.clone(), request);
make_chat_completion_request(self.config(), self.endpoint(), &chat_request)
}
}