Skip to main content

fhtmx_actix/
sse.rs

1//! # Server sent events
2
3use actix_web::{Responder, web};
4use actix_web_lab::sse::{Data, Event};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use std::{marker::PhantomData, sync::Arc, time::Duration};
8use tokio::{sync::mpsc, task::JoinHandle};
9use uuid::Uuid;
10
11/// Setups Server sent event state and routes
12///
13/// # Example
14///
15/// ```rust,ignore
16/// use actix_web::App;
17/// use fhtmx_actix::sse::SseSetup;
18///
19/// let sse_setup = SseSetup::new();
20/// let sse_data = sse_setup.state();
21/// App::new()
22///     .configure(|cfg| sse_setup.setup_route("/sse", cfg))
23///     .app_data(sse_data);
24/// ```
25#[derive(Clone, Copy)]
26pub struct SseSetup<T> {
27    session_data: PhantomData<T>,
28}
29
30/// Setup sse without session data
31#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
32pub struct FhtmxUiNoSessionData;
33
34impl SseSetup<()> {
35    pub fn new_with_data<T>() -> SseSetup<T> {
36        SseSetup {
37            session_data: PhantomData,
38        }
39    }
40}
41
42impl Default for SseSetup<FhtmxUiNoSessionData> {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl SseSetup<FhtmxUiNoSessionData> {
49    pub fn new() -> Self {
50        SseSetup {
51            session_data: PhantomData,
52        }
53    }
54}
55
56impl<T> SseSetup<T>
57where
58    T: DeserializeOwned + 'static,
59{
60    /// Gets a `SseState` instance for you to add it to your app
61    #[must_use]
62    pub fn state_data(&self) -> web::Data<SseState<T>> {
63        web::Data::new(SseState::default())
64    }
65
66    /// Setups the sse route
67    pub fn setup_route(&self, path: &str, cfg: &mut web::ServiceConfig) {
68        cfg.route(path, web::get().to(sse_handler::<T>));
69    }
70}
71
72impl<T> SseSetup<T>
73where
74    T: Send + Sync + 'static,
75{
76    /// Gets a `SseState` instance for you to add it to your app and launches a task to execute
77    /// `SseState::remove_stale_sessions`.
78    #[must_use]
79    pub fn state_data_and_spawn_cleaner(&self, period: Duration) -> web::Data<SseState<T>> {
80        let state = web::Data::new(SseState::default());
81        spawn_remove_stale_sessions_task(state.clone(), period);
82        state
83    }
84}
85
86/// SSE state
87pub struct SseState<T> {
88    pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
89}
90
91impl<T> Default for SseState<T> {
92    fn default() -> Self {
93        Self {
94            sessions: Arc::new(DashMap::new()),
95        }
96    }
97}
98
99impl<T: Clone> SseState<T> {
100    pub fn get_session_data(&self, id: Uuid) -> Option<T> {
101        self.sessions.get(&id).and_then(|x| x.data.clone())
102    }
103}
104
105impl<T> SseState<T> {
106    pub fn add_session(
107        &self,
108        id: Uuid,
109        data: Option<T>,
110        sender: mpsc::Sender<Event>,
111    ) -> Option<SseSession<T>> {
112        let session = SseSession { data, sender };
113        self.sessions.insert(id, session)
114    }
115
116    pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
117        self.sessions.remove(&id)
118    }
119
120    pub fn remove_stale_sessions(&self) {
121        self.sessions
122            .retain(|_, o| o.sender.try_send(Event::Comment("ping".into())).is_ok());
123    }
124
125    /// Sends a message to session id
126    pub fn send_message(&self, id: Uuid, data: Data) -> Option<()> {
127        let sender = self.sessions.get(&id)?.sender.clone();
128        if sender.try_send(Event::Data(data)).is_err() {
129            // Channel is closed so we remove the session
130            self.remove_session(id);
131        }
132        Some(())
133    }
134
135    /// Broadcast a message to all sessions and returns the number of sent messages
136    pub fn broadcast(&self, data: Data) -> usize {
137        let senders = self
138            .sessions
139            .iter()
140            .map(|o| o.value().sender.clone())
141            .collect::<Vec<_>>();
142        sse_broadcast(senders, data)
143    }
144
145    /// Broadcast a message to all sessions but one id and returns the number of sent messages
146    pub fn broadcast_all_but(&self, id: Uuid, data: Data) -> usize {
147        let senders = self
148            .sessions
149            .iter()
150            .filter_map(|o| {
151                if *o.key() == id {
152                    None
153                } else {
154                    Some(o.value().sender.clone())
155                }
156            })
157            .collect::<Vec<_>>();
158        sse_broadcast(senders, data)
159    }
160}
161
162/// Launches task to clean disconnected sessions
163pub fn spawn_remove_stale_sessions_task<T>(
164    state: web::Data<SseState<T>>,
165    period: Duration,
166) -> JoinHandle<()>
167where
168    T: Send + Sync + 'static,
169{
170    tokio::spawn(async move {
171        let mut interval = tokio::time::interval(period);
172        loop {
173            interval.tick().await;
174            state.remove_stale_sessions();
175        }
176    })
177}
178
179pub fn sse_broadcast(senders: Vec<mpsc::Sender<Event>>, data: Data) -> usize {
180    senders
181        .into_iter()
182        .filter_map(|o| o.try_send(Event::Data(data.clone())).ok())
183        .count()
184}
185
186pub struct SseSession<T> {
187    pub data: Option<T>,
188    pub sender: mpsc::Sender<Event>,
189}
190
191/// Identifier for the session
192#[derive(Clone, Debug, Serialize, Deserialize)]
193pub struct SseHandlerQuery {
194    pub id: Uuid,
195}
196
197/// Route to handle web sockets
198#[tracing::instrument(skip_all)]
199pub async fn sse_handler<T>(
200    web::Query(query): web::Query<SseHandlerQuery>,
201    state: web::Data<SseState<T>>,
202) -> impl Responder {
203    let (tx, rx) = mpsc::channel(8);
204    state.add_session(query.id, None, tx);
205    actix_web_lab::sse::Sse::from_infallible_receiver(rx).with_keep_alive(Duration::from_secs(3))
206}