1use actix_web::{Responder, web};
4use actix_web_lab::sse::{Data, Event};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use std::{marker::PhantomData, sync::Arc, time::Duration};
8use tokio::{sync::mpsc, task::JoinHandle};
9use uuid::Uuid;
10
11#[derive(Clone, Copy)]
26pub struct SseSetup<T> {
27 session_data: PhantomData<T>,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
32pub struct FhtmxUiNoSessionData;
33
34impl SseSetup<()> {
35 pub fn new_with_data<T>() -> SseSetup<T> {
36 SseSetup {
37 session_data: PhantomData,
38 }
39 }
40}
41
42impl Default for SseSetup<FhtmxUiNoSessionData> {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl SseSetup<FhtmxUiNoSessionData> {
49 pub fn new() -> Self {
50 SseSetup {
51 session_data: PhantomData,
52 }
53 }
54}
55
56impl<T> SseSetup<T>
57where
58 T: DeserializeOwned + 'static,
59{
60 #[must_use]
62 pub fn state_data(&self) -> web::Data<SseState<T>> {
63 web::Data::new(SseState::default())
64 }
65
66 pub fn setup_route(&self, path: &str, cfg: &mut web::ServiceConfig) {
68 cfg.route(path, web::get().to(sse_handler::<T>));
69 }
70}
71
72impl<T> SseSetup<T>
73where
74 T: Send + Sync + 'static,
75{
76 #[must_use]
79 pub fn state_data_and_spawn_cleaner(&self, period: Duration) -> web::Data<SseState<T>> {
80 let state = web::Data::new(SseState::default());
81 spawn_remove_stale_sessions_task(state.clone(), period);
82 state
83 }
84}
85
86pub struct SseState<T> {
88 pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
89}
90
91impl<T> Default for SseState<T> {
92 fn default() -> Self {
93 Self {
94 sessions: Arc::new(DashMap::new()),
95 }
96 }
97}
98
99impl<T: Clone> SseState<T> {
100 pub fn get_session_data(&self, id: Uuid) -> Option<T> {
101 self.sessions.get(&id).and_then(|x| x.data.clone())
102 }
103}
104
105impl<T> SseState<T> {
106 pub fn add_session(
107 &self,
108 id: Uuid,
109 data: Option<T>,
110 sender: mpsc::Sender<Event>,
111 ) -> Option<SseSession<T>> {
112 let session = SseSession { data, sender };
113 self.sessions.insert(id, session)
114 }
115
116 pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
117 self.sessions.remove(&id)
118 }
119
120 pub fn remove_stale_sessions(&self) -> usize {
121 let mut removed = 0;
122 self.sessions.retain(
123 |_, o| match o.sender.try_send(Event::Comment("ping".into())) {
124 Ok(()) => true,
125 Err(_) => {
126 removed += 1;
127 false
128 }
129 },
130 );
131 removed
132 }
133
134 pub fn send_message(&self, id: Uuid, data: Data) -> Option<()> {
136 let sender = self.sessions.get(&id)?.sender.clone();
137 if sender.try_send(Event::Data(data)).is_err() {
138 self.remove_session(id);
140 }
141 Some(())
142 }
143
144 pub fn broadcast(&self, data: Data) -> usize {
146 let senders = self
147 .sessions
148 .iter()
149 .map(|o| o.value().sender.clone())
150 .collect::<Vec<_>>();
151 sse_broadcast(senders, data)
152 }
153
154 pub fn broadcast_all_but(&self, id: Uuid, data: Data) -> usize {
156 let senders = self
157 .sessions
158 .iter()
159 .filter_map(|o| {
160 if *o.key() == id {
161 None
162 } else {
163 Some(o.value().sender.clone())
164 }
165 })
166 .collect::<Vec<_>>();
167 sse_broadcast(senders, data)
168 }
169}
170
171pub fn spawn_remove_stale_sessions_task<T>(
173 state: web::Data<SseState<T>>,
174 period: Duration,
175) -> JoinHandle<()>
176where
177 T: Send + Sync + 'static,
178{
179 tokio::spawn(async move {
180 let mut interval = tokio::time::interval(period);
181 loop {
182 interval.tick().await;
183 let removed = state.remove_stale_sessions();
184 if removed > 0 {
185 tracing::info!("Removed {removed} staled sessions.");
186 }
187 }
188 })
189}
190
191pub fn sse_broadcast(senders: Vec<mpsc::Sender<Event>>, data: Data) -> usize {
192 senders
193 .into_iter()
194 .filter_map(|o| o.try_send(Event::Data(data.clone())).ok())
195 .count()
196}
197
198pub struct SseSession<T> {
199 pub data: Option<T>,
200 pub sender: mpsc::Sender<Event>,
201}
202
203#[derive(Clone, Debug, Serialize, Deserialize)]
205pub struct SseHandlerQuery {
206 pub id: Uuid,
207}
208
209#[tracing::instrument(skip_all)]
211pub async fn sse_handler<T>(
212 web::Query(query): web::Query<SseHandlerQuery>,
213 state: web::Data<SseState<T>>,
214) -> impl Responder {
215 let (tx, rx) = mpsc::channel(8);
216 state.add_session(query.id, None, tx);
217 actix_web_lab::sse::Sse::from_infallible_receiver(rx).with_keep_alive(Duration::from_secs(3))
218}