use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_core::Stream;
use futures_util::StreamExt;
use serde::de::DeserializeOwned;
use crate::error::OpenAIError;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ServerSentEvent {
pub event: Option<String>,
pub data: String,
pub id: Option<String>,
pub retry: Option<u64>,
}
#[derive(Debug, Default)]
pub struct SseDecoder {
buf: Vec<u8>,
event: Option<String>,
data: Vec<String>,
id: Option<String>,
retry: Option<u64>,
bom_checked: bool,
}
const MAX_SSE_BUFFER: usize = 16 * 1024 * 1024;
impl SseDecoder {
pub fn new() -> Self {
Self::default()
}
pub fn buffered_len(&self) -> usize {
self.buf.len()
}
pub fn feed(&mut self, chunk: &[u8]) -> Vec<ServerSentEvent> {
self.buf.extend_from_slice(chunk);
if !self.bom_checked && self.buf.len() >= 3 {
if self.buf.starts_with(&[0xEF, 0xBB, 0xBF]) {
self.buf.drain(..3);
}
self.bom_checked = true;
}
let mut events = Vec::new();
while let Some(pos) = self.buf.iter().position(|&b| b == b'\n' || b == b'\r') {
let term_len = if self.buf[pos] == b'\r' {
if pos + 1 == self.buf.len() {
break;
}
if self.buf[pos + 1] == b'\n' {
2
} else {
1
}
} else {
1
};
let line = String::from_utf8_lossy(&self.buf[..pos]).into_owned();
self.buf.drain(..pos + term_len);
if let Some(event) = self.process_line(&line) {
events.push(event);
}
}
events
}
pub fn flush(&mut self) -> Vec<ServerSentEvent> {
let mut events = Vec::new();
if self.buf.last() == Some(&b'\r') {
self.buf.pop();
}
if !self.buf.is_empty() {
let line = String::from_utf8_lossy(&self.buf).into_owned();
self.buf.clear();
if let Some(event) = self.process_line(&line) {
events.push(event);
}
}
if !self.data.is_empty() || self.event.is_some() {
events.push(self.dispatch());
}
events
}
fn process_line(&mut self, line: &str) -> Option<ServerSentEvent> {
if line.is_empty() {
if self.event.is_none() && self.data.is_empty() && self.retry.is_none() {
return None;
}
return Some(self.dispatch());
}
if line.starts_with(':') {
return None; }
let (field, value) = match line.split_once(':') {
Some((field, value)) => (field, value.strip_prefix(' ').unwrap_or(value)),
None => (line, ""),
};
match field {
"event" => self.event = Some(value.to_string()),
"data" => self.data.push(value.to_string()),
"id" => {
if !value.contains('\0') {
self.id = Some(value.to_string());
}
}
"retry" => {
if let Ok(retry) = value.parse() {
self.retry = Some(retry);
}
}
_ => {} }
None
}
fn dispatch(&mut self) -> ServerSentEvent {
ServerSentEvent {
event: self.event.take(),
data: std::mem::take(&mut self.data).join("\n"),
id: self.id.clone(),
retry: self.retry.take(),
}
}
}
pub struct EventStream<T> {
bytes: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
decoder: SseDecoder,
pending: VecDeque<ServerSentEvent>,
done: bool,
bytes_exhausted: bool,
_marker: std::marker::PhantomData<fn() -> T>,
}
impl<T: DeserializeOwned> EventStream<T> {
pub(crate) fn new(response: reqwest::Response) -> Self {
Self {
bytes: Box::pin(response.bytes_stream()),
decoder: SseDecoder::new(),
pending: VecDeque::new(),
done: false,
bytes_exhausted: false,
_marker: std::marker::PhantomData,
}
}
fn item_from_event(event: ServerSentEvent) -> Option<Result<T, OpenAIError>> {
if event.data.starts_with("[DONE]") {
return None;
}
let value: serde_json::Value = match serde_json::from_str(&event.data) {
Ok(value) => value,
Err(err) => return Some(Err(OpenAIError::Stream(format!(
"invalid JSON in stream: {err}"
)))),
};
if let Some(error) = value.get("error").filter(|e| !e.is_null()) {
let message = error
.get("message")
.and_then(|m| m.as_str())
.filter(|m| !m.is_empty())
.unwrap_or("An error occurred during streaming");
return Some(Err(OpenAIError::Stream(message.to_string())));
}
match serde_json::from_value(value) {
Ok(item) => Some(Ok(item)),
Err(err) => Some(Err(OpenAIError::Json(err))),
}
}
}
impl<T: DeserializeOwned> Stream for EventStream<T> {
type Item = Result<T, OpenAIError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
if this.done {
return Poll::Ready(None);
}
if let Some(event) = this.pending.pop_front() {
if event.data.is_empty() {
continue;
}
match Self::item_from_event(event) {
Some(Ok(item)) => return Poll::Ready(Some(Ok(item))),
Some(Err(err)) => {
this.done = true;
return Poll::Ready(Some(Err(err)));
}
None => {
this.done = true;
return Poll::Ready(None);
}
}
}
if this.bytes_exhausted {
this.done = true;
return Poll::Ready(None);
}
match this.bytes.as_mut().poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(chunk))) => {
this.pending.extend(this.decoder.feed(&chunk));
if this.decoder.buffered_len() > MAX_SSE_BUFFER {
this.done = true;
return Poll::Ready(Some(Err(OpenAIError::Stream(format!(
"SSE line exceeded the {MAX_SSE_BUFFER} byte buffer limit"
)))));
}
}
Poll::Ready(Some(Err(err))) => {
this.done = true;
return Poll::Ready(Some(Err(if err.is_timeout() {
OpenAIError::Timeout
} else {
OpenAIError::Connection(err.to_string())
})));
}
Poll::Ready(None) => {
this.bytes_exhausted = true;
this.pending.extend(this.decoder.flush());
}
}
}
}
}
impl<T: DeserializeOwned + Send + 'static> EventStream<T> {
pub async fn collect_all(mut self) -> Result<Vec<T>, OpenAIError> {
let mut items = Vec::new();
while let Some(item) = self.next().await {
items.push(item?);
}
Ok(items)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decodes_simple_event() {
let mut decoder = SseDecoder::new();
let events = decoder.feed(b"data: {\"x\":1}\n\n");
assert_eq!(events.len(), 1);
assert_eq!(events[0].data, "{\"x\":1}");
assert_eq!(events[0].event, None);
}
#[test]
fn decodes_event_split_across_chunks() {
let mut decoder = SseDecoder::new();
assert!(decoder.feed(b"data: {\"x\"").is_empty());
assert!(decoder.feed(b":1}\n").is_empty());
let events = decoder.feed(b"\n");
assert_eq!(events.len(), 1);
assert_eq!(events[0].data, "{\"x\":1}");
}
#[test]
fn handles_crlf_and_cr_terminators() {
let mut decoder = SseDecoder::new();
let events = decoder.feed(b"data: a\r\n\r\ndata: b\r\rdata: c\n\n");
assert_eq!(
events.iter().map(|e| e.data.as_str()).collect::<Vec<_>>(),
vec!["a", "b", "c"]
);
}
#[test]
fn crlf_split_across_chunks() {
let mut decoder = SseDecoder::new();
assert!(decoder.feed(b"data: a\r").is_empty());
let events = decoder.feed(b"\n\r\n");
assert_eq!(events.len(), 1);
assert_eq!(events[0].data, "a");
}
#[test]
fn multiline_data_joined_with_newline() {
let mut decoder = SseDecoder::new();
let events = decoder.feed(b"data: line1\ndata: line2\n\n");
assert_eq!(events[0].data, "line1\nline2");
}
#[test]
fn parses_event_id_retry_and_comments() {
let mut decoder = SseDecoder::new();
let events =
decoder.feed(b": keep-alive\nevent: message\nid: 42\nretry: 100\ndata: hi\n\n");
assert_eq!(events.len(), 1);
assert_eq!(events[0].event.as_deref(), Some("message"));
assert_eq!(events[0].id.as_deref(), Some("42"));
assert_eq!(events[0].retry, Some(100));
assert_eq!(events[0].data, "hi");
}
#[test]
fn strips_single_leading_space_only() {
let mut decoder = SseDecoder::new();
let events = decoder.feed(b"data: padded\ndata:none\n\n");
assert_eq!(events[0].data, " padded\nnone");
}
#[test]
fn strips_leading_bom() {
let mut decoder = SseDecoder::new();
let events = decoder.feed(b"\xEF\xBB\xBFdata: first\n\n");
assert_eq!(events.len(), 1);
assert_eq!(events[0].data, "first");
let events = decoder.feed(b"data: \xEF\xBB\xBFx\n\n");
assert_eq!(events[0].data, "\u{FEFF}x");
}
#[test]
fn flush_emits_trailing_event() {
let mut decoder = SseDecoder::new();
assert!(decoder.feed(b"data: tail").is_empty());
let events = decoder.flush();
assert_eq!(events.len(), 1);
assert_eq!(events[0].data, "tail");
}
}