1use actix_web::{Responder, web};
4use actix_web_lab::sse::{Data, Event};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use std::{marker::PhantomData, sync::Arc, time::Duration};
8use tokio::sync::mpsc;
9use uuid::Uuid;
10
11#[derive(Clone, Copy)]
26pub struct SseSetup<T> {
27 session_data: PhantomData<T>,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
32pub struct FhtmxUiNoSessionData;
33
34impl SseSetup<()> {
35 pub fn new_with_data<T>() -> SseSetup<T> {
36 SseSetup {
37 session_data: PhantomData,
38 }
39 }
40}
41
42impl Default for SseSetup<FhtmxUiNoSessionData> {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl SseSetup<FhtmxUiNoSessionData> {
49 pub fn new() -> Self {
50 SseSetup {
51 session_data: PhantomData,
52 }
53 }
54}
55
56impl<T> SseSetup<T>
57where
58 T: DeserializeOwned + Send + Sync + 'static,
59{
60 #[must_use]
62 pub fn state_data(&self) -> web::Data<SseState<T>> {
63 web::Data::new(SseState::default())
64 }
65
66 pub fn setup_route(&self, path: &str, cfg: &mut web::ServiceConfig) {
68 cfg.route(path, web::get().to(sse_handler::<T>));
69 }
70}
71
72pub struct SseState<T> {
74 pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
75}
76
77impl<T> Default for SseState<T> {
78 fn default() -> Self {
79 Self {
80 sessions: Arc::new(DashMap::new()),
81 }
82 }
83}
84
85impl<T: Clone> SseState<T> {
86 pub fn get_session_data(&self, id: Uuid) -> Option<T> {
87 self.sessions.get(&id).and_then(|x| x.data.clone())
88 }
89}
90
91impl<T> SseState<T> {
92 pub fn add_session(
93 &self,
94 id: Uuid,
95 data: Option<T>,
96 sender: mpsc::Sender<Event>,
97 ) -> Option<SseSession<T>> {
98 let session = SseSession { data, sender };
99 self.sessions.insert(id, session)
100 }
101
102 pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
103 self.sessions.remove(&id)
104 }
105
106 pub fn remove_stale_sessions(&self) -> usize {
107 let mut removed = 0;
108 self.sessions.retain(
109 |_, o| match o.sender.try_send(Event::Comment("ping".into())) {
110 Ok(()) => true,
111 Err(_) => {
112 removed += 1;
113 false
114 }
115 },
116 );
117 removed
118 }
119
120 pub fn send_message(&self, id: Uuid, data: Data) -> Option<()> {
122 let sender = self.sessions.get(&id)?.sender.clone();
123 if sender.try_send(Event::Data(data)).is_err() {
124 self.remove_session(id);
126 }
127 Some(())
128 }
129
130 pub fn broadcast(&self, data: Data) -> usize {
132 let senders = self
133 .sessions
134 .iter()
135 .map(|o| o.value().sender.clone())
136 .collect::<Vec<_>>();
137 sse_broadcast(senders, data)
138 }
139
140 pub fn broadcast_all_but(&self, id: Uuid, data: Data) -> usize {
142 let senders = self
143 .sessions
144 .iter()
145 .filter_map(|o| {
146 if *o.key() == id {
147 None
148 } else {
149 Some(o.value().sender.clone())
150 }
151 })
152 .collect::<Vec<_>>();
153 sse_broadcast(senders, data)
154 }
155}
156
157pub fn sse_broadcast(senders: Vec<mpsc::Sender<Event>>, data: Data) -> usize {
158 senders
159 .into_iter()
160 .filter_map(|o| o.try_send(Event::Data(data.clone())).ok())
161 .count()
162}
163
164pub struct SseSession<T> {
165 pub data: Option<T>,
166 pub sender: mpsc::Sender<Event>,
167}
168
169#[derive(Clone, Debug, Serialize, Deserialize)]
171pub struct SseHandlerQuery {
172 pub id: Uuid,
173}
174
175#[tracing::instrument(skip_all)]
177pub async fn sse_handler<T: Send + Sync + 'static>(
178 web::Query(query): web::Query<SseHandlerQuery>,
179 state: web::Data<SseState<T>>,
180) -> impl Responder {
181 let (tx, rx) = mpsc::channel(8);
182 state.add_session(query.id, None, tx.clone());
183
184 let sessions = state.sessions.clone();
185 tokio::spawn(async move {
186 tx.closed().await;
187 sessions.remove(&query.id);
188 });
189
190 actix_web_lab::sse::Sse::from_infallible_receiver(rx).with_keep_alive(Duration::from_secs(3))
191}