use crate::config::{ProviderConfig, SharedProviderConfig};
use crate::error::LlmConnectorError;
use crate::types::{ChatRequest, ChatResponse};
use async_trait::async_trait;
use reqwest::Client;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::Value;
use std::sync::Arc;
#[cfg(feature = "streaming")]
use crate::types::{ChatStream, StreamingResponse};
#[async_trait]
pub trait Provider: Send + Sync {
fn name(&self) -> &str;
async fn fetch_models(&self) -> Result<Vec<String>, LlmConnectorError>;
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError>;
#[cfg(feature = "streaming")]
async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError>;
#[cfg(feature = "streaming")]
async fn chat_stream_ollama_pure(&self, request: &ChatRequest) -> Result<crate::types::OllamaChatStream, LlmConnectorError> {
use futures_util::StreamExt;
use crate::types::OllamaStreamChunk;
let stream = self.chat_stream(request).await?;
let model = request.model.clone();
let ollama_stream = stream.map(move |result| {
match result {
Ok(openai_chunk) => {
let is_final = openai_chunk.usage.is_some() ||
openai_chunk.choices.iter().any(|c| c.finish_reason.is_some());
let content = if !openai_chunk.content.is_empty() {
openai_chunk.content.clone()
} else {
openai_chunk.choices.get(0)
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default()
};
if is_final {
Ok(OllamaStreamChunk::final_chunk(model.clone(), openai_chunk.usage.as_ref()))
} else {
Ok(OllamaStreamChunk::new(model.clone(), content, false))
}
}
Err(e) => Err(e),
}
});
Ok(Box::pin(ollama_stream))
}
#[cfg(feature = "streaming")]
async fn chat_stream_universal(&self, request: &ChatRequest, config: &crate::types::StreamingConfig) -> Result<crate::types::UniversalChatStream, LlmConnectorError> {
use futures_util::StreamExt;
use crate::types::{StreamChunk, StreamingFormat};
let stream = self.chat_stream(request).await?;
let format = config.format;
let stream_format = config.stream_format;
let universal_stream = stream.map(move |result| {
match result {
Ok(openai_chunk) => {
match format {
StreamingFormat::OpenAI => {
StreamChunk::from_openai(&openai_chunk, stream_format)
.map_err(|e| crate::error::LlmConnectorError::ParseError(e.to_string()))
}
StreamingFormat::Ollama => {
let is_final = openai_chunk.usage.is_some() ||
openai_chunk.choices.iter().any(|c| c.finish_reason.is_some());
let content = if !openai_chunk.content.is_empty() {
openai_chunk.content.clone()
} else {
openai_chunk.choices.get(0)
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default()
};
let ollama_chunk = if is_final {
crate::types::OllamaStreamChunk::final_chunk(openai_chunk.model.clone(), openai_chunk.usage.as_ref())
} else {
crate::types::OllamaStreamChunk::new(openai_chunk.model.clone(), content, false)
};
StreamChunk::from_ollama(&ollama_chunk, stream_format)
.map_err(|e| crate::error::LlmConnectorError::ParseError(e.to_string()))
}
}
}
Err(e) => Err(e),
}
});
Ok(Box::pin(universal_stream))
}
#[cfg(feature = "streaming")]
async fn chat_stream_with_format(
&self,
request: &ChatRequest,
config: &crate::types::StreamingConfig,
) -> Result<ChatStream, LlmConnectorError> {
use futures_util::StreamExt;
use crate::types::{convert_streaming_format, create_final_ollama_chunk, StreamingFormat};
let stream = self.chat_stream(request).await?;
let format = config.format;
let model = request.model.clone();
match format {
StreamingFormat::OpenAI => Ok(stream), StreamingFormat::Ollama => {
use std::sync::{Arc, Mutex};
let seen_final = Arc::new(Mutex::new(false));
let seen_final_clone = seen_final.clone();
let model_name = Arc::new(model.clone());
let converted_stream = stream.map(move |result| {
match result {
Ok(chunk) => {
let is_final = chunk.usage.is_some() ||
chunk.choices.iter().any(|c| c.finish_reason.is_some());
if is_final {
let mut seen = seen_final_clone.lock().unwrap();
*seen = true;
}
convert_streaming_format(&chunk, StreamingFormat::Ollama, is_final)
.map_err(|e| crate::error::LlmConnectorError::ParseError(e.to_string()))
.map(|json_str| {
let mut response = chunk.clone();
response.content = json_str;
response
})
}
Err(e) => Err(e),
}
});
let final_stream = converted_stream.chain(futures_util::stream::once(async move {
let seen = seen_final.lock().unwrap();
if !*seen {
let final_json = create_final_ollama_chunk(&model_name, None);
let mut final_response = crate::types::StreamingResponse::default();
final_response.model = (*model_name).clone();
final_response.content = final_json;
Ok(final_response)
} else {
Err(crate::error::LlmConnectorError::ParseError("Stream ended".to_string()))
}
})).filter_map(|result| async move {
match result {
Ok(chunk) => Some(Ok(chunk)),
Err(e) if e.to_string().contains("Stream ended") => None, Err(e) => Some(Err(e)),
}
});
Ok(Box::pin(final_stream))
}
}
}
fn as_any(&self) -> &dyn std::any::Any;
}
#[async_trait]
pub trait ProviderAdapter: Send + Sync + Clone + 'static {
type RequestType: Serialize + Send + Sync;
type ResponseType: DeserializeOwned + Send + Sync;
#[cfg(feature = "streaming")]
type StreamResponseType: DeserializeOwned + Send + Sync;
type ErrorMapperType: ErrorMapper;
fn name(&self) -> &str;
fn endpoint_url(&self, base_url: &Option<String>) -> String;
fn models_endpoint_url(&self, base_url: &Option<String>) -> Option<String> {
let _ = base_url;
None }
fn build_request_data(&self, request: &ChatRequest, stream: bool) -> Self::RequestType;
fn parse_response_data(&self, response: Self::ResponseType) -> ChatResponse;
#[cfg(feature = "streaming")]
fn parse_stream_response_data(&self, response: Self::StreamResponseType) -> StreamingResponse;
#[cfg(feature = "streaming")]
fn uses_sse_stream(&self) -> bool { true }
fn validate_success_body(&self, _status: u16, _raw: &Value) -> Result<(), LlmConnectorError> { Ok(()) }
}
pub trait ErrorMapper {
fn map_http_error(status: u16, body: Value) -> LlmConnectorError;
fn map_network_error(error: reqwest::Error) -> LlmConnectorError;
fn is_retriable_error(error: &LlmConnectorError) -> bool;
}
#[derive(Clone, Debug)]
pub struct HttpTransport {
pub client: Arc<Client>,
pub config: SharedProviderConfig,
}
impl HttpTransport {
pub fn new(client: Client, config: ProviderConfig) -> Self {
Self {
client: Arc::new(client),
config: SharedProviderConfig::new(config),
}
}
pub fn from_shared(client: Arc<Client>, config: SharedProviderConfig) -> Self {
Self { client, config }
}
pub fn build_client(
proxy: &Option<String>,
timeout_ms: Option<u64>,
base_url: Option<&String>,
) -> Result<Client, LlmConnectorError> {
let mut client_builder = Client::builder();
if let Some(proxy) = proxy {
client_builder = client_builder.proxy(reqwest::Proxy::all(proxy)?);
}
if let Some(timeout) = timeout_ms {
client_builder = client_builder.timeout(std::time::Duration::from_millis(timeout));
}
if let Some(base) = base_url {
if let Ok(url) = reqwest::Url::parse(base) {
if matches!(url.host_str(), Some("localhost") | Some("127.0.0.1")) {
client_builder = client_builder.no_proxy();
}
}
}
client_builder
.build()
.map_err(|e| LlmConnectorError::ConfigError(e.to_string()))
}
pub async fn get(&self, url: &str) -> Result<reqwest::Response, LlmConnectorError> {
let mut request = self
.client
.get(url)
.header("Authorization", format!("Bearer {}", &self.config.api_key));
if let Some(headers) = &self.config.headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
request
.send()
.await
.map_err(LlmConnectorError::from)
}
pub async fn post<T: Serialize>(
&self,
url: &str,
body: &T,
) -> Result<reqwest::Response, LlmConnectorError> {
let mut request = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", &self.config.api_key))
.header("Content-Type", "application/json");
if let Some(headers) = &self.config.headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
request
.json(body)
.send()
.await
.map_err(LlmConnectorError::from)
}
#[cfg(feature = "streaming")]
pub async fn stream<T: Serialize>(
&self,
url: &str,
body: &T,
) -> Result<
impl futures_util::Stream<Item = Result<bytes::Bytes, reqwest::Error>>,
LlmConnectorError,
> {
let mut request = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", &self.config.api_key))
.header("Content-Type", "application/json");
if let Some(headers) = &self.config.headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
let response = request
.json(body)
.send()
.await
.map_err(LlmConnectorError::from)?;
if !response.status().is_success() {
return Err(LlmConnectorError::ProviderError(format!(
"HTTP error: {}",
response.status()
)));
}
Ok(response.bytes_stream())
}
}
pub struct StandardErrorMapper;
impl ErrorMapper for StandardErrorMapper {
fn map_http_error(status: u16, body: Value) -> LlmConnectorError {
let error_message = body["error"]["message"].as_str().unwrap_or("Unknown error");
let error_type = body["error"]["type"].as_str().unwrap_or("unknown_error");
let cleaned_message = if error_message.contains("platform.openai.com") {
if let Some(idx) = error_message.find(". You can find your API key at") {
&error_message[..idx]
} else {
error_message
}
} else {
error_message
};
match status {
400 => LlmConnectorError::InvalidRequest(cleaned_message.to_string()),
401 => LlmConnectorError::AuthenticationError(format!(
"{}. Please verify your API key is correct and has the necessary permissions.",
cleaned_message
)),
403 => LlmConnectorError::PermissionError(cleaned_message.to_string()),
404 => LlmConnectorError::NotFoundError(cleaned_message.to_string()),
429 => LlmConnectorError::RateLimitError(cleaned_message.to_string()),
500..=599 => {
LlmConnectorError::ServerError(format!("HTTP {}: {}", status, cleaned_message))
}
_ => LlmConnectorError::ProviderError(format!(
"HTTP {}: {} (type: {})",
status, cleaned_message, error_type
)),
}
}
fn map_network_error(error: reqwest::Error) -> LlmConnectorError {
if error.is_timeout() {
LlmConnectorError::TimeoutError(error.to_string())
} else if error.is_connect() {
LlmConnectorError::ConnectionError(error.to_string())
} else {
LlmConnectorError::NetworkError(error.to_string())
}
}
fn is_retriable_error(error: &LlmConnectorError) -> bool {
matches!(
error,
LlmConnectorError::RateLimitError(_)
| LlmConnectorError::ServerError(_)
| LlmConnectorError::TimeoutError(_)
| LlmConnectorError::ConnectionError(_)
)
}
}
#[derive(Clone)]
pub struct GenericProvider<A: ProviderAdapter> {
adapter: A,
transport: HttpTransport,
}
impl<A: ProviderAdapter> GenericProvider<A> {
pub fn new(config: ProviderConfig, adapter: A) -> Result<Self, LlmConnectorError> {
let client = HttpTransport::build_client(
&config.proxy,
config.timeout_ms,
config.base_url.as_ref(),
)?;
let transport = HttpTransport::new(client, config);
Ok(Self { adapter, transport })
}
pub fn adapter(&self) -> &A {
&self.adapter
}
}
#[async_trait]
impl<A: ProviderAdapter> Provider for GenericProvider<A> {
fn name(&self) -> &str {
self.adapter.name()
}
async fn fetch_models(&self) -> Result<Vec<String>, LlmConnectorError> {
let url = self.adapter.models_endpoint_url(&self.transport.config.base_url)
.ok_or_else(|| LlmConnectorError::UnsupportedOperation(
format!("{} does not support model listing", self.adapter.name())
))?;
let response = self.transport.get(&url).await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body: Value = response.json().await.unwrap_or_default();
return Err(A::ErrorMapperType::map_http_error(status, body));
}
let models_response: Value = response
.json()
.await
.map_err(|e| LlmConnectorError::ParseError(e.to_string()))?;
if let Some(data) = models_response.get("data").and_then(|d| d.as_array()) {
let models = data
.iter()
.filter_map(|m| m.get("id").and_then(|id| id.as_str()).map(String::from))
.collect();
Ok(models)
} else {
Err(LlmConnectorError::ParseError(
"Invalid models response format".to_string()
))
}
}
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, LlmConnectorError> {
let url = self.adapter.endpoint_url(&self.transport.config.base_url);
let request_data = self.adapter.build_request_data(request, false);
if std::env::var("LLM_DEBUG_REQUEST_RAW").map(|v| v == "1").unwrap_or(false) {
if let Ok(j) = serde_json::to_string(&request_data) {
eprintln!("[request-raw] {}", j);
}
}
let response = self.transport.post(&url, &request_data).await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body: Value = response.json().await.unwrap_or_default();
return Err(A::ErrorMapperType::map_http_error(status, body));
}
let status_code = response.status().as_u16();
let text = response
.text()
.await
.map_err(|e| LlmConnectorError::ParseError(e.to_string()))?;
if std::env::var("LLM_DEBUG_RESPONSE_RAW").map(|v| v == "1").unwrap_or(false) {
eprintln!("[response-raw] {}", text);
}
let raw: Value = serde_json::from_str(&text).unwrap_or_default();
if let Err(err) = self.adapter.validate_success_body(status_code, &raw) {
return Err(err);
}
match serde_json::from_str::<A::ResponseType>(&text) {
Ok(response_data) => {
let mut chat_response = self.adapter.parse_response_data(response_data);
chat_response.populate_reasoning_synonyms(&raw);
Ok(chat_response)
}
Err(e) => {
if std::env::var("LLM_DEBUG_PARSE_FALLBACK").map(|v| v == "1").unwrap_or(false) {
eprintln!("[parse-fallback] strict parse failed: {}\nbody: {}", e, text);
}
let model = raw.get("model").and_then(|v| v.as_str()).unwrap_or("").to_string();
let id = raw
.get("id")
.or_else(|| raw.get("request_id"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let object = raw
.get("object")
.and_then(|v| v.as_str())
.unwrap_or("chat.completion")
.to_string();
let created = raw.get("created").and_then(|v| v.as_u64()).unwrap_or(0);
let usage = raw
.get("usage")
.and_then(|u| serde_json::from_value::<crate::types::Usage>(u.clone()).ok());
let (choice_msg, finish_reason) = if let Some(choices) = raw.get("choices").and_then(|c| c.as_array()) {
let first = choices.get(0);
let content = first
.and_then(|c| c.get("message"))
.and_then(|m| m.get("content"))
.and_then(|s| s.as_str())
.unwrap_or("")
.to_string();
let fr = first
.and_then(|c| c.get("finish_reason"))
.and_then(|s| s.as_str())
.map(|s| s.to_string());
(crate::types::Message::assistant(content), fr)
} else {
(crate::types::Message::assistant(String::new()), None)
};
let choices = vec![crate::types::Choice {
index: 0,
message: choice_msg,
finish_reason,
logprobs: None,
}];
let mut chat_response = crate::types::ChatResponse {
id,
object,
created,
model,
choices,
content: raw
.get("choices")
.and_then(|c| c.as_array())
.and_then(|arr| arr.get(0))
.and_then(|c0| c0.get("message"))
.and_then(|m| m.get("content"))
.and_then(|s| s.as_str())
.unwrap_or("")
.to_string(),
usage,
system_fingerprint: raw
.get("system_fingerprint")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
};
chat_response.populate_reasoning_synonyms(&raw);
Ok(chat_response)
}
}
}
#[cfg(feature = "streaming")]
async fn chat_stream(&self, request: &ChatRequest) -> Result<ChatStream, LlmConnectorError> {
use crate::sse::json_lines_events;
use futures_util::StreamExt;
let url = self.adapter.endpoint_url(&self.transport.config.base_url);
let request_data = self.adapter.build_request_data(request, true);
let response = self
.transport
.client
.post(&url)
.header(
"Authorization",
format!("Bearer {}", &self.transport.config.api_key),
)
.header("Content-Type", "application/json")
.json(&request_data)
.send()
.await
.map_err(LlmConnectorError::from)?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body: Value = response.json().await.unwrap_or_default();
return Err(A::ErrorMapperType::map_http_error(status, body));
}
let adapter = self.adapter.clone();
let event_stream = if self.adapter.uses_sse_stream() {
crate::sse::sse_events(response)
} else {
json_lines_events(response)
};
let mapped_stream = event_stream.filter_map(move |event| {
let adapter = adapter.clone();
async move {
match event {
Ok(data) => {
if std::env::var("LLM_DEBUG_STREAM_RAW").map(|v| v == "1").unwrap_or(false) {
eprintln!("[stream-raw] {}", data);
}
if data.trim() == "[DONE]" { return None; }
match serde_json::from_str::<A::StreamResponseType>(&data) {
Ok(stream_response) => {
let raw: Value = serde_json::from_str(&data).unwrap_or_default();
let mut sr = adapter.parse_stream_response_data(stream_response);
sr.populate_reasoning_synonyms(&raw);
Some(Ok(sr))
}
Err(_e) => {
let raw: Value = serde_json::from_str(&data).unwrap_or_default();
let model = raw
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let id = raw
.get("id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.unwrap_or_else(|| format!("{}-{}", adapter.name(), model));
let content_opt = raw
.pointer("/choices/0/delta/content")
.and_then(|v| v.as_str().map(|s| s.to_string()))
.or_else(|| raw.pointer("/message/content").and_then(|v| v.as_str().map(|s| s.to_string())));
let content = content_opt.unwrap_or_default();
let finish_reason = raw
.pointer("/choices/0/finish_reason")
.and_then(|v| v.as_str().map(|s| s.to_string()));
let mut sr = crate::types::StreamingResponse {
id,
object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model,
choices: vec![crate::types::StreamingChoice {
index: 0,
delta: crate::types::Delta {
role: None,
content: if content.is_empty() { None } else { Some(content.clone()) },
tool_calls: None,
reasoning_content: None,
..Default::default()
},
finish_reason,
logprobs: None,
}],
content,
reasoning_content: None,
usage: None,
system_fingerprint: None,
};
sr.populate_reasoning_synonyms(&raw);
Some(Ok(sr))
}
}
}
Err(e) => Some(Err(LlmConnectorError::StreamingError(e.to_string()))),
}
}
});
Ok(Box::pin(mapped_stream))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}