1use actix_web::{Responder, web};
4use actix_web_lab::sse::{Data, Event};
5use dashmap::DashMap;
6use futures::{StreamExt, stream};
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8use std::{marker::PhantomData, sync::Arc, time::Duration};
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12#[derive(Clone, Copy)]
29pub struct SseSetup<T> {
30 session_data: PhantomData<T>,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
35pub struct FhtmxUiNoSessionData;
36
37impl SseSetup<()> {
38 pub fn new_with_data<T>() -> SseSetup<T> {
39 SseSetup {
40 session_data: PhantomData,
41 }
42 }
43}
44
45impl Default for SseSetup<FhtmxUiNoSessionData> {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl SseSetup<FhtmxUiNoSessionData> {
52 pub fn new() -> Self {
53 SseSetup {
54 session_data: PhantomData,
55 }
56 }
57}
58
59impl<T> SseSetup<T>
60where
61 T: DeserializeOwned + 'static,
62{
63 #[must_use]
65 pub fn state_data(&self) -> web::Data<SseState<T>> {
66 web::Data::new(SseState::default())
67 }
68
69 pub fn setup_route(&self, path: &str, cfg: &mut web::ServiceConfig) {
71 cfg.route(path, web::get().to(sse_handler::<T>));
72 }
73}
74
75pub struct SseState<T> {
77 pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
78}
79
80impl<T> Default for SseState<T> {
81 fn default() -> Self {
82 Self {
83 sessions: Arc::new(DashMap::new()),
84 }
85 }
86}
87
88impl<T: Clone> SseState<T> {
89 pub fn get_session_data(&self, id: Uuid) -> Option<T> {
90 self.sessions.get(&id).and_then(|x| x.data.clone())
91 }
92}
93
94impl<T> SseState<T> {
95 pub fn add_session(
96 &self,
97 id: Uuid,
98 data: Option<T>,
99 sender: mpsc::Sender<Event>,
100 ) -> Option<SseSession<T>> {
101 let session = SseSession { data, sender };
102 self.sessions.insert(id, session)
103 }
104
105 pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
106 self.sessions.remove(&id)
107 }
108
109 pub async fn send_message(&self, id: Uuid, data: Data) -> Option<()> {
111 let sender = self.sessions.get(&id)?.sender.clone();
112 if sender.send(Event::Data(data)).await.is_err() {
113 self.remove_session(id);
115 }
116 Some(())
117 }
118
119 pub async fn broadcast(&self, data: Data) -> usize {
121 let senders = self
122 .sessions
123 .iter()
124 .map(|o| o.value().sender.clone())
125 .collect::<Vec<_>>();
126 sse_broadcast(senders, data).await
127 }
128
129 pub async fn broadcast_all_but(&self, id: Uuid, data: Data) -> usize {
131 let senders = self
132 .sessions
133 .iter()
134 .filter_map(|o| {
135 if *o.key() == id {
136 None
137 } else {
138 Some(o.value().sender.clone())
139 }
140 })
141 .collect::<Vec<_>>();
142 sse_broadcast(senders, data).await
143 }
144}
145
146pub async fn sse_broadcast(senders: Vec<mpsc::Sender<Event>>, data: Data) -> usize {
147 stream::iter(senders)
148 .filter_map(|o| {
149 let data = data.clone();
150 async move { o.send(Event::Data(data)).await.ok() }
151 })
152 .count()
153 .await
154}
155
156pub struct SseSession<T> {
157 pub data: Option<T>,
158 pub sender: mpsc::Sender<Event>,
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize)]
163pub struct SseHandlerQuery {
164 pub id: Uuid,
165}
166
167#[tracing::instrument(skip_all)]
169pub async fn sse_handler<T>(
170 web::Query(query): web::Query<SseHandlerQuery>,
171 state: web::Data<SseState<T>>,
172) -> impl Responder {
173 let (tx, rx) = mpsc::channel(8);
174 state.add_session(query.id, None, tx);
175 actix_web_lab::sse::Sse::from_infallible_receiver(rx).with_keep_alive(Duration::from_secs(3))
176}