use std::sync::Arc;
use arc_swap::ArcSwap;
use async_trait::async_trait;
use futures::stream::{self, BoxStream, StreamExt};
use reqwest::header;
use secrecy::ExposeSecret;
use url::Url;
use atomr_infer_core::batch::ExecuteBatch;
use atomr_infer_core::cost::from_rates;
use atomr_infer_core::deployment::RateLimits;
use atomr_infer_core::error::{InferenceError, InferenceResult};
use atomr_infer_core::runner::{ModelRunner, RunHandle, SessionRebuildCause};
use atomr_infer_core::runtime::{ProviderKind, RuntimeKind, TransportKind};
use atomr_infer_core::tokens::{FinishReason, TokenChunk, TokenUsage};
use crate::config::OpenAiConfig;
use crate::cost::OpenAiPricing;
use crate::error::classify_openai_error;
use crate::wire::{ChatChunk, ChatRequest, ChatResponse};
use atomr_infer_remote_core::session::SessionSnapshot;
use atomr_infer_remote_core::sse::{decode_sse_stream, SseChunk};
pub struct OpenAiRunner {
config: OpenAiConfig,
session: Arc<ArcSwap<SessionSnapshot>>,
chat_url: Url,
}
impl OpenAiRunner {
pub fn new(config: OpenAiConfig, session: Arc<ArcSwap<SessionSnapshot>>) -> InferenceResult<Self> {
let chat_url = config
.variant
.chat_completions_url()
.map_err(|e| InferenceError::Internal(format!("openai endpoint url: {e}")))?;
Ok(Self {
config,
session,
chat_url,
})
}
fn auth_headers(&self) -> InferenceResult<header::HeaderMap> {
let mut h = header::HeaderMap::new();
let snap = self.session.load();
let token = snap.credential.expose_secret().to_string();
let value = header::HeaderValue::from_str(&format!("Bearer {token}"))
.map_err(|e| InferenceError::Internal(format!("invalid bearer token: {e}")))?;
h.insert(header::AUTHORIZATION, value);
if let Some(org) = &self.config.organization {
h.insert(
header::HeaderName::from_static("openai-organization"),
header::HeaderValue::from_str(org)
.map_err(|e| InferenceError::Internal(format!("invalid org header: {e}")))?,
);
}
if let Some(proj) = &self.config.project {
h.insert(
header::HeaderName::from_static("openai-project"),
header::HeaderValue::from_str(proj)
.map_err(|e| InferenceError::Internal(format!("invalid project header: {e}")))?,
);
}
Ok(h)
}
}
fn lift_chunk(request_id: &str, sc: SseChunk) -> Option<InferenceResult<TokenChunk>> {
if sc.data == "[DONE]" {
return None;
}
match serde_json::from_str::<ChatChunk>(&sc.data) {
Err(e) => Some(Err(InferenceError::Internal(format!("openai chunk decode: {e}")))),
Ok(parsed) => {
let mut text_delta = String::new();
let mut finish = None;
for ch in &parsed.choices {
if let Some(c) = &ch.delta.content {
text_delta.push_str(c);
}
finish = ch.finish_reason.as_deref().and_then(map_finish_reason);
}
let usage = parsed.usage.as_ref().map(|u| TokenUsage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
cached_tokens: u
.prompt_tokens_details
.as_ref()
.map(|d| d.cached_tokens)
.unwrap_or(0),
reasoning_tokens: u
.completion_tokens_details
.as_ref()
.map(|d| d.reasoning_tokens)
.unwrap_or(0),
});
Some(Ok(TokenChunk {
request_id: request_id.to_string(),
text_delta,
tool_call_delta: parsed.choices.into_iter().find_map(|c| c.delta.tool_calls),
usage,
finish_reason: finish,
}))
}
}
}
fn map_finish_reason(s: &str) -> Option<FinishReason> {
match s {
"stop" | "end_turn" => Some(FinishReason::Stop),
"length" => Some(FinishReason::Length),
"tool_calls" | "function_call" => Some(FinishReason::ToolCalls),
"content_filter" => Some(FinishReason::ContentFilter),
_ => Some(FinishReason::Stop),
}
}
#[async_trait]
impl ModelRunner for OpenAiRunner {
#[tracing::instrument(skip(self, batch), fields(request_id = %batch.request_id, model = %batch.model))]
async fn execute(&mut self, batch: ExecuteBatch) -> InferenceResult<RunHandle> {
let snap = self.session.load_full();
let body = ChatRequest::from_batch(&batch);
let req = snap
.client
.post(self.chat_url.clone())
.headers(self.auth_headers()?)
.json(&body);
let resp = req
.send()
.await
.map_err(|e| InferenceError::NetworkError(e.to_string()))?;
if !resp.status().is_success() {
let status = resp.status().as_u16();
let retry_after = resp
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let body = resp.text().await.ok();
return Err(classify_openai_error(status, retry_after.as_deref(), body));
}
let request_id = batch.request_id.clone();
if batch.stream {
let stream = decode_sse_stream(resp.bytes_stream());
let request_id_for_stream = request_id.clone();
let lifted = stream.filter_map(move |item| {
let id = request_id_for_stream.clone();
async move {
match item {
Ok(chunk) => lift_chunk(&id, chunk),
Err(e) => Some(Err(e)),
}
}
});
Ok(RunHandle::streaming(lifted.boxed()))
} else {
let parsed: ChatResponse = resp
.json()
.await
.map_err(|e| InferenceError::Internal(format!("openai response decode: {e}")))?;
let mut text = String::new();
let mut finish = None;
for ch in &parsed.choices {
if let Some(s) = ch.message.content.as_str() {
text.push_str(s);
}
finish = ch.finish_reason.as_deref().and_then(map_finish_reason);
}
let usage = parsed.usage.map(|u| TokenUsage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
..Default::default()
});
let chunk = TokenChunk {
request_id,
text_delta: text,
tool_call_delta: None,
usage,
finish_reason: finish.or(Some(FinishReason::Stop)),
};
let s: BoxStream<'static, InferenceResult<TokenChunk>> = stream::iter(vec![Ok(chunk)]).boxed();
Ok(RunHandle::streaming(s))
}
}
async fn rebuild_session(&mut self, _cause: SessionRebuildCause) -> InferenceResult<()> {
Ok(())
}
fn runtime_kind(&self) -> RuntimeKind {
RuntimeKind::OpenAi
}
fn transport_kind(&self) -> TransportKind {
TransportKind::RemoteNetwork {
provider: ProviderKind::OpenAi,
}
}
fn gil_pinned(&self) -> bool {
false
}
fn rate_limits(&self) -> Option<&RateLimits> {
Some(&self.config.rate_limits)
}
fn estimate_cost_usd(&self, batch: &ExecuteBatch) -> f64 {
OpenAiPricing::published()
.get(&batch.model)
.map(|p| from_rates(p.input_per_mtok_usd, p.output_per_mtok_usd, batch).usd)
.unwrap_or(0.0)
}
}