use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::collections::HashMap;
use crate::error::LlmConnectorError;
use crate::types::{ChatRequest, ChatResponse};
#[cfg(feature = "streaming")]
use crate::types::ChatStream;
#[async_trait]
pub trait Protocol: Send + Sync + Clone + 'static {
type Request: Serialize + Send + Sync;
type Response: for<'de> Deserialize<'de> + Send + Sync;
fn name(&self) -> &str;
fn chat_endpoint(&self, base_url: &str) -> String;
fn models_endpoint(&self, _base_url: &str) -> Option<String> {
None
}
fn build_request(&self, request: &ChatRequest) -> Result<Self::Request, LlmConnectorError>;
fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError>;
fn parse_models(&self, _response: &str) -> Result<Vec<String>, LlmConnectorError> {
Err(LlmConnectorError::UnsupportedOperation(format!(
"{} does not support model listing",
self.name()
)))
}
fn map_error(&self, status: u16, body: &str) -> LlmConnectorError;
fn auth_headers(&self) -> Vec<(String, String)> {
Vec::new()
}
#[cfg(feature = "streaming")]
async fn parse_stream_response(
&self,
response: reqwest::Response,
) -> Result<ChatStream, LlmConnectorError> {
Ok(crate::sse::sse_to_streaming_response(response))
}
}
#[async_trait]
pub trait Provider: Send + Sync {
fn name(&self) -> &str;
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError>;
#[cfg(feature = "streaming")]
async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError>;
async fn models(&self) -> Result<Vec<String>, LlmConnectorError>;
fn as_any(&self) -> &dyn Any;
}
fn build_request_overrides(request: &ChatRequest) -> HashMap<String, String> {
let mut overrides = HashMap::new();
if let Some(ref key) = request.api_key {
overrides.insert("Authorization".to_string(), format!("Bearer {}", key));
overrides.insert("x-api-key".to_string(), key.clone());
overrides.insert("api-key".to_string(), key.clone()); }
if let Some(ref extra) = request.extra_headers {
overrides.extend(extra.clone());
}
overrides
}
pub struct GenericProvider<P: Protocol> {
protocol: P,
client: super::HttpClient,
}
impl<P: Protocol> GenericProvider<P> {
pub fn new(protocol: P, client: super::HttpClient) -> Self {
Self { protocol, client }
}
pub fn protocol(&self) -> &P {
&self.protocol
}
pub fn client(&self) -> &super::HttpClient {
&self.client
}
fn resolve_endpoint(&self, request: &ChatRequest, endpoint_template: String) -> String {
let base_url = request
.base_url
.as_deref()
.unwrap_or_else(|| self.client.base_url())
.trim_end_matches('/');
endpoint_template.replace("{base_url}", base_url)
}
}
impl<P: Protocol> Clone for GenericProvider<P> {
fn clone(&self) -> Self {
Self {
protocol: self.protocol.clone(),
client: self.client.clone(),
}
}
}
#[async_trait]
impl<P: Protocol> Provider for GenericProvider<P> {
fn name(&self) -> &str {
self.protocol.name()
}
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
let protocol_request = self.protocol.build_request(request)?;
let base_url = request
.base_url
.as_deref()
.unwrap_or_else(|| self.client.base_url());
let url = self.protocol.chat_endpoint(base_url);
let overrides = build_request_overrides(request);
let response = if overrides.is_empty() {
self.client.post(&url, &protocol_request).await?
} else {
self.client
.post_with_overrides(&url, &protocol_request, &overrides)
.await?
};
let status = response.status();
let text = response
.text()
.await
.map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
if !status.is_success() {
let error_detail = if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
format!("HTTP {} - Body: {}", status, json)
} else {
format!("HTTP {} - Body: {}", status, text)
};
return Err(self.protocol.map_error(status.as_u16(), &error_detail));
}
self.protocol.parse_response(&text)
}
#[cfg(feature = "streaming")]
async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError> {
let mut streaming_request = request.clone();
streaming_request.stream = Some(true);
let protocol_request = self.protocol.build_request(&streaming_request)?;
let base_url = request
.base_url
.as_deref()
.unwrap_or_else(|| self.client.base_url());
let url = self.protocol.chat_endpoint(base_url);
let overrides = build_request_overrides(request);
let response = if overrides.is_empty() {
self.client.stream(&url, &protocol_request).await?
} else {
self.client
.stream_with_overrides(&url, &protocol_request, &overrides)
.await?
};
let status = response.status();
if !status.is_success() {
let text = response
.text()
.await
.map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
return Err(self.protocol.map_error(status.as_u16(), &text));
}
self.protocol.parse_stream_response(response).await
}
async fn models(&self) -> Result<Vec<String>, LlmConnectorError> {
let endpoint = self
.protocol
.models_endpoint(self.client.base_url())
.ok_or_else(|| {
LlmConnectorError::UnsupportedOperation(format!(
"{} does not support model listing",
self.protocol.name()
))
})?;
let response = self.client.get(&endpoint).await?;
let status = response.status();
let text = response
.text()
.await
.map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
if !status.is_success() {
return Err(self.protocol.map_error(status.as_u16(), &text));
}
self.protocol.parse_models(&text)
}
fn as_any(&self) -> &dyn Any {
self
}
}