use super::codec::encode_event;
use super::event::SseEvent;
use crate::error::Result;
use bytes::Bytes;
use eventsource_stream::Eventsource;
use futures_util::{Stream, StreamExt};
use serde_json::Value;
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
type EventStream = Pin<Box<dyn Stream<Item = std::io::Result<SseEvent>> + Send>>;
type ByteStream = Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Send>>;
pub type ObserverSender = mpsc::UnboundedSender<ObserverMsg>;
pub type ObserverReceiver = mpsc::UnboundedReceiver<ObserverMsg>;
#[derive(Debug)]
pub enum ObserverMsg {
From(Bytes),
Parsed(Option<Value>),
Transformed(Option<Value>),
To(Bytes),
Done,
Error(String),
}
pub fn observer_channel() -> (ObserverSender, ObserverReceiver) {
mpsc::unbounded_channel()
}
pub trait EventTransformer: Send {
fn transform(&mut self, event: SseEvent) -> Result<Vec<SseEvent>>;
fn finish(&mut self) -> Result<Vec<SseEvent>> {
Ok(Vec::new())
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct TapKinds {
pub from: bool,
pub parsed: bool,
pub transformed: bool,
pub to: bool,
}
impl TapKinds {
pub const fn all() -> Self {
Self {
from: true,
parsed: true,
transformed: true,
to: true,
}
}
pub const fn parsed_only() -> Self {
Self {
from: false,
parsed: true,
transformed: false,
to: false,
}
}
}
pub struct SsePipeline {
source: ByteStream,
tap: Option<Arc<ObserverSender>>,
tap_kinds: TapKinds,
transformers: Vec<Box<dyn EventTransformer>>,
}
impl SsePipeline {
pub fn from_stream<S>(source: S) -> Self
where
S: Stream<Item = io::Result<Bytes>> + Send + 'static,
{
Self {
source: Box::pin(source),
tap: None,
tap_kinds: TapKinds::default(),
transformers: Vec::new(),
}
}
pub fn from_response(resp: reqwest::Response) -> Self {
Self::from_stream(resp.bytes_stream().map(|item| item.map_err(io::Error::other)))
}
pub fn with_transformer<T>(mut self, transformer: T) -> Self
where
T: EventTransformer + 'static,
{
self.transformers.push(Box::new(transformer));
self
}
pub fn with_tap(mut self, tap: ObserverSender, kinds: TapKinds) -> Self {
self.tap = Some(Arc::new(tap));
self.tap_kinds = kinds;
self
}
pub fn with_tap_all(self, tap: ObserverSender) -> Self {
self.with_tap(tap, TapKinds::all())
}
pub fn with_tap_parsed(self, tap: ObserverSender) -> Self {
self.with_tap(tap, TapKinds::parsed_only())
}
pub fn run(self) -> ByteStream {
let source: EventStream = if self.tap_kinds.from {
let tap = self.tap.clone();
Box::pin(self.source.map(move |result| match result {
Ok(bytes) => {
if let Some(ref tap) = tap {
let _ = tap.send(ObserverMsg::From(bytes.clone()));
}
Ok(bytes)
}
Err(err) => Err(err),
}))
.eventsource()
.map(|item| match item {
Ok(event) => Ok(SseEvent::from(event)),
Err(err) => Err(io::Error::other(err.to_string())),
})
.boxed()
} else {
self
.source
.eventsource()
.map(|item| match item {
Ok(event) => Ok(SseEvent::from(event)),
Err(err) => Err(io::Error::other(err.to_string())),
})
.boxed()
};
Box::pin(StreamWithFinalizer::new(
PipelineStream::new(source, self.transformers, self.tap, self.tap_kinds),
finalize_tap,
))
}
}
struct PipelineStream {
source: EventStream,
transformers: Vec<Box<dyn EventTransformer>>,
tap: Option<Arc<ObserverSender>>,
tap_kinds: TapKinds,
pending: VecDeque<std::io::Result<Bytes>>,
source_done: bool,
}
impl PipelineStream {
fn new(
source: EventStream,
transformers: Vec<Box<dyn EventTransformer>>,
tap: Option<Arc<ObserverSender>>,
tap_kinds: TapKinds,
) -> Self {
Self {
source,
transformers,
tap,
tap_kinds,
pending: VecDeque::new(),
source_done: false,
}
}
#[inline]
fn send_tap(&self, msg: ObserverMsg) {
if let Some(ref tap) = self.tap {
let _ = tap.send(msg);
}
}
#[inline]
fn send_tap_parsed(&self, value: Option<Value>) {
if self.tap_kinds.parsed {
self.send_tap(ObserverMsg::Parsed(value));
}
}
#[inline]
fn send_tap_transformed(&self, value: Option<Value>) {
if self.tap_kinds.transformed {
self.send_tap(ObserverMsg::Transformed(value));
}
}
#[inline]
fn send_tap_to(&self, bytes: Bytes) {
if self.tap_kinds.to {
self.send_tap(ObserverMsg::To(bytes));
}
}
fn process_event(&mut self, event: SseEvent) -> std::io::Result<()> {
self.send_tap_parsed(event.json.clone());
let transformed = self.apply_transformers(vec![event], 0)?;
for event in transformed {
self.send_tap_transformed(event.json.clone());
let encoded = encode_event(&event);
if !encoded.is_empty() {
self.send_tap_to(encoded.clone());
self.pending.push_back(Ok(encoded));
}
}
Ok(())
}
fn apply_transformers(&mut self, mut events: Vec<SseEvent>, start: usize) -> std::io::Result<Vec<SseEvent>> {
for idx in start..self.transformers.len() {
let mut next = Vec::new();
for event in events {
next.extend(self.transformers[idx].transform(event).map_err(std::io::Error::other)?);
}
events = next;
}
Ok(events)
}
fn finish_transformers(&mut self) -> std::io::Result<()> {
for idx in 0..self.transformers.len() {
let events = self.transformers[idx].finish().map_err(std::io::Error::other)?;
for event in self.apply_transformers(events, idx + 1)? {
self.send_tap_transformed(event.json.clone());
let encoded = encode_event(&event);
if !encoded.is_empty() {
self.send_tap_to(encoded.clone());
self.pending.push_back(Ok(encoded));
}
}
}
Ok(())
}
fn signal_error(&self, err: &std::io::Error) {
self.send_tap(ObserverMsg::Error(err.to_string()));
}
}
impl Stream for PipelineStream {
type Item = std::io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if let Some(item) = self.pending.pop_front() {
return Poll::Ready(Some(item));
}
if self.source_done {
return Poll::Ready(None);
}
match self.source.as_mut().poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(event))) => {
if let Err(err) = self.process_event(event) {
self.signal_error(&err);
self.pending.push_back(Err(err));
self.source_done = true;
}
}
Poll::Ready(Some(Err(err))) => {
self.signal_error(&err);
self.pending.push_back(Err(err));
self.source_done = true;
}
Poll::Ready(None) => {
if let Err(err) = self.finish_transformers() {
self.signal_error(&err);
self.pending.push_back(Err(err));
}
self.source_done = true;
}
}
}
}
}
fn finalize_tap(stream: &mut PipelineStream) {
stream.send_tap(ObserverMsg::Done);
}
struct StreamWithFinalizer<S, F>
where
S: Stream + Unpin,
F: FnOnce(&mut S) + Send + 'static,
{
inner: S,
fin: Option<F>,
}
impl<S, F> StreamWithFinalizer<S, F>
where
S: Stream + Unpin,
F: FnOnce(&mut S) + Send + 'static,
{
fn new(inner: S, fin: F) -> Self {
Self { inner, fin: Some(fin) }
}
}
impl<S, F> Stream for StreamWithFinalizer<S, F>
where
S: Stream + Unpin,
F: FnOnce(&mut S) + Send + 'static + Unpin,
{
type Item = S::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let poll = Pin::new(&mut self.inner).poll_next(cx);
if let Poll::Ready(None) = &poll {
if let Some(fin) = self.fin.take() {
fin(&mut self.inner);
}
}
poll
}
}
impl<S, F> Drop for StreamWithFinalizer<S, F>
where
S: Stream + Unpin,
F: FnOnce(&mut S) + Send + 'static,
{
fn drop(&mut self) {
if let Some(fin) = self.fin.take() {
fin(&mut self.inner);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use bytes::BytesMut;
use futures_util::{stream, StreamExt};
struct AppendTransformer(&'static str);
impl EventTransformer for AppendTransformer {
fn transform(&mut self, mut event: SseEvent) -> Result<Vec<SseEvent>> {
if !event.is_done() {
event.data.push_str(self.0);
}
Ok(vec![event])
}
}
#[test]
fn pipeline_applies_transformers_in_order() {
let (tx, mut rx) = observer_channel();
let body = futures::executor::block_on(async move {
let body = SsePipeline::from_stream(stream::iter(vec![
Ok(Bytes::from_static(b"data: hello\n\n")),
Ok(Bytes::from_static(b"data: [DONE]\n\n")),
]))
.with_transformer(AppendTransformer("-a"))
.with_transformer(AppendTransformer("-b"))
.with_tap_all(tx)
.run()
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.unwrap()
.into_iter()
.fold(BytesMut::new(), |mut out, chunk| {
out.extend_from_slice(&chunk);
out
})
.freeze();
let mut msgs = Vec::new();
while let Ok(msg) = rx.try_recv() {
msgs.push(msg);
}
assert!(msgs.iter().any(|m| matches!(m, ObserverMsg::Done)));
let to_count = msgs.iter().filter(|m| matches!(m, ObserverMsg::To(_))).count();
assert_eq!(to_count, 2);
body
});
let text = std::str::from_utf8(&body).unwrap();
assert!(text.contains("data: hello-a-b"));
assert!(text.contains("data: [DONE]"));
}
#[test]
fn pipeline_tees_raw_bytes_when_enabled() {
let (tx, mut rx) = observer_channel();
futures::executor::block_on(async move {
let _ = SsePipeline::from_stream(stream::iter(vec![
Ok(Bytes::from_static(b"data: hello\n\n")),
Ok(Bytes::from_static(b"data: [DONE]\n\n")),
]))
.with_tap_all(tx)
.run()
.collect::<Vec<_>>()
.await;
});
let mut from_count = 0;
while let Ok(msg) = rx.try_recv() {
if matches!(msg, ObserverMsg::From(_)) {
from_count += 1;
}
}
assert_eq!(from_count, 2);
}
#[test]
fn pipeline_parsed_only_mask_suppresses_other_kinds() {
let (tx, mut rx) = observer_channel();
futures::executor::block_on(async move {
let _ = SsePipeline::from_stream(stream::iter(vec![
Ok(Bytes::from_static(b"data: {\"x\":1}\n\n")),
Ok(Bytes::from_static(b"data: [DONE]\n\n")),
]))
.with_tap_parsed(tx)
.run()
.collect::<Vec<_>>()
.await;
});
let mut from = 0;
let mut parsed = 0;
let mut transformed = 0;
let mut to = 0;
let mut done = 0;
while let Ok(msg) = rx.try_recv() {
match msg {
ObserverMsg::From(_) => from += 1,
ObserverMsg::Parsed(_) => parsed += 1,
ObserverMsg::Transformed(_) => transformed += 1,
ObserverMsg::To(_) => to += 1,
ObserverMsg::Done => done += 1,
ObserverMsg::Error(_) => {}
}
}
assert_eq!(from, 0);
assert_eq!(transformed, 0);
assert_eq!(to, 0);
assert_eq!(done, 1);
assert!(parsed >= 1);
}
}