1use actix_web::{Responder, web};
4use actix_web_lab::sse::{Data, Event};
5use dashmap::DashMap;
6use futures::{StreamExt, stream};
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8use std::{marker::PhantomData, sync::Arc, time::Duration};
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12#[derive(Clone, Copy)]
29pub struct SseSetup<T> {
30 session_data: PhantomData<T>,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
35pub struct FhtmxUiNoSessionData;
36
37impl SseSetup<()> {
38 pub fn new_with_data<T>() -> SseSetup<T> {
39 SseSetup {
40 session_data: PhantomData,
41 }
42 }
43}
44
45impl Default for SseSetup<FhtmxUiNoSessionData> {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl SseSetup<FhtmxUiNoSessionData> {
52 pub fn new() -> Self {
53 SseSetup {
54 session_data: PhantomData,
55 }
56 }
57}
58
59impl<T> SseSetup<T>
60where
61 T: DeserializeOwned + 'static,
62{
63 #[must_use]
65 pub fn state_data(&self) -> web::Data<SseState<T>> {
66 web::Data::new(SseState::default())
67 }
68
69 pub fn setup_route(&self, path: &str, cfg: &mut web::ServiceConfig) {
71 cfg.route(path, web::get().to(sse_handler::<T>));
72 }
73}
74
75pub struct SseState<T> {
77 pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
78}
79
80impl<T> Default for SseState<T> {
81 fn default() -> Self {
82 Self {
83 sessions: Arc::new(DashMap::new()),
84 }
85 }
86}
87
88impl<T> SseState<T> {
89 pub fn add_session(
90 &self,
91 id: Uuid,
92 data: Option<T>,
93 sender: mpsc::Sender<Event>,
94 ) -> Option<SseSession<T>> {
95 let session = SseSession { data, sender };
96 self.sessions.insert(id, session)
97 }
98
99 pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
100 self.sessions.remove(&id)
101 }
102
103 pub async fn send_message(&self, id: Uuid, data: Data) -> Option<()> {
105 let sender = self.sessions.get(&id)?.sender.clone();
106 if sender.send(Event::Data(data)).await.is_err() {
107 self.remove_session(id);
109 }
110 Some(())
111 }
112
113 pub async fn broadcast(&self, data: Data) -> usize {
115 let senders = self
116 .sessions
117 .iter()
118 .map(|o| o.value().sender.clone())
119 .collect::<Vec<_>>();
120 sse_broadcast(senders, data).await
121 }
122
123 pub async fn broadcast_all_but(&self, id: Uuid, data: Data) -> usize {
125 let senders = self
126 .sessions
127 .iter()
128 .filter_map(|o| {
129 if *o.key() == id {
130 None
131 } else {
132 Some(o.value().sender.clone())
133 }
134 })
135 .collect::<Vec<_>>();
136 sse_broadcast(senders, data).await
137 }
138}
139
140pub async fn sse_broadcast(senders: Vec<mpsc::Sender<Event>>, data: Data) -> usize {
141 stream::iter(senders)
142 .filter_map(|o| {
143 let data = data.clone();
144 async move { o.send(Event::Data(data)).await.ok() }
145 })
146 .count()
147 .await
148}
149
150pub struct SseSession<T> {
151 pub data: Option<T>,
152 pub sender: mpsc::Sender<Event>,
153}
154
155#[derive(Clone, Debug, Serialize, Deserialize)]
157pub struct SseQueryParams<T> {
158 pub id: Uuid,
159 pub data: Option<T>,
160}
161
162#[tracing::instrument(skip_all)]
164pub async fn sse_handler<T>(
165 web::Query(query): web::Query<SseQueryParams<T>>,
166 state: web::Data<SseState<T>>,
167) -> impl Responder {
168 let (tx, rx) = mpsc::channel(8);
169 state.add_session(query.id, query.data, tx);
170 actix_web_lab::sse::Sse::from_infallible_receiver(rx).with_keep_alive(Duration::from_secs(3))
171}