use crate::config::{ProviderConfig, SharedProviderConfig};
use crate::error::LlmConnectorError;
use crate::types::{ChatRequest, ChatResponse};
use async_trait::async_trait;
use reqwest::Client;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::Value;
use std::sync::Arc;
#[cfg(feature = "streaming")]
use crate::types::{ChatStream, StreamingResponse};
#[async_trait]
pub trait Provider: Send + Sync {
fn name(&self) -> &str;
fn supported_models(&self) -> Vec<String>;
fn supports_model(&self, model: &str) -> bool {
self.supported_models().iter().any(|m| m == model)
}
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError>;
#[cfg(feature = "streaming")]
async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError>;
}
#[async_trait]
pub trait ProviderAdapter: Send + Sync + Clone + 'static {
type RequestType: Serialize + Send + Sync;
type ResponseType: DeserializeOwned + Send + Sync;
#[cfg(feature = "streaming")]
type StreamResponseType: DeserializeOwned + Send + Sync;
type ErrorMapperType: ErrorMapper;
fn name(&self) -> &str;
fn supported_models(&self) -> Vec<String>;
fn endpoint_url(&self, base_url: &Option<String>) -> String;
fn build_request_data(&self, request: &ChatRequest, stream: bool) -> Self::RequestType;
fn parse_response_data(&self, response: Self::ResponseType) -> ChatResponse;
#[cfg(feature = "streaming")]
fn parse_stream_response_data(&self, response: Self::StreamResponseType) -> StreamingResponse;
}
pub trait ErrorMapper {
fn map_http_error(status: u16, body: Value) -> LlmConnectorError;
fn map_network_error(error: reqwest::Error) -> LlmConnectorError;
fn is_retriable_error(error: &LlmConnectorError) -> bool;
}
#[derive(Clone, Debug)]
pub struct HttpTransport {
pub client: Arc<Client>,
pub config: SharedProviderConfig,
}
impl HttpTransport {
pub fn new(client: Client, config: ProviderConfig) -> Self {
Self {
client: Arc::new(client),
config: SharedProviderConfig::new(config),
}
}
pub fn from_shared(client: Arc<Client>, config: SharedProviderConfig) -> Self {
Self { client, config }
}
pub fn build_client(
proxy: &Option<String>,
timeout_ms: Option<u64>,
base_url: Option<&String>,
) -> Result<Client, LlmConnectorError> {
let mut client_builder = Client::builder();
if let Some(proxy) = proxy {
client_builder = client_builder.proxy(reqwest::Proxy::all(proxy)?);
}
if let Some(timeout) = timeout_ms {
client_builder = client_builder.timeout(std::time::Duration::from_millis(timeout));
}
if let Some(base) = base_url {
if let Ok(url) = reqwest::Url::parse(base) {
if matches!(url.host_str(), Some("localhost") | Some("127.0.0.1")) {
client_builder = client_builder.no_proxy();
}
}
}
client_builder
.build()
.map_err(|e| LlmConnectorError::ConfigError(e.to_string()))
}
pub async fn post<T: Serialize>(
&self,
url: &str,
body: &T,
) -> Result<reqwest::Response, LlmConnectorError> {
let mut request = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", &self.config.api_key))
.header("Content-Type", "application/json");
if let Some(headers) = &self.config.headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
request
.json(body)
.send()
.await
.map_err(LlmConnectorError::from)
}
#[cfg(feature = "streaming")]
pub async fn stream<T: Serialize>(
&self,
url: &str,
body: &T,
) -> Result<
impl futures_util::Stream<Item = Result<reqwest::Bytes, reqwest::Error>>,
LlmConnectorError,
> {
let mut request = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", &self.config.api_key))
.header("Content-Type", "application/json");
if let Some(headers) = &self.config.headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
let response = request
.json(body)
.send()
.await
.map_err(LlmConnectorError::from)?;
if !response.status().is_success() {
return Err(LlmConnectorError::ProviderError(format!(
"HTTP error: {}",
response.status()
)));
}
Ok(response.bytes_stream())
}
}
pub struct StandardErrorMapper;
impl ErrorMapper for StandardErrorMapper {
fn map_http_error(status: u16, body: Value) -> LlmConnectorError {
let error_message = body["error"]["message"].as_str().unwrap_or("Unknown error");
let error_type = body["error"]["type"].as_str().unwrap_or("unknown_error");
match status {
400 => LlmConnectorError::InvalidRequest(error_message.to_string()),
401 => LlmConnectorError::AuthenticationError(error_message.to_string()),
403 => LlmConnectorError::PermissionError(error_message.to_string()),
404 => LlmConnectorError::NotFoundError(error_message.to_string()),
429 => LlmConnectorError::RateLimitError(error_message.to_string()),
500..=599 => {
LlmConnectorError::ServerError(format!("HTTP {}: {}", status, error_message))
}
_ => LlmConnectorError::ProviderError(format!(
"HTTP {}: {} (type: {})",
status, error_message, error_type
)),
}
}
fn map_network_error(error: reqwest::Error) -> LlmConnectorError {
if error.is_timeout() {
LlmConnectorError::TimeoutError(error.to_string())
} else if error.is_connect() {
LlmConnectorError::ConnectionError(error.to_string())
} else {
LlmConnectorError::NetworkError(error.to_string())
}
}
fn is_retriable_error(error: &LlmConnectorError) -> bool {
matches!(
error,
LlmConnectorError::RateLimitError(_)
| LlmConnectorError::ServerError(_)
| LlmConnectorError::TimeoutError(_)
| LlmConnectorError::ConnectionError(_)
)
}
}
#[derive(Clone)]
pub struct GenericProvider<A: ProviderAdapter> {
adapter: A,
transport: HttpTransport,
}
impl<A: ProviderAdapter> GenericProvider<A> {
pub fn new(config: ProviderConfig, adapter: A) -> Result<Self, LlmConnectorError> {
let client = HttpTransport::build_client(
&config.proxy,
config.timeout_ms,
config.base_url.as_ref(),
)?;
let transport = HttpTransport::new(client, config);
Ok(Self { adapter, transport })
}
}
#[async_trait]
impl<A: ProviderAdapter> Provider for GenericProvider<A> {
fn name(&self) -> &str {
self.adapter.name()
}
fn supported_models(&self) -> Vec<String> {
self.adapter.supported_models()
}
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
let url = self.adapter.endpoint_url(&self.transport.config.base_url);
let request_data = self.adapter.build_request_data(request, false);
let response = self.transport.post(&url, &request_data).await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body: Value = response.json().await.unwrap_or_default();
return Err(A::ErrorMapperType::map_http_error(status, body));
}
let response_data: A::ResponseType = response
.json()
.await
.map_err(|e| LlmConnectorError::ParseError(e.to_string()))?;
Ok(self.adapter.parse_response_data(response_data))
}
#[cfg(feature = "streaming")]
async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError> {
use crate::sse::sse_events;
use futures_util::StreamExt;
let url = self.adapter.endpoint_url(&self.transport.config.base_url);
let request_data = self.adapter.build_request_data(request, true);
let response = self
.transport
.client
.post(&url)
.header(
"Authorization",
format!("Bearer {}", &self.transport.config.api_key),
)
.header("Content-Type", "application/json")
.json(&request_data)
.send()
.await
.map_err(LlmConnectorError::from)?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body: Value = response.json().await.unwrap_or_default();
return Err(A::ErrorMapperType::map_http_error(status, body));
}
let mapped_stream = sse_data_events(response).filter_map(|event| async move {
match event {
Ok(data) => {
if data.trim() == "[DONE]" {
return None;
}
match serde_json::from_str::<A::StreamResponseType>(&data) {
Ok(stream_response) => {
Some(Ok(self.adapter.parse_stream_response_data(stream_response)))
}
Err(e) => Some(Err(LlmConnectorError::ParseError(e.to_string()))),
}
}
Err(e) => Some(Err(LlmConnectorError::StreamError(e.to_string()))),
}
});
Ok(Box::pin(mapped_stream))
}
}