use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::{self, Interval};
use crate::message::HubEvent;
pub type Bytes = Vec<u8>;
pub struct SseStream {
receiver: mpsc::Receiver<HubEvent>,
heartbeat: Interval,
}
impl SseStream {
pub fn new(receiver: mpsc::Receiver<HubEvent>, heartbeat_interval: Duration) -> Self {
Self {
receiver,
heartbeat: time::interval(heartbeat_interval),
}
}
fn format_event(event: &HubEvent) -> Bytes {
event.to_sse_string().into_bytes()
}
fn heartbeat_bytes() -> Bytes {
b": heartbeat\n\n".to_vec()
}
}
impl futures_core::Stream for SseStream {
type Item = Result<Bytes, std::io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.receiver.poll_recv(cx) {
Poll::Ready(Some(event)) => {
self.heartbeat.reset();
return Poll::Ready(Some(Ok(Self::format_event(&event))));
}
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Pending => {}
}
match self.heartbeat.poll_tick(cx) {
Poll::Ready(_) => Poll::Ready(Some(Ok(Self::heartbeat_bytes()))),
Poll::Pending => Poll::Pending,
}
}
}
pub fn sse_retry_directive(reconnect_ms: u64) -> Bytes {
format!("retry: {}\n\n", reconnect_ms).into_bytes()
}
pub fn sse_comment(text: &str) -> Bytes {
format!(": {}\n\n", text).into_bytes()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::HubEvent;
use chrono::Utc;
use serde_json::json;
use tokio::sync::mpsc;
fn sample_event(id: u64, topic: &str) -> HubEvent {
HubEvent {
id,
topic: topic.to_string(),
data: json!({"key": "value"}),
timestamp: Utc::now(),
}
}
#[test]
fn format_event_produces_valid_sse() {
let evt = sample_event(1, "test/topic");
let bytes = SseStream::format_event(&evt);
let text = String::from_utf8(bytes).unwrap();
assert!(text.starts_with("event: test/topic\n"));
assert!(text.contains("data: "));
assert!(text.contains("id: 1\n"));
assert!(text.ends_with("\n\n"));
}
#[test]
fn heartbeat_bytes_format() {
let bytes = SseStream::heartbeat_bytes();
let text = String::from_utf8(bytes).unwrap();
assert_eq!(text, ": heartbeat\n\n");
}
#[test]
fn sse_retry_directive_format() {
let bytes = sse_retry_directive(3000);
let text = String::from_utf8(bytes).unwrap();
assert_eq!(text, "retry: 3000\n\n");
}
#[test]
fn sse_comment_format() {
let bytes = sse_comment("connected");
let text = String::from_utf8(bytes).unwrap();
assert_eq!(text, ": connected\n\n");
}
#[tokio::test]
async fn stream_delivers_events() {
use futures_core::Stream;
use std::pin::Pin;
use std::task::Poll;
let (tx, rx) = mpsc::channel(256);
let mut stream = SseStream::new(rx, Duration::from_secs(60));
let evt = sample_event(10, "app/deploy");
tx.try_send(evt).unwrap();
let waker = std::task::Waker::noop();
let mut cx = std::task::Context::from_waker(&waker);
let pin = Pin::new(&mut stream);
match pin.poll_next(&mut cx) {
Poll::Ready(Some(Ok(bytes))) => {
let text = String::from_utf8(bytes).unwrap();
assert!(text.contains("event: app/deploy"));
assert!(text.contains("id: 10"));
}
other => panic!("expected Ready(Some(Ok)), got {:?}", other),
}
}
#[tokio::test]
async fn stream_ends_when_channel_closed() {
use futures_core::Stream;
use std::pin::Pin;
use std::task::Poll;
let (tx, rx) = mpsc::channel::<HubEvent>(256);
let mut stream = SseStream::new(rx, Duration::from_secs(60));
drop(tx);
let waker = std::task::Waker::noop();
let mut cx = std::task::Context::from_waker(&waker);
let pin = Pin::new(&mut stream);
match pin.poll_next(&mut cx) {
Poll::Ready(None) => {} other => panic!("expected Ready(None), got {:?}", other),
}
}
#[tokio::test]
async fn stream_emits_heartbeat_when_idle() {
use futures_core::Stream;
use std::pin::Pin;
use std::task::Poll;
let (_tx, rx) = mpsc::channel::<HubEvent>(256);
let mut stream = SseStream::new(rx, Duration::from_millis(1));
tokio::time::sleep(Duration::from_millis(10)).await;
let waker = std::task::Waker::noop();
let mut cx = std::task::Context::from_waker(&waker);
let pin = Pin::new(&mut stream);
match pin.poll_next(&mut cx) {
Poll::Ready(Some(Ok(bytes))) => {
let text = String::from_utf8(bytes).unwrap();
assert_eq!(text, ": heartbeat\n\n");
}
other => panic!("expected heartbeat, got {:?}", other),
}
}
#[tokio::test]
async fn multiple_events_delivered_in_order() {
use futures_core::Stream;
use std::pin::Pin;
use std::task::Poll;
let (tx, rx) = mpsc::channel(256);
let mut stream = SseStream::new(rx, Duration::from_secs(60));
tx.try_send(sample_event(1, "a")).unwrap();
tx.try_send(sample_event(2, "b")).unwrap();
let waker = std::task::Waker::noop();
let mut cx = std::task::Context::from_waker(&waker);
let pin = Pin::new(&mut stream);
match pin.poll_next(&mut cx) {
Poll::Ready(Some(Ok(bytes))) => {
let text = String::from_utf8(bytes).unwrap();
assert!(text.contains("id: 1"));
}
other => panic!("expected first event, got {:?}", other),
}
let pin = Pin::new(&mut stream);
match pin.poll_next(&mut cx) {
Poll::Ready(Some(Ok(bytes))) => {
let text = String::from_utf8(bytes).unwrap();
assert!(text.contains("id: 2"));
}
other => panic!("expected second event, got {:?}", other),
}
}
}