use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use asupersync::stream::Stream;
use crate::response::{Response, ResponseBody, StatusCode};
#[derive(Debug, Clone)]
pub struct SseEvent {
data: Option<String>,
event_type: Option<String>,
id: Option<String>,
retry: Option<u64>,
comment: Option<String>,
}
impl SseEvent {
#[must_use]
pub fn new(data: impl Into<String>) -> Self {
Self {
data: Some(data.into()),
event_type: None,
id: None,
retry: None,
comment: None,
}
}
#[must_use]
pub fn message(data: impl Into<String>) -> Self {
Self::new(data)
}
#[must_use]
pub fn comment(comment: impl Into<String>) -> Self {
Self {
data: None,
event_type: None,
id: None,
retry: None,
comment: Some(comment.into()),
}
}
#[must_use]
pub fn event_type(mut self, event_type: impl Into<String>) -> Self {
self.event_type = Some(event_type.into());
self
}
#[must_use]
pub fn id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
#[must_use]
pub fn retry_ms(mut self, milliseconds: u64) -> Self {
self.retry = Some(milliseconds);
self
}
#[must_use]
pub fn retry(self, duration: Duration) -> Self {
self.retry_ms(duration.as_millis() as u64)
}
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut output = Vec::with_capacity(256);
if let Some(ref comment) = self.comment {
for line in comment.lines() {
output.extend_from_slice(b": ");
output.extend_from_slice(line.as_bytes());
output.push(b'\n');
}
}
if let Some(ref event_type) = self.event_type {
output.extend_from_slice(b"event: ");
output.extend_from_slice(event_type.as_bytes());
output.push(b'\n');
}
if let Some(ref id) = self.id {
output.extend_from_slice(b"id: ");
output.extend_from_slice(id.as_bytes());
output.push(b'\n');
}
if let Some(retry) = self.retry {
output.extend_from_slice(b"retry: ");
output.extend_from_slice(retry.to_string().as_bytes());
output.push(b'\n');
}
if let Some(ref data) = self.data {
for line in data.lines() {
output.extend_from_slice(b"data: ");
output.extend_from_slice(line.as_bytes());
output.push(b'\n');
}
if data.is_empty() {
output.extend_from_slice(b"data: \n");
}
}
output.push(b'\n');
output
}
}
impl From<&str> for SseEvent {
fn from(data: &str) -> Self {
Self::new(data)
}
}
impl From<String> for SseEvent {
fn from(data: String) -> Self {
Self::new(data)
}
}
pub struct SseStream<S> {
inner: S,
}
impl<S> SseStream<S> {
pub fn new(stream: S) -> Self {
Self { inner: stream }
}
}
impl<S> Stream for SseStream<S>
where
S: Stream<Item = SseEvent> + Unpin,
{
type Item = Vec<u8>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(event)) => Poll::Ready(Some(event.to_bytes())),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug, Clone)]
pub struct SseConfig {
pub keep_alive_secs: u64,
pub keep_alive_comment: String,
}
impl Default for SseConfig {
fn default() -> Self {
Self {
keep_alive_secs: 30,
keep_alive_comment: "keep-alive".to_string(),
}
}
}
impl SseConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn keep_alive_secs(mut self, seconds: u64) -> Self {
self.keep_alive_secs = seconds;
self
}
#[must_use]
pub fn disable_keep_alive(mut self) -> Self {
self.keep_alive_secs = 0;
self
}
#[must_use]
pub fn keep_alive_comment(mut self, comment: impl Into<String>) -> Self {
self.keep_alive_comment = comment.into();
self
}
}
pub struct SseResponse<S> {
stream: S,
_config: SseConfig,
}
impl<S> SseResponse<S>
where
S: Stream<Item = SseEvent> + Send + Unpin + 'static,
{
pub fn new(stream: S) -> Self {
Self {
stream,
_config: SseConfig::default(),
}
}
pub fn with_config(stream: S, config: SseConfig) -> Self {
Self {
stream,
_config: config,
}
}
#[must_use]
pub fn into_response(self) -> Response {
let sse_stream = SseStream::new(self.stream);
Response::with_status(StatusCode::OK)
.header("content-type", b"text/event-stream".to_vec())
.header("cache-control", b"no-cache".to_vec())
.header("connection", b"keep-alive".to_vec())
.header("x-accel-buffering", b"no".to_vec()) .body(ResponseBody::stream(sse_stream))
}
}
pub fn sse_response<S>(stream: S) -> Response
where
S: Stream<Item = SseEvent> + Send + Unpin + 'static,
{
SseResponse::new(stream).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn event_simple_message() {
let event = SseEvent::message("Hello, World!");
let bytes = event.to_bytes();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("data: Hello, World!\n"));
assert!(output.ends_with("\n\n"));
}
#[test]
fn event_with_type() {
let event = SseEvent::new("user joined").event_type("join");
let bytes = event.to_bytes();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("event: join\n"));
assert!(output.contains("data: user joined\n"));
}
#[test]
fn event_with_id() {
let event = SseEvent::new("data").id("12345");
let bytes = event.to_bytes();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("id: 12345\n"));
}
#[test]
fn event_with_retry() {
let event = SseEvent::new("data").retry_ms(5000);
let bytes = event.to_bytes();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("retry: 5000\n"));
}
#[test]
fn event_multiline_data() {
let event = SseEvent::new("line1\nline2\nline3");
let bytes = event.to_bytes();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("data: line1\n"));
assert!(output.contains("data: line2\n"));
assert!(output.contains("data: line3\n"));
}
#[test]
fn event_comment() {
let event = SseEvent::comment("keep-alive");
let bytes = event.to_bytes();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains(": keep-alive\n"));
}
#[test]
fn event_full_format() {
let event = SseEvent::new("payload")
.event_type("update")
.id("42")
.retry_ms(3000);
let bytes = event.to_bytes();
let output = String::from_utf8_lossy(&bytes);
let event_pos = output.find("event:").unwrap();
let id_pos = output.find("id:").unwrap();
let retry_pos = output.find("retry:").unwrap();
let data_pos = output.find("data:").unwrap();
assert!(event_pos < id_pos);
assert!(id_pos < retry_pos);
assert!(retry_pos < data_pos);
}
#[test]
fn event_from_str() {
let event: SseEvent = "Hello".into();
let bytes = event.to_bytes();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("data: Hello\n"));
}
#[test]
fn event_from_string() {
let event: SseEvent = String::from("World").into();
let bytes = event.to_bytes();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("data: World\n"));
}
#[test]
fn config_defaults() {
let config = SseConfig::default();
assert_eq!(config.keep_alive_secs, 30);
assert_eq!(config.keep_alive_comment, "keep-alive");
}
#[test]
fn config_custom() {
let config = SseConfig::new()
.keep_alive_secs(60)
.keep_alive_comment("heartbeat");
assert_eq!(config.keep_alive_secs, 60);
assert_eq!(config.keep_alive_comment, "heartbeat");
}
#[test]
fn config_disable_keepalive() {
let config = SseConfig::new().disable_keep_alive();
assert_eq!(config.keep_alive_secs, 0);
}
#[test]
fn event_empty_data() {
let event = SseEvent::new("");
let bytes = event.to_bytes();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("data: \n"));
}
#[test]
fn retry_from_duration() {
let event = SseEvent::new("data").retry(Duration::from_secs(10));
let bytes = event.to_bytes();
let output = String::from_utf8_lossy(&bytes);
assert!(output.contains("retry: 10000\n"));
}
}