1use actix_web::{Responder, web};
4use actix_web_lab::sse::{Data, Event};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use std::{marker::PhantomData, sync::Arc, time::Duration};
8use tokio::{sync::mpsc, task::JoinHandle};
9use uuid::Uuid;
10
11#[derive(Clone, Copy)]
26pub struct SseSetup<T> {
27 session_data: PhantomData<T>,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
32pub struct FhtmxUiNoSessionData;
33
34impl SseSetup<()> {
35 pub fn new_with_data<T>() -> SseSetup<T> {
36 SseSetup {
37 session_data: PhantomData,
38 }
39 }
40}
41
42impl Default for SseSetup<FhtmxUiNoSessionData> {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl SseSetup<FhtmxUiNoSessionData> {
49 pub fn new() -> Self {
50 SseSetup {
51 session_data: PhantomData,
52 }
53 }
54}
55
56impl<T> SseSetup<T>
57where
58 T: DeserializeOwned + 'static,
59{
60 #[must_use]
62 pub fn state_data(&self) -> web::Data<SseState<T>> {
63 web::Data::new(SseState::default())
64 }
65
66 pub fn setup_route(&self, path: &str, cfg: &mut web::ServiceConfig) {
68 cfg.route(path, web::get().to(sse_handler::<T>));
69 }
70}
71
72impl<T> SseSetup<T>
73where
74 T: Send + Sync + 'static,
75{
76 #[must_use]
79 pub fn state_data_and_spawn_cleaner(&self, period: Duration) -> web::Data<SseState<T>> {
80 let state = web::Data::new(SseState::default());
81 spawn_remove_stale_sessions_task(state.clone(), period);
82 state
83 }
84}
85
86pub struct SseState<T> {
88 pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
89}
90
91impl<T> Default for SseState<T> {
92 fn default() -> Self {
93 Self {
94 sessions: Arc::new(DashMap::new()),
95 }
96 }
97}
98
99impl<T: Clone> SseState<T> {
100 pub fn get_session_data(&self, id: Uuid) -> Option<T> {
101 self.sessions.get(&id).and_then(|x| x.data.clone())
102 }
103}
104
105impl<T> SseState<T> {
106 pub fn add_session(
107 &self,
108 id: Uuid,
109 data: Option<T>,
110 sender: mpsc::Sender<Event>,
111 ) -> Option<SseSession<T>> {
112 let session = SseSession { data, sender };
113 self.sessions.insert(id, session)
114 }
115
116 pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
117 self.sessions.remove(&id)
118 }
119
120 pub fn remove_stale_sessions(&self) {
121 self.sessions
122 .retain(|_, o| o.sender.try_send(Event::Comment("ping".into())).is_ok());
123 }
124
125 pub fn send_message(&self, id: Uuid, data: Data) -> Option<()> {
127 let sender = self.sessions.get(&id)?.sender.clone();
128 if sender.try_send(Event::Data(data)).is_err() {
129 self.remove_session(id);
131 }
132 Some(())
133 }
134
135 pub fn broadcast(&self, data: Data) -> usize {
137 let senders = self
138 .sessions
139 .iter()
140 .map(|o| o.value().sender.clone())
141 .collect::<Vec<_>>();
142 sse_broadcast(senders, data)
143 }
144
145 pub fn broadcast_all_but(&self, id: Uuid, data: Data) -> usize {
147 let senders = self
148 .sessions
149 .iter()
150 .filter_map(|o| {
151 if *o.key() == id {
152 None
153 } else {
154 Some(o.value().sender.clone())
155 }
156 })
157 .collect::<Vec<_>>();
158 sse_broadcast(senders, data)
159 }
160}
161
162pub fn spawn_remove_stale_sessions_task<T>(
164 state: web::Data<SseState<T>>,
165 period: Duration,
166) -> JoinHandle<()>
167where
168 T: Send + Sync + 'static,
169{
170 tokio::spawn(async move {
171 let mut interval = tokio::time::interval(period);
172 loop {
173 interval.tick().await;
174 state.remove_stale_sessions();
175 }
176 })
177}
178
179pub fn sse_broadcast(senders: Vec<mpsc::Sender<Event>>, data: Data) -> usize {
180 senders
181 .into_iter()
182 .filter_map(|o| o.try_send(Event::Data(data.clone())).ok())
183 .count()
184}
185
186pub struct SseSession<T> {
187 pub data: Option<T>,
188 pub sender: mpsc::Sender<Event>,
189}
190
191#[derive(Clone, Debug, Serialize, Deserialize)]
193pub struct SseHandlerQuery {
194 pub id: Uuid,
195}
196
197#[tracing::instrument(skip_all)]
199pub async fn sse_handler<T>(
200 web::Query(query): web::Query<SseHandlerQuery>,
201 state: web::Data<SseState<T>>,
202) -> impl Responder {
203 let (tx, rx) = mpsc::channel(8);
204 state.add_session(query.id, None, tx);
205 actix_web_lab::sse::Sse::from_infallible_receiver(rx).with_keep_alive(Duration::from_secs(3))
206}