1use actix_web::{Responder, web};
4use actix_web_lab::sse::{Data, Event};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use std::{marker::PhantomData, sync::Arc, time::Duration};
8use tokio::{sync::mpsc, task::JoinHandle};
9use uuid::Uuid;
10
11#[derive(Clone, Copy)]
26pub struct SseSetup<T> {
27 session_data: PhantomData<T>,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
32pub struct FhtmxUiNoSessionData;
33
34impl SseSetup<()> {
35 pub fn new_with_data<T>() -> SseSetup<T> {
36 SseSetup {
37 session_data: PhantomData,
38 }
39 }
40}
41
42impl Default for SseSetup<FhtmxUiNoSessionData> {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl SseSetup<FhtmxUiNoSessionData> {
49 pub fn new() -> Self {
50 SseSetup {
51 session_data: PhantomData,
52 }
53 }
54}
55
56impl<T> SseSetup<T>
57where
58 T: DeserializeOwned + 'static,
59{
60 #[must_use]
62 pub fn state_data(&self) -> web::Data<SseState<T>> {
63 web::Data::new(SseState::default())
64 }
65
66 pub fn setup_route(&self, path: &str, cfg: &mut web::ServiceConfig) {
68 cfg.route(path, web::get().to(sse_handler::<T>));
69 }
70}
71
72impl<T> SseSetup<T>
73where
74 T: Send + Sync + 'static,
75{
76 #[must_use]
79 pub fn state_data_and_spawn_cleaner(&self, period: Duration) -> web::Data<SseState<T>> {
80 let state = web::Data::new(SseState::default());
81 spawn_remove_stale_sessions_task(state.clone(), period);
82 state
83 }
84}
85
86pub struct SseState<T> {
88 pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
89}
90
91impl<T> Default for SseState<T> {
92 fn default() -> Self {
93 Self {
94 sessions: Arc::new(DashMap::new()),
95 }
96 }
97}
98
99impl<T: Clone> SseState<T> {
100 pub fn get_session_data(&self, id: Uuid) -> Option<T> {
101 self.sessions.get(&id).and_then(|x| x.data.clone())
102 }
103}
104
105impl<T> SseState<T> {
106 pub fn add_session(
107 &self,
108 id: Uuid,
109 data: Option<T>,
110 sender: mpsc::Sender<Event>,
111 ) -> Option<SseSession<T>> {
112 let session = SseSession { data, sender };
113 self.sessions.insert(id, session)
114 }
115
116 pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
117 self.sessions.remove(&id)
118 }
119
120 pub fn remove_stale_sessions(&self) -> usize {
121 let mut removed = 0;
122 self.sessions.retain(
123 |_, o| match o.sender.try_send(Event::Comment("ping".into())) {
124 Ok(()) => true,
125 Err(_) => {
126 removed += 1;
127 false
128 }
129 },
130 );
131 removed
132 }
133
134 pub fn send_message(&self, id: Uuid, data: Data) -> Option<()> {
136 let sender = self.sessions.get(&id)?.sender.clone();
137 sender.try_send(Event::Data(data)).ok()
138 }
139
140 pub fn broadcast(&self, data: Data) -> usize {
142 let senders = self
143 .sessions
144 .iter()
145 .map(|o| o.value().sender.clone())
146 .collect::<Vec<_>>();
147 sse_broadcast(senders, data)
148 }
149
150 pub fn broadcast_all_but(&self, id: Uuid, data: Data) -> usize {
152 let senders = self
153 .sessions
154 .iter()
155 .filter_map(|o| {
156 if *o.key() == id {
157 None
158 } else {
159 Some(o.value().sender.clone())
160 }
161 })
162 .collect::<Vec<_>>();
163 sse_broadcast(senders, data)
164 }
165}
166
167pub fn spawn_remove_stale_sessions_task<T>(
169 state: web::Data<SseState<T>>,
170 period: Duration,
171) -> JoinHandle<()>
172where
173 T: Send + Sync + 'static,
174{
175 tokio::spawn(async move {
176 let mut interval = tokio::time::interval(period);
177 loop {
178 interval.tick().await;
179 let removed = state.remove_stale_sessions();
180 if removed > 0 {
181 tracing::info!("Removed {removed} staled sessions.");
182 }
183 }
184 })
185}
186
187pub fn sse_broadcast(senders: Vec<mpsc::Sender<Event>>, data: Data) -> usize {
188 senders
189 .into_iter()
190 .filter_map(|o| o.try_send(Event::Data(data.clone())).ok())
191 .count()
192}
193
194pub struct SseSession<T> {
195 pub data: Option<T>,
196 pub sender: mpsc::Sender<Event>,
197}
198
199#[derive(Clone, Debug, Serialize, Deserialize)]
201pub struct SseHandlerQuery {
202 pub id: Uuid,
203}
204
205#[tracing::instrument(skip_all)]
207pub async fn sse_handler<T>(
208 web::Query(query): web::Query<SseHandlerQuery>,
209 state: web::Data<SseState<T>>,
210) -> impl Responder {
211 let (tx, rx) = mpsc::channel(8);
212 state.add_session(query.id, None, tx);
213 actix_web_lab::sse::Sse::from_infallible_receiver(rx).with_keep_alive(Duration::from_secs(3))
214}