use async_trait::async_trait;
use bytes::Bytes;
use futures::{Stream, StreamExt};
use reqwest::Client;
use serde::Deserialize;
use serde_json::Value as JsonValue;
use std::pin::Pin;
use crate::{
error::ProviderError, Api, AssistantMessage, ContentBlock, Context, Model, Provider,
ProviderEvent, StopReason, StreamOptions, Usage,
};
use super::shared_client;
#[derive(Clone)]
pub struct OpenAiResponsesProvider {
client: &'static Client,
api_key: Option<String>,
base_url: Option<String>,
}
impl OpenAiResponsesProvider {
pub fn new() -> Self {
Self {
client: shared_client(),
api_key: std::env::var("OPENAI_API_KEY").ok(),
base_url: None,
}
}
#[allow(dead_code)]
pub fn with_api_key(api_key: impl Into<String>) -> Self {
Self {
client: shared_client(),
api_key: Some(api_key.into()),
base_url: None,
}
}
pub fn with_base_url_and_key(base_url: &str, api_key: Option<String>) -> Self {
Self {
client: shared_client(),
api_key,
base_url: Some(base_url.to_string()),
}
}
}
impl Default for OpenAiResponsesProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Provider for OpenAiResponsesProvider {
async fn stream(
&self,
model: &Model,
context: &Context,
options: Option<StreamOptions>,
) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
let options = options.unwrap_or_default();
let effective_base_url = self.base_url.as_deref().unwrap_or(&model.base_url);
let url = format!("{}/responses", effective_base_url);
let api_key = options
.api_key
.as_ref()
.or(self.api_key.as_ref())
.ok_or_else(|| ProviderError::MissingApiKey)?;
let input = build_input(context)?;
let mut body = serde_json::json!({
"model": model.id,
"input": input,
"stream": true,
});
if let Some(temp) = options.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max) = options.max_tokens {
body["max_tokens"] = serde_json::json!(max);
}
if !context.tools.is_empty() {
body["tools"] = build_tools(&context.tools);
}
if let Some(ref thinking_level) = options.thinking_level {
if thinking_level != &crate::ThinkingLevel::Off {
if let Some(effort) = thinking_level.as_str() {
body["reasoning"] = serde_json::json!({
"effort": effort,
});
}
}
}
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", api_key).parse().expect("valid bearer header"),
);
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().expect("valid header value"),
);
for (k, v) in &options.headers {
if let (Ok(name), Ok(value)) = (
k.parse::<reqwest::header::HeaderName>(),
v.parse::<reqwest::header::HeaderValue>(),
) {
headers.insert(name, value);
}
}
let response = self
.client
.post(&url)
.headers(headers)
.json(&body)
.send()
.await
.map_err(ProviderError::RequestFailed)?;
if !response.status().is_success() {
let status = response.status();
let body: String = response.text().await.unwrap_or_default();
return Err(ProviderError::HttpError(status.as_u16(), body));
}
let provider_name = model.provider.clone();
let model_id = model.id.clone();
let stream = response.bytes_stream().flat_map(
move |chunk: Result<Bytes, reqwest::Error>| match chunk {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes).to_string();
futures::stream::iter(parse_sse_events(&text, &provider_name, &model_id))
}
Err(e) => futures::stream::iter(vec![ProviderEvent::Error {
reason: StopReason::Error,
error: create_error_message(&e.to_string(), &provider_name, &model_id),
}]),
},
);
Ok(Box::pin(stream))
}
fn name(&self) -> &str {
"openai-responses"
}
}
fn build_input(context: &Context) -> Result<Vec<JsonValue>, ProviderError> {
let mut input = Vec::new();
if let Some(ref prompt) = context.system_prompt {
input.push(serde_json::json!({
"role": "developer",
"content": prompt,
}));
}
for msg in &context.messages {
match msg {
crate::Message::User(u) => {
let content = match &u.content {
crate::MessageContent::Text(s) => serde_json::json!(s.clone()),
crate::MessageContent::Blocks(blocks) => blocks_to_json(blocks)?,
};
input.push(serde_json::json!({
"role": "user",
"content": content,
}));
}
crate::Message::Assistant(a) => {
let content = blocks_to_json(&a.content)?;
input.push(serde_json::json!({
"role": "assistant",
"content": content,
}));
}
crate::Message::ToolResult(t) => {
let content = blocks_to_json(&t.content)?;
input.push(serde_json::json!({
"role": "user",
"content": content,
}));
}
}
}
Ok(input)
}
fn blocks_to_json(blocks: &[ContentBlock]) -> Result<JsonValue, ProviderError> {
if blocks.len() == 1 {
if let Some(text) = blocks[0].as_text() {
return Ok(JsonValue::String(text.to_string()));
}
}
let items: Result<Vec<_>, _> = blocks
.iter()
.map(|block| match block {
ContentBlock::Text(t) => Ok(serde_json::json!({
"type": "output_text",
"text": t.text,
})),
ContentBlock::ToolCall(tc) => Ok(serde_json::json!({
"type": "function_call",
"id": tc.id,
"name": tc.name,
"arguments": tc.arguments.to_string(),
})),
ContentBlock::Thinking(th) => Ok(serde_json::json!({
"type": "reasoning",
"summary": [
{
"type": "summary_text",
"text": th.thinking,
}
]
})),
ContentBlock::Image(img) => Ok(serde_json::json!({
"type": "input_image",
"data": format!("data:{};base64,{}", img.mime_type, img.data),
"mime_type": img.mime_type,
})),
ContentBlock::Unknown(_) => Err(ProviderError::InvalidResponse(
"Unknown content block type".into(),
)),
})
.collect();
Ok(serde_json::json!(items?))
}
fn build_tools(tools: &[crate::Tool]) -> JsonValue {
let items: Vec<_> = tools
.iter()
.map(|tool| {
serde_json::json!({
"type": "function",
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
})
})
.collect();
serde_json::json!(items)
}
fn parse_sse_events(text: &str, provider: &str, model_id: &str) -> Vec<ProviderEvent> {
let mut events = Vec::new();
let mut partial_message = AssistantMessage::new(Api::OpenAiResponses, provider, model_id);
let mut current_text_index: Option<usize> = None;
let mut current_tool_call_index: Option<usize> = None;
let mut accumulated_usage = Usage::default();
let estimated_events = text
.split('\n')
.filter(|l| l.starts_with("event: ") || l.starts_with("data: "))
.count();
events.reserve(estimated_events);
for line in text.split('\n') {
let line = line.trim_end_matches('\r');
if line.is_empty() {
continue;
}
if line.starts_with("event: ") {
let event_name = line[7..].trim();
match event_name {
"response.created"
| "response.output_item.added"
| "response.content_part.added"
| "response.output_text.delta"
| "response.function_call_arguments.delta"
| "response.completed"
| "response.output_text.done"
| "response.reasoning.done" => {
}
_ => {}
}
continue;
}
if !line.starts_with("data: ") {
continue;
}
let data = line[6..].trim();
if data.is_empty() || data == "[DONE]" {
continue;
}
if let Ok(event) = serde_json::from_str::<ResponsesEvent>(data) {
match event {
ResponsesEvent::ResponseCreatedData { response } => {
if let Some(id) = response.id {
partial_message.response_id = Some(id);
}
events.push(ProviderEvent::Start {
partial: partial_message.clone(),
});
}
ResponsesEvent::OutputItemAdded { output_item } => {
match output_item.r#type.as_str() {
"message" => {
events.push(ProviderEvent::ToolCallStart {
content_index: output_item.index,
tool_call_id: output_item.id.clone(),
partial: partial_message.clone(),
});
current_tool_call_index = Some(output_item.index);
}
"function_call" => {
events.push(ProviderEvent::ToolCallStart {
content_index: output_item.index,
tool_call_id: output_item.id.clone(),
partial: partial_message.clone(),
});
current_tool_call_index = Some(output_item.index);
}
"reasoning" => {
events.push(ProviderEvent::ThinkingStart {
content_index: output_item.index,
partial: partial_message.clone(),
});
}
_ => {}
}
}
ResponsesEvent::ContentPartAdded { content_part } => {
match content_part.r#type.as_str() {
"output_text" => {
events.push(ProviderEvent::TextStart {
content_index: content_part.index,
partial: partial_message.clone(),
});
current_text_index = Some(content_part.index);
}
"function_call" => {
events.push(ProviderEvent::ToolCallStart {
content_index: content_part.index,
tool_call_id: None,
partial: partial_message.clone(),
});
current_tool_call_index = Some(content_part.index);
}
_ => {}
}
}
ResponsesEvent::OutputTextDelta { output_text: delta } => {
let content_idx = delta.content_index.or(current_text_index).unwrap_or(0);
events.push(ProviderEvent::TextDelta {
content_index: content_idx,
delta: delta.slice.unwrap_or_default(),
partial: partial_message.clone(),
});
if current_text_index.is_none() {
current_text_index = Some(content_idx);
}
}
ResponsesEvent::FunctionCallArgumentsDelta {
function_call: delta,
} => {
let content_idx = delta.content_index.or(current_tool_call_index).unwrap_or(0);
events.push(ProviderEvent::ToolCallDelta {
content_index: content_idx,
delta: delta.arguments.unwrap_or_default(),
partial: partial_message.clone(),
});
if current_tool_call_index.is_none() {
current_tool_call_index = Some(content_idx);
}
}
ResponsesEvent::OutputTextDone { output_text } => {
if let Some(idx) = current_text_index {
let text_content = output_text
.content
.map(|c| c.text.unwrap_or_default())
.unwrap_or_default();
events.push(ProviderEvent::TextEnd {
content_index: idx,
content: text_content,
partial: partial_message.clone(),
});
current_text_index = None;
}
}
ResponsesEvent::ReasoningDone { reasoning } => {
if let Some(summary) = reasoning.summary {
for item in summary {
if item.r#type == "summary_text" {
events.push(ProviderEvent::ThinkingEnd {
content_index: 0,
content: item.text.unwrap_or_default(),
partial: partial_message.clone(),
});
}
}
}
}
ResponsesEvent::ResponseWithUsage { response } => {
let is_incomplete = response.incomplete_details.is_some();
if let Some(usage) = response.usage {
accumulated_usage.input = usage.input_tokens;
accumulated_usage.output = usage.output_tokens;
accumulated_usage.total_tokens = usage.total_tokens;
if let Some(cached) = usage.input_tokens_details {
accumulated_usage.cache_read = cached.cached_tokens;
}
}
let stop_reason = if is_incomplete {
if let Some(incomplete) = response.incomplete_details {
match incomplete.reason.as_str() {
"max_output_tokens" => StopReason::Length,
"content_filter" => StopReason::Error,
_ => StopReason::Stop,
}
} else {
StopReason::Stop
}
} else {
StopReason::Stop
};
let mut done_msg = partial_message.clone();
done_msg.usage = accumulated_usage.clone();
events.push(ProviderEvent::Done {
reason: stop_reason,
message: done_msg,
});
}
_ => {}
}
}
}
events
}
fn create_error_message(msg: &str, provider: &str, model_id: &str) -> AssistantMessage {
let mut message = AssistantMessage::new(Api::OpenAiResponses, provider, model_id);
message.stop_reason = StopReason::Error;
message.error_message = Some(msg.to_string());
message
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ResponsesEvent {
ResponseWithUsage {
response: ResponseWithUsageData,
},
OutputItemAdded {
output_item: OutputItem,
},
ContentPartAdded {
content_part: ContentPart,
},
OutputTextDelta {
output_text: TextDelta,
},
FunctionCallArgumentsDelta {
function_call: FunctionCallDelta,
},
OutputTextDone {
output_text: OutputTextDone,
},
ReasoningDone {
reasoning: ReasoningDone,
},
ResponseCreatedData {
response: ResponseCreatedData,
},
#[allow(dead_code)] Unknown(JsonValue),
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] struct ResponseCreatedData {
id: Option<String>,
#[serde(rename = "object")]
object: Option<String>,
status: Option<String>,
#[serde(rename = "model")]
model: Option<String>,
created_at: Option<i64>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] struct OutputItem {
index: usize,
#[serde(rename = "type")]
r#type: String,
id: Option<String>,
status: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ContentPart {
index: usize,
#[serde(rename = "type")]
r#type: String,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] struct TextDelta {
content_index: Option<usize>,
output_index: Option<usize>,
slice: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] struct FunctionCallDelta {
content_index: Option<usize>,
output_index: Option<usize>,
name: Option<String>,
arguments: Option<String>,
call_id: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] struct OutputTextDone {
content_index: Option<usize>,
output_index: Option<usize>,
content: Option<TextContent>,
}
#[derive(Debug, Deserialize)]
struct TextContent {
text: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] struct ReasoningDone {
content_index: Option<usize>,
output_index: Option<usize>,
summary: Option<Vec<SummaryItem>>,
}
#[derive(Debug, Deserialize)]
struct SummaryItem {
#[serde(rename = "type")]
r#type: String,
text: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] struct ResponseWithUsageData {
id: Option<String>,
status: Option<String>,
usage: Option<UsageData>,
incomplete_details: Option<IncompleteDetails>,
}
#[derive(Debug, Deserialize)]
struct IncompleteDetails {
reason: String,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] struct UsageData {
input_tokens: usize,
output_tokens: usize,
total_tokens: usize,
#[serde(rename = "input_tokens_details")]
input_tokens_details: Option<InputTokensDetails>,
}
#[derive(Debug, Deserialize)]
struct InputTokensDetails {
#[serde(rename = "cached_tokens")]
cached_tokens: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Context, Message, Model, TextContent};
use serde_json::json;
fn create_test_model() -> Model {
Model::new(
"gpt-4o",
"GPT-4o",
Api::OpenAiResponses,
"openai-responses",
"https://api.openai.com/v1",
)
}
fn create_test_context() -> Context {
Context::new()
}
#[test]
fn test_provider_name() {
let provider = OpenAiResponsesProvider::new();
assert_eq!(provider.name(), "openai-responses");
}
#[test]
fn test_build_input_with_text() {
let mut context = create_test_context();
context.add_message(Message::user("Hello, world!"));
let input = build_input(&context).unwrap();
assert_eq!(input.len(), 1);
assert_eq!(input[0]["role"], "user");
assert_eq!(input[0]["content"], "Hello, world!");
}
#[test]
fn test_build_input_with_system_prompt() {
let mut context = create_test_context();
context.set_system_prompt("You are a helpful assistant.");
context.add_message(Message::user("Hi!"));
let input = build_input(&context).unwrap();
assert_eq!(input.len(), 2);
assert_eq!(input[0]["role"], "developer");
assert_eq!(input[0]["content"], "You are a helpful assistant.");
}
#[test]
fn test_build_input_with_multiple_messages() {
let mut context = create_test_context();
context.add_message(Message::user("First message"));
context.add_message(Message::user("Second message"));
let input = build_input(&context).unwrap();
assert_eq!(input.len(), 2);
}
#[test]
fn test_blocks_to_json_text() {
let blocks = vec![ContentBlock::Text(TextContent::new("Hello"))];
let result = blocks_to_json(&blocks).unwrap();
assert_eq!(result, "Hello");
}
#[test]
fn test_blocks_to_json_multiple_blocks() {
let blocks = vec![
ContentBlock::Text(TextContent::new("Hello")),
ContentBlock::Text(TextContent::new("World")),
];
let result = blocks_to_json(&blocks).unwrap();
assert!(result.is_array());
assert_eq!(result.as_array().unwrap().len(), 2);
}
#[test]
fn test_build_tools() {
let tools = vec![crate::Tool {
name: "get_weather".to_string(),
description: "Get weather for a location".to_string(),
parameters: json!({
"type": "object",
"properties": {
"location": {"type": "string"}
}
}),
}];
let result = build_tools(&tools);
assert!(result.is_array());
let tool = &result[0];
assert_eq!(tool["type"], "function");
assert_eq!(tool["name"], "get_weather");
}
#[test]
fn test_parse_response_created_event() {
let sse_data =
r#"data: {"response":{"id":"resp_123","status":"in_progress","model":"gpt-4o"}}"#;
let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
assert!(!events.is_empty());
if let ProviderEvent::Start { partial } = &events[0] {
assert_eq!(partial.api, Api::OpenAiResponses);
}
}
#[test]
fn test_parse_output_item_added_event() {
let sse_data = r#"data: {"output_item":{"index":0,"id":"msg_123","type":"message","status":"in_progress"}}"#;
let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
assert!(events
.iter()
.any(|e| matches!(e, ProviderEvent::ToolCallStart { .. })));
}
#[test]
fn test_parse_text_delta_event() {
let sse_data = r#"data: {"output_text":{"content_index":0,"slice":"Hello"}}"#;
let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
assert!(events
.iter()
.any(|e| matches!(e, ProviderEvent::TextDelta { .. })));
}
#[test]
fn test_parse_function_call_delta_event() {
let sse_data = r#"data: {"function_call":{"content_index":0,"arguments":"{\"location"}}"#;
let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
assert!(events
.iter()
.any(|e| matches!(e, ProviderEvent::ToolCallDelta { .. })));
}
#[test]
fn test_parse_completed_event_with_usage() {
let sse_data = r#"data: {"response":{"id":"resp_123","status":"completed","usage":{"input_tokens":100,"output_tokens":50,"total_tokens":150}}}"#;
let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
assert!(events.iter().any(|e| matches!(
e,
ProviderEvent::Done {
reason: StopReason::Stop,
..
}
)));
}
#[test]
fn test_parse_reasoning_event() {
let sse_data = r#"data: {"reasoning":{"content_index":0,"summary":[{"type":"summary_text","text":"Thinking process..."}]}}"#;
let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
assert!(events
.iter()
.any(|e| matches!(e, ProviderEvent::ThinkingEnd { .. })));
}
#[test]
fn test_provider_with_api_key() {
let provider = OpenAiResponsesProvider::with_api_key("sk-test-key");
assert_eq!(provider.name(), "openai-responses");
}
#[test]
fn test_multiple_events_in_stream() {
let sse_data = r#"data: {"response":{"id":"resp_123"}}
data: {"output_item":{"index":0,"type":"message"}}
data: {"output_text":{"slice":"Hello"}}
data: {"response":{"status":"completed"}}"#;
let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
assert!(events.len() >= 4);
}
#[test]
fn test_invalid_json_skipped() {
let sse_data = r#"event: response.created
data: {invalid json here}
event: response.created
data: {"response":{"id":"resp_123"}}"#;
let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
assert!(!events.is_empty());
}
#[test]
fn test_done_marker() {
let sse_data = r#"event: response.created
data: {"response":{"id":"resp_123"}}
data: [DONE]"#;
let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
assert!(events.len() <= 2);
}
#[test]
fn test_incomplete_response() {
let sse_data = r#"data: {"response":{"id":"resp_123","incomplete_details":{"reason":"max_output_tokens"}}}"#;
let events = parse_sse_events(sse_data, "openai-responses", "gpt-4o");
assert!(events.iter().any(|e| matches!(
e,
ProviderEvent::Done {
reason: StopReason::Length,
..
}
)));
}
}