use async_trait::async_trait;
use crate::error::LlmConnectorError;
use crate::types::{ChatRequest as Request, ChatResponse as Response};
use crate::core::Protocol;
use crate::core::protocol::ProtocolError;
#[cfg(feature = "streaming")]
use crate::types::ChatStream;
#[async_trait]
pub trait Provider: Send + Sync + 'static {
fn name(&self) -> &str;
async fn chat(&self, request: &Request) -> Result<Response, LlmConnectorError>;
#[cfg(feature = "streaming")]
async fn chat_stream(&self, request: &Request) -> Result<ChatStream, LlmConnectorError>;
async fn fetch_models(&self) -> Result<Vec<String>, LlmConnectorError>;
fn as_any(&self) -> &dyn std::any::Any;
}
pub struct ProtocolProvider<P: Protocol> {
protocol: P,
base_url: String,
transport: crate::core::HttpTransport,
}
impl<P: Protocol> ProtocolProvider<P> {
pub fn new(
protocol: P,
base_url: &str,
api_key: &str,
) -> Result<Self, LlmConnectorError> {
let config = crate::config::ProviderConfig::new(api_key)
.with_base_url(base_url.to_string());
let client = crate::core::HttpTransport::build_client(
&config.proxy,
config.timeout_ms,
config.base_url.as_ref(),
)?;
let transport = crate::core::HttpTransport::new(client, config);
Ok(Self {
protocol,
base_url: base_url.to_string(),
transport,
})
}
pub fn protocol(&self) -> &P {
&self.protocol
}
#[cfg(feature = "streaming")]
async fn chat_stream_sse_impl(&self, request: &Request) -> Result<ChatStream, LlmConnectorError> {
use futures_util::{StreamExt, TryStreamExt, stream};
let endpoint = self.protocol.endpoint(&self.base_url);
let protocol_request = self.protocol.build_request(request, true);
let byte_stream = self.transport.stream(&endpoint, &protocol_request).await?;
let sse_stream = byte_stream
.map_ok(|chunk| String::from_utf8_lossy(&chunk).to_string())
.map(|result| {
match result {
Ok(chunk_text) => {
let mut responses = Vec::new();
for line in chunk_text.lines() {
if line.starts_with("data: ") {
let data = line[6..].trim(); if data == "[DONE]" {
break; }
if let Ok(mut response) = serde_json::from_str::<crate::types::StreamingResponse>(data) {
if response.content.is_empty() {
if let Some(first_choice) = response.choices.first() {
if let Some(ref delta_content) = first_choice.delta.content {
response.content = delta_content.clone();
}
}
}
responses.push(Ok(response));
} else {
continue;
}
}
}
responses
}
Err(e) => vec![Err(LlmConnectorError::NetworkError(e.to_string()))],
}
})
.flat_map(stream::iter);
Ok(Box::pin(sse_stream))
}
#[cfg(feature = "streaming")]
async fn chat_stream_fallback_impl(&self, request: &Request) -> Result<ChatStream, LlmConnectorError> {
use futures_util::stream;
let response = self.chat(request).await?;
let stream_response = crate::types::StreamingResponse {
id: response.id,
object: "chat.completion.chunk".to_string(),
created: response.created,
model: response.model,
choices: response.choices.into_iter().map(|choice| {
crate::types::StreamingChoice {
index: choice.index,
delta: crate::types::Delta {
role: Some(crate::types::Role::Assistant),
content: Some(choice.message.content),
..Default::default()
},
finish_reason: choice.finish_reason,
logprobs: None,
}
}).collect(),
content: response.content,
reasoning_content: None,
usage: response.usage,
system_fingerprint: response.system_fingerprint,
};
let single_chunk_stream = stream::once(async { Ok(stream_response) });
Ok(Box::pin(single_chunk_stream))
}
}
#[async_trait]
impl<P: Protocol> Provider for ProtocolProvider<P>
where
P::Error: Send + Sync,
{
fn name(&self) -> &str {
self.protocol.name()
}
async fn chat(&self, request: &Request) -> Result<Response, LlmConnectorError> {
let endpoint = self.protocol.endpoint(&self.base_url);
let protocol_request = self.protocol.build_request(request, false);
let response = self.transport.post(&endpoint, &protocol_request).await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body: serde_json::Value = response.json().await.unwrap_or_default();
return Err(P::Error::map_http_error(status, body));
}
let status = response.status().as_u16();
let text = response.text().await.map_err(|e| {
LlmConnectorError::ParseError(e.to_string())
})?;
let raw: serde_json::Value = serde_json::from_str(&text).unwrap_or_default();
self.protocol.validate_success_body(status, &raw)?;
let protocol_response: P::Response = serde_json::from_str(&text)
.map_err(|e| LlmConnectorError::ParseError(e.to_string()))?;
Ok(self.protocol.parse_response(protocol_response))
}
#[cfg(feature = "streaming")]
async fn chat_stream(&self, request: &Request) -> Result<ChatStream, LlmConnectorError> {
if self.protocol.uses_sse_stream() {
self.chat_stream_sse_impl(request).await
} else {
self.chat_stream_fallback_impl(request).await
}
}
async fn fetch_models(&self) -> Result<Vec<String>, LlmConnectorError> {
let models_endpoint = self.protocol.models_endpoint(&self.base_url)
.ok_or_else(|| LlmConnectorError::UnsupportedOperation(
format!("{} does not support model listing", self.protocol.name())
))?;
let response = self.transport.get(&models_endpoint).await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body: serde_json::Value = response.json().await.unwrap_or_default();
return Err(P::Error::map_http_error(status, body));
}
let models_response: serde_json::Value = response.json().await
.map_err(|e| LlmConnectorError::ParseError(e.to_string()))?;
if let Some(data) = models_response.get("data").and_then(|d| d.as_array()) {
let models = data
.iter()
.filter_map(|m| m.get("id").and_then(|id| id.as_str()).map(String::from))
.collect();
Ok(models)
} else {
Err(LlmConnectorError::ParseError("Invalid models response format".to_string()))
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[async_trait]
impl<P: Protocol> crate::protocols::Provider for ProtocolProvider<P>
where
P::Error: Send + Sync,
{
fn name(&self) -> &str {
self.protocol.name()
}
async fn chat(&self, request: &crate::types::ChatRequest) -> Result<crate::types::ChatResponse, LlmConnectorError> {
let endpoint = self.protocol.endpoint(&self.base_url);
let protocol_request = self.protocol.build_request(request, false);
let response = self.transport.post(&endpoint, &protocol_request).await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body: serde_json::Value = response.json().await.unwrap_or_default();
return Err(P::Error::map_http_error(status, body));
}
let status = response.status().as_u16();
let text = response.text().await.map_err(|e| {
LlmConnectorError::ParseError(e.to_string())
})?;
let raw: serde_json::Value = serde_json::from_str(&text).unwrap_or_default();
self.protocol.validate_success_body(status, &raw)?;
let protocol_response: P::Response = serde_json::from_str(&text)
.map_err(|e| LlmConnectorError::ParseError(e.to_string()))?;
Ok(self.protocol.parse_response(protocol_response))
}
#[cfg(feature = "streaming")]
async fn chat_stream(&self, request: &crate::types::ChatRequest) -> Result<crate::types::ChatStream, LlmConnectorError> {
<Self as crate::core::Provider>::chat_stream(self, request).await
}
async fn fetch_models(&self) -> Result<Vec<String>, LlmConnectorError> {
let models_endpoint = self.protocol.models_endpoint(&self.base_url)
.ok_or_else(|| LlmConnectorError::UnsupportedOperation(
format!("{} does not support model listing", self.protocol.name())
))?;
let response = self.transport.get(&models_endpoint).await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body: serde_json::Value = response.json().await.unwrap_or_default();
return Err(P::Error::map_http_error(status, body));
}
let models_response: serde_json::Value = response.json().await
.map_err(|e| LlmConnectorError::ParseError(e.to_string()))?;
if let Some(data) = models_response.get("data").and_then(|d| d.as_array()) {
let models = data
.iter()
.filter_map(|m| m.get("id").and_then(|id| id.as_str()).map(String::from))
.collect();
Ok(models)
} else {
Err(LlmConnectorError::ParseError("Invalid models response format".to_string()))
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}