use super::super::error::{ConvertError, Result};
use super::super::ir::{IrDelta, IrResponse};
use super::event::SseEvent;
use crate::provider::Endpoint;
use eventsource_stream::Eventsource;
use futures_util::StreamExt;
use serde_json::Value;
use std::collections::BTreeMap;
#[derive(Default)]
struct ResponsesState {
response_id: Option<String>,
model: Option<String>,
output_items: BTreeMap<usize, ResponseOutputItem>,
}
#[derive(Default)]
struct ResponseOutputItem {
item_type: Option<String>,
id: Option<String>,
call_id: Option<String>,
name: Option<String>,
status: Option<String>,
text: String,
reasoning_summary: BTreeMap<usize, String>,
reasoning_content: BTreeMap<usize, String>,
arguments: String,
}
#[derive(Clone, Debug, Default)]
pub struct SseMetadata {
pub response_id: Option<String>,
pub model: Option<String>,
}
pub struct SseAccumulator {
endpoint: Endpoint,
response: IrResponse,
responses: ResponsesState,
}
impl SseAccumulator {
pub fn new(endpoint: Endpoint) -> Self {
Self {
endpoint,
response: IrResponse::default(),
responses: ResponsesState::default(),
}
}
pub fn push_value(&mut self, value: &Value) -> Vec<IrDelta> {
let deltas = match self.endpoint {
Endpoint::ChatCompletions => {
self.observe_chat_chunk(value);
crate::chat::delta_from_chat_chunk(value)
}
Endpoint::Responses => self.delta_from_responses_event(value),
Endpoint::Messages => {
self.observe_messages_event(value);
crate::messages::delta_from_messages_event(value)
}
};
for delta in deltas.iter().cloned() {
self.response.push_delta(delta);
}
deltas
}
pub fn finish(self) -> IrResponse {
self.response
}
pub fn metadata(&self) -> SseMetadata {
SseMetadata {
response_id: self.responses.response_id.clone(),
model: self.responses.model.clone(),
}
}
fn delta_from_responses_event(&mut self, value: &Value) -> Vec<IrDelta> {
self.observe_responses_response(value);
self.observe_responses_output_item(value);
self.observe_responses_part(value);
let mut deltas = crate::responses::delta_from_responses_event(value);
self.observe_responses_deltas(value, &deltas);
for delta in &mut deltas {
if let IrDelta::ToolCall { index, id, name, .. } = delta {
if let Some(item) = self.responses.output_items.get(index) {
if id.is_none() {
*id = item.call_id.clone().or_else(|| item.id.clone());
}
if name.is_none() {
*name = item.name.clone();
}
}
}
}
deltas
}
fn observe_chat_chunk(&mut self, value: &Value) {
if self.responses.response_id.is_none() {
self.responses.response_id = value.get("id").and_then(Value::as_str).map(str::to_string);
}
if self.responses.model.is_none() {
self.responses.model = value.get("model").and_then(Value::as_str).map(str::to_string);
}
}
fn observe_messages_event(&mut self, value: &Value) {
if !matches!(value.get("type").and_then(Value::as_str), Some("message_start")) {
return;
}
let Some(message) = value.get("message") else {
return;
};
if self.responses.response_id.is_none() {
self.responses.response_id = message.get("id").and_then(Value::as_str).map(str::to_string);
}
if self.responses.model.is_none() {
self.responses.model = message.get("model").and_then(Value::as_str).map(str::to_string);
}
}
fn observe_responses_response(&mut self, value: &Value) {
let Some(response) = value.get("response") else {
return;
};
if self.responses.response_id.is_none() {
self.responses.response_id = response.get("id").and_then(Value::as_str).map(str::to_string);
}
if self.responses.model.is_none() {
self.responses.model = response.get("model").and_then(Value::as_str).map(str::to_string);
}
if let Some(usage) = crate::ir::usage_from_openai(response) {
self.response.usage = Some(usage);
}
}
fn observe_responses_output_item(&mut self, value: &Value) {
match value.get("type").and_then(Value::as_str) {
Some("response.output_item.added") | Some("response.output_item.done") => {}
_ => return,
}
let Some(index) = value.get("output_index").and_then(Value::as_u64).map(|v| v as usize) else {
return;
};
let Some(item) = value.get("item") else {
return;
};
let entry = self.responses.output_items.entry(index).or_default();
if entry.item_type.is_none() {
entry.item_type = item.get("type").and_then(Value::as_str).map(str::to_string);
}
if entry.status.is_none() {
entry.status = item.get("status").and_then(Value::as_str).map(str::to_string);
}
if entry.id.is_none() {
entry.id = item
.get("id")
.or_else(|| value.get("item_id"))
.and_then(Value::as_str)
.map(str::to_string);
}
if entry.call_id.is_none() {
entry.call_id = item.get("call_id").and_then(Value::as_str).map(str::to_string);
}
if entry.name.is_none() {
entry.name = item.get("name").and_then(Value::as_str).map(str::to_string);
}
if let Some(arguments) = item
.get("arguments")
.or_else(|| item.get("input"))
.and_then(Value::as_str)
{
entry.arguments = arguments.to_string();
}
}
fn observe_responses_part(&mut self, value: &Value) {
let Some(index) = value.get("output_index").and_then(Value::as_u64).map(|v| v as usize) else {
return;
};
let Some(entry) = self.responses.output_items.get_mut(&index) else {
return;
};
match value.get("type").and_then(Value::as_str) {
Some("response.output_text.done") => {
if let Some(text) = value.get("text").and_then(Value::as_str) {
entry.text = text.to_string();
}
}
Some("response.reasoning_summary_text.done") => {
if let (Some(summary_index), Some(text)) = (
value.get("summary_index").and_then(Value::as_u64).map(|v| v as usize),
value.get("text").and_then(Value::as_str),
) {
entry.reasoning_summary.insert(summary_index, text.to_string());
}
}
Some("response.function_call_arguments.done") | Some("response.custom_tool_call_input.done") => {
if let Some(arguments) = value
.get("arguments")
.or_else(|| value.get("input"))
.and_then(Value::as_str)
{
entry.arguments = arguments.to_string();
}
}
_ => {}
}
}
fn observe_responses_deltas(&mut self, value: &Value, deltas: &[IrDelta]) {
let Some(index) = value.get("output_index").and_then(Value::as_u64).map(|v| v as usize) else {
return;
};
let entry = self.responses.output_items.entry(index).or_default();
for delta in deltas {
match delta {
IrDelta::Text(text) => entry.text.push_str(text),
IrDelta::Reasoning(text) => {
let target = match value.get("type").and_then(Value::as_str) {
Some("response.reasoning_summary_text.delta") => value
.get("summary_index")
.and_then(Value::as_u64)
.map(|v| v as usize)
.map(|i| entry.reasoning_summary.entry(i).or_default()),
Some("response.reasoning_text.delta") => value
.get("content_index")
.and_then(Value::as_u64)
.map(|v| v as usize)
.map(|i| entry.reasoning_content.entry(i).or_default()),
_ => None,
};
if let Some(buf) = target {
buf.push_str(text);
}
}
IrDelta::ToolCall { arguments_delta, .. } => entry.arguments.push_str(arguments_delta),
_ => {}
}
}
}
}
pub async fn accumulate(endpoint: Endpoint, resp: reqwest::Response) -> Result<IrResponse> {
let mut acc = SseAccumulator::new(endpoint);
let mut stream = resp.bytes_stream().eventsource();
while let Some(item) = stream.next().await {
let ev = item.map_err(|e| ConvertError::sse(e.to_string()))?;
let event = SseEvent::from(ev);
if event.is_done() {
break;
}
let value = event
.json
.as_ref()
.ok_or_else(|| ConvertError::sse("expected JSON SSE payload"))?;
acc.push_value(value);
}
Ok(acc.finish())
}