use std::sync::Arc;
use futures_util::{StreamExt, stream};
use serde::Deserialize;
use tokio_util::sync::CancellationToken;
use crate::http::HttpClient;
use crate::message::{AssistantContent, AssistantMessage, OutputContent, ToolCall};
use crate::provider::{EventStream, ModelInfo, Provider, ProviderError, Request};
use crate::stream::{AssistantStreamEvent, StopReason, Usage};
struct SseFrame {
data: String,
}
pub enum ParsedEvent {
Valid(Vec<OpenAiChatEvent>),
Malformed { data: String, error: String },
}
pub fn parse_sse_events(input: &str) -> impl Iterator<Item = ParsedEvent> + '_ {
parse_frames(input).filter_map(|frame| {
if frame.data == "[DONE]" {
return None;
}
match serde_json::from_str::<OpenAiRawChunk>(&frame.data) {
Ok(raw) => Some(ParsedEvent::Valid(OpenAiChatEvent::from_raw_vec(raw))),
Err(e) => Some(ParsedEvent::Malformed {
data: frame.data.clone(),
error: e.to_string(),
}),
}
})
}
fn parse_frames(input: &str) -> impl Iterator<Item = SseFrame> + '_ {
let mut lines = input.split('\n').peekable();
std::iter::from_fn(move || {
let mut data_parts: Vec<&str> = Vec::new();
loop {
match lines.next() {
Some(line) if line.starts_with(':') => continue,
Some(line) if line.trim_end_matches('\r').is_empty() => {
if !data_parts.is_empty() {
return Some(SseFrame {
data: data_parts.join("\n"),
});
}
continue;
}
Some(line) => {
let line = line.trim_end_matches('\r');
let (field, value) = if let Some(idx) = line.find(':') {
let v = if line.get(idx + 1..idx + 2) == Some(" ") {
&line[idx + 2..]
} else {
&line[idx + 1..]
};
(&line[..idx], v)
} else {
(line, "")
};
if field == "data" {
data_parts.push(value);
}
}
None => {
if !data_parts.is_empty() {
return Some(SseFrame {
data: data_parts.join("\n"),
});
}
return None;
}
}
}
})
}
#[derive(Debug, Deserialize)]
struct OpenAiRawChunk {
#[allow(dead_code)]
id: Option<String>,
#[allow(dead_code)]
object: Option<String>,
#[allow(dead_code)]
created: Option<u64>,
model: Option<String>,
choices: Option<Vec<RawChoice>>,
usage: Option<RawUsage>,
error: Option<RawError>,
}
#[derive(Debug, Deserialize)]
struct RawChoice {
#[allow(dead_code)]
index: Option<usize>,
delta: Option<RawDelta>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct RawDelta {
role: Option<String>,
content: Option<String>,
tool_calls: Option<Vec<RawToolCall>>,
}
#[derive(Debug, Deserialize)]
struct RawToolCall {
index: usize,
id: Option<String>,
#[allow(dead_code)]
r#type: Option<String>,
function: Option<RawFunction>,
}
#[derive(Debug, Deserialize)]
struct RawFunction {
name: Option<String>,
arguments: Option<String>,
}
#[derive(Debug, Deserialize)]
struct RawUsage {
prompt_tokens: Option<u32>,
completion_tokens: Option<u32>,
#[allow(dead_code)]
total_tokens: Option<u32>,
prompt_tokens_details: Option<RawPromptTokenDetails>,
}
#[derive(Debug, Deserialize)]
struct RawPromptTokenDetails {
cached_tokens: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct RawError {
message: Option<String>,
#[allow(dead_code)]
r#type: Option<String>,
}
#[derive(Debug, Clone)]
pub enum OpenAiChatEvent {
RoleDelta {
role: Option<String>,
model: Option<String>,
},
ContentDelta { content: String },
ToolCallStart {
index: usize,
id: String,
name: String,
},
ToolCallDelta { index: usize, arguments: String },
Finish {
finish_reason: String,
usage: Option<Usage>,
},
Error { message: Option<String> },
}
impl OpenAiChatEvent {
fn from_raw_vec(raw: OpenAiRawChunk) -> Vec<Self> {
if let Some(err) = raw.error {
return vec![OpenAiChatEvent::Error {
message: err.message,
}];
}
let usage = raw.usage.map(|u| {
let cached = u
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens)
.unwrap_or(0);
Usage {
input_tokens: u.prompt_tokens.unwrap_or(0),
output_tokens: u.completion_tokens.unwrap_or(0),
cache_read_tokens: cached,
cache_write_tokens: 0,
}
});
let choices = match raw.choices {
Some(c) => c,
None => {
if let Some(u) = usage {
return vec![OpenAiChatEvent::Finish {
finish_reason: String::new(),
usage: Some(u),
}];
}
return vec![];
}
};
let mut events = Vec::new();
if let Some(choice) = choices.into_iter().next() {
if let Some(reason) = choice.finish_reason {
return vec![OpenAiChatEvent::Finish {
finish_reason: reason,
usage,
}];
}
let delta = match choice.delta {
Some(d) => d,
None => return events,
};
if let Some(tool_calls) = delta.tool_calls {
for tc in tool_calls {
let func = tc.function.unwrap_or(RawFunction {
name: None,
arguments: None,
});
if let Some(id) = tc.id {
let name = func.name.unwrap_or_default();
events.push(OpenAiChatEvent::ToolCallStart {
index: tc.index,
id,
name,
});
} else {
let arguments = func.arguments.unwrap_or_default();
if !arguments.is_empty() {
events.push(OpenAiChatEvent::ToolCallDelta {
index: tc.index,
arguments,
});
}
}
}
if !events.is_empty() {
return events;
}
}
if delta.role.is_some() {
return vec![OpenAiChatEvent::RoleDelta {
role: delta.role,
model: raw.model,
}];
}
let content = delta.content.unwrap_or_default();
if !content.is_empty() {
return vec![OpenAiChatEvent::ContentDelta { content }];
}
}
events
}
}
pub struct OpenAiChatMapper {
partial: AssistantMessage,
tool_calls: Vec<ToolCallState>,
saw_done: bool,
text_started: bool,
}
struct ToolCallState {
id: String,
name: String,
arguments: String,
}
impl OpenAiChatMapper {
pub fn new(api: crate::ApiKind, provider: &str) -> Self {
Self {
partial: empty_assistant_message(api, provider),
tool_calls: Vec::new(),
saw_done: false,
text_started: false,
}
}
pub fn process(&mut self, event: OpenAiChatEvent) -> Vec<AssistantStreamEvent> {
if self.saw_done {
return Vec::new();
}
match event {
OpenAiChatEvent::RoleDelta { model, .. } => {
if let Some(m) = model {
self.partial.model = m;
}
let start = self.partial.clone();
vec![AssistantStreamEvent::Start { partial: start }]
}
OpenAiChatEvent::ContentDelta { content } => {
if content.is_empty() {
return Vec::new();
}
let mut events = Vec::new();
if !self.text_started {
self.text_started = true;
self.partial.content.push(AssistantContent::Text {
text: String::new(),
});
events.push(AssistantStreamEvent::TextStart {
content_index: 0,
partial: self.partial.clone(),
});
}
if let Some(AssistantContent::Text { text }) = self.partial.content.last_mut() {
text.push_str(&content);
}
events.push(AssistantStreamEvent::TextDelta {
content_index: 0,
delta: content,
partial: self.partial.clone(),
});
events
}
OpenAiChatEvent::ToolCallStart { index, id, name } => {
let mut events = Vec::new();
if self.text_started {
self.text_started = false;
if let Some(AssistantContent::Text { text }) = self.partial.content.last() {
let content = text.clone();
events.push(AssistantStreamEvent::TextEnd {
content_index: 0,
content,
partial: self.partial.clone(),
});
}
}
while self.tool_calls.len() <= index {
self.tool_calls.push(ToolCallState {
id: String::new(),
name: String::new(),
arguments: String::new(),
});
}
let content_index = self.partial.content.len();
self.tool_calls[index] = ToolCallState {
id: id.clone(),
name: name.clone(),
arguments: String::new(),
};
self.partial.content.push(AssistantContent::ToolCall {
tool_call: ToolCall {
id,
name,
arguments: String::new(),
},
});
events.push(AssistantStreamEvent::ToolCallStart {
content_index,
partial: self.partial.clone(),
});
events
}
OpenAiChatEvent::ToolCallDelta { index, arguments } => {
if arguments.is_empty() || index >= self.tool_calls.len() {
return Vec::new();
}
self.tool_calls[index].arguments.push_str(&arguments);
let mut tool_count = 0;
let tool_content_index = self
.partial
.content
.iter()
.position(|c| {
if matches!(c, AssistantContent::ToolCall { .. }) {
if tool_count == index {
return true;
}
tool_count += 1;
}
false
})
.unwrap_or(0);
if let Some(AssistantContent::ToolCall { tool_call }) =
self.partial.content.get_mut(tool_content_index)
{
tool_call.arguments.push_str(&arguments);
}
vec![AssistantStreamEvent::ToolCallDelta {
content_index: tool_content_index,
delta: arguments,
partial: self.partial.clone(),
}]
}
OpenAiChatEvent::Finish {
finish_reason,
usage,
} => {
let mut events = Vec::new();
if self.text_started {
self.text_started = false;
if let Some(AssistantContent::Text { text }) = self.partial.content.last() {
let content = text.clone();
events.push(AssistantStreamEvent::TextEnd {
content_index: 0,
content,
partial: self.partial.clone(),
});
}
}
for (tc_idx, tc_state) in self.tool_calls.iter().enumerate() {
if tc_state.id.is_empty() {
continue;
}
let mut tool_count = 0;
let tool_content_index = self
.partial
.content
.iter()
.position(|c| {
if matches!(c, AssistantContent::ToolCall { .. }) {
if tool_count == tc_idx {
return true;
}
tool_count += 1;
}
false
})
.unwrap_or(0);
if let Some(AssistantContent::ToolCall { tool_call }) =
self.partial.content.get_mut(tool_content_index)
{
tool_call.arguments = tc_state.arguments.clone();
}
let tool_call = ToolCall {
id: tc_state.id.clone(),
name: tc_state.name.clone(),
arguments: tc_state.arguments.clone(),
};
events.push(AssistantStreamEvent::ToolCallEnd {
content_index: tool_content_index,
tool_call,
partial: self.partial.clone(),
});
}
if let Some(u) = usage {
self.partial.usage = u;
}
self.partial.stop_reason = map_stop_reason(&finish_reason);
self.saw_done = true;
events.push(AssistantStreamEvent::Done {
reason: self.partial.stop_reason,
message: self.partial.clone(),
});
events
}
OpenAiChatEvent::Error { message } => {
self.saw_done = true;
let mut err_msg = self.partial.clone();
err_msg.error_message = message;
vec![AssistantStreamEvent::Error {
reason: StopReason::Error,
message: err_msg,
}]
}
}
}
}
fn map_stop_reason(raw: &str) -> StopReason {
match raw {
"stop" => StopReason::Stop,
"length" => StopReason::Length,
"tool_calls" => StopReason::ToolUse,
"content_filter" => StopReason::Error,
_ => StopReason::Error,
}
}
fn empty_assistant_message(api: crate::ApiKind, provider: &str) -> AssistantMessage {
AssistantMessage {
content: Vec::new(),
api,
provider: provider.into(),
model: String::new(),
response_model: None,
response_id: None,
usage: Usage::default(),
stop_reason: StopReason::Stop,
error_message: None,
timestamp_ms: 0,
}
}
#[derive(Debug, Clone)]
pub struct CompatConfig {
pub system_role_override: Option<String>,
pub max_tokens_field: String,
pub tool_result_name_field: bool,
pub usage_in_stream: bool,
}
impl Default for CompatConfig {
fn default() -> Self {
Self {
system_role_override: None,
max_tokens_field: "max_tokens".into(),
tool_result_name_field: false,
usage_in_stream: false,
}
}
}
pub struct OpenAiChatProvider {
#[allow(dead_code)] api_key: String,
#[allow(dead_code)] base_url: String,
models: Vec<ModelInfo>,
compat: CompatConfig,
provider_id: String,
extra_headers: Vec<(String, String)>,
client: Arc<HttpClient>,
}
impl OpenAiChatProvider {
pub fn new(api_key: String, base_url: Option<String>) -> Self {
Self::new_with_compat(api_key, base_url, CompatConfig::default())
}
pub fn new_with_compat(
api_key: String,
base_url: Option<String>,
compat: CompatConfig,
) -> Self {
Self::with_client_and_compat(
api_key,
base_url,
compat,
"openai".into(),
vec![],
Arc::new(HttpClient::new()),
)
}
pub fn with_client(
api_key: String,
base_url: Option<String>,
provider_id: String,
extra_headers: Vec<(String, String)>,
client: Arc<HttpClient>,
) -> Self {
Self::with_client_and_compat(
api_key,
base_url,
CompatConfig::default(),
provider_id,
extra_headers,
client,
)
}
fn with_client_and_compat(
api_key: String,
base_url: Option<String>,
compat: CompatConfig,
provider_id: String,
extra_headers: Vec<(String, String)>,
client: Arc<HttpClient>,
) -> Self {
let base_url = base_url.unwrap_or_else(|| "https://api.openai.com".into());
let models = vec![
ModelInfo {
id: "gpt-4o".into(),
display_name: "GPT-4o".into(),
context_window: 128000,
max_output_tokens: 16384,
supports_images: true,
supports_streaming: true,
supports_thinking: false,
},
ModelInfo {
id: "gpt-4o-mini".into(),
display_name: "GPT-4o Mini".into(),
context_window: 128000,
max_output_tokens: 16384,
supports_images: true,
supports_streaming: true,
supports_thinking: false,
},
ModelInfo {
id: "o3".into(),
display_name: "o3".into(),
context_window: 200000,
max_output_tokens: 100000,
supports_images: true,
supports_streaming: true,
supports_thinking: false,
},
ModelInfo {
id: "o4-mini".into(),
display_name: "o4-mini".into(),
context_window: 200000,
max_output_tokens: 100000,
supports_images: true,
supports_streaming: true,
supports_thinking: false,
},
];
Self {
api_key,
base_url,
models,
compat,
provider_id,
extra_headers,
client,
}
}
pub fn new_for_profile(
api_key: String,
base_url: String,
provider_id: String,
compat: CompatConfig,
extra_headers: Vec<(String, String)>,
models: Vec<ModelInfo>,
) -> Self {
Self {
api_key,
base_url,
models,
compat,
provider_id,
extra_headers,
client: Arc::new(HttpClient::new()),
}
}
pub fn with_shared_client(self, client: Arc<HttpClient>) -> Self {
Self { client, ..self }
}
pub fn http_client(&self) -> &Arc<HttpClient> {
&self.client
}
pub fn build_request_body(&self, request: &Request) -> serde_json::Value {
let model_id = request
.model
.split_once(':')
.map(|(_, id)| id)
.unwrap_or(&request.model);
let mut body = serde_json::json!({
"model": model_id,
"stream": true,
"messages": serialize_messages(&request.messages, &request.system, &self.compat),
});
if let Some(max_tokens) = request.max_tokens {
body[&self.compat.max_tokens_field] = serde_json::Value::Number(max_tokens.into());
}
if let Some(temp) = request.temperature
&& let Some(n) = serde_json::Number::from_f64(temp)
{
body["temperature"] = serde_json::Value::Number(n);
}
if !request.tools.is_empty() {
body["tools"] = serde_json::Value::Array(
request
.tools
.iter()
.map(|t| {
serde_json::json!({
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.input_schema,
}
})
})
.collect(),
);
}
if !request.stop_sequences.is_empty() {
body["stop"] = serde_json::Value::Array(
request
.stop_sequences
.iter()
.map(|s| serde_json::Value::String(s.clone()))
.collect(),
);
}
body
}
pub fn stream_from_sse(&self, sse_body: &str, cancel: CancellationToken) -> EventStream {
let mut mapper = OpenAiChatMapper::new(crate::ApiKind::OpenAi, &self.provider_id);
let mut stream_events: Vec<Result<AssistantStreamEvent, ProviderError>> = Vec::new();
for parsed in parse_sse_events(sse_body) {
match parsed {
ParsedEvent::Valid(events) => {
for event in events {
stream_events.extend(mapper.process(event).into_iter().map(Ok));
}
}
ParsedEvent::Malformed { data, error } => {
stream_events.push(Err(ProviderError::StreamError(format!(
"malformed SSE data: {error} (data: {data:.80})"
))));
}
}
}
let _cancel = cancel;
Box::pin(stream::iter(stream_events))
}
#[allow(clippy::too_many_arguments)]
async fn stream_http(
http_client: reqwest::Client,
api_key: String,
base_url: String,
provider_id: String,
extra_headers: Vec<(String, String)>,
body: &serde_json::Value,
cancel: CancellationToken,
tx: &tokio::sync::mpsc::Sender<Result<AssistantStreamEvent, ProviderError>>,
) -> Result<(), ProviderError> {
let mut req = http_client
.post(format!("{base_url}/v1/chat/completions"))
.header("authorization", format!("Bearer {api_key}"))
.header("content-type", "application/json");
for (name, value) in &extra_headers {
req = req.header(name.as_str(), value.as_str());
}
let response = req
.body(serde_json::to_string(body).unwrap_or_default())
.send()
.await
.map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let headers = response.headers().clone();
let error_body = response.text().await.unwrap_or_default();
return Err(map_http_status(status, &error_body, &headers));
}
let mut byte_stream = response.bytes_stream();
let mut buffer = String::new();
let mut mapper = OpenAiChatMapper::new(crate::ApiKind::OpenAi, &provider_id);
loop {
let chunk = tokio::select! {
_ = cancel.cancelled() => {
return Ok(());
}
chunk = byte_stream.next() => {
match chunk {
Some(c) => c,
None => break,
}
}
};
let chunk = chunk.map_err(|e| ProviderError::StreamError(e.to_string()))?;
buffer.push_str(&String::from_utf8_lossy(&chunk));
for parsed in drain_sse_events(&mut buffer) {
match parsed {
ParsedEvent::Valid(events) => {
for event in events {
for stream_event in mapper.process(event) {
if tx.send(Ok(stream_event)).await.is_err() {
return Ok(());
}
}
}
}
ParsedEvent::Malformed { data, error } => {
let err = ProviderError::StreamError(format!(
"malformed SSE data: {error} (data: {data:.80})"
));
if tx.send(Err(err)).await.is_err() {
return Ok(());
}
}
}
}
}
if !mapper.saw_done {
let err = ProviderError::StreamError("stream ended without a terminal event".into());
let _ = tx.send(Err(err)).await;
}
Ok(())
}
}
struct ReceiverStream {
rx: tokio::sync::mpsc::Receiver<Result<AssistantStreamEvent, ProviderError>>,
}
impl futures_core::Stream for ReceiverStream {
type Item = Result<AssistantStreamEvent, ProviderError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}
fn drain_sse_events(buffer: &mut String) -> Vec<ParsedEvent> {
if buffer.contains('\r') {
*buffer = buffer.replace("\r\n", "\n").replace('\r', "\n");
}
let mut events = Vec::new();
while let Some(idx) = buffer.find("\n\n") {
let end = idx + 2;
let chunk: String = buffer.drain(..end).collect();
events.extend(parse_sse_events(&chunk));
}
events
}
fn map_http_status(
status: reqwest::StatusCode,
body: &str,
headers: &reqwest::header::HeaderMap,
) -> ProviderError {
match status.as_u16() {
401 => ProviderError::AuthFailed(format!("authentication failed: {body}")),
403 => ProviderError::AuthFailed(format!("access denied: {body}")),
429 => ProviderError::RateLimited {
retry_after_ms: crate::retry::parse_retry_after(headers),
},
408 | 504 => ProviderError::Timeout,
code => ProviderError::RequestFailed(format!("HTTP {code}: {body}")),
}
}
fn serialize_messages(
messages: &[crate::message::Message],
system: &Option<String>,
compat: &CompatConfig,
) -> serde_json::Value {
let mut result = Vec::new();
if let Some(sys) = system {
let role = compat.system_role_override.as_deref().unwrap_or("system");
result.push(serde_json::json!({
"role": role,
"content": sys,
}));
}
for msg in messages {
match msg {
crate::message::Message::User(u) => {
let content: Vec<serde_json::Value> = u
.content
.iter()
.map(|c| match c {
crate::message::InputContent::Text { text } => {
serde_json::json!({"type": "text", "text": text})
}
crate::message::InputContent::Image { source, media_type } => {
let url = match source {
crate::message::ImageSource::Url { url } => url.clone(),
crate::message::ImageSource::Base64 { data } => {
format!("data:{};base64,{}", media_type.as_str(), data)
}
crate::message::ImageSource::Bytes { data } => {
format!(
"data:{};base64,{}",
media_type.as_str(),
base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
data,
)
)
}
};
serde_json::json!({
"type": "image_url",
"image_url": {"url": url}
})
}
})
.collect();
if content.len() == 1
&& let Some(text_val) = content[0].get("text")
{
result.push(serde_json::json!({
"role": "user",
"content": text_val,
}));
continue;
}
result.push(serde_json::json!({
"role": "user",
"content": content,
}));
}
crate::message::Message::Assistant(a) => {
let mut tool_calls_json = Vec::new();
let mut text_parts = Vec::new();
for c in &a.content {
match c {
AssistantContent::Text { text } => {
text_parts.push(text.clone());
}
AssistantContent::ToolCall { tool_call } => {
let input: serde_json::Value =
serde_json::from_str(&tool_call.arguments)
.ok()
.filter(|v: &serde_json::Value| v.is_object())
.unwrap_or(serde_json::json!({}));
tool_calls_json.push(serde_json::json!({
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.name,
"arguments": serde_json::to_string(&input).unwrap_or_default(),
}
}));
}
AssistantContent::Thinking { .. } => {}
}
}
let mut assistant_msg = serde_json::json!({
"role": "assistant",
});
if !tool_calls_json.is_empty() {
assistant_msg["tool_calls"] = serde_json::Value::Array(tool_calls_json);
assistant_msg["content"] = serde_json::Value::Null;
} else {
assistant_msg["content"] = serde_json::Value::String(text_parts.join(""));
}
result.push(assistant_msg);
}
crate::message::Message::ToolResult(t) => {
let content_text: String = t
.content
.iter()
.map(|c| match c {
OutputContent::Text { text } => text.clone(),
OutputContent::Image { media_type, .. } => {
format!("[image: {}]", media_type.as_str())
}
})
.collect();
let mut tool_msg = serde_json::json!({
"role": "tool",
"tool_call_id": t.tool_call_id,
"content": content_text,
});
if compat.tool_result_name_field {
tool_msg["name"] = serde_json::Value::String(t.tool_name.clone());
}
result.push(tool_msg);
}
}
}
serde_json::Value::Array(result)
}
impl Provider for OpenAiChatProvider {
fn stream(&self, request: Request) -> EventStream {
let api_key = self.api_key.clone();
let base_url = self.base_url.clone();
let provider_id = self.provider_id.clone();
let extra_headers = self.extra_headers.clone();
let body = self.build_request_body(&request);
let cancel = request.cancel.clone();
let http_client = self.client.client().clone();
let (tx, rx) = tokio::sync::mpsc::channel(64);
tokio::spawn(async move {
if let Err(e) = Self::stream_http(
http_client,
api_key,
base_url,
provider_id,
extra_headers,
&body,
cancel,
&tx,
)
.await
{
let _ = tx.send(Err(e)).await;
}
});
Box::pin(ReceiverStream { rx })
}
fn id(&self) -> &str {
&self.provider_id
}
fn models(&self) -> &[ModelInfo] {
&self.models
}
}