mod conversions;
use std::sync::Arc;
use async_trait::async_trait;
use futures_util::StreamExt;
use crate::{
capability::{CapabilityNegotiation, ChatCompletionStream, Identifiable, ModelCatalog},
error::{BackendConstructError, BackendError, CapabilityError},
provider::validation::{into_validated_streaming_request, validate_non_streaming_request},
types::{
chat::{ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ToolDefinition},
model::{ModelCatalogResponse, ModelInfo},
},
};
use super::LlmBackend;
#[derive(Clone, Debug)]
pub struct OpenAiCompatBackend {
client: just_openai_compat::OpenAiCompatClient,
}
impl OpenAiCompatBackend {
pub fn from_provider_client(client: just_openai_compat::OpenAiCompatClient) -> Self {
Self { client }
}
}
impl Identifiable for OpenAiCompatBackend {
fn family(&self) -> &'static str {
crate::family::OPENAI_COMPATIBLE
}
}
impl CapabilityNegotiation for OpenAiCompatBackend {
fn model_catalog(&self) -> Result<&dyn ModelCatalog, CapabilityError> {
Ok(self)
}
}
#[async_trait]
impl LlmBackend for OpenAiCompatBackend {
fn prepare(&self, request: ChatCompletionRequest) -> Result<reqwest::Request, BackendError> {
validate_non_streaming_request(&request, "prepare", "prepare_streaming")?;
let provider_req: just_openai_compat::types::chat::ChatCompletionRequest = request.into();
self.client
.prepare(provider_req)
.map_err(|e| BackendError::provider(self.family(), e))
}
fn prepare_streaming(
&self,
request: ChatCompletionRequest,
) -> Result<reqwest::Request, BackendError> {
let request = into_validated_streaming_request(request, "prepare_streaming")?;
let provider_req: just_openai_compat::types::chat::ChatCompletionRequest = request.into();
self.client
.prepare_streaming(provider_req)
.map_err(|e| BackendError::provider(self.family(), e))
}
async fn send(&self, prepared: reqwest::Request) -> Result<reqwest::Response, BackendError> {
self.client
.send(prepared)
.await
.map_err(|e| BackendError::provider(self.family(), e))
}
async fn parse(
&self,
response: reqwest::Response,
) -> Result<ChatCompletionResponse, BackendError> {
let native: just_openai_compat::types::chat::ChatCompletion = self
.client
.parse(response)
.await
.map_err(|e| BackendError::provider(self.family(), e))?;
Ok(native.into())
}
async fn parse_streaming(
&self,
response: reqwest::Response,
) -> Result<ChatCompletionStream, BackendError> {
let stream = self
.client
.parse_streaming(response)
.await
.map_err(|e| BackendError::provider(self.family(), e))?;
let mapped = stream.map(|chunk| chunk.map(Into::into));
Ok(ChatCompletionStream::new(Box::pin(mapped)))
}
fn render_messages(&self, messages: &[ChatMessage]) -> Result<String, BackendError> {
let provider_messages: Vec<just_openai_compat::types::chat::ChatMessage> =
messages.iter().cloned().map(Into::into).collect();
serde_json::to_string(&provider_messages).map_err(BackendError::serialization)
}
fn render_tools(&self, tools: &[ToolDefinition]) -> Result<String, BackendError> {
let provider_tools: Vec<just_openai_compat::types::chat::ToolDefinition> =
tools.iter().cloned().map(Into::into).collect();
serde_json::to_string(&provider_tools).map_err(BackendError::serialization)
}
fn family() -> &'static str
where
Self: Sized,
{
crate::family::OPENAI_COMPATIBLE
}
#[allow(clippy::new_ret_no_self)]
fn new(
http: reqwest::ClientBuilder,
api_key: &str,
base_url: Option<&str>,
) -> Result<Arc<dyn LlmBackend>, BackendConstructError>
where
Self: Sized,
{
let mut builder = just_openai_compat::OpenAiCompatClient::builder()
.api_key(api_key)
.http_client(http);
if let Some(url) = base_url {
builder = builder.base_url(url);
}
let client = builder
.build()
.map_err(|e| BackendConstructError::provider(crate::family::OPENAI_COMPATIBLE, e))?;
Ok(Arc::new(Self::from_provider_client(client)))
}
}
#[async_trait]
impl ModelCatalog for OpenAiCompatBackend {
async fn list_models(&self) -> Result<ModelCatalogResponse, BackendError> {
let models = self
.client
.list_models()
.await
.map_err(|e| BackendError::provider(self.family(), e))?;
Ok(ModelCatalogResponse {
data: models
.data
.into_iter()
.map(|model| ModelInfo {
id: model.id,
object: Some(model.object),
owned_by: Some(model.owned_by),
})
.collect(),
})
}
}