use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::stream::Stream;
use futures::StreamExt;
struct SSEParser<S> {
inner: S,
acc: String,
}
impl<S> SSEParser<S>
where
S: Stream<Item = Result<Bytes, String>>,
{
fn new(inner: S) -> Self {
Self {
inner,
acc: String::new(),
}
}
fn poll_block(&mut self) -> Option<Option<String>> {
if self.acc.starts_with("data: [DONE]") {
self.acc.clear();
return Some(None);
}
if let Some(pos) = self.acc.find("\n\n") {
let block = self.acc[..pos].to_string();
self.acc = self.acc[pos + 2..].to_string();
let mut payload = String::new();
for line in block.lines() {
if let Some(rest) = line.strip_prefix("data: ") {
if !payload.is_empty() {
payload.push('\n');
}
payload.push_str(rest);
}
}
if payload.is_empty() {
None
} else {
Some(Some(payload))
}
} else {
None
}
}
}
impl<S> Stream for SSEParser<S>
where
S: Stream<Item = Result<Bytes, String>> + Unpin,
{
type Item = Option<String>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if let Some(block) = self.poll_block() {
return Poll::Ready(Some(block));
}
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
self.acc.push_str(&String::from_utf8_lossy(&bytes));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Some(format!("__SSE_ERROR__:{e}"))));
}
Poll::Ready(None) => {
if self.acc.is_empty() {
return Poll::Ready(None);
}
let remaining = std::mem::take(&mut self.acc);
if remaining.contains("data: [DONE]") {
return Poll::Ready(Some(None));
}
if let Some(pos) = remaining.find("data: ") {
let payload = remaining[pos + 6..].trim().to_string();
if !payload.is_empty() {
return Poll::Ready(Some(Some(payload)));
}
}
return Poll::Ready(None);
}
Poll::Pending => return Poll::Pending,
}
}
}
}
pub struct ModelStream<E> {
inner: Pin<Box<dyn Stream<Item = Result<String, E>> + Send>>,
}
impl<E> Unpin for ModelStream<E> {}
impl<E: 'static + Send> ModelStream<E> {
pub fn new(
byte_stream: Pin<Box<dyn Stream<Item = Result<Bytes, String>> + Send>>,
parse_fn: fn(&str) -> Result<Vec<String>, String>,
err_map: fn(String) -> E,
) -> Self {
let sse = SSEParser::new(byte_stream);
let chunks = sse.flat_map(move |payload| -> Pin<Box<dyn Stream<Item = Result<String, E>> + Send>> {
match payload {
None => Box::pin(futures::stream::empty()), Some(s) if s.starts_with("__SSE_ERROR__:") => {
let msg = s.trim_start_matches("__SSE_ERROR__:").to_string();
Box::pin(futures::stream::once(async move { Err(err_map(msg)) }))
}
Some(json_str) => {
let texts: Vec<Result<String, E>> = match parse_fn(&json_str) {
Ok(t) => t.into_iter().filter(|s| !s.is_empty()).map(Ok).collect(),
Err(_) => vec![],
};
Box::pin(futures::stream::iter(texts))
}
}
});
Self {
inner: Box::pin(chunks),
}
}
}
impl<E> Stream for ModelStream<E> {
type Item = Result<String, E>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.get_mut().inner).poll_next(cx)
}
}