use std::collections::HashMap;
use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::Client;
use tracing::instrument;
use tt_shared::{
filter_extra_headers, validate_provider_url, ChatCompletionChunk, ChatCompletionRequest,
ChatCompletionResponse, EmbeddingsRequest, EmbeddingsResponse, ModelInfo, ModelPricing,
Provider, ProviderError, RequestContext,
};
use crate::client::{build_client, ClientConfig};
use crate::errors::{map_reqwest_error, map_response_error};
use crate::{stream, translate};
pub struct CompatConfig {
pub id: &'static str,
pub default_base_url: String,
pub models: Vec<ModelInfo>,
pub pricing_table: HashMap<String, ModelPricing>,
pub fee_multiplier: f64,
pub allow_local: bool,
}
pub struct OpenAICompatibleProvider {
client: Client,
cfg: CompatConfig,
}
impl OpenAICompatibleProvider {
pub fn new(client_cfg: ClientConfig, cfg: CompatConfig) -> Self {
let client = build_client(&client_cfg)
.unwrap_or_else(|e| panic!("failed to build HTTP client for {}: {e}", cfg.id));
Self { client, cfg }
}
pub fn fee_multiplier(&self) -> f64 {
self.cfg.fee_multiplier
}
fn base_url<'a>(&'a self, ctx: &'a RequestContext) -> &'a str {
ctx.credentials
.base_url
.as_deref()
.unwrap_or(self.cfg.default_base_url.as_str())
}
}
#[async_trait]
impl Provider for OpenAICompatibleProvider {
fn id(&self) -> &'static str {
self.cfg.id
}
fn models(&self) -> Vec<ModelInfo> {
self.cfg.models.clone()
}
fn pricing(&self, model: &str) -> Option<ModelPricing> {
self.cfg.pricing_table.get(model).cloned()
}
fn dropped_params(&self, req: &tt_shared::ChatCompletionRequest) -> Vec<String> {
crate::translate::dropped_params(req)
}
#[instrument(skip(self, ctx), fields(provider = %self.cfg.id, model = %req.model))]
async fn chat_completion(
&self,
req: ChatCompletionRequest,
ctx: &RequestContext,
) -> Result<ChatCompletionResponse, ProviderError> {
let base_url = self.base_url(ctx);
validate_provider_url(base_url, self.cfg.allow_local)
.map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
let url = format!("{base_url}/chat/completions");
let body = translate::translate_request(req)?;
let mut rb = self
.client
.post(&url)
.header(
"Authorization",
format!("Bearer {}", ctx.credentials.api_key.expose()),
)
.header("Content-Type", "application/json")
.json(&body);
for (name, value) in &filter_extra_headers(&ctx.credentials.extra_headers) {
rb = rb.header(name, value);
}
let response = rb.send().await.map_err(map_reqwest_error)?;
let status = response.status().as_u16();
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let response_text = response.text().await.map_err(map_reqwest_error)?;
if status >= 400 {
return Err(map_response_error(
status,
&response_text,
retry_after.as_deref(),
));
}
translate::deserialize_response(&response_text)
}
#[instrument(skip(self, ctx), fields(provider = %self.cfg.id, model = %req.model))]
async fn chat_completion_stream(
&self,
req: ChatCompletionRequest,
ctx: &RequestContext,
) -> Result<BoxStream<'static, Result<ChatCompletionChunk, ProviderError>>, ProviderError> {
let base_url = self.base_url(ctx);
validate_provider_url(base_url, self.cfg.allow_local)
.map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
let base_url = base_url.to_string();
let client = self.client.clone();
stream::stream_chat_completion(client, &base_url, req, ctx).await
}
#[instrument(skip(self, ctx), fields(provider = %self.cfg.id, model = %req.model))]
async fn embeddings(
&self,
req: EmbeddingsRequest,
ctx: &RequestContext,
) -> Result<EmbeddingsResponse, ProviderError> {
let base_url = self.base_url(ctx);
validate_provider_url(base_url, self.cfg.allow_local)
.map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
let url = format!("{base_url}/embeddings");
let body = translate::translate_embeddings_request(req)?;
let mut rb = self
.client
.post(&url)
.header(
"Authorization",
format!("Bearer {}", ctx.credentials.api_key.expose()),
)
.header("Content-Type", "application/json")
.json(&body);
for (name, value) in &filter_extra_headers(&ctx.credentials.extra_headers) {
rb = rb.header(name, value);
}
let response = rb.send().await.map_err(map_reqwest_error)?;
let status = response.status().as_u16();
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let response_text = response.text().await.map_err(map_reqwest_error)?;
if status >= 400 {
return Err(map_response_error(
status,
&response_text,
retry_after.as_deref(),
));
}
translate::deserialize_embeddings_response(&response_text)
}
}