use crate::error::{Error, Result};
use crate::providers::{ListModels, ProviderApi, ProviderClient};
use crate::sse::drain_sse_frames;
use crate::types::{
ContentPart, DEFAULT_OPENAI_MODEL, Event, Message, Model, ModelInfo, Provider, Response,
ResponseRequest, Role, ToolCall, ToolSpec,
};
use futures_util::StreamExt;
use serde::Deserialize;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Deserialize)]
struct OpenAiModelsResponse {
data: Vec<OpenAiModel>,
}
#[derive(Debug, Deserialize)]
struct OpenAiModel {
id: String,
#[serde(default)]
display_name: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(default)]
created_at: Option<String>,
#[serde(default)]
created: Option<u64>,
#[serde(default)]
max_input_tokens: Option<u64>,
#[serde(default)]
context_length: Option<u64>,
#[serde(default)]
max_output_tokens: Option<u64>,
#[serde(default)]
max_tokens: Option<u64>,
}
#[derive(Debug, Clone, Copy)]
pub struct OpenAiProvider;
#[derive(Debug, Clone)]
pub struct OpenAiApi {
api_key: String,
base_url: String,
}
impl OpenAiApi {
pub fn new(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: base_url.into(),
}
}
}
impl OpenAiProvider {
fn url(base_url: &str) -> String {
format!("{}/v1/responses", base_url.trim_end_matches('/'))
}
fn models_url(base_url: &str) -> String {
format!("{}/v1/models", base_url.trim_end_matches('/'))
}
fn input_from_messages(messages: &[Message]) -> Vec<serde_json::Value> {
let mut out: Vec<serde_json::Value> = Vec::new();
for m in messages {
match m.role {
Role::Tool => {
for part in &m.content {
if let ContentPart::ToolResult { id, content, .. } = part {
out.push(serde_json::json!({
"type": "function_call_output",
"call_id": id,
"output": content.to_string(),
}));
}
}
}
Role::System | Role::User | Role::Assistant => {
let role = match m.role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => unreachable!(),
};
let part_type = match m.role {
Role::Assistant => "output_text",
Role::System | Role::User => "input_text",
Role::Tool => unreachable!(),
};
for part in &m.content {
if let ContentPart::ToolCall {
id,
name,
arguments,
} = part
{
out.push(serde_json::json!({
"type": "function_call",
"call_id": id,
"name": name,
"arguments": arguments.to_string(),
}));
}
}
let content = m
.content
.iter()
.filter_map(|p| match p {
ContentPart::Text(t) => Some(serde_json::json!({
"type": part_type,
"text": t
})),
ContentPart::ImageUrl { url } => Some(serde_json::json!({
"type": "input_image",
"image_url": url,
})),
ContentPart::ImageBase64 { media_type, data } => {
Some(serde_json::json!({
"type": "input_image",
"image_url": format!("data:{media_type};base64,{data}"),
}))
}
ContentPart::Thinking { text, .. } => Some(serde_json::json!({
"type": part_type,
"text": text,
})),
ContentPart::Citation { .. } => None,
ContentPart::ToolCall { .. } => None,
ContentPart::ToolResult { .. } => None,
})
.collect::<Vec<_>>();
if !content.is_empty() {
out.push(serde_json::json!({ "role": role, "content": content }));
}
}
}
}
out
}
fn tools_from_tools(tools: &[ToolSpec]) -> Vec<serde_json::Value> {
tools
.iter()
.map(|t| {
serde_json::json!({
"type": "function",
"name": t.name,
"description": t.description,
"parameters": t.parameters,
})
})
.collect()
}
async fn resolve_model(http: &reqwest::Client, api_key: &str, base_url: &str) -> String {
match ListModels::list_models(&OpenAiProvider, http, api_key, base_url).await {
Ok(models) => models
.first()
.map(|m| m.id.clone())
.unwrap_or_else(|| DEFAULT_OPENAI_MODEL.to_string()),
Err(_) => DEFAULT_OPENAI_MODEL.to_string(),
}
}
fn model_infos_from_response(resp: OpenAiModelsResponse) -> Vec<ModelInfo> {
resp.data
.into_iter()
.map(|m| {
let display_name = m.display_name.or(m.name);
let created_at = m.created_at.or_else(|| m.created.map(|n| n.to_string()));
let max_input_tokens = m
.max_input_tokens
.or(m.context_length)
.and_then(|n| u32::try_from(n).ok());
let max_output_tokens = m
.max_output_tokens
.or(m.max_tokens)
.and_then(|n| u32::try_from(n).ok());
ModelInfo {
id: m.id,
display_name,
provider: Provider::OpenAi,
created_at,
max_input_tokens,
max_output_tokens,
}
})
.collect()
}
fn extract_output_text(v: &serde_json::Value) -> String {
let v = if v.get("type").and_then(|x| x.as_str()) == Some("response.completed") {
v.get("response").unwrap_or(v)
} else {
v
};
if let Some(s) = v.get("output_text").and_then(|x| x.as_str())
&& !s.trim().is_empty()
{
return s.to_string();
}
let mut out = String::new();
let Some(items) = v.get("output").and_then(|x| x.as_array()) else {
return out;
};
for item in items {
if let Some(t) = item.get("text").and_then(|x| x.as_str()) {
out.push_str(t);
continue;
}
let Some(content) = item.get("content").and_then(|x| x.as_array()) else {
continue;
};
for part in content {
let text = part.get("text").and_then(|x| x.as_str());
if let Some(t) = text {
out.push_str(t);
}
}
}
out
}
fn text_from_message_output_item(item: &serde_json::Value) -> String {
let mut s = String::new();
let Some(content) = item.get("content").and_then(|c| c.as_array()) else {
return s;
};
for part in content {
if let Some(t) = part.get("text").and_then(|x| x.as_str()) {
s.push_str(t);
}
}
s
}
fn emit_assistant_text_tail<F: FnMut(Event) + ?Sized>(
out: &mut String,
full: &str,
on_event: &mut F,
) {
if full.is_empty() {
return;
}
if out.is_empty() {
out.push_str(full);
on_event(Event::TextDelta(full.to_string()));
} else if full.len() > out.len() && full.starts_with(out.as_str()) {
let rest = &full[out.len()..];
out.push_str(rest);
if !rest.is_empty() {
on_event(Event::TextDelta(rest.to_string()));
}
} else if full != out.as_str() {
out.clear();
out.push_str(full);
on_event(Event::TextDelta(full.to_string()));
}
}
fn extract_tool_calls(v: &serde_json::Value) -> Vec<ToolCall> {
let mut out: Vec<ToolCall> = Vec::new();
let v = if v.get("type").and_then(|x| x.as_str()) == Some("response.completed") {
v.get("response").unwrap_or(v)
} else {
v
};
if let Some(items) = v.get("output").and_then(|x| x.as_array()) {
for item in items {
let t = item.get("type").and_then(|x| x.as_str()).unwrap_or("");
if t != "function_call" {
continue;
}
let id = item
.get("call_id")
.and_then(|x| x.as_str())
.or_else(|| item.get("id").and_then(|x| x.as_str()))
.map(|s| s.to_string());
let name = item
.get("name")
.and_then(|x| x.as_str())
.unwrap_or("")
.to_string();
let args_raw = item.get("arguments");
let arguments = match args_raw {
Some(serde_json::Value::String(s)) => {
serde_json::from_str::<serde_json::Value>(s)
.unwrap_or(serde_json::Value::Null)
}
Some(v) => v.clone(),
None => serde_json::Value::Null,
};
if name.is_empty() {
continue;
}
out.push(ToolCall {
id,
name,
arguments,
});
}
}
if out.is_empty() {
let tool_calls = v
.pointer("/choices/0/message/tool_calls")
.and_then(|x| x.as_array());
if let Some(calls) = tool_calls {
for call in calls {
let id = call
.get("id")
.and_then(|x| x.as_str())
.map(|s| s.to_string());
let name = call
.pointer("/function/name")
.and_then(|x| x.as_str())
.unwrap_or("")
.to_string();
let raw = call.pointer("/function/arguments");
let arguments = match raw {
Some(serde_json::Value::String(s)) => {
serde_json::from_str::<serde_json::Value>(s)
.unwrap_or(serde_json::Value::Null)
}
Some(v) => v.clone(),
None => serde_json::Value::Null,
};
if name.is_empty() {
continue;
}
out.push(ToolCall {
id,
name,
arguments,
});
}
}
}
out
}
}
#[async_trait::async_trait(?Send)]
impl ProviderApi for OpenAiApi {
fn provider(&self) -> Provider {
Provider::OpenAi
}
async fn send(&self, http: &reqwest::Client, req: ResponseRequest) -> Result<Response> {
OpenAiProvider
.send(http, &self.api_key, &self.base_url, req)
.await
}
async fn stream(
&self,
http: &reqwest::Client,
req: ResponseRequest,
on_event: &mut dyn FnMut(Event),
) -> Result<Response> {
OpenAiProvider
.stream(http, &self.api_key, &self.base_url, req, on_event)
.await
}
async fn list_models(&self, http: &reqwest::Client) -> Result<Vec<ModelInfo>> {
OpenAiProvider
.list_models(http, &self.api_key, &self.base_url)
.await
}
}
impl ListModels for OpenAiProvider {
fn list_models(
&self,
http: &reqwest::Client,
api_key: &str,
base_url: &str,
) -> impl std::future::Future<Output = Result<Vec<ModelInfo>>> + Send {
let url = Self::models_url(base_url);
let http = http.clone();
let api_key = api_key.to_string();
async move {
let resp = http.get(url).bearer_auth(api_key).send().await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
return Err(Error::Api {
provider: Provider::OpenAi,
status: status.as_u16(),
body: text,
});
}
let parsed: OpenAiModelsResponse = serde_json::from_str(&text)?;
Ok(Self::model_infos_from_response(parsed))
}
}
}
impl ProviderClient for OpenAiProvider {
async fn send(
&self,
http: &reqwest::Client,
api_key: &str,
base_url: &str,
req: ResponseRequest,
) -> Result<Response> {
let url = Self::url(base_url);
let input = Self::input_from_messages(&req.messages);
let model = match req.model {
Some(m) => m.0,
None => Self::resolve_model(http, api_key, base_url).await,
};
let mut body = serde_json::json!({
"model": model,
"input": input,
});
if let Some(max) = req.max_output_tokens {
body["max_output_tokens"] = serde_json::json!(max);
}
if !req.tools.is_empty() {
body["tools"] = serde_json::json!(Self::tools_from_tools(&req.tools));
}
let resp = http
.post(url)
.bearer_auth(api_key)
.json(&body)
.send()
.await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
return Err(Error::Api {
provider: Provider::OpenAi,
status: status.as_u16(),
body: text,
});
}
let v: serde_json::Value = serde_json::from_str(&text)?;
let output_text = Self::extract_output_text(&v);
let tool_calls = Self::extract_tool_calls(&v);
Ok(Response {
model: Model::new(body["model"].as_str().unwrap_or(DEFAULT_OPENAI_MODEL)),
message: Message::text(Role::Assistant, output_text),
tool_calls,
metadata: serde_json::Value::Null,
#[cfg(feature = "raw-json")]
raw_json: Some(v),
})
}
async fn stream<F>(
&self,
http: &reqwest::Client,
api_key: &str,
base_url: &str,
req: ResponseRequest,
on_event: &mut F,
) -> Result<Response>
where
F: FnMut(Event) + ?Sized,
{
let url = Self::url(base_url);
let input = Self::input_from_messages(&req.messages);
let model = match req.model {
Some(m) => m.0,
None => Self::resolve_model(http, api_key, base_url).await,
};
let mut body = serde_json::json!({
"model": model,
"input": input,
"stream": true,
});
if let Some(max) = req.max_output_tokens {
body["max_output_tokens"] = serde_json::json!(max);
}
if !req.tools.is_empty() {
body["tools"] = serde_json::json!(Self::tools_from_tools(&req.tools));
}
let resp = http
.post(url)
.bearer_auth(api_key)
.json(&body)
.send()
.await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await?;
return Err(Error::Api {
provider: Provider::OpenAi,
status: status.as_u16(),
body: text,
});
}
let mut out = String::new();
let mut tool_calls: HashMap<String, (Option<String>, String)> = HashMap::new();
let mut emitted_stream_tool_ids: HashSet<String> = HashSet::new();
let mut data_buf = String::new();
let mut frames = Vec::new();
let mut bytes = resp.bytes_stream();
while let Some(chunk) = bytes.next().await {
let chunk = chunk?;
let s = String::from_utf8_lossy(&chunk);
data_buf.push_str(&s);
drain_sse_frames(&mut data_buf, &mut frames)?;
for frame in frames.drain(..) {
let data_trim = frame.data.trim();
if data_trim == "[DONE]" || data_trim.starts_with("[DONE]") {
return Ok(Response {
model: Model::new(body["model"].as_str().unwrap_or(DEFAULT_OPENAI_MODEL)),
message: Message::text(Role::Assistant, out),
tool_calls: Vec::new(),
metadata: serde_json::Value::Null,
#[cfg(feature = "raw-json")]
raw_json: None,
});
}
let v: Option<serde_json::Value> = serde_json::from_str(&frame.data).ok();
let ev = frame
.event
.as_deref()
.filter(|e| !e.is_empty())
.or_else(|| {
v.as_ref()
.and_then(|val| val.get("type").and_then(|x| x.as_str()))
})
.unwrap_or("");
if ev == "error" || ev == "response.error" {
return Err(Error::InvalidInput(
format!("openai stream error event: {}", frame.data).into(),
));
}
if ev == "response.output_text.delta" {
if let Some(val) = &v {
let delta = val.get("delta").and_then(|x| x.as_str()).or_else(|| {
val.get("delta")
.and_then(|d| d.get("text"))
.and_then(|x| x.as_str())
});
if let Some(delta) = delta {
out.push_str(delta);
on_event(Event::TextDelta(delta.to_string()));
}
}
} else if ev == "response.output_text.done" {
if let Some(val) = &v
&& let Some(t) = val.get("text").and_then(|x| x.as_str())
{
Self::emit_assistant_text_tail(&mut out, t, on_event);
}
} else if ev == "response.output_item.done" {
if let Some(val) = &v
&& let Some(item) = val.get("item")
&& item.get("type").and_then(|x| x.as_str()) == Some("message")
{
let t = Self::text_from_message_output_item(item);
Self::emit_assistant_text_tail(&mut out, t.as_str(), on_event);
}
} else if ev == "response.output_item.added" {
if let Some(val) = &v
&& let Some(item) = val.get("item")
&& item.get("type").and_then(|x| x.as_str()) == Some("function_call")
&& let Some(id) = item
.get("call_id")
.and_then(|x| x.as_str())
.or_else(|| item.get("id").and_then(|x| x.as_str()))
{
let name = item
.get("name")
.and_then(|x| x.as_str())
.map(|s| s.to_string());
let args = item
.get("arguments")
.and_then(|x| x.as_str())
.unwrap_or("")
.to_string();
tool_calls.insert(id.to_string(), (name, args));
}
} else if ev == "response.function_call_arguments.delta" {
if let Some(val) = &v
&& let Some(call_id) = val.get("call_id").and_then(|x| x.as_str())
&& let Some(delta) = val.get("delta").and_then(|x| x.as_str())
{
let entry = tool_calls
.entry(call_id.to_string())
.or_insert((None, String::new()));
entry.1.push_str(delta);
}
} else if ev == "response.function_call_arguments.done" {
if let Some(val) = &v
&& let Some(call_id) = val.get("call_id").and_then(|x| x.as_str())
&& let Some((name_opt, raw_args)) = tool_calls.remove(call_id)
{
let name = name_opt.unwrap_or_else(|| "tool".to_string());
if let Ok(args) = serde_json::from_str::<serde_json::Value>(&raw_args) {
on_event(Event::ToolCall(ToolCall {
id: Some(call_id.to_string()),
name,
arguments: args,
}));
emitted_stream_tool_ids.insert(call_id.to_string());
}
}
} else if ev == "response.completed" {
let mut message_text = out.clone();
if let Some(val) = &v {
let extracted = Self::extract_output_text(val);
if !extracted.trim().is_empty() && extracted.len() > message_text.len() {
message_text = extracted;
}
}
Self::emit_assistant_text_tail(&mut out, message_text.as_str(), on_event);
let extracted_tool_calls =
v.as_ref().map(Self::extract_tool_calls).unwrap_or_default();
for tc in &extracted_tool_calls {
let already = tc.id.as_ref().is_some_and(|id| {
!id.is_empty() && emitted_stream_tool_ids.contains(id.as_str())
});
if !already {
on_event(Event::ToolCall(tc.clone()));
if let Some(id) = tc.id.clone().filter(|s| !s.is_empty()) {
emitted_stream_tool_ids.insert(id);
}
}
}
let resp = Response {
model: Model::new(body["model"].as_str().unwrap_or(DEFAULT_OPENAI_MODEL)),
message: Message::text(Role::Assistant, out.clone()),
tool_calls: extracted_tool_calls,
metadata: serde_json::Value::Null,
#[cfg(feature = "raw-json")]
raw_json: None,
};
on_event(Event::Completed(resp.clone()));
return Ok(resp);
}
}
}
Ok(Response {
model: Model::new(body["model"].as_str().unwrap_or(DEFAULT_OPENAI_MODEL)),
message: Message::text(Role::Assistant, out),
tool_calls: Vec::new(),
metadata: serde_json::Value::Null,
#[cfg(feature = "raw-json")]
raw_json: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_output_text_unwraps_response_completed_event() {
let json = r#"{
"type": "response.completed",
"sequence_number": 9,
"response": {
"output_text": "42"
}
}"#;
let v: serde_json::Value = serde_json::from_str(json).unwrap();
assert_eq!(OpenAiProvider::extract_output_text(&v), "42");
}
#[test]
fn extract_output_text_from_message_output_items() {
let json = r#"{
"output": [
{
"type": "message",
"role": "assistant",
"content": [
{ "type": "output_text", "text": "hi" }
]
}
]
}"#;
let v: serde_json::Value = serde_json::from_str(json).unwrap();
assert_eq!(OpenAiProvider::extract_output_text(&v), "hi");
}
#[test]
fn extract_tool_calls_from_response_completed_envelope() {
let json = r#"{
"type": "response.completed",
"response": {
"output": [
{
"type": "function_call",
"call_id": "call_abc",
"name": "add",
"arguments": {"a": 19, "b": 23}
}
]
}
}"#;
let v: serde_json::Value = serde_json::from_str(json).unwrap();
let calls = OpenAiProvider::extract_tool_calls(&v);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "add");
assert_eq!(calls[0].id.as_deref(), Some("call_abc"));
}
#[test]
fn parses_models_response_into_model_info() {
let json = r#"
{
"data": [
{
"id": "gpt-test",
"name": "GPT Test",
"created": 123,
"context_length": 9999,
"max_output_tokens": 111
}
]
}
"#;
let parsed: OpenAiModelsResponse = serde_json::from_str(json).unwrap();
let infos = OpenAiProvider::model_infos_from_response(parsed);
assert_eq!(infos.len(), 1);
assert_eq!(infos[0].id, "gpt-test");
assert_eq!(infos[0].display_name.as_deref(), Some("GPT Test"));
assert_eq!(infos[0].max_input_tokens, Some(9999));
assert_eq!(infos[0].max_output_tokens, Some(111));
}
}