use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
#[cfg(not(target_arch = "wasm32"))]
use std::time::Duration;
use futures_core::Stream;
use futures_util::stream::StreamExt;
use reqwest::{Client, Url};
use crate::backends::openai::wire::{ChatChunk, ChatRequest, ChatResponse};
use crate::backends::sse::{ByteStream, SseFrameStream};
use crate::error::{Error, Result};
const DEFAULT_BASE_URL: &str = "https://api.openai.com";
const DONE_SENTINEL: &[u8] = b"[DONE]";
#[cfg(not(target_arch = "wasm32"))]
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
pub struct OpenAiClient {
http: Client,
api_key: Box<str>,
key_provider: Option<crate::backends::KeyProvider>,
base_url: Url,
}
impl fmt::Debug for OpenAiClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OpenAiClient")
.field("base_url", &self.base_url.as_str())
.field("api_key", &"<redacted>")
.finish()
}
}
impl OpenAiClient {
pub fn new(api_key: impl Into<String>) -> Result<Self> {
let builder =
Client::builder().user_agent(concat!("localharness/", env!("CARGO_PKG_VERSION")));
#[cfg(not(target_arch = "wasm32"))]
let builder = builder.timeout(DEFAULT_TIMEOUT);
let http = builder
.build()
.map_err(|e| Error::other(format!("reqwest client build: {e}")))?;
Ok(Self {
http,
api_key: api_key.into().into_boxed_str(),
key_provider: None,
base_url: Url::parse(DEFAULT_BASE_URL).expect("default base url is valid"),
})
}
pub fn with_key_provider(mut self, provider: crate::backends::KeyProvider) -> Self {
self.key_provider = Some(provider);
self
}
fn current_key(&self) -> String {
match &self.key_provider {
Some(p) => p(),
None => self.api_key.to_string(),
}
}
pub fn with_base_url(mut self, url: Url) -> Self {
self.base_url = url;
self
}
fn completions_url(&self) -> Result<Url> {
self.base_url
.join("v1/chat/completions")
.map_err(|e| Error::other(format!("invalid completions url: {e}")))
}
pub async fn chat(&self, req: &ChatRequest) -> Result<ChatResponse> {
let url = self.completions_url()?;
let mut body = req.clone();
body.stream = false;
body.stream_options = None;
let response = self
.http
.post(url)
.bearer_auth(self.current_key())
.header("content-type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| Error::other(format!("openai POST: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|_| "<no body>".to_string());
return Err(Error::other(format!("openai HTTP {status}: {body}")));
}
response
.json::<ChatResponse>()
.await
.map_err(|e| Error::other(format!("openai JSON: {e}")))
}
pub async fn stream_chat(&self, req: &ChatRequest) -> Result<ChatSseStream> {
let url = self.completions_url()?;
let mut body = req.clone();
body.stream = true;
let response = self
.http
.post(url)
.bearer_auth(self.current_key())
.header("content-type", "application/json")
.header("accept", "text/event-stream")
.json(&body)
.send()
.await
.map_err(|e| Error::other(format!("openai POST: {e}")))?;
let debug_sse = std::env::var("LH_DEBUG_SSE").is_ok();
if debug_sse {
eprintln!(
"[openai resp] status={} content-type={:?}",
response.status(),
response.headers().get("content-type"),
);
}
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|_| "<no body>".to_string());
if debug_sse {
eprintln!("[openai ERROR] HTTP {status}: {body}");
}
return Err(Error::other(format!("openai HTTP {status}: {body}")));
}
let byte_stream = response
.bytes_stream()
.map(|res| res.map_err(|e| Error::other(format!("openai chunk read: {e}"))));
Ok(ChatSseStream::new(Box::pin(byte_stream)))
}
}
pub struct ChatSseStream {
frames: SseFrameStream,
}
impl ChatSseStream {
pub fn new(upstream: ByteStream) -> Self {
Self {
frames: SseFrameStream::new(upstream, Some(DONE_SENTINEL), "openai"),
}
}
}
impl Stream for ChatSseStream {
type Item = Result<ChatChunk>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.frames).poll_next(cx) {
Poll::Ready(Some(Ok(payload))) => Poll::Ready(Some(decode_chunk(&payload))),
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
fn decode_chunk(payload: &[u8]) -> Result<ChatChunk> {
serde_json::from_slice::<ChatChunk>(payload).map_err(|e| {
Error::other(format!(
"openai sse decode: {e}; payload: {}",
String::from_utf8_lossy(payload)
))
})
}
pub type SharedClient = Arc<OpenAiClient>;
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::openai::wire::FinishReason;
use bytes::Bytes;
use futures_util::stream;
fn bytes_from(parts: &[&[u8]]) -> ByteStream {
let owned: Vec<Bytes> = parts.iter().map(|b| Bytes::copy_from_slice(b)).collect();
Box::pin(stream::iter(owned.into_iter().map(Ok)))
}
fn canonical_frames() -> Vec<Vec<u8>> {
let raw = [
"data: {\"id\":\"c1\",\"model\":\"gpt-5-nano\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"Read\"}}]}\n\n",
"data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ing.\"}}]}\n\n",
"data: {\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_x\",\"type\":\"function\",\"function\":{\"name\":\"view_file\",\"arguments\":\"\"}}]}}]}\n\n",
"data: {\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"path\\\":\"}}]}}]}\n\n",
"data: {\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"main.rs\\\"}\"}}]}}]}\n\n",
"data: {\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"tool_calls\"}],\"usage\":{\"prompt_tokens\":12,\"completion_tokens\":33,\"total_tokens\":45}}\n\n",
"data: [DONE]\n\n",
];
raw.iter().map(|s| s.as_bytes().to_vec()).collect()
}
async fn assert_canonical(mut s: ChatSseStream) {
use std::collections::BTreeMap;
let mut text = String::new();
let mut tools: BTreeMap<u32, (String, String, String)> = BTreeMap::new();
let mut finish: Option<FinishReason> = None;
let mut completion_tokens: Option<i32> = None;
while let Some(chunk) = s.next().await {
let chunk = chunk.unwrap();
for choice in &chunk.choices {
if let Some(t) = &choice.delta.content {
text.push_str(t);
}
for tc in &choice.delta.tool_calls {
let entry = tools.entry(tc.index).or_default();
if let Some(id) = &tc.id {
entry.0 = id.clone();
}
if let Some(f) = &tc.function {
if let Some(name) = &f.name {
entry.1 = name.clone();
}
if let Some(args) = &f.arguments {
entry.2.push_str(args);
}
}
}
if let Some(fr) = choice.finish_reason {
finish = Some(fr);
}
}
if let Some(u) = chunk.usage {
completion_tokens = u.completion_tokens;
}
}
assert_eq!(text, "Reading.");
let call = &tools[&0];
assert_eq!(call.0, "call_x");
assert_eq!(call.1, "view_file");
let parsed: serde_json::Value = serde_json::from_str(&call.2).unwrap();
assert_eq!(parsed["path"], "main.rs");
assert_eq!(finish, Some(FinishReason::ToolCalls));
assert_eq!(completion_tokens, Some(33));
}
#[tokio::test]
async fn decodes_canonical_sequence_one_chunk() {
let blob: Vec<u8> = canonical_frames().concat();
let s = ChatSseStream::new(bytes_from(&[&blob]));
assert_canonical(s).await;
}
#[tokio::test]
async fn decodes_canonical_sequence_split_mid_frame() {
let blob: Vec<u8> = canonical_frames().concat();
let chunks: Vec<&[u8]> = blob.chunks(19).collect();
let s = ChatSseStream::new(bytes_from(&chunks));
assert_canonical(s).await;
}
#[tokio::test]
async fn decodes_canonical_sequence_crlf() {
let blob: Vec<u8> = canonical_frames().concat();
let crlf: Vec<u8> = String::from_utf8(blob)
.unwrap()
.replace('\n', "\r\n")
.into_bytes();
let chunks: Vec<&[u8]> = crlf.chunks(11).collect();
let s = ChatSseStream::new(bytes_from(&chunks));
assert_canonical(s).await;
}
#[tokio::test]
async fn done_sentinel_terminates_stream() {
let frames = [
"data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n".as_bytes(),
"data: [DONE]\n\n".as_bytes(),
"data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"LEAK\"}}]}\n\n".as_bytes(),
];
let mut s = ChatSseStream::new(bytes_from(&frames));
let first = s.next().await.unwrap().unwrap();
assert_eq!(first.choices[0].delta.content.as_deref(), Some("hi"));
assert!(s.next().await.is_none(), "[DONE] must terminate the stream");
}
#[tokio::test]
async fn malformed_json_yields_error_not_panic() {
let frames = ["data: {not valid json}\n\n".as_bytes()];
let mut s = ChatSseStream::new(bytes_from(&frames));
let item = s.next().await.unwrap();
assert!(item.is_err(), "malformed JSON must be an Err, got {item:?}");
}
}