use axum::{
Router,
extract::State,
response::{
Sse,
sse::{Event, KeepAlive},
},
routing::get,
};
use dashmap::DashMap;
use futures_util::Stream;
use serde::{Deserialize, Serialize};
use std::{convert::Infallible, marker::PhantomData, sync::Arc};
use tokio::sync::mpsc;
use tokio_stream::{StreamExt, wrappers::ReceiverStream};
use uuid::Uuid;
#[derive(Clone, Copy)]
pub struct SseSetup<T> {
session_data: PhantomData<T>,
}
#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
pub struct FhtmxUiNoSessionData;
impl SseSetup<()> {
#[must_use]
pub fn new_with_data<T>() -> SseSetup<T> {
SseSetup {
session_data: PhantomData,
}
}
}
impl Default for SseSetup<FhtmxUiNoSessionData> {
fn default() -> Self {
Self::new()
}
}
impl SseSetup<FhtmxUiNoSessionData> {
pub fn new() -> Self {
SseSetup {
session_data: PhantomData,
}
}
}
impl<T> SseSetup<T>
where
T: Clone + Send + Sync + 'static,
{
#[must_use]
pub fn state_data(&self) -> SseState<T> {
SseState::default()
}
pub fn sse_route(path: &str) -> Router<SseState<T>> {
Router::new().route(path, get(sse_handler::<T>))
}
}
pub struct SseSession<T> {
pub data: Option<T>,
pub sender: mpsc::Sender<Event>,
}
#[derive(Clone)]
pub struct SseState<T> {
pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
}
impl<T> Default for SseState<T> {
fn default() -> Self {
Self {
sessions: Arc::new(DashMap::new()),
}
}
}
impl<T: Clone> SseState<T> {
pub fn get_session_data(&self, id: Uuid) -> Option<T> {
self.sessions.get(&id).and_then(|x| x.data.clone())
}
}
impl<T> SseState<T> {
pub fn add_session(
&self,
id: Uuid,
data: Option<T>,
sender: mpsc::Sender<Event>,
) -> Option<SseSession<T>> {
let session = SseSession { data, sender };
self.sessions.insert(id, session)
}
pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
self.sessions.remove(&id)
}
pub fn send_message<D: AsRef<str>>(&self, id: Uuid, data: D) -> Option<()> {
let sender = self.sessions.get(&id)?.sender.clone();
if sender.try_send(Event::default().data(data)).is_err() {
self.remove_session(id);
}
Some(())
}
pub fn broadcast<D: AsRef<str>>(&self, data: D) -> usize {
let senders = self
.sessions
.iter()
.map(|o| o.value().sender.clone())
.collect::<Vec<_>>();
sse_broadcast(senders, data)
}
pub fn broadcast_all_but<D: AsRef<str>>(&self, id: Uuid, data: D) -> usize {
let senders = self
.sessions
.iter()
.filter_map(|o| {
if *o.key() == id {
None
} else {
Some(o.value().sender.clone())
}
})
.collect::<Vec<_>>();
sse_broadcast(senders, data)
}
}
pub fn sse_broadcast<D: AsRef<str>>(senders: Vec<mpsc::Sender<Event>>, data: D) -> usize {
let data = data.as_ref();
senders
.into_iter()
.filter_map(|o| o.try_send(Event::default().data(data)).ok())
.count()
}
#[tracing::instrument(skip_all)]
pub async fn sse_handler<T: Send + Sync + 'static>(
State(state): State<SseState<T>>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let (tx, rx) = mpsc::channel(8);
let id = Uuid::new_v4();
state.add_session(id, None, tx.clone());
let _ = tx
.send(Event::default().data(id.to_string()).event("sse_id"))
.await;
let sessions = state.sessions.clone();
tokio::spawn(async move {
tx.closed().await;
sessions.remove(&id);
});
let stream = ReceiverStream::new(rx).map(Ok);
Sse::new(stream).keep_alive(KeepAlive::default())
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SseHandlerQuery {
pub id: Uuid,
}