use std::pin::Pin;
use std::sync::{Arc, RwLock};
use async_trait::async_trait;
use futures::Stream;
use crate::ir::{ChatRequest, ChatResponse, StreamEvent};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Protocol {
OpenAI,
Anthropic,
OpenAIResponses,
Gemini,
}
pub struct RawResponse {
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
}
impl RawResponse {
pub fn into_axum(self) -> axum::response::Response {
let status = axum::http::StatusCode::from_u16(self.status)
.unwrap_or(axum::http::StatusCode::BAD_GATEWAY);
let mut builder = axum::response::Response::builder().status(status);
for (k, v) in &self.headers {
if k.to_lowercase() != "transfer-encoding" {
builder = builder.header(k.as_str(), v.as_str());
}
}
builder
.body(axum::body::Body::from(self.body))
.unwrap_or_else(|_| axum::response::Response::new(axum::body::Body::empty()))
}
}
#[derive(Clone)]
pub struct WatchedField(Arc<RwLock<String>>);
impl WatchedField {
pub fn new(value: String) -> Self {
Self(Arc::new(RwLock::new(value)))
}
pub fn read(&self) -> String {
self.0.read().unwrap().clone()
}
pub fn update(&self, new_value: String) {
*self.0.write().unwrap() = new_value;
}
}
pub type ApiKey = WatchedField;
#[derive(Debug, thiserror::Error)]
pub enum AdapterError {
#[error("backend request failed: {0}")]
BackendError(String),
#[error("protocol translation error: {0}")]
TranslationError(String),
#[error("stream error: {0}")]
StreamError(String),
#[error("feature not supported: {feature} (provider: {provider})")]
UnsupportedFeature { provider: String, feature: String },
}
#[async_trait]
pub trait Adapter: Send + Sync {
fn provider_name(&self) -> &str;
fn supports_model(&self, model: &str) -> bool;
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, AdapterError>;
async fn chat_stream(
&self,
request: &ChatRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, AdapterError>> + Send>>, AdapterError>;
fn protocol(&self) -> Protocol;
async fn chat_raw(&self, _body: &[u8]) -> Result<RawResponse, AdapterError> {
Err(AdapterError::UnsupportedFeature {
provider: self.provider_name().to_string(),
feature: "raw passthrough".into(),
})
}
async fn chat_stream_raw(&self, _body: &[u8]) -> Result<RawResponse, AdapterError> {
Err(AdapterError::UnsupportedFeature {
provider: self.provider_name().to_string(),
feature: "raw stream passthrough".into(),
})
}
fn update_api_key(&self, _new_key: String) {}
fn update_base_url(&self, _new_url: String) {}
}
#[async_trait]
impl<T: Adapter + ?Sized> Adapter for Box<T> {
fn provider_name(&self) -> &str {
(**self).provider_name()
}
fn supports_model(&self, model: &str) -> bool {
(**self).supports_model(model)
}
fn protocol(&self) -> Protocol {
(**self).protocol()
}
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, AdapterError> {
(**self).chat(request).await
}
async fn chat_stream(
&self,
request: &ChatRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, AdapterError>> + Send>>, AdapterError>
{
(**self).chat_stream(request).await
}
async fn chat_raw(&self, body: &[u8]) -> Result<RawResponse, AdapterError> {
(**self).chat_raw(body).await
}
async fn chat_stream_raw(&self, body: &[u8]) -> Result<RawResponse, AdapterError> {
(**self).chat_stream_raw(body).await
}
}