use futures_util::{Stream, StreamExt};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
#[cfg(feature = "http")]
use axum::{body::Body, http::Response, response::IntoResponse};
#[derive(Debug)]
pub struct StreamResponse<T> {
pub stream: ReceiverStream<Result<T, String>>,
pub is_final: bool,
}
impl<T: Send + 'static> StreamResponse<T> {
pub fn new(stream: ReceiverStream<Result<T, String>>) -> Self {
Self {
stream,
is_final: false,
}
}
pub fn single(item: T) -> Self
where
T: Clone,
{
let (tx, rx) = mpsc::channel(1);
let item = item.clone();
tokio::spawn(async move {
let _ = tx.send(Ok(item)).await;
});
Self::new(ReceiverStream::new(rx))
}
pub fn final_marker() -> Self {
let (_tx, rx) = mpsc::channel(1);
Self {
stream: ReceiverStream::new(rx),
is_final: true,
}
}
}
pub fn create_stream_channel<T: Send + 'static>(
buffer_size: usize,
) -> (mpsc::Sender<Result<T, String>>, StreamResponse<T>) {
let (tx, rx) = mpsc::channel(buffer_size);
(tx, StreamResponse::new(ReceiverStream::new(rx)))
}
#[derive(Debug, Clone, serde::Serialize)]
#[serde(tag = "type")]
pub enum StreamEvent<T = serde_json::Value> {
#[serde(rename = "data")]
Data {
id: Option<String>,
event_name: Option<String>,
data: T,
},
#[serde(rename = "ping")]
Ping {
timestamp: i64,
},
#[serde(rename = "error")]
Error {
message: String,
},
#[serde(rename = "complete")]
Complete,
}
impl<T> StreamEvent<T> {
pub fn data(data: T) -> Self {
Self::Data {
id: None,
event_name: None,
data,
}
}
pub fn ping() -> Self {
Self::Ping {
timestamp: chrono::Utc::now().timestamp(),
}
}
pub fn error(message: String) -> Self {
Self::Error { message }
}
pub fn complete() -> Self {
Self::Complete
}
}
pub fn stream_to_sse<S, T, F>(
stream: S,
mapper: F,
) -> impl Stream<Item = Result<String, std::convert::Infallible>> + Send + 'static
where
S: Stream<Item = T> + Send + 'static,
F: Fn(T) -> StreamEvent + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = mpsc::channel(32);
tokio::spawn(async move {
let mut stream = Box::pin(stream);
while let Some(item) = tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_secs(30)) => None,
next = stream.next() => next,
} {
let event = mapper(item);
let data = match serde_json::to_string(&event) {
Ok(data) => data,
Err(e) => {
#[cfg(feature = "logging")]
tracing::error!(error = %e, "Failed to serialize SSE event");
serde_json::to_string(&StreamEvent::<()>::error(format!(
"Serialization error: {}",
e
)))
.unwrap_or_else(|_| r#"{"error":"Serialization failed"}"#.to_string())
}
};
let sse = format!("data: {}\n\n", data);
if tx.send(Ok(sse)).await.is_err() {
break;
}
}
let _ = tx
.send(Ok("event: complete\ndata: {}\n\n".to_string()))
.await;
});
ReceiverStream::new(rx)
}
#[cfg(feature = "http")]
impl<T> IntoResponse for StreamResponse<T>
where
T: serde::Serialize + Send + 'static,
{
fn into_response(self) -> Response<Body> {
use axum::body::Body;
use axum::http::header::{CACHE_CONTROL, CONTENT_TYPE};
let sse_stream = stream_to_sse(self.stream, |item| match item {
Ok(data) => {
StreamEvent::data(serde_json::to_value(data).unwrap_or(serde_json::Value::Null))
}
Err(err) => StreamEvent::error(err),
});
Response::builder()
.status(200)
.header(CONTENT_TYPE, "text/event-stream")
.header(CACHE_CONTROL, "no-cache")
.header("Connection", "keep-alive")
.header("X-Accel-Buffering", "no") .body(Body::from_stream(sse_stream))
.unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
#[tokio::test]
async fn test_stream_response() {
let (tx, rx) = mpsc::channel(32);
let stream = StreamResponse::new(ReceiverStream::new(rx));
tokio::spawn(async move {
let _ = tx.send(Ok("test")).await;
});
assert!(!stream.is_final);
}
#[tokio::test]
async fn test_stream_event_data() {
let event = StreamEvent::data(serde_json::json!({"key": "value"}));
match event {
StreamEvent::Data {
id,
event_name: _,
data,
} => {
assert!(id.is_none());
assert_eq!(data, serde_json::json!({"key": "value"}));
}
_ => panic!("Expected Data event"),
}
}
#[tokio::test]
async fn test_stream_event_complete() {
let event: StreamEvent<()> = StreamEvent::complete();
match event {
StreamEvent::Complete => {}
_ => panic!("Expected Complete event"),
}
}
}