use super::client::{ChatChunk, ChatRequest};
use super::mapper::{ChunkAccumulator, MapperError, request_from_resolved};
use crate::backend::{
AcceleratorInfo, Backend, BackendCapabilities, GenerateError, TokenEventV2, TokenStream,
TokenStreamV2,
};
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures_util::StreamExt;
use inferd_proto::Resolved;
use inferd_proto::v2::ResolvedV2;
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct OpenAiCompatConfig {
pub base_url: String,
pub api_key: String,
pub model: String,
pub timeout: Duration,
}
impl Default for OpenAiCompatConfig {
fn default() -> Self {
Self {
base_url: "https://api.openai.com".into(),
api_key: String::new(),
model: "gpt-4o-mini".into(),
timeout: Duration::from_secs(300),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum OpenAiCompatError {
#[error("transport: {0}")]
Transport(#[from] reqwest::Error),
#[error("upstream HTTP {status}: {body}")]
HttpStatus {
status: u16,
body: String,
},
#[error("malformed SSE chunk: {0}")]
MalformedChunk(String),
#[error("request mapping: {0}")]
Mapper(#[from] MapperError),
}
impl From<OpenAiCompatError> for GenerateError {
fn from(e: OpenAiCompatError) -> Self {
match e {
OpenAiCompatError::Mapper(MapperError::AttachmentUnsupported(_)) => {
GenerateError::InvalidRequest(e.to_string())
}
OpenAiCompatError::Mapper(MapperError::UnknownContentBlock) => {
GenerateError::InvalidRequest(e.to_string())
}
OpenAiCompatError::Mapper(MapperError::NonTextToolResult) => {
GenerateError::InvalidRequest(e.to_string())
}
_ => GenerateError::Unavailable(e.to_string()),
}
}
}
pub struct OpenAiCompat {
name: &'static str,
config: OpenAiCompatConfig,
client: reqwest::Client,
}
impl OpenAiCompat {
pub fn new(config: OpenAiCompatConfig) -> Result<Self, OpenAiCompatError> {
let client = reqwest::Client::builder().timeout(config.timeout).build()?;
Ok(Self {
name: "openai-compat",
config,
client,
})
}
fn endpoint(&self) -> String {
let base = self.config.base_url.trim_end_matches('/');
format!("{base}/v1/chat/completions")
}
fn build_request(&self, body: &ChatRequest) -> reqwest::RequestBuilder {
let mut rb = self
.client
.post(self.endpoint())
.header(CONTENT_TYPE, "application/json")
.json(body);
if !self.config.api_key.is_empty() {
rb = rb.header(
AUTHORIZATION,
format!("Bearer {}", self.config.api_key.as_str()),
);
}
rb
}
}
#[async_trait]
impl Backend for OpenAiCompat {
fn name(&self) -> &str {
self.name
}
fn ready(&self) -> bool {
true
}
fn capabilities(&self) -> BackendCapabilities {
BackendCapabilities {
v2: true,
tools: true,
vision: false,
audio: false,
video: false,
thinking: false,
embed: false,
accelerator: AcceleratorInfo::default(),
}
}
async fn generate(&self, _req: Resolved) -> Result<TokenStream, GenerateError> {
Err(GenerateError::Internal(
"openai-compat backend supports v2 only; use the v2 socket".into(),
))
}
async fn generate_v2(&self, req: ResolvedV2) -> Result<TokenStreamV2, GenerateError> {
let body =
request_from_resolved(&req, &self.config.model).map_err(OpenAiCompatError::from)?;
let request = self.build_request(&body);
let response = request.send().await.map_err(OpenAiCompatError::from)?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response
.text()
.await
.unwrap_or_else(|_| "<failed to read body>".into());
let truncated = if body.len() > 4096 {
body[..4096].to_string()
} else {
body
};
return Err(OpenAiCompatError::HttpStatus {
status,
body: truncated,
}
.into());
}
let (tx, rx) = mpsc::channel(8);
let event_stream = response.bytes_stream().eventsource();
tokio::spawn(async move {
drive_sse(event_stream, tx).await;
});
Ok(Box::pin(ReceiverStream::new(rx)))
}
async fn stop(&self, _timeout: Duration) -> Result<(), GenerateError> {
Ok(())
}
}
async fn drive_sse<S>(mut event_stream: S, tx: mpsc::Sender<TokenEventV2>)
where
S: futures_util::Stream<
Item = Result<
eventsource_stream::Event,
eventsource_stream::EventStreamError<reqwest::Error>,
>,
> + Unpin,
{
let mut acc = ChunkAccumulator::new();
while let Some(event) = event_stream.next().await {
let event = match event {
Ok(ev) => ev,
Err(e) => {
warn!(error = %e, "openai-compat SSE transport error");
return;
}
};
if event.data == "[DONE]" {
break;
}
let chunk: ChatChunk = match serde_json::from_str(&event.data) {
Ok(c) => c,
Err(e) => {
warn!(
error = %e,
data = %truncate(&event.data, 256),
"openai-compat malformed chunk; dropping"
);
continue;
}
};
for ev in acc.ingest(chunk) {
if tx.send(ev).await.is_err() {
debug!("openai-compat generation cancelled (receiver dropped)");
return;
}
}
}
for ev in acc.finalize() {
if tx.send(ev).await.is_err() {
return;
}
}
}
fn truncate(s: &str, n: usize) -> &str {
match s.char_indices().nth(n) {
Some((idx, _)) => &s[..idx],
None => s,
}
}