use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::any::Any;
use crate::types::{ChatRequest, ChatResponse};
use crate::error::LlmConnectorError;
#[cfg(feature = "streaming")]
use crate::types::ChatStream;
#[async_trait]
pub trait Protocol: Send + Sync + Clone + 'static {
type Request: Serialize + Send + Sync;
type Response: for<'de> Deserialize<'de> + Send + Sync;
fn name(&self) -> &str;
fn chat_endpoint(&self, base_url: &str) -> String;
fn models_endpoint(&self, _base_url: &str) -> Option<String> {
None
}
fn build_request(&self, request: &ChatRequest) -> Result<Self::Request, LlmConnectorError>;
fn parse_response(&self, response: &str) -> Result<ChatResponse, LlmConnectorError>;
fn parse_models(&self, _response: &str) -> Result<Vec<String>, LlmConnectorError> {
Err(LlmConnectorError::UnsupportedOperation(
format!("{} does not support model listing", self.name())
))
}
fn map_error(&self, status: u16, body: &str) -> LlmConnectorError;
fn auth_headers(&self) -> Vec<(String, String)> {
Vec::new()
}
#[cfg(feature = "streaming")]
async fn parse_stream_response(&self, response: reqwest::Response) -> Result<ChatStream, LlmConnectorError> {
Ok(crate::sse::sse_to_streaming_response(response))
}
}
#[async_trait]
pub trait Provider: Send + Sync {
fn name(&self) -> &str;
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError>;
#[cfg(feature = "streaming")]
async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError>;
async fn models(&self) -> Result<Vec<String>, LlmConnectorError>;
fn as_any(&self) -> &dyn Any;
}
pub struct GenericProvider<P: Protocol> {
protocol: P,
client: super::HttpClient,
}
impl<P: Protocol> GenericProvider<P> {
pub fn new(protocol: P, client: super::HttpClient) -> Self {
Self { protocol, client }
}
pub fn protocol(&self) -> &P {
&self.protocol
}
pub fn client(&self) -> &super::HttpClient {
&self.client
}
}
impl<P: Protocol> Clone for GenericProvider<P> {
fn clone(&self) -> Self {
Self {
protocol: self.protocol.clone(),
client: self.client.clone(),
}
}
}
#[async_trait]
impl<P: Protocol> Provider for GenericProvider<P> {
fn name(&self) -> &str {
self.protocol.name()
}
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
let protocol_request = self.protocol.build_request(request)?;
let url = self.protocol.chat_endpoint(self.client.base_url());
let response = self.client.post(&url, &protocol_request).await?;
let status = response.status();
let text = response.text().await
.map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
if !status.is_success() {
return Err(self.protocol.map_error(status.as_u16(), &text));
}
self.protocol.parse_response(&text)
}
#[cfg(feature = "streaming")]
async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError> {
let mut streaming_request = request.clone();
streaming_request.stream = Some(true);
let protocol_request = self.protocol.build_request(&streaming_request)?;
let url = self.protocol.chat_endpoint(self.client.base_url());
let response = self.client.stream(&url, &protocol_request).await?;
let status = response.status();
if !status.is_success() {
let text = response.text().await
.map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
return Err(self.protocol.map_error(status.as_u16(), &text));
}
self.protocol.parse_stream_response(response).await
}
async fn models(&self) -> Result<Vec<String>, LlmConnectorError> {
let endpoint = self.protocol.models_endpoint(self.client.base_url())
.ok_or_else(|| LlmConnectorError::UnsupportedOperation(
format!("{} does not support model listing", self.protocol.name())
))?;
let response = self.client.get(&endpoint).await?;
let status = response.status();
let text = response.text().await
.map_err(|e| LlmConnectorError::NetworkError(e.to_string()))?;
if !status.is_success() {
return Err(self.protocol.map_error(status.as_u16(), &text));
}
self.protocol.parse_models(&text)
}
fn as_any(&self) -> &dyn Any {
self
}
}