use std::pin::Pin;
use async_stream::try_stream;
use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use crate::{Client, Result};
pub(crate) type EventStream = Pin<Box<dyn Stream<Item = Result<SseEvent>> + Send>>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct SseEvent {
pub event: Option<String>,
pub data: String,
}
pub(crate) fn decode<S>(input: S) -> EventStream
where
S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send + 'static,
{
Box::pin(try_stream! {
let mut stream = Box::pin(input);
let mut buffer = String::new();
let mut event_name: Option<String> = None;
let mut data_lines: Vec<String> = Vec::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(Client::map_transport_error)?;
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(line) = next_line(&mut buffer) {
if line.is_empty() {
if !data_lines.is_empty() {
yield SseEvent {
event: event_name.take(),
data: data_lines.join("\n"),
};
data_lines.clear();
}
continue;
}
if let Some(rest) = line.strip_prefix("event:") {
event_name = Some(rest.trim_start().to_string());
continue;
}
if let Some(rest) = line.strip_prefix("data:") {
data_lines.push(rest.trim_start().to_string());
}
}
}
if !buffer.is_empty() {
buffer.push('\n');
while let Some(line) = next_line(&mut buffer) {
if let Some(rest) = line.strip_prefix("event:") {
event_name = Some(rest.trim_start().to_string());
continue;
}
if let Some(rest) = line.strip_prefix("data:") {
data_lines.push(rest.trim_start().to_string());
}
}
}
if !data_lines.is_empty() {
yield SseEvent {
event: event_name.take(),
data: data_lines.join("\n"),
};
}
})
}
fn next_line(buffer: &mut String) -> Option<String> {
let index = buffer.find('\n')?;
let mut line: String = buffer.drain(..=index).collect();
if line.ends_with('\n') {
line.pop();
}
if line.ends_with('\r') {
line.pop();
}
Some(line)
}
#[cfg(test)]
mod tests {
use futures_util::{StreamExt, stream};
use super::decode;
#[tokio::test]
async fn decodes_basic_sse_frames() {
let input = stream::iter(vec![
Ok(bytes::Bytes::from_static(b"data: {\"a\":1}\n\n")),
Ok(bytes::Bytes::from_static(
b"event: update\ndata: {\"b\":2}\n\ndata: [DONE]\n\n",
)),
]);
let events = decode(input)
.collect::<Vec<_>>()
.await
.into_iter()
.map(|event| event.expect("event should decode"))
.collect::<Vec<_>>();
assert_eq!(events.len(), 3);
assert_eq!(events[0].data, "{\"a\":1}");
assert_eq!(events[1].event.as_deref(), Some("update"));
assert_eq!(events[1].data, "{\"b\":2}");
assert_eq!(events[2].data, "[DONE]");
}
}