use bytes::Bytes;
use futures_util::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::error::ProviderError;
use crate::provider::ProviderId;
#[derive(Debug, Clone)]
pub(crate) struct SseEvent {
#[allow(dead_code)]
pub(crate) event: Option<String>,
pub(crate) data: String,
}
impl SseEvent {
#[must_use]
pub(crate) fn is_openai_done(&self) -> bool {
self.data.trim() == "[DONE]"
}
}
pub(crate) struct SseStream<S> {
inner: S,
buffer: String,
provider: ProviderId,
}
impl<S> SseStream<S> {
pub(crate) fn new(inner: S, provider: ProviderId) -> Self {
Self {
inner,
buffer: String::new(),
provider,
}
}
}
impl<S> Stream for SseStream<S>
where
S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
{
type Item = Result<SseEvent, ProviderError>;
fn poll_next(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
if let Some(event) = extract_event(&mut this.buffer) {
return Poll::Ready(Some(Ok(event)));
}
let pinned = Pin::new(&mut this.inner);
match pinned.poll_next(context) {
Poll::Ready(Some(Ok(chunk))) => {
this.buffer.push_str(&String::from_utf8_lossy(&chunk));
}
Poll::Ready(Some(Err(source))) => {
return Poll::Ready(Some(Err(ProviderError::Transport {
provider: this.provider.clone(),
source,
})));
}
Poll::Ready(None) => {
if let Some(event) = extract_event(&mut this.buffer) {
return Poll::Ready(Some(Ok(event)));
}
return Poll::Ready(None);
}
Poll::Pending => return Poll::Pending,
}
}
}
}
fn extract_event(buffer: &mut String) -> Option<SseEvent> {
let boundary = find_event_boundary(buffer)?;
let raw: String = buffer.drain(..boundary.end).collect();
parse_event_block(&raw)
}
fn find_event_boundary(buffer: &str) -> Option<EventBoundary> {
let bytes = buffer.as_bytes();
for i in 0..bytes.len().saturating_sub(1) {
match bytes[i] {
b'\n' if bytes[i + 1] == b'\n' => {
return Some(EventBoundary { end: i + 2 });
}
b'\r'
if i + 3 < bytes.len()
&& bytes[i + 1] == b'\r'
&& bytes[i + 2] == b'\n'
&& bytes[i + 3] == b'\n' =>
{
return Some(EventBoundary { end: i + 4 });
}
b'\r' if bytes[i + 1] == b'\r' => {
return Some(EventBoundary { end: i + 2 });
}
_ => {}
}
}
None
}
struct EventBoundary {
end: usize,
}
fn parse_event_block(block: &str) -> Option<SseEvent> {
let mut event_name: Option<String> = None;
let mut data_lines: Vec<String> = Vec::new();
for line in block.lines() {
if let Some(value) = line.strip_prefix("event:") {
event_name = Some(value.trim().to_owned());
} else if let Some(value) = line.strip_prefix("data:") {
data_lines.push(value.trim_start().to_owned());
}
}
if data_lines.is_empty() {
return None;
}
Some(SseEvent {
event: event_name,
data: data_lines.join("\n"),
})
}