pub use axum::response::sse::{Event, KeepAlive, Sse};
#[cfg(feature = "ws")]
use std::convert::Infallible;
#[cfg(feature = "ws")]
use std::future::Future;
use std::time::Duration;
pub fn keep_alive() -> KeepAlive {
KeepAlive::new().interval(Duration::from_secs(15))
}
#[cfg(feature = "ws")]
pub fn from_subscriber(
subscriber: crate::channels::Subscriber,
) -> Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>> + use<>> {
use tokio_stream::StreamExt;
let stream = subscriber
.into_stream()
.map(|msg| Ok(Event::default().data(msg.into_string())));
Sse::new(stream).keep_alive(keep_alive())
}
#[cfg(feature = "ws")]
pub fn stream(
state: &crate::AppState,
topic: &str,
) -> Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>> + use<>> {
from_subscriber(state.channels().subscribe(topic))
}
#[cfg(feature = "ws")]
pub async fn stream_authorized<E, F, Fut>(
state: &crate::AppState,
topic: &str,
authorize: F,
) -> Result<Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>> + use<E, F, Fut>>, E>
where
F: FnOnce(String) -> Fut,
Fut: Future<Output = Result<(), E>>,
{
let subscriber = state
.channels()
.subscribe_authorized(topic, authorize)
.await?;
Ok(from_subscriber(subscriber))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keep_alive_default() {
let ka = keep_alive();
let debug_str = format!("{ka:?}");
assert!(debug_str.contains("KeepAlive"));
}
#[cfg(feature = "ws")]
#[tokio::test]
async fn stream_helper_builds_sse_from_app_state_channels() {
let state = crate::AppState::for_test();
let _sse = stream(&state, "lobby");
}
#[cfg(feature = "ws")]
#[tokio::test]
async fn stream_authorized_rejects_before_subscription() {
let state = crate::AppState::for_test();
let result = stream_authorized(&state, "private", |topic| async move {
assert_eq!(topic, "private");
Err::<(), &'static str>("denied")
})
.await;
assert!(matches!(result, Err("denied")));
assert!(!state.channels().snapshot().contains_key("private"));
}
}