Skip to main content

fhtmx_actix/
sse.rs

1//! # Server sent events
2
3use actix_web::{Responder, web};
4use actix_web_lab::sse::{Data, Event};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use std::{marker::PhantomData, sync::Arc, time::Duration};
8use tokio::{sync::mpsc, task::JoinHandle};
9use uuid::Uuid;
10
11/// Setups Server sent event state and routes
12///
13/// # Example
14///
15/// ```rust,ignore
16/// use actix_web::App;
17/// use fhtmx_actix::sse::SseSetup;
18///
19/// let sse_setup = SseSetup::new();
20/// let sse_data = sse_setup.state();
21/// App::new()
22///     .configure(|cfg| sse_setup.setup_route("/sse", cfg))
23///     .app_data(sse_data);
24/// ```
25#[derive(Clone, Copy)]
26pub struct SseSetup<T> {
27    session_data: PhantomData<T>,
28}
29
30/// Setup sse without session data
31#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
32pub struct FhtmxUiNoSessionData;
33
34impl SseSetup<()> {
35    pub fn new_with_data<T>() -> SseSetup<T> {
36        SseSetup {
37            session_data: PhantomData,
38        }
39    }
40}
41
42impl Default for SseSetup<FhtmxUiNoSessionData> {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl SseSetup<FhtmxUiNoSessionData> {
49    pub fn new() -> Self {
50        SseSetup {
51            session_data: PhantomData,
52        }
53    }
54}
55
56impl<T> SseSetup<T>
57where
58    T: DeserializeOwned + 'static,
59{
60    /// Gets a `SseState` instance for you to add it to your app
61    #[must_use]
62    pub fn state_data(&self) -> web::Data<SseState<T>> {
63        web::Data::new(SseState::default())
64    }
65
66    /// Setups the sse route
67    pub fn setup_route(&self, path: &str, cfg: &mut web::ServiceConfig) {
68        cfg.route(path, web::get().to(sse_handler::<T>));
69    }
70}
71
72impl<T> SseSetup<T>
73where
74    T: Send + Sync + 'static,
75{
76    /// Gets a `SseState` instance for you to add it to your app and launches a task to execute
77    /// `SseState::remove_stale_sessions`.
78    #[must_use]
79    pub fn state_data_and_spawn_cleaner(&self, period: Duration) -> web::Data<SseState<T>> {
80        let state = web::Data::new(SseState::default());
81        spawn_remove_stale_sessions_task(state.clone(), period);
82        state
83    }
84}
85
86/// SSE state
87pub struct SseState<T> {
88    pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
89}
90
91impl<T> Default for SseState<T> {
92    fn default() -> Self {
93        Self {
94            sessions: Arc::new(DashMap::new()),
95        }
96    }
97}
98
99impl<T: Clone> SseState<T> {
100    pub fn get_session_data(&self, id: Uuid) -> Option<T> {
101        self.sessions.get(&id).and_then(|x| x.data.clone())
102    }
103}
104
105impl<T> SseState<T> {
106    pub fn add_session(
107        &self,
108        id: Uuid,
109        data: Option<T>,
110        sender: mpsc::Sender<Event>,
111    ) -> Option<SseSession<T>> {
112        let session = SseSession { data, sender };
113        self.sessions.insert(id, session)
114    }
115
116    pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
117        self.sessions.remove(&id)
118    }
119
120    pub fn remove_stale_sessions(&self) -> usize {
121        let mut removed = 0;
122        self.sessions.retain(
123            |_, o| match o.sender.try_send(Event::Comment("ping".into())) {
124                Ok(()) => true,
125                Err(_) => {
126                    removed += 1;
127                    false
128                }
129            },
130        );
131        removed
132    }
133
134    /// Sends a message to session id
135    pub fn send_message(&self, id: Uuid, data: Data) -> Option<()> {
136        let sender = self.sessions.get(&id)?.sender.clone();
137        if sender.try_send(Event::Data(data)).is_err() {
138            // Channel is closed so we remove the session
139            self.remove_session(id);
140        }
141        Some(())
142    }
143
144    /// Broadcast a message to all sessions and returns the number of sent messages
145    pub fn broadcast(&self, data: Data) -> usize {
146        let senders = self
147            .sessions
148            .iter()
149            .map(|o| o.value().sender.clone())
150            .collect::<Vec<_>>();
151        sse_broadcast(senders, data)
152    }
153
154    /// Broadcast a message to all sessions but one id and returns the number of sent messages
155    pub fn broadcast_all_but(&self, id: Uuid, data: Data) -> usize {
156        let senders = self
157            .sessions
158            .iter()
159            .filter_map(|o| {
160                if *o.key() == id {
161                    None
162                } else {
163                    Some(o.value().sender.clone())
164                }
165            })
166            .collect::<Vec<_>>();
167        sse_broadcast(senders, data)
168    }
169}
170
171/// Launches task to clean disconnected sessions
172pub fn spawn_remove_stale_sessions_task<T>(
173    state: web::Data<SseState<T>>,
174    period: Duration,
175) -> JoinHandle<()>
176where
177    T: Send + Sync + 'static,
178{
179    tokio::spawn(async move {
180        let mut interval = tokio::time::interval(period);
181        loop {
182            interval.tick().await;
183            let removed = state.remove_stale_sessions();
184            if removed > 0 {
185                tracing::info!("Removed {removed} staled sessions.");
186            }
187        }
188    })
189}
190
191pub fn sse_broadcast(senders: Vec<mpsc::Sender<Event>>, data: Data) -> usize {
192    senders
193        .into_iter()
194        .filter_map(|o| o.try_send(Event::Data(data.clone())).ok())
195        .count()
196}
197
198pub struct SseSession<T> {
199    pub data: Option<T>,
200    pub sender: mpsc::Sender<Event>,
201}
202
203/// Identifier for the session
204#[derive(Clone, Debug, Serialize, Deserialize)]
205pub struct SseHandlerQuery {
206    pub id: Uuid,
207}
208
209/// Route to handle web sockets
210#[tracing::instrument(skip_all)]
211pub async fn sse_handler<T>(
212    web::Query(query): web::Query<SseHandlerQuery>,
213    state: web::Data<SseState<T>>,
214) -> impl Responder {
215    let (tx, rx) = mpsc::channel(8);
216    state.add_session(query.id, None, tx);
217    actix_web_lab::sse::Sse::from_infallible_receiver(rx).with_keep_alive(Duration::from_secs(3))
218}