use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use super::types::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamEvent {
#[serde(rename = "response.created")]
ResponseCreated {
response: Response,
sequence_number: i32,
},
#[serde(rename = "response.in_progress")]
ResponseInProgress {
sequence_number: i32,
},
#[serde(rename = "response.completed")]
ResponseCompleted {
response: Response,
sequence_number: i32,
},
#[serde(rename = "response.failed")]
ResponseFailed {
error: ResponseError,
sequence_number: i32,
},
#[serde(rename = "response.incomplete")]
ResponseIncomplete {
sequence_number: i32,
},
#[serde(rename = "response.output_item.added")]
ResponseOutputItemAdded {
item: OutputItem,
output_index: i32,
sequence_number: i32,
},
#[serde(rename = "response.output_item.done")]
ResponseOutputItemDone {
output_index: i32,
sequence_number: i32,
},
#[serde(rename = "response.content_part.added")]
ResponseContentPartAdded {
part: ContentPart,
output_index: i32,
content_index: i32,
sequence_number: i32,
},
#[serde(rename = "response.content_part.done")]
ResponseContentPartDone {
output_index: i32,
content_index: i32,
sequence_number: i32,
},
#[serde(rename = "response.output_text.delta")]
ResponseTextDelta {
delta: String,
item_id: String,
output_index: i32,
content_index: i32,
sequence_number: i32,
#[serde(skip_serializing_if = "Option::is_none")]
logprobs: Option<Vec<Logprob>>,
},
#[serde(rename = "response.output_text.done")]
ResponseTextDone {
text: String,
item_id: String,
output_index: i32,
content_index: i32,
sequence_number: i32,
#[serde(default)]
annotations: Vec<Annotation>,
#[serde(skip_serializing_if = "Option::is_none")]
logprobs: Option<Vec<Logprob>>,
},
#[serde(rename = "response.function_call_arguments.delta")]
ResponseFunctionCallArgumentsDelta {
delta: String,
item_id: String,
output_index: i32,
sequence_number: i32,
},
#[serde(rename = "response.function_call_arguments.done")]
ResponseFunctionCallArgumentsDone {
arguments: String,
item_id: String,
#[serde(default)]
call_id: Option<String>,
#[serde(default)]
name: Option<String>,
output_index: i32,
sequence_number: i32,
},
#[serde(rename = "response.reasoning_summary_part.added")]
ResponseReasoningSummaryPartAdded {
item_id: String,
output_index: i32,
summary_index: i32,
part: Value,
sequence_number: i32,
},
#[serde(rename = "response.reasoning_summary_part.done")]
ResponseReasoningSummaryPartDone {
item_id: String,
output_index: i32,
summary_index: i32,
part: Value,
sequence_number: i32,
},
#[serde(rename = "response.reasoning_summary_text.delta")]
ResponseReasoningSummaryTextDelta {
delta: String,
item_id: String,
output_index: i32,
summary_index: i32,
sequence_number: i32,
#[serde(skip_serializing_if = "Option::is_none")]
obfuscation: Option<String>,
},
#[serde(rename = "response.reasoning_summary_text.done")]
ResponseReasoningSummaryTextDone {
text: String,
item_id: String,
output_index: i32,
summary_index: i32,
sequence_number: i32,
},
#[serde(rename = "response.reasoning_text.delta")]
ResponseReasoningTextDelta {
delta: String,
item_id: String,
output_index: i32,
content_index: i32,
sequence_number: i32,
},
#[serde(rename = "response.reasoning_text.done")]
ResponseReasoningTextDone {
text: String,
item_id: String,
output_index: i32,
content_index: i32,
sequence_number: i32,
},
#[serde(rename = "response.refusal.delta")]
ResponseRefusalDelta {
delta: String,
item_id: String,
output_index: i32,
content_index: i32,
sequence_number: i32,
},
#[serde(rename = "response.refusal.done")]
ResponseRefusalDone {
refusal: String,
item_id: String,
output_index: i32,
content_index: i32,
sequence_number: i32,
},
#[serde(rename = "response.error")]
ResponseError {
error: ResponseError,
sequence_number: i32,
},
}
pub struct StreamAccumulator {
pub response_snapshot: Option<Response>,
text_buffers: HashMap<(usize, usize), String>,
function_arg_buffers: HashMap<usize, String>,
reasoning_buffers: HashMap<(usize, usize), String>,
}
impl StreamAccumulator {
pub fn new() -> Self {
Self {
response_snapshot: None,
text_buffers: HashMap::new(),
function_arg_buffers: HashMap::new(),
reasoning_buffers: HashMap::new(),
}
}
pub fn handle_event(&mut self, event: &StreamEvent) -> Option<String> {
match event {
StreamEvent::ResponseCreated { response, .. } => {
self.response_snapshot = Some(response.clone());
None
}
StreamEvent::ResponseTextDelta {
delta,
output_index,
content_index,
..
} => {
let key = (*output_index as usize, *content_index as usize);
let buffer = self.text_buffers.entry(key).or_default();
buffer.push_str(delta);
Some(delta.clone())
}
StreamEvent::ResponseFunctionCallArgumentsDelta {
delta,
output_index,
..
} => {
let buffer = self
.function_arg_buffers
.entry(*output_index as usize)
.or_default();
buffer.push_str(delta);
None
}
StreamEvent::ResponseReasoningTextDelta {
delta,
output_index,
content_index,
..
} => {
let key = (*output_index as usize, *content_index as usize);
let buffer = self.reasoning_buffers.entry(key).or_default();
buffer.push_str(delta);
Some(delta.clone())
}
StreamEvent::ResponseCompleted { response, .. } => {
self.response_snapshot = Some(response.clone());
None
}
_ => None,
}
}
pub fn get_final_response(&self) -> Option<&Response> {
self.response_snapshot.as_ref()
}
pub fn get_text(&self, output_index: usize, content_index: usize) -> Option<&String> {
self.text_buffers.get(&(output_index, content_index))
}
pub fn get_function_args(&self, output_index: usize) -> Option<&String> {
self.function_arg_buffers.get(&output_index)
}
pub fn get_reasoning(&self, output_index: usize, content_index: usize) -> Option<&String> {
self.reasoning_buffers.get(&(output_index, content_index))
}
}
impl Default for StreamAccumulator {
fn default() -> Self {
Self::new()
}
}
pub fn is_chunk_error_recoverable(error: &anyhow::Error) -> bool {
let error_str = format!("{:#}", error);
let error_str_lower = error_str.to_lowercase();
error_str_lower.contains("unexpected eof")
|| error_str_lower.contains("connection reset")
|| error_str_lower.contains("broken pipe")
|| error_str_lower.contains("connection closed")
|| error_str_lower.contains("incomplete")
|| error_str_lower.contains("chunk size")
|| error_str_lower.contains("dns error")
|| error_str_lower.contains("failed to lookup address")
|| error_str_lower.contains("nodename nor servname provided")
|| error_str_lower.contains("decoding response body")
|| error_str_lower.contains("reading a body from connection")
}