Skip to main content

fhtmx_actix/
sse.rs

1//! # Server sent events
2
3use actix_web::{Responder, web};
4use actix_web_lab::sse::{Data, Event};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use std::{marker::PhantomData, sync::Arc, time::Duration};
8use tokio::{sync::mpsc, task::JoinHandle};
9use uuid::Uuid;
10
11/// Setups Server sent event state and routes
12///
13/// # Example
14///
15/// ```rust,ignore
16/// use actix_web::App;
17/// use fhtmx_actix::sse::SseSetup;
18///
19/// let sse_setup = SseSetup::new();
20/// let sse_data = sse_setup.state();
21/// App::new()
22///     .configure(|cfg| sse_setup.setup_route("/sse", cfg))
23///     .app_data(sse_data);
24/// ```
25#[derive(Clone, Copy)]
26pub struct SseSetup<T> {
27    session_data: PhantomData<T>,
28}
29
30/// Setup sse without session data
31#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
32pub struct FhtmxUiNoSessionData;
33
34impl SseSetup<()> {
35    pub fn new_with_data<T>() -> SseSetup<T> {
36        SseSetup {
37            session_data: PhantomData,
38        }
39    }
40}
41
42impl Default for SseSetup<FhtmxUiNoSessionData> {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl SseSetup<FhtmxUiNoSessionData> {
49    pub fn new() -> Self {
50        SseSetup {
51            session_data: PhantomData,
52        }
53    }
54}
55
56impl<T> SseSetup<T>
57where
58    T: DeserializeOwned + 'static,
59{
60    /// Gets a `SseState` instance for you to add it to your app
61    #[must_use]
62    pub fn state_data(&self) -> web::Data<SseState<T>> {
63        web::Data::new(SseState::default())
64    }
65
66    /// Setups the sse route
67    pub fn setup_route(&self, path: &str, cfg: &mut web::ServiceConfig) {
68        cfg.route(path, web::get().to(sse_handler::<T>));
69    }
70}
71
72impl<T> SseSetup<T>
73where
74    T: Send + Sync + 'static,
75{
76    /// Gets a `SseState` instance for you to add it to your app and launches a task to execute
77    /// `SseState::remove_stale_sessions`.
78    #[must_use]
79    pub fn state_data_and_spawn_cleaner(&self, period: Duration) -> web::Data<SseState<T>> {
80        let state = web::Data::new(SseState::default());
81        spawn_remove_stale_sessions_task(state.clone(), period);
82        state
83    }
84}
85
86/// SSE state
87pub struct SseState<T> {
88    pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
89}
90
91impl<T> Default for SseState<T> {
92    fn default() -> Self {
93        Self {
94            sessions: Arc::new(DashMap::new()),
95        }
96    }
97}
98
99impl<T: Clone> SseState<T> {
100    pub fn get_session_data(&self, id: Uuid) -> Option<T> {
101        self.sessions.get(&id).and_then(|x| x.data.clone())
102    }
103}
104
105impl<T> SseState<T> {
106    pub fn add_session(
107        &self,
108        id: Uuid,
109        data: Option<T>,
110        sender: mpsc::Sender<Event>,
111    ) -> Option<SseSession<T>> {
112        let session = SseSession { data, sender };
113        self.sessions.insert(id, session)
114    }
115
116    pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
117        self.sessions.remove(&id)
118    }
119
120    pub fn remove_stale_sessions(&self) -> usize {
121        let mut removed = 0;
122        self.sessions.retain(
123            |_, o| match o.sender.try_send(Event::Comment("ping".into())) {
124                Ok(()) => true,
125                Err(_) => {
126                    removed += 1;
127                    false
128                }
129            },
130        );
131        removed
132    }
133
134    /// Sends a message to session id
135    pub fn send_message(&self, id: Uuid, data: Data) -> Option<()> {
136        let sender = self.sessions.get(&id)?.sender.clone();
137        sender.try_send(Event::Data(data)).ok()
138    }
139
140    /// Broadcast a message to all sessions and returns the number of sent messages
141    pub fn broadcast(&self, data: Data) -> usize {
142        let senders = self
143            .sessions
144            .iter()
145            .map(|o| o.value().sender.clone())
146            .collect::<Vec<_>>();
147        sse_broadcast(senders, data)
148    }
149
150    /// Broadcast a message to all sessions but one id and returns the number of sent messages
151    pub fn broadcast_all_but(&self, id: Uuid, data: Data) -> usize {
152        let senders = self
153            .sessions
154            .iter()
155            .filter_map(|o| {
156                if *o.key() == id {
157                    None
158                } else {
159                    Some(o.value().sender.clone())
160                }
161            })
162            .collect::<Vec<_>>();
163        sse_broadcast(senders, data)
164    }
165}
166
167/// Launches task to clean disconnected sessions
168pub fn spawn_remove_stale_sessions_task<T>(
169    state: web::Data<SseState<T>>,
170    period: Duration,
171) -> JoinHandle<()>
172where
173    T: Send + Sync + 'static,
174{
175    tokio::spawn(async move {
176        let mut interval = tokio::time::interval(period);
177        loop {
178            interval.tick().await;
179            let removed = state.remove_stale_sessions();
180            if removed > 0 {
181                tracing::info!("Removed {removed} staled sessions.");
182            }
183        }
184    })
185}
186
187pub fn sse_broadcast(senders: Vec<mpsc::Sender<Event>>, data: Data) -> usize {
188    senders
189        .into_iter()
190        .filter_map(|o| o.try_send(Event::Data(data.clone())).ok())
191        .count()
192}
193
194pub struct SseSession<T> {
195    pub data: Option<T>,
196    pub sender: mpsc::Sender<Event>,
197}
198
199/// Identifier for the session
200#[derive(Clone, Debug, Serialize, Deserialize)]
201pub struct SseHandlerQuery {
202    pub id: Uuid,
203}
204
205/// Route to handle web sockets
206#[tracing::instrument(skip_all)]
207pub async fn sse_handler<T>(
208    web::Query(query): web::Query<SseHandlerQuery>,
209    state: web::Data<SseState<T>>,
210) -> impl Responder {
211    let (tx, rx) = mpsc::channel(8);
212    state.add_session(query.id, None, tx);
213    actix_web_lab::sse::Sse::from_infallible_receiver(rx).with_keep_alive(Duration::from_secs(3))
214}