Skip to main content

fhtmx_actix/
sse.rs

1//! # Server sent events
2
3use actix_web::{Responder, web};
4use actix_web_lab::sse::{Data, Event};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use std::{marker::PhantomData, sync::Arc, time::Duration};
8use tokio::sync::mpsc;
9use uuid::Uuid;
10
11/// Setups Server sent event state and routes
12///
13/// # Example
14///
15/// ```rust,ignore
16/// use actix_web::App;
17/// use fhtmx_actix::sse::SseSetup;
18///
19/// let sse_setup = SseSetup::new();
20/// let sse_data = sse_setup.state();
21/// App::new()
22///     .configure(|cfg| sse_setup.setup_route("/sse", cfg))
23///     .app_data(sse_data);
24/// ```
25#[derive(Clone, Copy)]
26pub struct SseSetup<T> {
27    session_data: PhantomData<T>,
28}
29
30/// Setup sse without session data
31#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
32pub struct FhtmxUiNoSessionData;
33
34impl SseSetup<()> {
35    pub fn new_with_data<T>() -> SseSetup<T> {
36        SseSetup {
37            session_data: PhantomData,
38        }
39    }
40}
41
42impl Default for SseSetup<FhtmxUiNoSessionData> {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl SseSetup<FhtmxUiNoSessionData> {
49    pub fn new() -> Self {
50        SseSetup {
51            session_data: PhantomData,
52        }
53    }
54}
55
56impl<T> SseSetup<T>
57where
58    T: DeserializeOwned + Send + Sync + 'static,
59{
60    /// Gets a `SseState` instance for you to add it to your app
61    #[must_use]
62    pub fn state_data(&self) -> web::Data<SseState<T>> {
63        web::Data::new(SseState::default())
64    }
65
66    /// Setups the sse route
67    pub fn setup_route(&self, path: &str, cfg: &mut web::ServiceConfig) {
68        cfg.route(path, web::get().to(sse_handler::<T>));
69    }
70}
71
72/// SSE state
73pub struct SseState<T> {
74    pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
75}
76
77impl<T> Default for SseState<T> {
78    fn default() -> Self {
79        Self {
80            sessions: Arc::new(DashMap::new()),
81        }
82    }
83}
84
85impl<T: Clone> SseState<T> {
86    pub fn get_session_data(&self, id: Uuid) -> Option<T> {
87        self.sessions.get(&id).and_then(|x| x.data.clone())
88    }
89}
90
91impl<T> SseState<T> {
92    pub fn add_session(
93        &self,
94        id: Uuid,
95        data: Option<T>,
96        sender: mpsc::Sender<Event>,
97    ) -> Option<SseSession<T>> {
98        let session = SseSession { data, sender };
99        self.sessions.insert(id, session)
100    }
101
102    pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
103        self.sessions.remove(&id)
104    }
105
106    pub fn remove_stale_sessions(&self) -> usize {
107        let mut removed = 0;
108        self.sessions.retain(
109            |_, o| match o.sender.try_send(Event::Comment("ping".into())) {
110                Ok(()) => true,
111                Err(_) => {
112                    removed += 1;
113                    false
114                }
115            },
116        );
117        removed
118    }
119
120    /// Sends a message to session id
121    pub fn send_message(&self, id: Uuid, data: Data) -> Option<()> {
122        let sender = self.sessions.get(&id)?.sender.clone();
123        if sender.try_send(Event::Data(data)).is_err() {
124            // Channel is closed so we remove the session
125            self.remove_session(id);
126        }
127        Some(())
128    }
129
130    /// Broadcast a message to all sessions and returns the number of sent messages
131    pub fn broadcast(&self, data: Data) -> usize {
132        let senders = self
133            .sessions
134            .iter()
135            .map(|o| o.value().sender.clone())
136            .collect::<Vec<_>>();
137        sse_broadcast(senders, data)
138    }
139
140    /// Broadcast a message to all sessions but one id and returns the number of sent messages
141    pub fn broadcast_all_but(&self, id: Uuid, data: Data) -> usize {
142        let senders = self
143            .sessions
144            .iter()
145            .filter_map(|o| {
146                if *o.key() == id {
147                    None
148                } else {
149                    Some(o.value().sender.clone())
150                }
151            })
152            .collect::<Vec<_>>();
153        sse_broadcast(senders, data)
154    }
155}
156
157pub fn sse_broadcast(senders: Vec<mpsc::Sender<Event>>, data: Data) -> usize {
158    senders
159        .into_iter()
160        .filter_map(|o| o.try_send(Event::Data(data.clone())).ok())
161        .count()
162}
163
164pub struct SseSession<T> {
165    pub data: Option<T>,
166    pub sender: mpsc::Sender<Event>,
167}
168
169/// Identifier for the session
170#[derive(Clone, Debug, Serialize, Deserialize)]
171pub struct SseHandlerQuery {
172    pub id: Uuid,
173}
174
175/// Route to handle web sockets
176#[tracing::instrument(skip_all)]
177pub async fn sse_handler<T: Send + Sync + 'static>(
178    web::Query(query): web::Query<SseHandlerQuery>,
179    state: web::Data<SseState<T>>,
180) -> impl Responder {
181    let (tx, rx) = mpsc::channel(8);
182    state.add_session(query.id, None, tx.clone());
183
184    let sessions = state.sessions.clone();
185    tokio::spawn(async move {
186        tx.closed().await;
187        sessions.remove(&query.id);
188    });
189
190    actix_web_lab::sse::Sse::from_infallible_receiver(rx).with_keep_alive(Duration::from_secs(3))
191}